From ada71586ac0a7c6281103a3cd6fcfc5f1322b7d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Damir=20Jeli=C4=87?= Date: Tue, 8 Jun 2021 16:13:14 +0200 Subject: [PATCH] crypto: Scope the verifications per sender --- matrix_sdk/examples/emoji_verification.rs | 16 ++++-- matrix_sdk/src/client.rs | 4 +- matrix_sdk_base/src/client.rs | 9 ++- matrix_sdk_crypto/src/machine.rs | 14 +++-- matrix_sdk_crypto/src/verification/cache.rs | 57 +++++++++++++------ matrix_sdk_crypto/src/verification/machine.rs | 31 ++++++---- matrix_sdk_crypto/src/verification/mod.rs | 22 +++++++ .../src/verification/requests.rs | 6 +- 8 files changed, 114 insertions(+), 45 deletions(-) diff --git a/matrix_sdk/examples/emoji_verification.rs b/matrix_sdk/examples/emoji_verification.rs index 72e8b7cd..967154f4 100644 --- a/matrix_sdk/examples/emoji_verification.rs +++ b/matrix_sdk/examples/emoji_verification.rs @@ -81,7 +81,7 @@ async fn login( match event { AnyToDeviceEvent::KeyVerificationStart(e) => { let sas = client - .get_verification(&e.content.transaction_id) + .get_verification(&e.sender, &e.content.transaction_id) .await .expect("Sas object wasn't created"); println!( @@ -95,7 +95,7 @@ async fn login( AnyToDeviceEvent::KeyVerificationKey(e) => { let sas = client - .get_verification(&e.content.transaction_id) + .get_verification(&e.sender, &e.content.transaction_id) .await .expect("Sas object wasn't created"); @@ -104,7 +104,7 @@ async fn login( AnyToDeviceEvent::KeyVerificationMac(e) => { let sas = client - .get_verification(&e.content.transaction_id) + .get_verification(&e.sender, &e.content.transaction_id) .await .expect("Sas object wasn't created"); @@ -141,7 +141,10 @@ async fn login( } AnySyncMessageEvent::KeyVerificationKey(e) => { let sas = client - .get_verification(e.content.relation.event_id.as_str()) + .get_verification( + &e.sender, + e.content.relation.event_id.as_str(), + ) .await .expect("Sas object wasn't created"); @@ -149,7 +152,10 @@ async fn login( } AnySyncMessageEvent::KeyVerificationMac(e) => { let sas = client - .get_verification(e.content.relation.event_id.as_str()) + .get_verification( + &e.sender, + e.content.relation.event_id.as_str(), + ) .await .expect("Sas object wasn't created"); diff --git a/matrix_sdk/src/client.rs b/matrix_sdk/src/client.rs index 46359771..a5efaf9d 100644 --- a/matrix_sdk/src/client.rs +++ b/matrix_sdk/src/client.rs @@ -2185,9 +2185,9 @@ impl Client { /// Get a `Sas` verification object with the given flow id. #[cfg(feature = "encryption")] #[cfg_attr(feature = "docs", doc(cfg(encryption)))] - pub async fn get_verification(&self, flow_id: &str) -> Option { + pub async fn get_verification(&self, user_id: &UserId, flow_id: &str) -> Option { self.base_client - .get_verification(flow_id) + .get_verification(user_id, flow_id) .await .map(|sas| Sas { inner: sas, client: self.clone() }) } diff --git a/matrix_sdk_base/src/client.rs b/matrix_sdk_base/src/client.rs index 78853077..0b16beff 100644 --- a/matrix_sdk_base/src/client.rs +++ b/matrix_sdk_base/src/client.rs @@ -1213,8 +1213,13 @@ impl BaseClient { /// *m.key.verification.start* event. #[cfg(feature = "encryption")] #[cfg_attr(feature = "docs", doc(cfg(encryption)))] - pub async fn get_verification(&self, flow_id: &str) -> Option { - self.olm.lock().await.as_ref().and_then(|o| o.get_verification(flow_id)) + pub async fn get_verification(&self, user_id: &UserId, flow_id: &str) -> Option { + self.olm + .lock() + .await + .as_ref() + .and_then(|o| o.get_verification(user_id, flow_id).map(|v| v.sas_v1())) + .flatten() } /// Get a specific device of a user. diff --git a/matrix_sdk_crypto/src/machine.rs b/matrix_sdk_crypto/src/machine.rs index 807d13cf..6abde4de 100644 --- a/matrix_sdk_crypto/src/machine.rs +++ b/matrix_sdk_crypto/src/machine.rs @@ -59,7 +59,7 @@ use crate::{ Changes, CryptoStore, DeviceChanges, IdentityChanges, MemoryStore, Result as StoreResult, Store, }, - verification::{Sas, VerificationMachine, VerificationRequest}, + verification::{Verification, VerificationMachine, VerificationRequest}, ToDeviceRequest, }; @@ -717,9 +717,9 @@ impl OlmMachine { Ok(()) } - /// Get a `Sas` verification object with the given flow id. - pub fn get_verification(&self, flow_id: &str) -> Option { - self.verification_machine.get_sas(flow_id) + /// Get a verification object for the given user id with the given flow id. + pub fn get_verification(&self, user_id: &UserId, flow_id: &str) -> Option { + self.verification_machine.get_verification(user_id, flow_id) } /// Get a verification request object with the given flow id. @@ -1761,7 +1761,11 @@ pub(crate) mod test { let event = request_to_event(alice.user_id(), &request.into()); bob.handle_verification_event(&event).await; - let bob_sas = bob.get_verification(alice_sas.flow_id().as_str()).unwrap(); + let bob_sas = bob + .get_verification(alice.user_id(), alice_sas.flow_id().as_str()) + .unwrap() + .sas_v1() + .unwrap(); assert!(alice_sas.emoji().is_none()); assert!(bob_sas.emoji().is_none()); diff --git a/matrix_sdk_crypto/src/verification/cache.rs b/matrix_sdk_crypto/src/verification/cache.rs index 975f9fd6..8d54a165 100644 --- a/matrix_sdk_crypto/src/verification/cache.rs +++ b/matrix_sdk_crypto/src/verification/cache.rs @@ -23,7 +23,7 @@ use crate::{OutgoingRequest, RoomMessageRequest}; #[derive(Clone, Debug)] pub struct VerificationCache { - verification: Arc>, + verification: Arc>>, outgoing_requests: Arc>, } @@ -35,11 +35,24 @@ impl VerificationCache { #[cfg(test)] #[allow(dead_code)] pub fn is_empty(&self) -> bool { - self.verification.is_empty() + self.verification.iter().all(|m| m.is_empty()) + } + + pub fn insert(&self, verification: impl Into) { + let verification = verification.into(); + + self.verification + .entry(verification.other_user().to_owned()) + .or_insert_with(DashMap::new) + .insert(verification.flow_id().to_owned(), verification); } pub fn insert_sas(&self, sas: Sas) { - self.verification.insert(sas.flow_id().as_str().to_string(), sas.into()); + self.insert(sas); + } + + pub fn get(&self, sender: &UserId, flow_id: &str) -> Option { + self.verification.get(sender).and_then(|m| m.get(flow_id).map(|v| v.clone())) } pub fn outgoing_requests(&self) -> Vec { @@ -47,28 +60,38 @@ impl VerificationCache { } pub fn garbage_collect(&self) -> Vec { - self.verification.retain(|_, s| !(s.is_done() || s.is_cancelled())); + for user_verification in self.verification.iter() { + user_verification.retain(|_, s| !(s.is_done() || s.is_cancelled())); + } + + self.verification.retain(|_, m| !m.is_empty()); self.verification .iter() - .filter_map(|s| { - #[allow(irrefutable_let_patterns)] - if let Verification::SasV1(s) = s.value() { - s.cancel_if_timed_out().map(|r| OutgoingRequest { - request_id: r.request_id(), - request: Arc::new(r.into()), + .flat_map(|v| { + let requests: Vec = v + .value() + .iter() + .filter_map(|s| { + if let Verification::SasV1(s) = s.value() { + s.cancel_if_timed_out().map(|r| OutgoingRequest { + request_id: r.request_id(), + request: Arc::new(r.into()), + }) + } else { + None + } }) - } else { - None - } + .collect(); + + requests }) .collect() } - pub fn get_sas(&self, transaction_id: &str) -> Option { - self.verification.get(transaction_id).and_then(|v| { - #[allow(irrefutable_let_patterns)] - if let Verification::SasV1(sas) = v.value() { + pub fn get_sas(&self, user_id: &UserId, flow_id: &str) -> Option { + self.get(user_id, flow_id).and_then(|v| { + if let Verification::SasV1(sas) = v { Some(sas.clone()) } else { None diff --git a/matrix_sdk_crypto/src/verification/machine.rs b/matrix_sdk_crypto/src/verification/machine.rs index b4a7da22..c84e39dc 100644 --- a/matrix_sdk_crypto/src/verification/machine.rs +++ b/matrix_sdk_crypto/src/verification/machine.rs @@ -24,7 +24,7 @@ use super::{ event_enums::{AnyEvent, AnyVerificationContent, OutgoingContent}, requests::VerificationRequest, sas::{content_to_request, Sas}, - FlowId, VerificationResult, + FlowId, Verification, VerificationResult, }; use crate::{ olm::PrivateCrossSigningIdentity, @@ -94,8 +94,12 @@ impl VerificationMachine { self.requests.get(flow_id.as_ref()).map(|s| s.clone()) } - pub fn get_sas(&self, transaction_id: &str) -> Option { - self.verifications.get_sas(transaction_id) + pub fn get_verification(&self, user_id: &UserId, flow_id: &str) -> Option { + self.verifications.get(user_id, flow_id) + } + + pub fn get_sas(&self, user_id: &UserId, flow_id: &str) -> Option { + self.verifications.get_sas(user_id, flow_id) } #[cfg(not(target_arch = "wasm32"))] @@ -242,7 +246,7 @@ impl VerificationMachine { verification.receive_cancel(event.sender(), c); } - if let Some(sas) = self.verifications.get_sas(flow_id.as_str()) { + if let Some(sas) = self.get_sas(event.sender(), flow_id.as_str()) { // This won't produce an outgoing content let _ = sas.receive_any_event(event.sender(), &content); } @@ -296,7 +300,7 @@ impl VerificationMachine { } } AnyVerificationContent::Accept(_) | AnyVerificationContent::Key(_) => { - if let Some(sas) = self.verifications.get_sas(flow_id.as_str()) { + if let Some(sas) = self.get_sas(event.sender(), flow_id.as_str()) { if sas.flow_id() == &flow_id { if let Some(content) = sas.receive_any_event(event.sender(), &content) { self.queue_up_content( @@ -311,7 +315,7 @@ impl VerificationMachine { } } AnyVerificationContent::Mac(_) => { - if let Some(s) = self.verifications.get_sas(flow_id.as_str()) { + if let Some(s) = self.get_sas(event.sender(), flow_id.as_str()) { if s.flow_id() == &flow_id { let content = s.receive_any_event(event.sender(), &content); @@ -328,12 +332,15 @@ impl VerificationMachine { verification.receive_done(event.sender(), c); } - if let Some(s) = self.verifications.get_sas(flow_id.as_str()) { - let content = s.receive_any_event(event.sender(), &content); + match self.get_verification(event.sender(), flow_id.as_str()) { + Some(Verification::SasV1(sas)) => { + let content = sas.receive_any_event(event.sender(), &content); - if s.is_done() { - self.mark_sas_as_done(s, content).await?; + if sas.is_done() { + self.mark_sas_as_done(sas, content).await?; + } } + None => (), } } } @@ -426,7 +433,7 @@ mod test { async fn full_flow() { let (alice_machine, bob) = setup_verification_machine().await; - let alice = alice_machine.get_sas(bob.flow_id().as_str()).unwrap(); + let alice = alice_machine.get_sas(bob.user_id(), bob.flow_id().as_str()).unwrap(); let request = alice.accept().unwrap(); @@ -472,7 +479,7 @@ mod test { #[tokio::test] async fn timing_out() { let (alice_machine, bob) = setup_verification_machine().await; - let alice = alice_machine.get_sas(bob.flow_id().as_str()).unwrap(); + let alice = alice_machine.get_sas(bob.user_id(), bob.flow_id().as_str()).unwrap(); assert!(!alice.timed_out()); assert!(alice_machine.verifications.outgoing_requests().is_empty()); diff --git a/matrix_sdk_crypto/src/verification/mod.rs b/matrix_sdk_crypto/src/verification/mod.rs index 3335ec72..52b2d097 100644 --- a/matrix_sdk_crypto/src/verification/mod.rs +++ b/matrix_sdk_crypto/src/verification/mod.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#![allow(missing_docs)] + mod cache; mod event_enums; mod machine; @@ -57,11 +59,31 @@ impl Verification { } } + pub fn sas_v1(self) -> Option { + if let Verification::SasV1(sas) = self { + Some(sas) + } else { + None + } + } + + pub fn flow_id(&self) -> &str { + match self { + Verification::SasV1(s) => s.flow_id().as_str(), + } + } + pub fn is_cancelled(&self) -> bool { match self { Verification::SasV1(s) => s.is_cancelled(), } } + + pub fn other_user(&self) -> &UserId { + match self { + Verification::SasV1(s) => s.other_user_id(), + } + } } impl From for Verification { diff --git a/matrix_sdk_crypto/src/verification/requests.rs b/matrix_sdk_crypto/src/verification/requests.rs index 819d2dc4..8e25cf3b 100644 --- a/matrix_sdk_crypto/src/verification/requests.rs +++ b/matrix_sdk_crypto/src/verification/requests.rs @@ -810,7 +810,8 @@ mod test { let content = StartContent::try_from(&start_content).unwrap(); let flow_id = content.flow_id().to_owned(); alice_request.receive_start(bob_device.user_id(), &content).await.unwrap(); - let alice_sas = alice_request.verification_cache.get_sas(&flow_id).unwrap(); + let alice_sas = + alice_request.verification_cache.get_sas(bob_device.user_id(), &flow_id).unwrap(); assert!(!bob_sas.is_cancelled()); assert!(!alice_sas.is_cancelled()); @@ -867,7 +868,8 @@ mod test { let content = StartContent::try_from(&start_content).unwrap(); let flow_id = content.flow_id().to_owned(); alice_request.receive_start(bob_device.user_id(), &content).await.unwrap(); - let alice_sas = alice_request.verification_cache.get_sas(&flow_id).unwrap(); + let alice_sas = + alice_request.verification_cache.get_sas(bob_device.user_id(), &flow_id).unwrap(); assert!(!bob_sas.is_cancelled()); assert!(!alice_sas.is_cancelled());