From 06b9c71dbc526b87ff5aa3b2c4a12571cef95510 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Damir=20Jeli=C4=87?= Date: Wed, 7 Oct 2020 12:42:39 +0200 Subject: [PATCH] crypto: Refactor out the key share wait queue. --- matrix_sdk_crypto/src/key_request.rs | 121 +++++++++++++++++++-------- 1 file changed, 85 insertions(+), 36 deletions(-) diff --git a/matrix_sdk_crypto/src/key_request.rs b/matrix_sdk_crypto/src/key_request.rs index b40797c5..b1d067ab 100644 --- a/matrix_sdk_crypto/src/key_request.rs +++ b/matrix_sdk_crypto/src/key_request.rs @@ -20,7 +20,7 @@ // If we don't trust the device store an object that remembers the request and // let the users introspect that object. -use dashmap::{DashMap, DashSet}; +use dashmap::{mapref::entry::Entry, DashMap, DashSet}; use serde::{Deserialize, Serialize}; use serde_json::value::to_raw_value; use std::{collections::BTreeMap, sync::Arc}; @@ -63,6 +63,66 @@ pub enum KeyshareDecision { UntrustedDevice, } +/// A queue where we store room key requests that we want to serve but the +/// device that requested the key doesn't share an Olm session with us. +#[derive(Debug, Clone)] +struct WaitQueue { + requests_waiting_for_session: + Arc>>, + requests_ids_waiting: Arc>>, +} + +impl WaitQueue { + fn new() -> Self { + Self { + requests_waiting_for_session: Arc::new(DashMap::new()), + requests_ids_waiting: Arc::new(DashMap::new()), + } + } + + #[cfg(test)] + fn is_empty(&self) -> bool { + self.requests_ids_waiting.is_empty() && self.requests_waiting_for_session.is_empty() + } + + fn insert(&self, device: &Device, event: &ToDeviceEvent) { + let key = ( + device.user_id().to_owned(), + device.device_id().into(), + event.content.request_id.to_owned(), + ); + self.requests_waiting_for_session.insert(key, event.clone()); + + let key = (device.user_id().to_owned(), device.device_id().into()); + self.requests_ids_waiting + .entry(key) + .or_insert_with(DashSet::new) + .insert(event.content.request_id.clone()); + } + + fn remove( + &self, + user_id: &UserId, + device_id: &DeviceId, + ) -> Vec<( + (UserId, DeviceIdBox, String), + ToDeviceEvent, + )> { + self.requests_ids_waiting + .remove(&(user_id.to_owned(), device_id.into())) + .map(|(_, request_ids)| { + request_ids + .iter() + .filter_map(|id| { + let key = (user_id.to_owned(), device_id.into(), id.to_owned()); + self.requests_waiting_for_session.remove(&key) + }) + .collect() + }) + .unwrap_or_default() + } +} + #[derive(Debug, Clone)] pub(crate) struct KeyRequestMachine { user_id: Arc, @@ -72,10 +132,7 @@ pub(crate) struct KeyRequestMachine { outgoing_to_device_requests: Arc>, incoming_key_requests: Arc>>, - // TODO group these hashmaps into a logical unit. - requests_waiting_for_session: - Arc>>, - requests_ids_waiting: Arc>>, + wait_queue: WaitQueue, users_for_key_claim: Arc>>, } @@ -147,9 +204,8 @@ impl KeyRequestMachine { outbound_group_sessions, outgoing_to_device_requests: Arc::new(DashMap::new()), incoming_key_requests: Arc::new(DashMap::new()), - requests_waiting_for_session: Arc::new(DashMap::new()), + wait_queue: WaitQueue::new(), users_for_key_claim: Arc::new(DashMap::new()), - requests_ids_waiting: Arc::new(DashMap::new()), } } @@ -189,6 +245,9 @@ impl KeyRequestMachine { Ok(()) } + /// Store the key share request for later, once we get an Olm session with + /// the given device [`retry_keyshare`](#method.retry_keyshare) should be + /// called. fn handle_key_share_without_session( &self, device: Device, @@ -198,18 +257,7 @@ impl KeyRequestMachine { .entry(device.user_id().to_owned()) .or_insert_with(DashSet::new) .insert(device.device_id().into()); - let key = ( - device.user_id().to_owned(), - device.device_id().into(), - event.content.request_id.to_owned(), - ); - self.requests_waiting_for_session.insert(key, event.clone()); - - let key = (device.user_id().to_owned(), device.device_id().into()); - self.requests_ids_waiting - .entry(key) - .or_insert_with(DashSet::new) - .insert(event.content.request_id.clone()); + self.wait_queue.insert(&device, event); } /// Retry keyshares for a device that previously didn't have an Olm session @@ -225,20 +273,21 @@ impl KeyRequestMachine { /// /// * `device_id` - The device id of the device that got the Olm session. pub fn retry_keyshare(&self, user_id: &UserId, device_id: &DeviceId) { - if let Some((_, request_ids)) = self - .requests_ids_waiting - .remove(&(user_id.to_owned(), device_id.into())) - { - for id in request_ids { - let key = (user_id.to_owned(), device_id.into(), id.to_owned()); - let content = self.requests_waiting_for_session.remove(&key); + match self.users_for_key_claim.entry(user_id.to_owned()) { + Entry::Occupied(e) => { + e.get().remove(device_id); - if let Some((_, c)) = content { - if !self.incoming_key_requests.contains_key(&key) { - self.incoming_key_requests.insert(key, c); - } + if e.get().is_empty() { + e.remove(); } } + _ => (), + } + + for (key, event) in self.wait_queue.remove(user_id, device_id) { + if !self.incoming_key_requests.contains_key(&key) { + self.incoming_key_requests.insert(key, event); + } } } @@ -1188,16 +1237,16 @@ mod test { // Bob doesn't have any outgoing requests. assert!(bob_machine.outgoing_to_device_requests.is_empty()); - assert!(bob_machine.requests_ids_waiting.is_empty()); - assert!(bob_machine.requests_waiting_for_session.is_empty()); + assert!(bob_machine.users_for_key_claim.is_empty()); + assert!(bob_machine.wait_queue.is_empty()); // Receive the room key request from alice. bob_machine.receive_incoming_key_request(&event); bob_machine.collect_incoming_key_requests().await.unwrap(); // Bob doens't have an outgoing requests since we're lacking a session. assert!(bob_machine.outgoing_to_device_requests.is_empty()); - assert!(!bob_machine.requests_ids_waiting.is_empty()); - assert!(!bob_machine.requests_waiting_for_session.is_empty()); + assert!(!bob_machine.users_for_key_claim.is_empty()); + assert!(!bob_machine.wait_queue.is_empty()); // We create a session now. alice_machine @@ -1212,11 +1261,11 @@ mod test { .unwrap(); bob_machine.retry_keyshare(&alice_id(), &alice_device_id()); + assert!(bob_machine.users_for_key_claim.is_empty()); bob_machine.collect_incoming_key_requests().await.unwrap(); // Bob now has an outgoing requests. assert!(!bob_machine.outgoing_to_device_requests.is_empty()); - assert!(bob_machine.requests_ids_waiting.is_empty()); - assert!(bob_machine.requests_waiting_for_session.is_empty()); + assert!(bob_machine.wait_queue.is_empty()); // Get the request and convert it to a encrypted to-device event. let request = bob_machine