diff --git a/matrix_sdk_crypto/src/key_request.rs b/matrix_sdk_crypto/src/key_request.rs index 84f4eab2..060a85cb 100644 --- a/matrix_sdk_crypto/src/key_request.rs +++ b/matrix_sdk_crypto/src/key_request.rs @@ -196,6 +196,7 @@ impl KeyRequestMachine { device_id: Arc, store: Store, outbound_group_sessions: Arc>, + users_for_key_claim: Arc>>, ) -> Self { Self { user_id, @@ -205,7 +206,7 @@ impl KeyRequestMachine { outgoing_to_device_requests: Arc::new(DashMap::new()), incoming_key_requests: Arc::new(DashMap::new()), wait_queue: WaitQueue::new(), - users_for_key_claim: Arc::new(DashMap::new()), + users_for_key_claim, } } @@ -214,11 +215,6 @@ impl KeyRequestMachine { &self.user_id } - /// Get the map of user/devices which we need to claim one-time for. - pub fn users_for_key_claim(&self) -> &DashMap> { - &self.users_for_key_claim - } - pub fn outgoing_to_device_requests(&self) -> Vec { #[allow(clippy::map_clone)] self.outgoing_to_device_requests @@ -719,6 +715,7 @@ mod test { Arc::new(bob_device_id()), store, Arc::new(DashMap::new()), + Arc::new(DashMap::new()), ) } @@ -736,6 +733,7 @@ mod test { Arc::new(alice_device_id()), store, Arc::new(DashMap::new()), + Arc::new(DashMap::new()), ) } diff --git a/matrix_sdk_crypto/src/machine.rs b/matrix_sdk_crypto/src/machine.rs index cbe0acbf..10a1fd4c 100644 --- a/matrix_sdk_crypto/src/machine.rs +++ b/matrix_sdk_crypto/src/machine.rs @@ -131,11 +131,14 @@ impl OlmMachine { let store = Store::new(user_id.clone(), store, verification_machine.clone()); let device_id: Arc = Arc::new(device_id); let outbound_group_sessions = Arc::new(DashMap::new()); + let users_for_key_claim = Arc::new(DashMap::new()); + let key_request_machine = KeyRequestMachine::new( user_id.clone(), device_id.clone(), store.clone(), outbound_group_sessions, + users_for_key_claim.clone(), ); let account = Account { @@ -143,8 +146,12 @@ impl OlmMachine { store: store.clone(), }; - let session_manager = - SessionManager::new(account.clone(), key_request_machine.clone(), store.clone()); + let session_manager = SessionManager::new( + account.clone(), + users_for_key_claim, + key_request_machine.clone(), + store.clone(), + ); let group_session_manager = GroupSessionManager::new(account.clone(), store.clone()); let identity_manager = IdentityManager::new( user_id.clone(), diff --git a/matrix_sdk_crypto/src/session_manager.rs b/matrix_sdk_crypto/src/session_manager.rs index e0dc57bf..e10a2e5e 100644 --- a/matrix_sdk_crypto/src/session_manager.rs +++ b/matrix_sdk_crypto/src/session_manager.rs @@ -12,12 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{collections::BTreeMap, time::Duration}; +use std::{collections::BTreeMap, sync::Arc, time::Duration}; +use dashmap::{DashMap, DashSet}; use matrix_sdk_common::{ api::r0::keys::claim_keys::{Request as KeysClaimRequest, Response as KeysClaimResponse}, assign, - identifiers::{DeviceKeyAlgorithm, UserId}, + identifiers::{DeviceIdBox, DeviceKeyAlgorithm, UserId}, uuid::Uuid, }; use tracing::{error, info, warn}; @@ -28,17 +29,28 @@ use crate::{error::OlmResult, key_request::KeyRequestMachine, olm::Account, stor pub(crate) struct SessionManager { account: Account, store: Store, + /// A map of user/devices that we need to automatically claim keys for. + /// Submodules can insert user/device pairs into this map and the + /// user/device paris will be added to the list of users when + /// [`get_missing_sessions`](#method.get_missing_sessions) is called. + users_for_key_claim: Arc>>, key_request_machine: KeyRequestMachine, } impl SessionManager { const KEY_CLAIM_TIMEOUT: Duration = Duration::from_secs(10); - pub fn new(account: Account, key_request_machine: KeyRequestMachine, store: Store) -> Self { + pub fn new( + account: Account, + users_for_key_claim: Arc>>, + key_request_machine: KeyRequestMachine, + store: Store, + ) -> Self { Self { account, store, key_request_machine, + users_for_key_claim, } } @@ -109,7 +121,7 @@ impl SessionManager { // Add the list of sessions that for some reason automatically need to // create an Olm session. - for item in self.key_request_machine.users_for_key_claim().iter() { + for item in self.users_for_key_claim.iter() { let user = item.key(); for device_id in item.value().iter() {