From 804bd221b2f095af5fe51d721be3812b70baa7d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Damir=20Jeli=C4=87?= Date: Wed, 2 Dec 2020 11:12:46 +0100 Subject: [PATCH] crypto: Improve key imports. This patch changes so key imports load all existing sessions at once instead loading a single session for each session we are importing. It removes the need to lock the session when we check the first known index and exposes the total number of sessions the key export contained. --- matrix_sdk/src/client.rs | 6 +- .../src/file_encryption/key_export.rs | 2 +- matrix_sdk_crypto/src/key_request.rs | 8 +-- matrix_sdk_crypto/src/machine.rs | 59 ++++++++++++------- .../src/olm/group_sessions/inbound.rs | 15 ++++- matrix_sdk_crypto/src/olm/mod.rs | 2 +- 6 files changed, 61 insertions(+), 31 deletions(-) diff --git a/matrix_sdk/src/client.rs b/matrix_sdk/src/client.rs index be63cc41..41ff0cc3 100644 --- a/matrix_sdk/src/client.rs +++ b/matrix_sdk/src/client.rs @@ -2042,6 +2042,10 @@ impl Client { /// * `passphrase` - The passphrase that should be used to decrypt the /// exported room keys. /// + /// Returns a tuple of numbers that represent the number of sessions that + /// were imported and the total number of sessions that were found in the + /// key export. + /// /// # Panics /// /// This method will panic if it isn't run on a Tokio runtime. @@ -2071,7 +2075,7 @@ impl Client { feature = "docs", doc(cfg(all(encryption, not(target_arch = "wasm32")))) )] - pub async fn import_keys(&self, path: PathBuf, passphrase: &str) -> Result { + pub async fn import_keys(&self, path: PathBuf, passphrase: &str) -> Result<(usize, usize)> { let olm = self .base_client .olm_machine() diff --git a/matrix_sdk_crypto/src/file_encryption/key_export.rs b/matrix_sdk_crypto/src/file_encryption/key_export.rs index f41d5825..b881ec64 100644 --- a/matrix_sdk_crypto/src/file_encryption/key_export.rs +++ b/matrix_sdk_crypto/src/file_encryption/key_export.rs @@ -314,7 +314,7 @@ mod test { let decrypted = decrypt_key_export(Cursor::new(encrypted), "1234").unwrap(); assert_eq!(export, decrypted); - assert_eq!(machine.import_keys(decrypted).await.unwrap(), 0); + assert_eq!(machine.import_keys(decrypted).await.unwrap(), (0, 1)); } #[test] diff --git a/matrix_sdk_crypto/src/key_request.rs b/matrix_sdk_crypto/src/key_request.rs index 32cbd69f..69ea629d 100644 --- a/matrix_sdk_crypto/src/key_request.rs +++ b/matrix_sdk_crypto/src/key_request.rs @@ -637,8 +637,8 @@ impl KeyRequestMachine { // If we have a previous session, check if we have a better version // and store the new one if so. let session = if let Some(old_session) = old_session { - let first_old_index = old_session.first_known_index().await; - let first_index = session.first_known_index().await; + let first_old_index = old_session.first_known_index(); + let first_index = session.first_known_index(); if first_old_index > first_index { self.mark_as_done(info).await?; @@ -855,7 +855,7 @@ mod test { .unwrap(); let first_session = first_session.unwrap(); - assert_eq!(first_session.first_known_index().await, 10); + assert_eq!(first_session.first_known_index(), 10); machine .store @@ -914,7 +914,7 @@ mod test { .await .unwrap(); - assert_eq!(second_session.unwrap().first_known_index().await, 0); + assert_eq!(second_session.unwrap().first_known_index(), 0); } #[async_test] diff --git a/matrix_sdk_crypto/src/machine.rs b/matrix_sdk_crypto/src/machine.rs index 9252e678..2cdab741 100644 --- a/matrix_sdk_crypto/src/machine.rs +++ b/matrix_sdk_crypto/src/machine.rs @@ -1012,7 +1012,9 @@ impl OlmMachine { /// imported into our store. If we already have a better version of a key /// the key will *not* be imported. /// - /// Returns the number of sessions that were imported to the store. + /// Returns a tuple of numbers that represent the number of sessions that + /// were imported and the total number of sessions that were found in the + /// key export. /// /// # Examples /// ```no_run @@ -1028,32 +1030,47 @@ impl OlmMachine { /// machine.import_keys(exported_keys).await.unwrap(); /// # }); /// ``` - pub async fn import_keys(&self, mut exported_keys: Vec) -> StoreResult { + pub async fn import_keys( + &self, + exported_keys: Vec, + ) -> StoreResult<(usize, usize)> { + struct ShallowSessions { + inner: BTreeMap, u32>, + } + + impl ShallowSessions { + fn has_better_session(&self, session: &InboundGroupSession) -> bool { + self.inner + .get(&session.room_id) + .map(|existing| existing <= &session.first_known_index()) + .unwrap_or(false) + } + } + let mut sessions = Vec::new(); - for key in exported_keys.drain(..) { + let existing_sessions = ShallowSessions { + inner: self + .store + .get_inbound_group_sessions() + .await? + .into_iter() + .map(|s| { + let index = s.first_known_index(); + (s.room_id, index) + }) + .collect(), + }; + + let total_sessions = exported_keys.len(); + + for key in exported_keys.into_iter() { let session = InboundGroupSession::from_export(key)?; // Only import the session if we didn't have this session or if it's // a better version of the same session, that is the first known // index is lower. - // TODO load all sessions so we don't do a thousand small loads. - if let Some(existing_session) = self - .store - .get_inbound_group_session( - &session.room_id, - &session.sender_key, - session.session_id(), - ) - .await? - { - let first_index = session.first_known_index().await; - let existing_index = existing_session.first_known_index().await; - - if first_index < existing_index { - sessions.push(session) - } - } else { + if !existing_sessions.has_better_session(&session) { sessions.push(session) } } @@ -1072,7 +1089,7 @@ impl OlmMachine { num_sessions ); - Ok(num_sessions) + Ok((num_sessions, total_sessions)) } /// Export the keys that match the given predicate. diff --git a/matrix_sdk_crypto/src/olm/group_sessions/inbound.rs b/matrix_sdk_crypto/src/olm/group_sessions/inbound.rs index a7249563..e5fdc9bf 100644 --- a/matrix_sdk_crypto/src/olm/group_sessions/inbound.rs +++ b/matrix_sdk_crypto/src/olm/group_sessions/inbound.rs @@ -57,6 +57,7 @@ use crate::error::{EventError, MegolmResult}; pub struct InboundGroupSession { inner: Arc>, session_id: Arc, + first_known_index: u32, pub(crate) sender_key: Arc, pub(crate) signing_key: Arc>, pub(crate) room_id: Arc, @@ -89,6 +90,7 @@ impl InboundGroupSession { ) -> Result { let session = OlmInboundGroupSession::new(&session_key.0)?; let session_id = session.session_id(); + let first_known_index = session.first_known_index(); let mut keys: BTreeMap = BTreeMap::new(); keys.insert(DeviceKeyAlgorithm::Ed25519, signing_key.to_owned()); @@ -97,6 +99,7 @@ impl InboundGroupSession { inner: Arc::new(Mutex::new(session)), session_id: session_id.into(), sender_key: sender_key.to_owned().into(), + first_known_index, signing_key: Arc::new(keys), room_id: Arc::new(room_id.clone()), forwarding_chains: Arc::new(Mutex::new(None)), @@ -134,6 +137,7 @@ impl InboundGroupSession { let key = Zeroizing::from(mem::take(&mut content.session_key)); let session = OlmInboundGroupSession::import(&key)?; + let first_known_index = session.first_known_index(); let mut forwarding_chains = content.forwarding_curve25519_key_chain.clone(); forwarding_chains.push(sender_key.to_owned()); @@ -147,6 +151,7 @@ impl InboundGroupSession { inner: Arc::new(Mutex::new(session)), session_id: content.session_id.as_str().into(), sender_key: content.sender_key.as_str().into(), + first_known_index, signing_key: Arc::new(sender_claimed_key), room_id: Arc::new(content.room_id.clone()), forwarding_chains: Arc::new(Mutex::new(Some(forwarding_chains))), @@ -178,7 +183,7 @@ impl InboundGroupSession { /// If only a limited part of this session should be exported use /// [`export_at_index()`](#method.export_at_index). pub async fn export(&self) -> ExportedRoomKey { - self.export_at_index(self.first_known_index().await) + self.export_at_index(self.first_known_index()) .await .expect("Can't export at the first known index") } @@ -221,12 +226,14 @@ impl InboundGroupSession { pickle_mode: PicklingMode, ) -> Result { let session = OlmInboundGroupSession::unpickle(pickle.pickle.0, pickle_mode)?; + let first_known_index = session.first_known_index(); let session_id = session.session_id(); Ok(InboundGroupSession { inner: Arc::new(Mutex::new(session)), session_id: session_id.into(), sender_key: pickle.sender_key.into(), + first_known_index, signing_key: Arc::new(pickle.signing_key), room_id: Arc::new(pickle.room_id), forwarding_chains: Arc::new(Mutex::new(pickle.forwarding_chains)), @@ -245,8 +252,8 @@ impl InboundGroupSession { } /// Get the first message index we know how to decrypt. - pub async fn first_known_index(&self) -> u32 { - self.inner.lock().await.first_known_index() + pub fn first_known_index(&self) -> u32 { + self.first_known_index } /// Decrypt the given ciphertext. @@ -368,6 +375,7 @@ impl TryFrom for InboundGroupSession { fn try_from(key: ExportedRoomKey) -> Result { let session = OlmInboundGroupSession::import(&key.session_key.0)?; + let first_known_index = session.first_known_index(); let forwarding_chains = if key.forwarding_curve25519_key_chain.is_empty() { None @@ -379,6 +387,7 @@ impl TryFrom for InboundGroupSession { inner: Arc::new(Mutex::new(session)), session_id: key.session_id.into(), sender_key: key.sender_key.into(), + first_known_index, signing_key: Arc::new(key.sender_claimed_keys), room_id: Arc::new(key.room_id), forwarding_chains: Arc::new(Mutex::new(forwarding_chains)), diff --git a/matrix_sdk_crypto/src/olm/mod.rs b/matrix_sdk_crypto/src/olm/mod.rs index 2ec94b23..407d360c 100644 --- a/matrix_sdk_crypto/src/olm/mod.rs +++ b/matrix_sdk_crypto/src/olm/mod.rs @@ -213,7 +213,7 @@ pub(crate) mod test { ) .unwrap(); - assert_eq!(0, inbound.first_known_index().await); + assert_eq!(0, inbound.first_known_index()); assert_eq!(outbound.session_id(), inbound.session_id());