From 8c007510cd14f5a935807b674d974396a018722f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Damir=20Jeli=C4=87?= Date: Thu, 15 Apr 2021 19:40:24 +0200 Subject: [PATCH] crypto: Only load the outgoing key requests when we want to send them out --- matrix_sdk/src/client.rs | 11 +- matrix_sdk_base/src/client.rs | 4 +- matrix_sdk_crypto/src/key_request.rs | 156 ++++++++++++--------- matrix_sdk_crypto/src/machine.rs | 18 ++- matrix_sdk_crypto/src/store/memorystore.rs | 3 +- matrix_sdk_crypto/src/store/mod.rs | 2 +- matrix_sdk_crypto/src/store/sled.rs | 109 ++++++++++---- 7 files changed, 198 insertions(+), 105 deletions(-) diff --git a/matrix_sdk/src/client.rs b/matrix_sdk/src/client.rs index a6b6164b..7453f365 100644 --- a/matrix_sdk/src/client.rs +++ b/matrix_sdk/src/client.rs @@ -1839,7 +1839,16 @@ impl Client { warn!("Error while claiming one-time keys {:?}", e); } - for r in self.base_client.outgoing_requests().await { + // TODO we should probably abort if we get an cryptostore error here + let outgoing_requests = match self.base_client.outgoing_requests().await { + Ok(r) => r, + Err(e) => { + warn!("Could not fetch the outgoing requests {:?}", e); + vec![] + } + }; + + for r in outgoing_requests { match r.request() { OutgoingRequests::KeysQuery(request) => { if let Err(e) = self diff --git a/matrix_sdk_base/src/client.rs b/matrix_sdk_base/src/client.rs index 9e98489c..eedd48b0 100644 --- a/matrix_sdk_base/src/client.rs +++ b/matrix_sdk_base/src/client.rs @@ -1069,12 +1069,12 @@ impl BaseClient { /// [`mark_request_as_sent`]: #method.mark_request_as_sent #[cfg(feature = "encryption")] #[cfg_attr(feature = "docs", doc(cfg(encryption)))] - pub async fn outgoing_requests(&self) -> Vec { + pub async fn outgoing_requests(&self) -> Result, CryptoStoreError> { let olm = self.olm.lock().await; match &*olm { Some(o) => o.outgoing_requests().await, - None => vec![], + None => Ok(vec![]), } } diff --git a/matrix_sdk_crypto/src/key_request.rs b/matrix_sdk_crypto/src/key_request.rs index b00e8fbb..52ee3ade 100644 --- a/matrix_sdk_crypto/src/key_request.rs +++ b/matrix_sdk_crypto/src/key_request.rs @@ -150,7 +150,7 @@ pub struct OutgoingKeyRequest { } impl OutgoingKeyRequest { - fn into_request( + fn to_request( &self, recipient: &UserId, own_device_id: &DeviceId, @@ -214,30 +214,25 @@ impl KeyRequestMachine { device_id, store, outbound_group_sessions, - outgoing_to_device_requests: Arc::new(DashMap::new()), - incoming_key_requests: Arc::new(DashMap::new()), + outgoing_to_device_requests: DashMap::new().into(), + incoming_key_requests: DashMap::new().into(), wait_queue: WaitQueue::new(), users_for_key_claim, } } - /// Load stored non-sent out outgoing requests - pub async fn load_outgoing_requests(&mut self) -> Result<(), CryptoStoreError> { - let infos: Vec = vec![]; - let requests: DashMap = infos - .iter() + /// Load stored outgoing requests that were not yet sent out. + async fn load_outgoing_requests(&self) -> Result, CryptoStoreError> { + self.store + .get_unsent_key_requests() + .await? + .into_iter() .filter(|i| !i.sent_out) - .filter_map(|info| { - Some(( - info.request_id, - info.into_request(self.user_id(), self.device_id()).ok()?, - )) + .map(|info| { + info.to_request(self.user_id(), self.device_id()) + .map_err(CryptoStoreError::from) }) - .collect(); - - self.outgoing_to_device_requests = requests.into(); - - Ok(()) + .collect() } /// Our own user id. @@ -250,12 +245,18 @@ impl KeyRequestMachine { &self.device_id } - pub fn outgoing_to_device_requests(&self) -> Vec { - #[allow(clippy::map_clone)] - self.outgoing_to_device_requests + pub async fn outgoing_to_device_requests( + &self, + ) -> Result, CryptoStoreError> { + let mut key_requests = self.load_outgoing_requests().await?; + let key_forwards: Vec = self + .outgoing_to_device_requests .iter() - .map(|r| (*r).clone()) - .collect() + .map(|i| i.value().clone()) + .collect(); + key_requests.extend(key_forwards); + + Ok(key_requests) } /// Receive a room key request event. @@ -584,10 +585,7 @@ impl KeyRequestMachine { sent_out: false, }; - let request = info.into_request(self.user_id(), self.device_id())?; - self.save_outgoing_key_info(info).await?; - self.outgoing_to_device_requests.insert(id, request); Ok(()) } @@ -830,7 +828,11 @@ mod test { async fn create_machine() { let machine = get_machine().await; - assert!(machine.outgoing_to_device_requests().is_empty()); + assert!(machine + .outgoing_to_device_requests() + .await + .unwrap() + .is_empty()); } #[async_test] @@ -843,7 +845,11 @@ mod test { .await .unwrap(); - assert!(machine.outgoing_to_device_requests().is_empty()); + assert!(machine + .outgoing_to_device_requests() + .await + .unwrap() + .is_empty()); machine .create_outgoing_key_request( session.room_id(), @@ -852,8 +858,15 @@ mod test { ) .await .unwrap(); - assert!(!machine.outgoing_to_device_requests().is_empty()); - assert_eq!(machine.outgoing_to_device_requests().len(), 1); + assert!(!machine + .outgoing_to_device_requests() + .await + .unwrap() + .is_empty()); + assert_eq!( + machine.outgoing_to_device_requests().await.unwrap().len(), + 1 + ); machine .create_outgoing_key_request( @@ -863,15 +876,21 @@ mod test { ) .await .unwrap(); - assert_eq!(machine.outgoing_to_device_requests.len(), 1); - let request = machine.outgoing_to_device_requests.iter().next().unwrap(); + let requests = machine.outgoing_to_device_requests().await.unwrap(); + assert_eq!(requests.len(), 1); - let id = request.request_id; - drop(request); + let request = requests.get(0).unwrap(); - machine.mark_outgoing_request_as_sent(id).await.unwrap(); - assert!(machine.outgoing_to_device_requests.is_empty()); + machine + .mark_outgoing_request_as_sent(request.request_id) + .await + .unwrap(); + assert!(machine + .outgoing_to_device_requests() + .await + .unwrap() + .is_empty()); } #[async_test] @@ -892,9 +911,9 @@ mod test { .await .unwrap(); - let request = machine.outgoing_to_device_requests.iter().next().unwrap(); + let requests = machine.outgoing_to_device_requests().await.unwrap(); + let request = requests.get(0).unwrap(); let id = request.request_id; - drop(request); machine.mark_outgoing_request_as_sent(id).await.unwrap(); @@ -949,11 +968,13 @@ mod test { .await .unwrap(); - let request = machine.outgoing_to_device_requests.iter().next().unwrap(); - let id = request.request_id; - drop(request); + let requests = machine.outgoing_to_device_requests().await.unwrap(); + let request = &requests[0]; - machine.mark_outgoing_request_as_sent(id).await.unwrap(); + machine + .mark_outgoing_request_as_sent(request.request_id) + .await + .unwrap(); let export = session.export_at_index(15).await; @@ -1160,11 +1181,8 @@ mod test { .insert(group_session.clone()); // Get the request and convert it into a event. - let request = alice_machine - .outgoing_to_device_requests - .iter() - .next() - .unwrap(); + let requests = alice_machine.outgoing_to_device_requests().await.unwrap(); + let request = &requests[0]; let id = request.request_id; let content = request .request @@ -1178,7 +1196,6 @@ mod test { let content: RoomKeyRequestToDeviceEventContent = serde_json::from_str(content.get()).unwrap(); - drop(request); alice_machine .mark_outgoing_request_as_sent(id) .await @@ -1199,11 +1216,8 @@ mod test { assert!(!bob_machine.outgoing_to_device_requests.is_empty()); // Get the request and convert it to a encrypted to-device event. - let request = bob_machine - .outgoing_to_device_requests - .iter() - .next() - .unwrap(); + let requests = bob_machine.outgoing_to_device_requests().await.unwrap(); + let request = &requests[0]; let id = request.request_id; let content = request @@ -1217,7 +1231,6 @@ mod test { .unwrap(); let content: EncryptedEventContent = serde_json::from_str(content.get()).unwrap(); - drop(request); bob_machine.mark_outgoing_request_as_sent(id).await.unwrap(); let event = ToDeviceEvent { @@ -1326,11 +1339,8 @@ mod test { .insert(group_session.clone()); // Get the request and convert it into a event. - let request = alice_machine - .outgoing_to_device_requests - .iter() - .next() - .unwrap(); + let requests = alice_machine.outgoing_to_device_requests().await.unwrap(); + let request = &requests[0]; let id = request.request_id; let content = request .request @@ -1344,7 +1354,6 @@ mod test { let content: RoomKeyRequestToDeviceEventContent = serde_json::from_str(content.get()).unwrap(); - drop(request); alice_machine .mark_outgoing_request_as_sent(id) .await @@ -1356,7 +1365,11 @@ mod test { }; // Bob doesn't have any outgoing requests. - assert!(bob_machine.outgoing_to_device_requests.is_empty()); + assert!(bob_machine + .outgoing_to_device_requests() + .await + .unwrap() + .is_empty()); assert!(bob_machine.users_for_key_claim.is_empty()); assert!(bob_machine.wait_queue.is_empty()); @@ -1364,7 +1377,11 @@ mod test { bob_machine.receive_incoming_key_request(&event); bob_machine.collect_incoming_key_requests().await.unwrap(); // Bob doens't have an outgoing requests since we're lacking a session. - assert!(bob_machine.outgoing_to_device_requests.is_empty()); + assert!(bob_machine + .outgoing_to_device_requests() + .await + .unwrap() + .is_empty()); assert!(!bob_machine.users_for_key_claim.is_empty()); assert!(!bob_machine.wait_queue.is_empty()); @@ -1384,15 +1401,17 @@ mod test { assert!(bob_machine.users_for_key_claim.is_empty()); bob_machine.collect_incoming_key_requests().await.unwrap(); // Bob now has an outgoing requests. - assert!(!bob_machine.outgoing_to_device_requests.is_empty()); + assert!(!bob_machine + .outgoing_to_device_requests() + .await + .unwrap() + .is_empty()); assert!(bob_machine.wait_queue.is_empty()); // Get the request and convert it to a encrypted to-device event. - let request = bob_machine - .outgoing_to_device_requests - .iter() - .next() - .unwrap(); + let requests = bob_machine.outgoing_to_device_requests().await.unwrap(); + + let request = &requests[0]; let id = request.request_id; let content = request @@ -1406,7 +1425,6 @@ mod test { .unwrap(); let content: EncryptedEventContent = serde_json::from_str(content.get()).unwrap(); - drop(request); bob_machine.mark_outgoing_request_as_sent(id).await.unwrap(); let event = ToDeviceEvent { diff --git a/matrix_sdk_crypto/src/machine.rs b/matrix_sdk_crypto/src/machine.rs index 746523f3..e7203043 100644 --- a/matrix_sdk_crypto/src/machine.rs +++ b/matrix_sdk_crypto/src/machine.rs @@ -245,10 +245,9 @@ impl OlmMachine { } }; - let mut machine = OlmMachine::new_helper(&user_id, device_id, store, account, identity); - machine.key_request_machine.load_outgoing_requests().await?; - - Ok(machine) + Ok(OlmMachine::new_helper( + &user_id, device_id, store, account, identity, + )) } /// Create a new machine with the default crypto store. @@ -295,7 +294,7 @@ impl OlmMachine { /// machine using [`mark_request_as_sent`]. /// /// [`mark_request_as_sent`]: #method.mark_request_as_sent - pub async fn outgoing_requests(&self) -> Vec { + pub async fn outgoing_requests(&self) -> StoreResult> { let mut requests = Vec::new(); if let Some(r) = self.keys_for_upload().await.map(|r| OutgoingRequest { @@ -320,9 +319,14 @@ impl OlmMachine { requests.append(&mut self.outgoing_to_device_requests()); requests.append(&mut self.verification_machine.outgoing_room_message_requests()); - requests.append(&mut self.key_request_machine.outgoing_to_device_requests()); + requests.append( + &mut self + .key_request_machine + .outgoing_to_device_requests() + .await?, + ); - requests + Ok(requests) } /// Mark the request with the given request id as sent. diff --git a/matrix_sdk_crypto/src/store/memorystore.rs b/matrix_sdk_crypto/src/store/memorystore.rs index 150354ca..610262a5 100644 --- a/matrix_sdk_crypto/src/store/memorystore.rs +++ b/matrix_sdk_crypto/src/store/memorystore.rs @@ -268,10 +268,11 @@ impl CryptoStore for MemoryStore { .and_then(|i| self.outgoing_key_requests.get(&i).map(|r| r.clone()))) } - async fn get_outgoing_key_requests(&self) -> Result> { + async fn get_unsent_key_requests(&self) -> Result> { Ok(self .outgoing_key_requests .iter() + .filter(|i| !i.value().sent_out) .map(|i| i.value().clone()) .collect()) } diff --git a/matrix_sdk_crypto/src/store/mod.rs b/matrix_sdk_crypto/src/store/mod.rs index 6f5b1338..1982e1cc 100644 --- a/matrix_sdk_crypto/src/store/mod.rs +++ b/matrix_sdk_crypto/src/store/mod.rs @@ -450,7 +450,7 @@ pub trait CryptoStore: AsyncTraitDeps { ) -> Result>; /// Get all outgoing key requests that we have in the store. - async fn get_outgoing_key_requests(&self) -> Result>; + async fn get_unsent_key_requests(&self) -> Result>; /// Delete an outoing key request that we created that matches the given /// request id. diff --git a/matrix_sdk_crypto/src/store/sled.rs b/matrix_sdk_crypto/src/store/sled.rs index 0bda2de0..d82f62fc 100644 --- a/matrix_sdk_crypto/src/store/sled.rs +++ b/matrix_sdk_crypto/src/store/sled.rs @@ -149,6 +149,7 @@ pub struct SledStore { outbound_group_sessions: Tree, outgoing_key_requests: Tree, + unsent_key_requests: Tree, key_requests_by_info: Tree, devices: Tree, @@ -215,6 +216,7 @@ impl SledStore { let identities = db.open_tree("identities")?; let outgoing_key_requests = db.open_tree("outgoing_key_requests")?; + let unsent_key_requests = db.open_tree("unsent_key_requests")?; let key_requests_by_info = db.open_tree("key_requests_by_info")?; let session_cache = SessionStore::new(); @@ -240,6 +242,7 @@ impl SledStore { inbound_group_sessions, outbound_group_sessions, outgoing_key_requests, + unsent_key_requests, key_requests_by_info, devices, tracked_users, @@ -376,6 +379,7 @@ impl SledStore { &self.outbound_group_sessions, &self.olm_hashes, &self.outgoing_key_requests, + &self.unsent_key_requests, &self.key_requests_by_info, ) .transaction( @@ -389,6 +393,7 @@ impl SledStore { outbound_sessions, hashes, outgoing_key_requests, + unsent_key_requests, key_requests_by_info, )| { if let Some(a) = &account_pickle { @@ -463,11 +468,23 @@ impl SledStore { key_request.request_id.encode(), )?; - outgoing_key_requests.insert( - key_request.request_id.encode(), - serde_json::to_vec(&key_request) - .map_err(ConflictableTransactionError::Abort)?, - )?; + let key_request_id = key_request.request_id.encode(); + + if key_request.sent_out { + unsent_key_requests.remove(key_request_id.clone())?; + outgoing_key_requests.insert( + key_request_id, + serde_json::to_vec(&key_request) + .map_err(ConflictableTransactionError::Abort)?, + )?; + } else { + outgoing_key_requests.remove(key_request_id.clone())?; + unsent_key_requests.insert( + key_request_id, + serde_json::to_vec(&key_request) + .map_err(ConflictableTransactionError::Abort)?, + )?; + } } Ok(()) @@ -479,6 +496,28 @@ impl SledStore { Ok(()) } + + async fn get_outgoing_key_request_helper( + &self, + id: &[u8], + ) -> Result> { + let request = self + .outgoing_key_requests + .get(id)? + .map(|r| serde_json::from_slice(&r)) + .transpose()?; + + let request = if request.is_none() { + self.unsent_key_requests + .get(id)? + .map(|r| serde_json::from_slice(&r)) + .transpose()? + } else { + request + }; + + Ok(request) + } } #[async_trait] @@ -685,11 +724,9 @@ impl CryptoStore for SledStore { &self, request_id: Uuid, ) -> Result> { - Ok(self - .outgoing_key_requests - .get(request_id.encode())? - .map(|r| serde_json::from_slice(&r)) - .transpose()?) + let request_id = request_id.encode(); + + self.get_outgoing_key_request_helper(&request_id).await } async fn get_key_request_by_info( @@ -699,19 +736,15 @@ impl CryptoStore for SledStore { let id = self.key_requests_by_info.get(key_info.encode())?; if let Some(id) = id { - Ok(self - .outgoing_key_requests - .get(id)? - .map(|r| serde_json::from_slice(&r)) - .transpose()?) + self.get_outgoing_key_request_helper(&id).await } else { Ok(None) } } - async fn get_outgoing_key_requests(&self) -> Result> { + async fn get_unsent_key_requests(&self) -> Result> { let requests: Result> = self - .outgoing_key_requests + .unsent_key_requests .iter() .map(|i| serde_json::from_slice(&i?.1).map_err(CryptoStoreError::from)) .collect(); @@ -720,16 +753,30 @@ impl CryptoStore for SledStore { } async fn delete_outgoing_key_request(&self, request_id: Uuid) -> Result<()> { - let ret: Result<(), TransactionError> = - (&self.outgoing_key_requests, &self.key_requests_by_info).transaction( - |(outgoing_key_requests, key_requests_by_info)| { - let request: Option = outgoing_key_requests + let ret: Result<(), TransactionError> = ( + &self.outgoing_key_requests, + &self.unsent_key_requests, + &self.key_requests_by_info, + ) + .transaction( + |(outgoing_key_requests, unsent_key_requests, key_requests_by_info)| { + let sent_request: Option = outgoing_key_requests .remove(request_id.encode())? .map(|r| serde_json::from_slice(&r)) .transpose() .map_err(ConflictableTransactionError::Abort)?; - if let Some(request) = request { + let unsent_request: Option = unsent_key_requests + .remove(request_id.encode())? + .map(|r| serde_json::from_slice(&r)) + .transpose() + .map_err(ConflictableTransactionError::Abort)?; + + if let Some(request) = sent_request { + key_requests_by_info.remove((&request.info).encode())?; + } + + if let Some(request) = unsent_request { key_requests_by_info.remove((&request.info).encode())?; } @@ -1328,7 +1375,21 @@ mod test { let stored_request = store.get_key_request_by_info(&info).await.unwrap(); assert_eq!(request, stored_request); - assert!(!store.get_outgoing_key_requests().await.unwrap().is_empty()); + assert!(!store.get_unsent_key_requests().await.unwrap().is_empty()); + + let request = OutgoingKeyRequest { + request_id: id, + info: info.clone(), + sent_out: true, + }; + + let mut changes = Changes::default(); + changes.key_requests.push(request.clone()); + store.save_changes(changes).await.unwrap(); + + assert!(store.get_unsent_key_requests().await.unwrap().is_empty()); + let stored_request = store.get_outgoing_key_request(id).await.unwrap(); + assert_eq!(Some(request), stored_request); store.delete_outgoing_key_request(id).await.unwrap(); @@ -1337,6 +1398,6 @@ mod test { let stored_request = store.get_key_request_by_info(&info).await.unwrap(); assert_eq!(None, stored_request); - assert!(store.get_outgoing_key_requests().await.unwrap().is_empty()); + assert!(store.get_unsent_key_requests().await.unwrap().is_empty()); } }