diff --git a/matrix_sdk_crypto/src/verification/machine.rs b/matrix_sdk_crypto/src/verification/machine.rs index 691c8bcb..68e4870c 100644 --- a/matrix_sdk_crypto/src/verification/machine.rs +++ b/matrix_sdk_crypto/src/verification/machine.rs @@ -41,6 +41,7 @@ use crate::{ pub struct VerificationCache { sas_verification: Arc>, room_sas_verifications: Arc>, + outgoing_requests: Arc>, } impl VerificationCache { @@ -48,6 +49,7 @@ impl VerificationCache { Self { sas_verification: DashMap::new().into(), room_sas_verifications: DashMap::new().into(), + outgoing_requests: DashMap::new().into(), } } @@ -103,6 +105,45 @@ impl VerificationCache { self.sas_verification.get(transaction_id).map(|s| s.clone()) } } + + pub fn add_request(&self, request: OutgoingRequest) { + self.outgoing_requests.insert(request.request_id, request); + } + + pub fn queue_up_content( + &self, + recipient: &UserId, + recipient_device: &DeviceId, + content: OutgoingContent, + ) { + match content { + OutgoingContent::ToDevice(c) => { + let request = content_to_request(recipient, recipient_device.to_owned(), c); + let request_id = request.txn_id; + + let request = OutgoingRequest { request_id, request: Arc::new(request.into()) }; + + self.outgoing_requests.insert(request_id, request); + } + + OutgoingContent::Room(r, c) => { + let request_id = Uuid::new_v4(); + + let request = OutgoingRequest { + request: Arc::new( + RoomMessageRequest { room_id: r, txn_id: request_id, content: c }.into(), + ), + request_id, + }; + + self.outgoing_requests.insert(request_id, request); + } + } + } + + pub fn mark_request_as_sent(&self, uuid: &Uuid) { + self.outgoing_requests.remove(uuid); + } } #[derive(Clone, Debug)] @@ -112,7 +153,6 @@ pub struct VerificationMachine { pub(crate) store: Arc>, verifications: VerificationCache, requests: Arc>, - outgoing_messages: Arc>, } impl VerificationMachine { @@ -127,7 +167,6 @@ impl VerificationMachine { store, verifications: VerificationCache::new(), requests: DashMap::new().into(), - outgoing_messages: DashMap::new().into(), } } @@ -180,29 +219,7 @@ impl VerificationMachine { recipient_device: &DeviceId, content: OutgoingContent, ) { - match content { - OutgoingContent::ToDevice(c) => { - let request = content_to_request(recipient, recipient_device.to_owned(), c); - let request_id = request.txn_id; - - let request = OutgoingRequest { request_id, request: Arc::new(request.into()) }; - - self.outgoing_messages.insert(request_id, request); - } - - OutgoingContent::Room(r, c) => { - let request_id = Uuid::new_v4(); - - let request = OutgoingRequest { - request: Arc::new( - RoomMessageRequest { room_id: r, txn_id: request_id, content: c }.into(), - ), - request_id, - }; - - self.outgoing_messages.insert(request_id, request); - } - } + self.verifications.queue_up_content(recipient, recipient_device, content) } fn receive_room_event_helper(&self, sas: &Sas, event: &AnyMessageEvent) { @@ -218,16 +235,16 @@ impl VerificationMachine { } pub fn mark_request_as_sent(&self, uuid: &Uuid) { - self.outgoing_messages.remove(uuid); + self.verifications.mark_request_as_sent(uuid); } pub fn outgoing_messages(&self) -> Vec { - self.outgoing_messages.iter().map(|r| (*r).clone()).collect() + self.verifications.outgoing_requests.iter().map(|r| (*r).clone()).collect() } pub fn garbage_collect(&self) { for request in self.verifications.garbage_collect() { - self.outgoing_messages.insert(*request.request_id(), request); + self.verifications.add_request(request) } } @@ -309,8 +326,7 @@ impl VerificationMachine { .room_sas_verifications .insert(e.content.relation.event_id.clone(), s); - self.outgoing_messages - .insert(accept_request.request_id(), accept_request.into()); + self.verifications.add_request(accept_request.into()); } Err(c) => { warn!( @@ -357,12 +373,10 @@ impl VerificationMachine { } } VerificationResult::Cancel(r) => { - self.outgoing_messages.insert(r.request_id(), r.into()); + self.verifications.add_request(r.into()); } VerificationResult::SignatureUpload(r) => { - let request: OutgoingRequest = r.into(); - - self.outgoing_messages.insert(request.request_id, request); + self.verifications.add_request(r.into()); if let Some(c) = content { self.queue_up_content( @@ -465,21 +479,10 @@ impl VerificationMachine { match s.mark_as_done().await? { VerificationResult::Ok => (), VerificationResult::Cancel(r) => { - self.outgoing_messages.insert( - r.request_id(), - OutgoingRequest { - request_id: r.request_id(), - request: Arc::new(r.into()), - }, - ); + self.verifications.add_request(r.into()); } VerificationResult::SignatureUpload(r) => { - let request_id = Uuid::new_v4(); - - self.outgoing_messages.insert( - request_id, - OutgoingRequest { request_id, request: Arc::new(r.into()) }, - ); + self.verifications.add_request(r.into()); } } } @@ -586,11 +589,11 @@ mod test { .map(|c| wrap_any_to_device_content(bob.user_id(), c)) .unwrap(); - assert!(alice_machine.outgoing_messages.is_empty()); + assert!(alice_machine.verifications.outgoing_requests.is_empty()); alice_machine.receive_event(&event).await.unwrap(); - assert!(!alice_machine.outgoing_messages.is_empty()); + assert!(!alice_machine.verifications.outgoing_requests.is_empty()); - let request = alice_machine.outgoing_messages.iter().next().unwrap(); + let request = alice_machine.verifications.outgoing_requests.iter().next().unwrap(); let txn_id = *request.request_id(); @@ -635,14 +638,14 @@ mod test { let alice = alice_machine.get_sas(bob.flow_id().as_str()).unwrap(); assert!(!alice.timed_out()); - assert!(alice_machine.outgoing_messages.is_empty()); + assert!(alice_machine.verifications.outgoing_requests.is_empty()); // This line panics on macOS, so we're disabled for now. alice.set_creation_time(Instant::now() - Duration::from_secs(60 * 15)); assert!(alice.timed_out()); - assert!(alice_machine.outgoing_messages.is_empty()); + assert!(alice_machine.verifications.outgoing_requests.is_empty()); alice_machine.garbage_collect(); - assert!(!alice_machine.outgoing_messages.is_empty()); + assert!(!alice_machine.verifications.outgoing_requests.is_empty()); alice_machine.garbage_collect(); assert!(alice_machine.verifications.is_empty()); }