diff --git a/matrix_sdk_crypto/src/olm/account.rs b/matrix_sdk_crypto/src/olm/account.rs index a6522ba1..0960038b 100644 --- a/matrix_sdk_crypto/src/olm/account.rs +++ b/matrix_sdk_crypto/src/olm/account.rs @@ -683,14 +683,9 @@ impl ReadOnlyAccount { self.sign(&canonical_json).await } - /// Generate, sign and prepare one-time keys to be uploaded. - /// - /// If no one-time keys need to be uploaded returns an empty error. - pub(crate) async fn signed_one_time_keys( + pub(crate) async fn signed_one_time_keys_helper( &self, ) -> Result, ()> { - let _ = self.generate_one_time_keys().await?; - let one_time_keys = self.one_time_keys().await; let mut one_time_key_map = BTreeMap::new(); @@ -728,6 +723,16 @@ impl ReadOnlyAccount { Ok(one_time_key_map) } + /// Generate, sign and prepare one-time keys to be uploaded. + /// + /// If no one-time keys need to be uploaded returns an empty error. + pub(crate) async fn signed_one_time_keys( + &self, + ) -> Result, ()> { + let _ = self.generate_one_time_keys().await?; + self.signed_one_time_keys_helper().await + } + /// Create a new session with another account given a one-time key. /// /// Returns the newly created session or a `OlmSessionError` if creating a diff --git a/matrix_sdk_crypto/src/session_manager.rs b/matrix_sdk_crypto/src/session_manager.rs index 792caa81..69a86722 100644 --- a/matrix_sdk_crypto/src/session_manager.rs +++ b/matrix_sdk_crypto/src/session_manager.rs @@ -303,3 +303,113 @@ impl SessionManager { Ok(()) } } + +#[cfg(test)] +mod test { + use dashmap::DashMap; + use std::{collections::BTreeMap, sync::Arc}; + + use matrix_sdk_common::{ + api::r0::keys::claim_keys::Response as KeyClaimResponse, + identifiers::{user_id, DeviceIdBox, UserId}, + }; + use matrix_sdk_test::async_test; + + use super::SessionManager; + use crate::{ + identities::ReadOnlyDevice, + key_request::KeyRequestMachine, + olm::{Account, ReadOnlyAccount}, + store::{CryptoStore, MemoryStore, Store}, + verification::VerificationMachine, + }; + + fn user_id() -> UserId { + user_id!("@example:localhost") + } + + fn device_id() -> DeviceIdBox { + "DEVICEID".into() + } + + fn bob_account() -> ReadOnlyAccount { + ReadOnlyAccount::new(&user_id!("@bob:localhost"), "BOBDEVICE".into()) + } + + async fn session_manager() -> SessionManager { + let user_id = user_id(); + let device_id = device_id(); + + let outbound_sessions = Arc::new(DashMap::new()); + let users_for_key_claim = Arc::new(DashMap::new()); + let account = ReadOnlyAccount::new(&user_id, &device_id); + let store: Arc> = Arc::new(Box::new(MemoryStore::new())); + store.save_account(account.clone()).await.unwrap(); + + let verification = VerificationMachine::new(account.clone(), store.clone()); + + let user_id = Arc::new(user_id); + let device_id = Arc::new(device_id); + + let store = Store::new(user_id.clone(), store, verification); + + let account = Account { + inner: account, + store: store.clone(), + }; + + let key_request = KeyRequestMachine::new( + user_id, + device_id, + store.clone(), + outbound_sessions, + users_for_key_claim.clone(), + ); + + SessionManager::new(account, users_for_key_claim, key_request, store) + } + + #[async_test] + async fn session_creation() { + let manager = session_manager().await; + let bob = bob_account(); + + let bob_device = ReadOnlyDevice::from_account(&bob).await; + + manager.store.save_devices(&[bob_device]).await.unwrap(); + + let (_, request) = manager + .get_missing_sessions(&mut [bob.user_id().clone()].iter()) + .await + .unwrap() + .unwrap(); + + assert!(request.one_time_keys.contains_key(bob.user_id())); + + bob.generate_one_time_keys_helper(1).await; + let one_time = bob.signed_one_time_keys_helper().await.unwrap(); + bob.mark_keys_as_published().await; + + let mut one_time_keys = BTreeMap::new(); + one_time_keys + .entry(bob.user_id().clone()) + .or_insert_with(BTreeMap::new) + .insert(bob.device_id().into(), one_time); + + let response = KeyClaimResponse { + failures: BTreeMap::new(), + one_time_keys, + }; + + manager + .receive_keys_claim_response(&response) + .await + .unwrap(); + + assert!(manager + .get_missing_sessions(&mut [bob.user_id().clone()].iter()) + .await + .unwrap() + .is_none()); + } +}