diff --git a/matrix_sdk_base/src/client.rs b/matrix_sdk_base/src/client.rs index 59f9b24c..21168bf0 100644 --- a/matrix_sdk_base/src/client.rs +++ b/matrix_sdk_base/src/client.rs @@ -1288,9 +1288,9 @@ impl BaseClient { #[cfg_attr(docsrs, doc(cfg(feature = "encryption")))] pub async fn share_group_session(&self, room_id: &RoomId) -> Result> { let room = self.get_joined_room(room_id).await.expect("No room found"); - let mut olm = self.olm.lock().await; + let olm = self.olm.lock().await; - match &mut *olm { + match &*olm { Some(o) => { let room = room.write().await; @@ -1417,9 +1417,9 @@ impl BaseClient { #[cfg(feature = "encryption")] #[cfg_attr(docsrs, doc(cfg(feature = "encryption")))] pub async fn invalidate_group_session(&self, room_id: &RoomId) -> bool { - let mut olm = self.olm.lock().await; + let olm = self.olm.lock().await; - match &mut *olm { + match &*olm { Some(o) => o.invalidate_group_session(room_id), None => false, } diff --git a/matrix_sdk_crypto/src/machine.rs b/matrix_sdk_crypto/src/machine.rs index b06ee92e..340af00f 100644 --- a/matrix_sdk_crypto/src/machine.rs +++ b/matrix_sdk_crypto/src/machine.rs @@ -15,13 +15,15 @@ #[cfg(feature = "sqlite-cryptostore")] use std::path::Path; use std::{ - collections::{BTreeMap, HashMap, HashSet}, + collections::{BTreeMap, HashSet}, convert::{TryFrom, TryInto}, mem, result::Result as StdResult, sync::Arc, }; +use dashmap::DashMap; + use api::r0::{ keys::{claim_keys, get_keys, upload_keys, DeviceKeys, OneTimeKey}, sync::sync_events::Response as SyncResponse, @@ -67,6 +69,7 @@ pub type OneTimeKeys = BTreeMap; /// State machine implementation of the Olm/Megolm encryption protocol used for /// Matrix end to end encryption. +#[derive(Clone)] pub struct OlmMachine { /// The unique user id that owns this account. user_id: UserId, @@ -79,7 +82,7 @@ pub struct OlmMachine { /// without the need to create new keys. store: Arc>>, /// The currently active outbound group sessions. - outbound_group_sessions: HashMap, + outbound_group_sessions: Arc>, /// A state machine that is responsible to handle and keep track of SAS /// verification flows. verification_machine: VerificationMachine, @@ -119,7 +122,7 @@ impl OlmMachine { device_id: device_id.into(), account: account.clone(), store: store.clone(), - outbound_group_sessions: HashMap::new(), + outbound_group_sessions: Arc::new(DashMap::new()), verification_machine: VerificationMachine::new(account, store), } } @@ -165,7 +168,7 @@ impl OlmMachine { device_id, account, store, - outbound_group_sessions: HashMap::new(), + outbound_group_sessions: Arc::new(DashMap::new()), verification_machine, }) } @@ -802,7 +805,7 @@ impl OlmMachine { /// /// This also creates a matching inbound group session and saves that one in /// the store. - async fn create_outbound_group_session(&mut self, room_id: &RoomId) -> OlmResult<()> { + async fn create_outbound_group_session(&self, room_id: &RoomId) -> OlmResult<()> { let (outbound, inbound) = self.account.create_group_session_pair(room_id).await; let _ = self @@ -818,6 +821,17 @@ impl OlmMachine { Ok(()) } + /// Get an outbound group session for a room, if one exists. + /// + /// # Arguments + /// + /// * `room_id` - The id of the room for which we should get the outbound + /// group session. + fn get_outbound_group_session(&self, room_id: &RoomId) -> Option { + #[allow(clippy::map_clone)] + self.outbound_group_sessions.get(room_id).map(|s| s.clone()) + } + /// Encrypt a room message for the given room. /// /// Beware that a group session needs to be shared before this method can be @@ -844,9 +858,7 @@ impl OlmMachine { room_id: &RoomId, content: MessageEventContent, ) -> MegolmResult { - let session = self.outbound_group_sessions.get(room_id); - - let session = if let Some(s) = session { + let session = if let Some(s) = self.get_outbound_group_session(room_id) { s } else { panic!("Session wasn't created nor shared"); @@ -929,7 +941,7 @@ impl OlmMachine { /// /// Returns true if a session was invalidated, false if there was no session /// to invalidate. - pub fn invalidate_group_session(&mut self, room_id: &RoomId) -> bool { + pub fn invalidate_group_session(&self, room_id: &RoomId) -> bool { self.outbound_group_sessions.remove(room_id).is_some() } @@ -943,7 +955,7 @@ impl OlmMachine { /// /// `users` - The list of users that should receive the group session. pub async fn share_group_session<'a, I>( - &mut self, + &self, room_id: &RoomId, users: I, ) -> OlmResult> @@ -1538,7 +1550,7 @@ mod test { #[tokio::test] async fn tests_session_invalidation() { - let mut machine = OlmMachine::new(&user_id(), &alice_device_id()); + let machine = OlmMachine::new(&user_id(), &alice_device_id()); let room_id = room_id!("!test:example.org"); machine @@ -1754,7 +1766,7 @@ mod test { #[tokio::test] async fn test_room_key_sharing() { - let (mut alice, bob) = get_machine_pair_with_session().await; + let (alice, bob) = get_machine_pair_with_session().await; let room_id = room_id!("!test:example.org"); @@ -1800,7 +1812,7 @@ mod test { #[tokio::test] async fn test_megolm_encryption() { - let (mut alice, bob) = get_machine_pair_with_setup_sessions().await; + let (alice, bob) = get_machine_pair_with_setup_sessions().await; let room_id = room_id!("!test:example.org"); let to_device_requests = alice