crypto: Scope the verification requests behind the other user id
parent
58d3b42a60
commit
5d38bc3802
|
@ -129,7 +129,7 @@ async fn login(
|
||||||
if let MessageType::VerificationRequest(_) = &m.content.msgtype
|
if let MessageType::VerificationRequest(_) = &m.content.msgtype
|
||||||
{
|
{
|
||||||
let request = client
|
let request = client
|
||||||
.get_verification_request(&m.event_id)
|
.get_verification_request(&m.sender, &m.event_id)
|
||||||
.await
|
.await
|
||||||
.expect("Request object wasn't created");
|
.expect("Request object wasn't created");
|
||||||
|
|
||||||
|
|
|
@ -2192,16 +2192,18 @@ impl Client {
|
||||||
.map(|sas| Sas { inner: sas, client: self.clone() })
|
.map(|sas| Sas { inner: sas, client: self.clone() })
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get a `VerificationRequest` object with the given flow id.
|
/// Get a `VerificationRequest` object for the given user with the given
|
||||||
|
/// flow id.
|
||||||
#[cfg(feature = "encryption")]
|
#[cfg(feature = "encryption")]
|
||||||
#[cfg_attr(feature = "docs", doc(cfg(encryption)))]
|
#[cfg_attr(feature = "docs", doc(cfg(encryption)))]
|
||||||
pub async fn get_verification_request(
|
pub async fn get_verification_request(
|
||||||
&self,
|
&self,
|
||||||
|
user_id: &UserId,
|
||||||
flow_id: impl AsRef<str>,
|
flow_id: impl AsRef<str>,
|
||||||
) -> Option<VerificationRequest> {
|
) -> Option<VerificationRequest> {
|
||||||
let olm = self.base_client.olm_machine().await?;
|
let olm = self.base_client.olm_machine().await?;
|
||||||
|
|
||||||
olm.get_verification_request(flow_id)
|
olm.get_verification_request(user_id, flow_id)
|
||||||
.map(|r| VerificationRequest { inner: r, client: self.clone() })
|
.map(|r| VerificationRequest { inner: r, client: self.clone() })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -725,9 +725,10 @@ impl OlmMachine {
|
||||||
/// Get a verification request object with the given flow id.
|
/// Get a verification request object with the given flow id.
|
||||||
pub fn get_verification_request(
|
pub fn get_verification_request(
|
||||||
&self,
|
&self,
|
||||||
|
user_id: &UserId,
|
||||||
flow_id: impl AsRef<str>,
|
flow_id: impl AsRef<str>,
|
||||||
) -> Option<VerificationRequest> {
|
) -> Option<VerificationRequest> {
|
||||||
self.verification_machine.get_request(flow_id)
|
self.verification_machine.get_request(user_id, flow_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn update_one_time_key_count(&self, key_count: &BTreeMap<DeviceKeyAlgorithm, UInt>) {
|
async fn update_one_time_key_count(&self, key_count: &BTreeMap<DeviceKeyAlgorithm, UInt>) {
|
||||||
|
|
|
@ -40,7 +40,7 @@ pub struct VerificationMachine {
|
||||||
private_identity: Arc<Mutex<PrivateCrossSigningIdentity>>,
|
private_identity: Arc<Mutex<PrivateCrossSigningIdentity>>,
|
||||||
pub(crate) store: Arc<dyn CryptoStore>,
|
pub(crate) store: Arc<dyn CryptoStore>,
|
||||||
verifications: VerificationCache,
|
verifications: VerificationCache,
|
||||||
requests: Arc<DashMap<String, VerificationRequest>>,
|
requests: Arc<DashMap<UserId, DashMap<String, VerificationRequest>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl VerificationMachine {
|
impl VerificationMachine {
|
||||||
|
@ -91,8 +91,19 @@ impl VerificationMachine {
|
||||||
Ok((sas, request))
|
Ok((sas, request))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_request(&self, flow_id: impl AsRef<str>) -> Option<VerificationRequest> {
|
pub fn get_request(
|
||||||
self.requests.get(flow_id.as_ref()).map(|s| s.clone())
|
&self,
|
||||||
|
user_id: &UserId,
|
||||||
|
flow_id: impl AsRef<str>,
|
||||||
|
) -> Option<VerificationRequest> {
|
||||||
|
self.requests.get(user_id).and_then(|v| v.get(flow_id.as_ref()).map(|s| s.clone()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn insert_request(&self, request: VerificationRequest) {
|
||||||
|
self.requests
|
||||||
|
.entry(request.other_user().to_owned())
|
||||||
|
.or_insert_with(DashMap::new)
|
||||||
|
.insert(request.flow_id().as_str().to_owned(), request);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_verification(&self, user_id: &UserId, flow_id: &str) -> Option<Verification> {
|
pub fn get_verification(&self, user_id: &UserId, flow_id: &str) -> Option<Verification> {
|
||||||
|
@ -145,7 +156,10 @@ impl VerificationMachine {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn garbage_collect(&self) {
|
pub fn garbage_collect(&self) {
|
||||||
self.requests.retain(|_, r| !(r.is_done() || r.is_cancelled()));
|
for user_verification in self.requests.iter() {
|
||||||
|
user_verification.retain(|_, v| !(v.is_done() || v.is_cancelled()));
|
||||||
|
}
|
||||||
|
self.requests.retain(|_, v| !v.is_empty());
|
||||||
|
|
||||||
for request in self.verifications.garbage_collect() {
|
for request in self.verifications.garbage_collect() {
|
||||||
self.verifications.add_request(request)
|
self.verifications.add_request(request)
|
||||||
|
@ -234,8 +248,7 @@ impl VerificationMachine {
|
||||||
r,
|
r,
|
||||||
);
|
);
|
||||||
|
|
||||||
self.requests
|
self.insert_request(request);
|
||||||
.insert(request.flow_id().as_str().to_owned(), request);
|
|
||||||
} else {
|
} else {
|
||||||
trace!(
|
trace!(
|
||||||
sender = event.sender().as_str(),
|
sender = event.sender().as_str(),
|
||||||
|
@ -260,7 +273,7 @@ impl VerificationMachine {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
AnyVerificationContent::Cancel(c) => {
|
AnyVerificationContent::Cancel(c) => {
|
||||||
if let Some(verification) = self.get_request(flow_id.as_str()) {
|
if let Some(verification) = self.get_request(event.sender(), flow_id.as_str()) {
|
||||||
verification.receive_cancel(event.sender(), c);
|
verification.receive_cancel(event.sender(), c);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -270,7 +283,7 @@ impl VerificationMachine {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
AnyVerificationContent::Ready(c) => {
|
AnyVerificationContent::Ready(c) => {
|
||||||
if let Some(request) = self.requests.get(flow_id.as_str()) {
|
if let Some(request) = self.get_request(event.sender(), flow_id.as_str()) {
|
||||||
if request.flow_id() == &flow_id {
|
if request.flow_id() == &flow_id {
|
||||||
request.receive_ready(event.sender(), c);
|
request.receive_ready(event.sender(), c);
|
||||||
} else {
|
} else {
|
||||||
|
@ -279,7 +292,7 @@ impl VerificationMachine {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
AnyVerificationContent::Start(c) => {
|
AnyVerificationContent::Start(c) => {
|
||||||
if let Some(request) = self.requests.get(flow_id.as_str()) {
|
if let Some(request) = self.get_request(event.sender(), flow_id.as_str()) {
|
||||||
if request.flow_id() == &flow_id {
|
if request.flow_id() == &flow_id {
|
||||||
request.receive_start(event.sender(), c).await?
|
request.receive_start(event.sender(), c).await?
|
||||||
} else {
|
} else {
|
||||||
|
@ -345,7 +358,7 @@ impl VerificationMachine {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
AnyVerificationContent::Done(c) => {
|
AnyVerificationContent::Done(c) => {
|
||||||
if let Some(verification) = self.get_request(flow_id.as_str()) {
|
if let Some(verification) = self.get_request(event.sender(), flow_id.as_str()) {
|
||||||
verification.receive_done(event.sender(), c);
|
verification.receive_done(event.sender(), c);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue