diff --git a/matrix_sdk/src/client.rs b/matrix_sdk/src/client.rs index f24aae1e..a6b6164b 100644 --- a/matrix_sdk/src/client.rs +++ b/matrix_sdk/src/client.rs @@ -2330,7 +2330,7 @@ impl Client { // TODO remove this unwrap. let import = task.await.expect("Task join error").unwrap(); - Ok(olm.import_keys(import).await?) + Ok(olm.import_keys(import, |_, _| {}).await?) } } diff --git a/matrix_sdk_crypto/src/file_encryption/key_export.rs b/matrix_sdk_crypto/src/file_encryption/key_export.rs index 304ded6b..c6b2ae70 100644 --- a/matrix_sdk_crypto/src/file_encryption/key_export.rs +++ b/matrix_sdk_crypto/src/file_encryption/key_export.rs @@ -84,7 +84,7 @@ pub enum KeyExportError { /// # block_on(async { /// # let export = Cursor::new("".to_owned()); /// let exported_keys = decrypt_key_export(export, "1234").unwrap(); -/// machine.import_keys(exported_keys).await.unwrap(); +/// machine.import_keys(exported_keys, |_, _| {}).await.unwrap(); /// # }); /// ``` pub fn decrypt_key_export( @@ -316,7 +316,10 @@ mod test { let decrypted = decrypt_key_export(Cursor::new(encrypted), "1234").unwrap(); assert_eq!(export, decrypted); - assert_eq!(machine.import_keys(decrypted).await.unwrap(), (0, 1)); + assert_eq!( + machine.import_keys(decrypted, |_, _| {}).await.unwrap(), + (0, 1) + ); } #[test] diff --git a/matrix_sdk_crypto/src/machine.rs b/matrix_sdk_crypto/src/machine.rs index dba81260..415b80f7 100644 --- a/matrix_sdk_crypto/src/machine.rs +++ b/matrix_sdk_crypto/src/machine.rs @@ -1064,12 +1064,13 @@ impl OlmMachine { /// # block_on(async { /// # let export = Cursor::new("".to_owned()); /// let exported_keys = decrypt_key_export(export, "1234").unwrap(); - /// machine.import_keys(exported_keys).await.unwrap(); + /// machine.import_keys(exported_keys, |_, _| {}).await.unwrap(); /// # }); /// ``` pub async fn import_keys( &self, exported_keys: Vec, + progress_listener: impl Fn(usize, usize), ) -> StoreResult<(usize, usize)> { struct ShallowSessions { inner: BTreeMap, u32>, @@ -1101,7 +1102,7 @@ impl OlmMachine { let total_sessions = exported_keys.len(); - for key in exported_keys.into_iter() { + for (i, key) in exported_keys.into_iter().enumerate() { let session = InboundGroupSession::from_export(key)?; // Only import the session if we didn't have this session or if it's @@ -1110,6 +1111,8 @@ impl OlmMachine { if !existing_sessions.has_better_session(&session) { sessions.push(session) } + + progress_listener(i, total_sessions) } let num_sessions = sessions.len();