diff --git a/matrix_sdk/src/client.rs b/matrix_sdk/src/client.rs index 8d7aa708..fa63b5f4 100644 --- a/matrix_sdk/src/client.rs +++ b/matrix_sdk/src/client.rs @@ -1139,40 +1139,38 @@ impl Client { async fn preshare_group_session(&self, room_id: &RoomId) -> Result<()> { // TODO expose this publicly so people can pre-share a group session if // e.g. a user starts to type a message for a room. - if self.base_client.should_share_group_session(room_id).await { - #[allow(clippy::map_clone)] - if let Some(mutex) = self.group_session_locks.get(room_id).map(|m| m.clone()) { - // If a group session share request is already going on, - // await the release of the lock. - mutex.lock().await; - } else { - // Otherwise create a new lock and share the group - // session. - let mutex = Arc::new(Mutex::new(())); - self.group_session_locks - .insert(room_id.clone(), mutex.clone()); + #[allow(clippy::map_clone)] + if let Some(mutex) = self.group_session_locks.get(room_id).map(|m| m.clone()) { + // If a group session share request is already going on, + // await the release of the lock. + mutex.lock().await; + } else { + // Otherwise create a new lock and share the group + // session. + let mutex = Arc::new(Mutex::new(())); + self.group_session_locks + .insert(room_id.clone(), mutex.clone()); - let _guard = mutex.lock().await; + let _guard = mutex.lock().await; - { - let room = self.get_joined_room(room_id).unwrap(); - let members = room.joined_user_ids().await; - // TODO don't collect here. - let members_iter: Vec = members.collect().await; - self.claim_one_time_keys(&mut members_iter.iter()).await?; - }; + { + let room = self.get_joined_room(room_id).unwrap(); + let members = room.joined_user_ids().await; + // TODO don't collect here. + let members_iter: Vec = members.collect().await; + self.claim_one_time_keys(&mut members_iter.iter()).await?; + }; - let response = self.share_group_session(room_id).await; + let response = self.share_group_session(room_id).await; - self.group_session_locks.remove(room_id); + self.group_session_locks.remove(room_id); - // If one of the responses failed invalidate the group - // session as using it would end up in undecryptable - // messages. - if let Err(r) = response { - self.base_client.invalidate_group_session(room_id).await; - return Err(r); - } + // If one of the responses failed invalidate the group + // session as using it would end up in undecryptable + // messages. + if let Err(r) = response { + self.base_client.invalidate_group_session(room_id).await; + return Err(r); } } @@ -1858,11 +1856,7 @@ impl Client { #[cfg_attr(feature = "docs", doc(cfg(encryption)))] #[instrument] async fn share_group_session(&self, room_id: &RoomId) -> Result<()> { - let mut requests = self - .base_client - .share_group_session(room_id) - .await - .expect("Keys don't need to be uploaded"); + let mut requests = self.base_client.share_group_session(room_id).await?; for request in requests.drain(..) { let response = self.send_to_device(&request).await?; diff --git a/matrix_sdk_base/src/client.rs b/matrix_sdk_base/src/client.rs index 1b0695b1..5d35f95c 100644 --- a/matrix_sdk_base/src/client.rs +++ b/matrix_sdk_base/src/client.rs @@ -969,25 +969,6 @@ impl BaseClient { self.store.get_filter(filter_name).await } - /// Should the client share a group session for the given room. - /// - /// Returns true if a session needs to be shared before room messages can be - /// encrypted, false if one is already shared and ready to encrypt room - /// messages. - /// - /// This should be called every time a new room message wants to be sent out - /// since group sessions can expire at any time. - #[cfg(feature = "encryption")] - #[cfg_attr(feature = "docs", doc(cfg(encryption)))] - pub async fn should_share_group_session(&self, room_id: &RoomId) -> bool { - let olm = self.olm.lock().await; - - match &*olm { - Some(o) => o.should_share_group_session(room_id), - None => false, - } - } - /// Get the outgoing requests that need to be sent out. /// /// This returns a list of `OutGoingRequest`, those requests need to be sent diff --git a/matrix_sdk_crypto/src/identities/manager.rs b/matrix_sdk_crypto/src/identities/manager.rs index 48d5faf1..940009a7 100644 --- a/matrix_sdk_crypto/src/identities/manager.rs +++ b/matrix_sdk_crypto/src/identities/manager.rs @@ -32,7 +32,6 @@ use crate::{ UserIdentity, UserSigningPubkey, }, requests::KeysQueryRequest, - session_manager::GroupSessionManager, store::{Changes, DeviceChanges, IdentityChanges, Result as StoreResult, Store}, }; @@ -40,22 +39,15 @@ use crate::{ pub(crate) struct IdentityManager { user_id: Arc, device_id: Arc, - group_manager: GroupSessionManager, store: Store, } impl IdentityManager { - pub fn new( - user_id: Arc, - device_id: Arc, - store: Store, - group_manager: GroupSessionManager, - ) -> Self { + pub fn new(user_id: Arc, device_id: Arc, store: Store) -> Self { IdentityManager { user_id, device_id, store, - group_manager, } } @@ -118,8 +110,6 @@ impl IdentityManager { &self, device_keys_map: &BTreeMap>, ) -> StoreResult { - let mut users_with_new_or_deleted_devices = HashSet::new(); - let mut changes = DeviceChanges::default(); for (user_id, device_map) in device_keys_map { @@ -165,7 +155,6 @@ impl IdentityManager { } }; info!("Adding a new device to the device store {:?}", device); - users_with_new_or_deleted_devices.insert(user_id); changes.new.push(device); } } @@ -177,7 +166,6 @@ impl IdentityManager { let deleted_devices_set = stored_devices_set.difference(¤t_devices); for device_id in deleted_devices_set { - users_with_new_or_deleted_devices.insert(user_id); if let Some(device) = stored_devices.get(*device_id) { device.mark_as_deleted(); changes.deleted.push(device.clone()); @@ -185,9 +173,6 @@ impl IdentityManager { } } - self.group_manager - .invalidate_sessions_new_devices(&users_with_new_or_deleted_devices); - Ok(changes) } @@ -377,7 +362,7 @@ pub(crate) mod test { use matrix_sdk_common::{ api::r0::keys::get_keys::Response as KeyQueryResponse, - identifiers::{room_id, user_id, DeviceIdBox, RoomId, UserId}, + identifiers::{user_id, DeviceIdBox, UserId}, locks::Mutex, }; @@ -388,8 +373,7 @@ pub(crate) mod test { use crate::{ identities::IdentityManager, machine::test::response_from_file, - olm::{Account, PrivateCrossSigningIdentity, ReadOnlyAccount}, - session_manager::GroupSessionManager, + olm::{PrivateCrossSigningIdentity, ReadOnlyAccount}, store::{CryptoStore, MemoryStore, Store}, verification::VerificationMachine, }; @@ -406,10 +390,6 @@ pub(crate) mod test { "WSKKLTJZCL".into() } - fn room_id() -> RoomId { - room_id!("!test:localhost") - } - fn manager() -> IdentityManager { let identity = Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty(user_id()))); let user_id = Arc::new(user_id()); @@ -422,12 +402,7 @@ pub(crate) mod test { Arc::new(Box::new(MemoryStore::new())), verification, ); - let account = Account { - inner: account, - store: store.clone(), - }; - let group = GroupSessionManager::new(account, store.clone()); - IdentityManager::new(user_id, Arc::new(device_id()), store, group) + IdentityManager::new(user_id, Arc::new(device_id()), store) } pub(crate) fn other_key_query() -> KeyQueryResponse { @@ -657,78 +632,4 @@ pub(crate) mod test { assert!(identity.is_device_signed(&device).is_ok()) } - - #[async_test] - async fn test_session_invalidation() { - let manager = manager(); - let room_id = room_id(); - let user_id = other_user_id(); - let device_id: DeviceIdBox = "SKISMLNIMH".into(); - - manager - .group_manager - .create_outbound_group_session(&room_id, Default::default()) - .await - .unwrap(); - let session = manager - .group_manager - .get_outbound_group_session(&room_id) - .unwrap(); - - session.add_recipient(&user_id); - session.mark_as_shared(); - - assert!(!session.invalidated()); - assert!(!session.expired()); - - // Receiving a new device invalidates the session. - manager - .receive_keys_query_response(&other_key_query()) - .await - .unwrap(); - - assert!(session.invalidated()); - - manager - .group_manager - .create_outbound_group_session(&room_id, Default::default()) - .await - .unwrap(); - let session = manager - .group_manager - .get_outbound_group_session(&room_id) - .unwrap(); - - session.add_recipient(&user_id); - session.mark_as_shared(); - - assert!(!session.invalidated()); - assert!(!session.expired()); - - let device = manager - .store - .get_device(&user_id, &device_id) - .await - .unwrap() - .unwrap(); - - assert!(!device.deleted()); - - let response = KeyQueryResponse::try_from(response_from_file(&json!({ - "device_keys": { - user_id: {} - }, - "failures": {}, - }))) - .unwrap(); - - // Noticing that a device got deleted invalidates the session as well - manager - .receive_keys_query_response(&response) - .await - .unwrap(); - - assert!(device.deleted()); - assert!(session.invalidated()); - } } diff --git a/matrix_sdk_crypto/src/machine.rs b/matrix_sdk_crypto/src/machine.rs index a09b9fe4..b7921aed 100644 --- a/matrix_sdk_crypto/src/machine.rs +++ b/matrix_sdk_crypto/src/machine.rs @@ -178,12 +178,8 @@ impl OlmMachine { store.clone(), ); let group_session_manager = GroupSessionManager::new(account.clone(), store.clone()); - let identity_manager = IdentityManager::new( - user_id.clone(), - device_id.clone(), - store.clone(), - group_session_manager.clone(), - ); + let identity_manager = + IdentityManager::new(user_id.clone(), device_id.clone(), store.clone()); OlmMachine { user_id, @@ -662,19 +658,6 @@ impl OlmMachine { self.group_session_manager.encrypt(room_id, content).await } - /// Should the client share a group session for the given room. - /// - /// Returns true if a session needs to be shared before room messages can be - /// encrypted, false if one is already shared and ready to encrypt room - /// messages. - /// - /// This should be called every time a new room message wants to be sent out - /// since group sessions can expire at any time. - pub fn should_share_group_session(&self, room_id: &RoomId) -> bool { - self.group_session_manager - .should_share_group_session(room_id) - } - /// Invalidate the currently active outbound group session for the given /// room. /// @@ -1447,7 +1430,8 @@ pub(crate) mod test { assert!(machine .group_session_manager .get_outbound_group_session(&room_id) - .is_none()); + .unwrap() + .invalidated()); } #[tokio::test] diff --git a/matrix_sdk_crypto/src/olm/group_sessions/outbound.rs b/matrix_sdk_crypto/src/olm/group_sessions/outbound.rs index dd11f2f9..6b3c9d86 100644 --- a/matrix_sdk_crypto/src/olm/group_sessions/outbound.rs +++ b/matrix_sdk_crypto/src/olm/group_sessions/outbound.rs @@ -60,7 +60,7 @@ const ROTATION_MESSAGES: u64 = 100; /// Settings for an encrypted room. /// /// This determines the algorithm and rotation periods of a group session. -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct EncryptionSettings { /// The encryption algorithm that should be used in the room. pub algorithm: EventEncryptionAlgorithm, @@ -113,7 +113,7 @@ pub struct OutboundGroupSession { shared: Arc, invalidated: Arc, settings: Arc, - shared_with_set: Arc>>, + pub(crate) shared_with_set: Arc>>, to_share_with_set: Arc>>, } @@ -162,14 +162,9 @@ impl OutboundGroupSession { self.to_share_with_set.insert(request_id, request); } - pub fn add_recipient(&self, user_id: &UserId) { - self.shared_with_set - .entry(user_id.to_owned()) - .or_insert_with(DashSet::new); - } - - pub fn contains_recipient(&self, user_id: &UserId) -> bool { - self.shared_with_set.contains_key(user_id) + /// This should be called if an the user wishes to rotate this session. + pub fn invalidate_session(&self) { + self.invalidated.store(true, Ordering::Relaxed) } /// Mark the request with the given request id as sent. @@ -346,25 +341,6 @@ impl OutboundGroupSession { }) } - /// Mark the session as invalid. - /// - /// This should be called if an user/device deletes a device that received - /// this session. - pub fn invalidate_session(&self) { - self.invalidated.store(true, Ordering::Relaxed) - } - - /// Clear out the requests returning the request ids. - pub fn clear_requests(&self) -> Vec { - let request_ids = self - .to_share_with_set - .iter() - .map(|item| *item.key()) - .collect(); - self.to_share_with_set.clear(); - request_ids - } - /// Has or will the session be shared with the given user/device pair. pub(crate) fn is_shared_with(&self, user_id: &UserId, device_id: &DeviceId) -> bool { let shared_with = self diff --git a/matrix_sdk_crypto/src/session_manager/group_sessions.rs b/matrix_sdk_crypto/src/session_manager/group_sessions.rs index 374610fd..b8f27aa7 100644 --- a/matrix_sdk_crypto/src/session_manager/group_sessions.rs +++ b/matrix_sdk_crypto/src/session_manager/group_sessions.rs @@ -13,7 +13,7 @@ // limitations under the License. use std::{ - collections::{BTreeMap, HashSet}, + collections::{BTreeMap, HashMap, HashSet}, sync::Arc, }; @@ -21,7 +21,7 @@ use dashmap::DashMap; use matrix_sdk_common::{ api::r0::to_device::DeviceIdOrAllDevices, events::{room::encrypted::EncryptedEventContent, AnyMessageEventContent, EventType}, - identifiers::{RoomId, UserId}, + identifiers::{DeviceId, DeviceIdBox, RoomId, UserId}, uuid::Uuid, }; use tracing::{debug, info}; @@ -58,7 +58,12 @@ impl GroupSessionManager { } pub fn invalidate_group_session(&self, room_id: &RoomId) -> bool { - self.outbound_group_sessions.remove(room_id).is_some() + if let Some(s) = self.outbound_group_sessions.get(room_id) { + s.invalidate_session(); + true + } else { + false + } } pub fn mark_request_as_sent(&self, request_id: &Uuid) { @@ -67,25 +72,6 @@ impl GroupSessionManager { } } - pub fn invalidate_sessions_new_devices(&self, users: &HashSet<&UserId>) { - for session in self.outbound_group_sessions.iter() { - if users.iter().any(|u| session.contains_recipient(u)) { - info!( - "Invalidating outobund session {} for room {}", - session.session_id(), - session.room_id() - ); - session.invalidate_session(); - - if !session.shared() { - for request_id in session.clear_requests() { - self.outbound_sessions_being_shared.remove(&request_id); - } - } - } - } - } - /// Get an outbound group session for a room, if one exists. /// /// # Arguments @@ -115,23 +101,6 @@ impl GroupSessionManager { Ok(session.encrypt(content).await) } - /// Should the client share a group session for the given room. - /// - /// Returns true if a session needs to be shared before room messages can be - /// encrypted, false if one is already shared and ready to encrypt room - /// messages. - /// - /// This should be called every time a new room message wants to be sent out - /// since group sessions can expire at any time. - pub fn should_share_group_session(&self, room_id: &RoomId) -> bool { - let session = self.outbound_group_sessions.get(room_id); - - match session { - Some(s) => !s.shared() || s.expired() || s.invalidated(), - None => true, - } - } - /// Create a new outbound group session. /// /// This also creates a matching inbound group session and saves that one in @@ -153,6 +122,26 @@ impl GroupSessionManager { Ok((outbound, inbound)) } + pub async fn get_or_create_outbound_session( + &self, + room_id: &RoomId, + settings: EncryptionSettings, + ) -> OlmResult<(OutboundGroupSession, Option)> { + if let Some(s) = self.outbound_group_sessions.get(room_id).map(|s| s.clone()) { + if s.expired() || s.invalidated() { + self.create_outbound_group_session(room_id, settings) + .await + .map(|(o, i)| (o, i.into())) + } else { + Ok((s, None)) + } + } else { + self.create_outbound_group_session(room_id, settings) + .await + .map(|(o, i)| (o, i.into())) + } + } + /// Get to-device requests to share a group session with users in a room. /// /// # Arguments @@ -167,23 +156,96 @@ impl GroupSessionManager { users: impl Iterator, encryption_settings: impl Into, ) -> OlmResult>> { + let users: HashSet<&UserId> = users.collect(); + let encryption_settings = encryption_settings.into(); let mut changes = Changes::default(); - let (session, inbound_session) = self - .create_outbound_group_session(room_id, encryption_settings.into()) + let (outbound, inbound) = self + .get_or_create_outbound_session(room_id, encryption_settings.clone()) .await?; - changes.inbound_group_sessions.push(inbound_session); - let mut devices: Vec = Vec::new(); - - for user_id in users { - session.add_recipient(user_id); - let user_devices = self.store.get_user_devices(&user_id).await?; - devices.extend(user_devices.devices().filter(|d| !d.is_blacklisted())); + if let Some(inbound) = inbound { + changes.inbound_group_sessions.push(inbound); } + let mut devices: HashMap> = HashMap::new(); + + let users_shared_with: HashSet = outbound + .shared_with_set + .iter() + .map(|k| k.key().clone()) + .collect(); + + let users_shared_with: HashSet<&UserId> = users_shared_with.iter().collect(); + + let user_left = !users_shared_with + .difference(&users) + .collect::>() + .is_empty(); + + let mut device_got_deleted = false; + + for user_id in users { + let user_devices = self.store.get_user_devices(&user_id).await?; + + if !device_got_deleted { + let device_ids: HashSet = + user_devices.keys().map(|d| d.clone()).collect(); + + device_got_deleted = if let Some(shared) = outbound.shared_with_set.get(user_id) { + let shared: HashSet = shared.iter().map(|d| d.clone()).collect(); + !shared + .difference(&device_ids) + .collect::>() + .is_empty() + } else { + false + }; + } + + devices + .entry(user_id.clone()) + .or_insert_with(Vec::new) + .extend(user_devices.devices().filter(|d| !d.is_blacklisted())); + } + + let outbound = if user_left || device_got_deleted { + let (outbound, inbound) = self + .create_outbound_group_session(room_id, encryption_settings) + .await?; + changes.inbound_group_sessions.push(inbound); + + debug!( + "A user/device has left the group {} since we last sent a message, \ + rotating the outbound session.", + room_id + ); + + outbound + } else { + outbound + }; + + let devices: Vec = devices + .into_iter() + .map(|(_, d)| { + d.into_iter() + .filter(|d| !outbound.is_shared_with(d.user_id(), d.device_id())) + }) + .flatten() + .collect(); + + info!( + "Sharing outbound session at index {} with {:?}", + outbound.message_index().await, + devices + .iter() + .map(|d| (d.user_id(), d.device_id())) + .collect::>() + ); + let mut requests = Vec::new(); - let key_content = session.as_json().await; + let key_content = outbound.as_json().await; for device_map_chunk in devices.chunks(Self::MAX_TO_DEVICE_MESSAGES) { let mut messages = BTreeMap::new(); @@ -221,19 +283,19 @@ impl GroupSessionManager { messages, }); - session.add_request(id, request.clone()); + outbound.add_request(id, request.clone()); self.outbound_sessions_being_shared - .insert(id, session.clone()); + .insert(id, outbound.clone()); requests.push(request); } if requests.is_empty() { debug!( "Session {} for room {} doesn't need to be shared with anyone, marking as shared", - session.session_id(), - session.room_id() + outbound.session_id(), + outbound.room_id() ); - session.mark_as_shared(); + outbound.mark_as_shared(); } self.store.save_changes(changes).await?;