diff --git a/matrix_sdk_crypto/src/verification/machine.rs b/matrix_sdk_crypto/src/verification/machine.rs index 7bc5c666..d57c38f3 100644 --- a/matrix_sdk_crypto/src/verification/machine.rs +++ b/matrix_sdk_crypto/src/verification/machine.rs @@ -37,13 +37,80 @@ use crate::{ OutgoingVerificationRequest, ReadOnlyAccount, ReadOnlyDevice, RoomMessageRequest, }; +#[derive(Clone, Debug)] +pub struct VerificationCache { + sas_verification: Arc>, + room_sas_verifications: Arc>, +} + +impl VerificationCache { + fn new() -> Self { + Self { + sas_verification: DashMap::new().into(), + room_sas_verifications: DashMap::new().into(), + } + } + + #[cfg(test)] + fn is_empty(&self) -> bool { + self.room_sas_verifications.is_empty() && self.sas_verification.is_empty() + } + + pub fn get_room_sas(&self, event_id: &EventId) -> Option { + self.room_sas_verifications.get(event_id).map(|s| s.clone()) + } + + pub fn garbage_collect(&self) -> Vec { + self.sas_verification.retain(|_, s| !(s.is_done() || s.is_canceled())); + self.room_sas_verifications.retain(|_, s| !(s.is_done() || s.is_canceled())); + + let mut requests: Vec = self + .sas_verification + .iter() + .filter_map(|s| { + s.cancel_if_timed_out().map(|r| OutgoingRequest { + request_id: r.request_id(), + request: Arc::new(r.into()), + }) + }) + .collect(); + let room_requests: Vec = self + .room_sas_verifications + .iter() + .filter_map(|s| { + s.cancel_if_timed_out().map(|r| OutgoingRequest { + request_id: r.request_id(), + request: Arc::new(r.into()), + }) + }) + .collect(); + + requests.extend(room_requests); + + requests + } + + pub fn get_sas(&self, transaction_id: &str) -> Option { + let sas = if let Ok(e) = EventId::try_from(transaction_id) { + self.room_sas_verifications.get(&e).map(|s| s.clone()) + } else { + None + }; + + if sas.is_some() { + sas + } else { + self.sas_verification.get(transaction_id).map(|s| s.clone()) + } + } +} + #[derive(Clone, Debug)] pub struct VerificationMachine { account: ReadOnlyAccount, private_identity: Arc>, pub(crate) store: Arc>, - verifications: Arc>, - room_verifications: Arc>, + verifications: VerificationCache, requests: Arc>, outgoing_messages: Arc>, } @@ -58,9 +125,8 @@ impl VerificationMachine { account, private_identity: identity, store, - verifications: DashMap::new().into(), + verifications: VerificationCache::new(), requests: DashMap::new().into(), - room_verifications: DashMap::new().into(), outgoing_messages: DashMap::new().into(), } } @@ -89,7 +155,9 @@ impl VerificationMachine { let request = content_to_request(device.user_id(), device.device_id().to_owned(), c); - self.verifications.insert(sas.flow_id().as_str().to_owned(), sas.clone()); + self.verifications + .sas_verification + .insert(sas.flow_id().as_str().to_owned(), sas.clone()); request.into() } @@ -103,19 +171,7 @@ impl VerificationMachine { } pub fn get_sas(&self, transaction_id: &str) -> Option { - let sas = if let Ok(e) = EventId::try_from(transaction_id) { - #[allow(clippy::map_clone)] - self.room_verifications.get(&e).map(|s| s.clone()) - } else { - None - }; - - if sas.is_some() { - sas - } else { - #[allow(clippy::map_clone)] - self.verifications.get(transaction_id).map(|s| s.clone()) - } + self.verifications.get_sas(transaction_id) } fn queue_up_content( @@ -170,15 +226,8 @@ impl VerificationMachine { } pub fn garbage_collect(&self) { - self.verifications.retain(|_, s| !(s.is_done() || s.is_canceled())); - - for sas in self.verifications.iter() { - if let Some(r) = sas.cancel_if_timed_out() { - self.outgoing_messages.insert( - r.request_id(), - OutgoingRequest { request_id: r.request_id(), request: Arc::new(r.into()) }, - ); - } + for request in self.verifications.garbage_collect() { + self.outgoing_messages.insert(*request.request_id(), request); } } @@ -255,7 +304,8 @@ impl VerificationMachine { // TODO remove this unwrap let accept_request = s.accept().unwrap(); - self.room_verifications + self.verifications + .room_sas_verifications .insert(e.content.relation.event_id.clone(), s); self.outgoing_messages @@ -273,7 +323,7 @@ impl VerificationMachine { } } AnySyncMessageEvent::KeyVerificationKey(e) => { - if let Some(s) = self.room_verifications.get(&e.content.relation.event_id) { + if let Some(s) = self.verifications.get_room_sas(&e.content.relation.event_id) { self.receive_room_event_helper( &s, &m.clone().into_full_event(room_id.clone()), @@ -281,7 +331,7 @@ impl VerificationMachine { }; } AnySyncMessageEvent::KeyVerificationMac(e) => { - if let Some(s) = self.room_verifications.get(&e.content.relation.event_id) { + if let Some(s) = self.verifications.get_room_sas(&e.content.relation.event_id) { self.receive_room_event_helper( &s, &m.clone().into_full_event(room_id.clone()), @@ -290,7 +340,7 @@ impl VerificationMachine { } AnySyncMessageEvent::KeyVerificationDone(e) => { - if let Some(s) = self.room_verifications.get(&e.content.relation.event_id) { + if let Some(s) = self.verifications.get_room_sas(&e.content.relation.event_id) { let content = s.receive_room_event(&m.clone().into_full_event(room_id.clone())); @@ -373,7 +423,9 @@ impl VerificationMachine { self.store.get_user_identity(&e.sender).await?, ) { Ok(s) => { - self.verifications.insert(e.content.transaction_id.clone(), s); + self.verifications + .sas_verification + .insert(e.content.transaction_id.clone(), s); } Err(c) => { warn!( @@ -391,7 +443,7 @@ impl VerificationMachine { } } AnyToDeviceEvent::KeyVerificationCancel(e) => { - self.verifications.remove(&e.content.transaction_id); + self.verifications.sas_verification.remove(&e.content.transaction_id); } AnyToDeviceEvent::KeyVerificationAccept(e) => { if let Some(s) = self.get_sas(&e.content.transaction_id) {