diff --git a/matrix_sdk/src/client.rs b/matrix_sdk/src/client.rs index be63cc41..41ff0cc3 100644 --- a/matrix_sdk/src/client.rs +++ b/matrix_sdk/src/client.rs @@ -2042,6 +2042,10 @@ impl Client { /// * `passphrase` - The passphrase that should be used to decrypt the /// exported room keys. /// + /// Returns a tuple of numbers that represent the number of sessions that + /// were imported and the total number of sessions that were found in the + /// key export. + /// /// # Panics /// /// This method will panic if it isn't run on a Tokio runtime. @@ -2071,7 +2075,7 @@ impl Client { feature = "docs", doc(cfg(all(encryption, not(target_arch = "wasm32")))) )] - pub async fn import_keys(&self, path: PathBuf, passphrase: &str) -> Result { + pub async fn import_keys(&self, path: PathBuf, passphrase: &str) -> Result<(usize, usize)> { let olm = self .base_client .olm_machine() diff --git a/matrix_sdk_crypto/src/file_encryption/key_export.rs b/matrix_sdk_crypto/src/file_encryption/key_export.rs index f41d5825..b881ec64 100644 --- a/matrix_sdk_crypto/src/file_encryption/key_export.rs +++ b/matrix_sdk_crypto/src/file_encryption/key_export.rs @@ -314,7 +314,7 @@ mod test { let decrypted = decrypt_key_export(Cursor::new(encrypted), "1234").unwrap(); assert_eq!(export, decrypted); - assert_eq!(machine.import_keys(decrypted).await.unwrap(), 0); + assert_eq!(machine.import_keys(decrypted).await.unwrap(), (0, 1)); } #[test] diff --git a/matrix_sdk_crypto/src/key_request.rs b/matrix_sdk_crypto/src/key_request.rs index 32cbd69f..69ea629d 100644 --- a/matrix_sdk_crypto/src/key_request.rs +++ b/matrix_sdk_crypto/src/key_request.rs @@ -637,8 +637,8 @@ impl KeyRequestMachine { // If we have a previous session, check if we have a better version // and store the new one if so. let session = if let Some(old_session) = old_session { - let first_old_index = old_session.first_known_index().await; - let first_index = session.first_known_index().await; + let first_old_index = old_session.first_known_index(); + let first_index = session.first_known_index(); if first_old_index > first_index { self.mark_as_done(info).await?; @@ -855,7 +855,7 @@ mod test { .unwrap(); let first_session = first_session.unwrap(); - assert_eq!(first_session.first_known_index().await, 10); + assert_eq!(first_session.first_known_index(), 10); machine .store @@ -914,7 +914,7 @@ mod test { .await .unwrap(); - assert_eq!(second_session.unwrap().first_known_index().await, 0); + assert_eq!(second_session.unwrap().first_known_index(), 0); } #[async_test] diff --git a/matrix_sdk_crypto/src/machine.rs b/matrix_sdk_crypto/src/machine.rs index 9252e678..2cdab741 100644 --- a/matrix_sdk_crypto/src/machine.rs +++ b/matrix_sdk_crypto/src/machine.rs @@ -1012,7 +1012,9 @@ impl OlmMachine { /// imported into our store. If we already have a better version of a key /// the key will *not* be imported. /// - /// Returns the number of sessions that were imported to the store. + /// Returns a tuple of numbers that represent the number of sessions that + /// were imported and the total number of sessions that were found in the + /// key export. /// /// # Examples /// ```no_run @@ -1028,32 +1030,47 @@ impl OlmMachine { /// machine.import_keys(exported_keys).await.unwrap(); /// # }); /// ``` - pub async fn import_keys(&self, mut exported_keys: Vec) -> StoreResult { + pub async fn import_keys( + &self, + exported_keys: Vec, + ) -> StoreResult<(usize, usize)> { + struct ShallowSessions { + inner: BTreeMap, u32>, + } + + impl ShallowSessions { + fn has_better_session(&self, session: &InboundGroupSession) -> bool { + self.inner + .get(&session.room_id) + .map(|existing| existing <= &session.first_known_index()) + .unwrap_or(false) + } + } + let mut sessions = Vec::new(); - for key in exported_keys.drain(..) { + let existing_sessions = ShallowSessions { + inner: self + .store + .get_inbound_group_sessions() + .await? + .into_iter() + .map(|s| { + let index = s.first_known_index(); + (s.room_id, index) + }) + .collect(), + }; + + let total_sessions = exported_keys.len(); + + for key in exported_keys.into_iter() { let session = InboundGroupSession::from_export(key)?; // Only import the session if we didn't have this session or if it's // a better version of the same session, that is the first known // index is lower. - // TODO load all sessions so we don't do a thousand small loads. - if let Some(existing_session) = self - .store - .get_inbound_group_session( - &session.room_id, - &session.sender_key, - session.session_id(), - ) - .await? - { - let first_index = session.first_known_index().await; - let existing_index = existing_session.first_known_index().await; - - if first_index < existing_index { - sessions.push(session) - } - } else { + if !existing_sessions.has_better_session(&session) { sessions.push(session) } } @@ -1072,7 +1089,7 @@ impl OlmMachine { num_sessions ); - Ok(num_sessions) + Ok((num_sessions, total_sessions)) } /// Export the keys that match the given predicate. diff --git a/matrix_sdk_crypto/src/olm/group_sessions/inbound.rs b/matrix_sdk_crypto/src/olm/group_sessions/inbound.rs index a7249563..e5fdc9bf 100644 --- a/matrix_sdk_crypto/src/olm/group_sessions/inbound.rs +++ b/matrix_sdk_crypto/src/olm/group_sessions/inbound.rs @@ -57,6 +57,7 @@ use crate::error::{EventError, MegolmResult}; pub struct InboundGroupSession { inner: Arc>, session_id: Arc, + first_known_index: u32, pub(crate) sender_key: Arc, pub(crate) signing_key: Arc>, pub(crate) room_id: Arc, @@ -89,6 +90,7 @@ impl InboundGroupSession { ) -> Result { let session = OlmInboundGroupSession::new(&session_key.0)?; let session_id = session.session_id(); + let first_known_index = session.first_known_index(); let mut keys: BTreeMap = BTreeMap::new(); keys.insert(DeviceKeyAlgorithm::Ed25519, signing_key.to_owned()); @@ -97,6 +99,7 @@ impl InboundGroupSession { inner: Arc::new(Mutex::new(session)), session_id: session_id.into(), sender_key: sender_key.to_owned().into(), + first_known_index, signing_key: Arc::new(keys), room_id: Arc::new(room_id.clone()), forwarding_chains: Arc::new(Mutex::new(None)), @@ -134,6 +137,7 @@ impl InboundGroupSession { let key = Zeroizing::from(mem::take(&mut content.session_key)); let session = OlmInboundGroupSession::import(&key)?; + let first_known_index = session.first_known_index(); let mut forwarding_chains = content.forwarding_curve25519_key_chain.clone(); forwarding_chains.push(sender_key.to_owned()); @@ -147,6 +151,7 @@ impl InboundGroupSession { inner: Arc::new(Mutex::new(session)), session_id: content.session_id.as_str().into(), sender_key: content.sender_key.as_str().into(), + first_known_index, signing_key: Arc::new(sender_claimed_key), room_id: Arc::new(content.room_id.clone()), forwarding_chains: Arc::new(Mutex::new(Some(forwarding_chains))), @@ -178,7 +183,7 @@ impl InboundGroupSession { /// If only a limited part of this session should be exported use /// [`export_at_index()`](#method.export_at_index). pub async fn export(&self) -> ExportedRoomKey { - self.export_at_index(self.first_known_index().await) + self.export_at_index(self.first_known_index()) .await .expect("Can't export at the first known index") } @@ -221,12 +226,14 @@ impl InboundGroupSession { pickle_mode: PicklingMode, ) -> Result { let session = OlmInboundGroupSession::unpickle(pickle.pickle.0, pickle_mode)?; + let first_known_index = session.first_known_index(); let session_id = session.session_id(); Ok(InboundGroupSession { inner: Arc::new(Mutex::new(session)), session_id: session_id.into(), sender_key: pickle.sender_key.into(), + first_known_index, signing_key: Arc::new(pickle.signing_key), room_id: Arc::new(pickle.room_id), forwarding_chains: Arc::new(Mutex::new(pickle.forwarding_chains)), @@ -245,8 +252,8 @@ impl InboundGroupSession { } /// Get the first message index we know how to decrypt. - pub async fn first_known_index(&self) -> u32 { - self.inner.lock().await.first_known_index() + pub fn first_known_index(&self) -> u32 { + self.first_known_index } /// Decrypt the given ciphertext. @@ -368,6 +375,7 @@ impl TryFrom for InboundGroupSession { fn try_from(key: ExportedRoomKey) -> Result { let session = OlmInboundGroupSession::import(&key.session_key.0)?; + let first_known_index = session.first_known_index(); let forwarding_chains = if key.forwarding_curve25519_key_chain.is_empty() { None @@ -379,6 +387,7 @@ impl TryFrom for InboundGroupSession { inner: Arc::new(Mutex::new(session)), session_id: key.session_id.into(), sender_key: key.sender_key.into(), + first_known_index, signing_key: Arc::new(key.sender_claimed_keys), room_id: Arc::new(key.room_id), forwarding_chains: Arc::new(Mutex::new(forwarding_chains)), diff --git a/matrix_sdk_crypto/src/olm/mod.rs b/matrix_sdk_crypto/src/olm/mod.rs index 2ec94b23..407d360c 100644 --- a/matrix_sdk_crypto/src/olm/mod.rs +++ b/matrix_sdk_crypto/src/olm/mod.rs @@ -213,7 +213,7 @@ pub(crate) mod test { ) .unwrap(); - assert_eq!(0, inbound.first_known_index().await); + assert_eq!(0, inbound.first_known_index()); assert_eq!(outbound.session_id(), inbound.session_id());