diff --git a/matrix_sdk/examples/emoji_verification.rs b/matrix_sdk/examples/emoji_verification.rs index 967154f4..4ff70983 100644 --- a/matrix_sdk/examples/emoji_verification.rs +++ b/matrix_sdk/examples/emoji_verification.rs @@ -129,7 +129,7 @@ async fn login( if let MessageType::VerificationRequest(_) = &m.content.msgtype { let request = client - .get_verification_request(&m.event_id) + .get_verification_request(&m.sender, &m.event_id) .await .expect("Request object wasn't created"); diff --git a/matrix_sdk/src/client.rs b/matrix_sdk/src/client.rs index a5efaf9d..2f449684 100644 --- a/matrix_sdk/src/client.rs +++ b/matrix_sdk/src/client.rs @@ -2192,16 +2192,18 @@ impl Client { .map(|sas| Sas { inner: sas, client: self.clone() }) } - /// Get a `VerificationRequest` object with the given flow id. + /// Get a `VerificationRequest` object for the given user with the given + /// flow id. #[cfg(feature = "encryption")] #[cfg_attr(feature = "docs", doc(cfg(encryption)))] pub async fn get_verification_request( &self, + user_id: &UserId, flow_id: impl AsRef, ) -> Option { let olm = self.base_client.olm_machine().await?; - olm.get_verification_request(flow_id) + olm.get_verification_request(user_id, flow_id) .map(|r| VerificationRequest { inner: r, client: self.clone() }) } diff --git a/matrix_sdk_crypto/src/machine.rs b/matrix_sdk_crypto/src/machine.rs index 258c61b1..5b1cfa00 100644 --- a/matrix_sdk_crypto/src/machine.rs +++ b/matrix_sdk_crypto/src/machine.rs @@ -725,9 +725,10 @@ impl OlmMachine { /// Get a verification request object with the given flow id. pub fn get_verification_request( &self, + user_id: &UserId, flow_id: impl AsRef, ) -> Option { - self.verification_machine.get_request(flow_id) + self.verification_machine.get_request(user_id, flow_id) } async fn update_one_time_key_count(&self, key_count: &BTreeMap) { diff --git a/matrix_sdk_crypto/src/verification/machine.rs b/matrix_sdk_crypto/src/verification/machine.rs index b8b21200..5478a16e 100644 --- a/matrix_sdk_crypto/src/verification/machine.rs +++ b/matrix_sdk_crypto/src/verification/machine.rs @@ -40,7 +40,7 @@ pub struct VerificationMachine { private_identity: Arc>, pub(crate) store: Arc, verifications: VerificationCache, - requests: Arc>, + requests: Arc>>, } impl VerificationMachine { @@ -91,8 +91,19 @@ impl VerificationMachine { Ok((sas, request)) } - pub fn get_request(&self, flow_id: impl AsRef) -> Option { - self.requests.get(flow_id.as_ref()).map(|s| s.clone()) + pub fn get_request( + &self, + user_id: &UserId, + flow_id: impl AsRef, + ) -> Option { + self.requests.get(user_id).and_then(|v| v.get(flow_id.as_ref()).map(|s| s.clone())) + } + + fn insert_request(&self, request: VerificationRequest) { + self.requests + .entry(request.other_user().to_owned()) + .or_insert_with(DashMap::new) + .insert(request.flow_id().as_str().to_owned(), request); } pub fn get_verification(&self, user_id: &UserId, flow_id: &str) -> Option { @@ -145,7 +156,10 @@ impl VerificationMachine { } pub fn garbage_collect(&self) { - self.requests.retain(|_, r| !(r.is_done() || r.is_cancelled())); + for user_verification in self.requests.iter() { + user_verification.retain(|_, v| !(v.is_done() || v.is_cancelled())); + } + self.requests.retain(|_, v| !v.is_empty()); for request in self.verifications.garbage_collect() { self.verifications.add_request(request) @@ -234,8 +248,7 @@ impl VerificationMachine { r, ); - self.requests - .insert(request.flow_id().as_str().to_owned(), request); + self.insert_request(request); } else { trace!( sender = event.sender().as_str(), @@ -260,7 +273,7 @@ impl VerificationMachine { } } AnyVerificationContent::Cancel(c) => { - if let Some(verification) = self.get_request(flow_id.as_str()) { + if let Some(verification) = self.get_request(event.sender(), flow_id.as_str()) { verification.receive_cancel(event.sender(), c); } @@ -270,7 +283,7 @@ impl VerificationMachine { } } AnyVerificationContent::Ready(c) => { - if let Some(request) = self.requests.get(flow_id.as_str()) { + if let Some(request) = self.get_request(event.sender(), flow_id.as_str()) { if request.flow_id() == &flow_id { request.receive_ready(event.sender(), c); } else { @@ -279,7 +292,7 @@ impl VerificationMachine { } } AnyVerificationContent::Start(c) => { - if let Some(request) = self.requests.get(flow_id.as_str()) { + if let Some(request) = self.get_request(event.sender(), flow_id.as_str()) { if request.flow_id() == &flow_id { request.receive_start(event.sender(), c).await? } else { @@ -345,7 +358,7 @@ impl VerificationMachine { } } AnyVerificationContent::Done(c) => { - if let Some(verification) = self.get_request(flow_id.as_str()) { + if let Some(verification) = self.get_request(event.sender(), flow_id.as_str()) { verification.receive_done(event.sender(), c); }