diff --git a/matrix_sdk_crypto/src/group_manager.rs b/matrix_sdk_crypto/src/group_manager.rs index b9b070d8..d8391d50 100644 --- a/matrix_sdk_crypto/src/group_manager.rs +++ b/matrix_sdk_crypto/src/group_manager.rs @@ -12,13 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{collections::BTreeMap, sync::Arc}; +use std::{ + collections::{BTreeMap, HashSet}, + sync::Arc, +}; use dashmap::DashMap; use matrix_sdk_common::{ api::r0::to_device::DeviceIdOrAllDevices, events::{room::encrypted::EncryptedEventContent, AnyMessageEventContent, EventType}, - identifiers::{RoomId, UserId}, + identifiers::{DeviceId, RoomId, UserId}, uuid::Uuid, }; use tracing::debug; @@ -30,7 +33,7 @@ use crate::{ Device, EncryptionSettings, OlmError, ToDeviceRequest, }; -#[derive(Clone)] +#[derive(Debug, Clone)] pub struct GroupSessionManager { account: Account, /// Store for the encryption keys. @@ -64,6 +67,29 @@ impl GroupSessionManager { } } + pub fn invalidate_sessions_new_devices(&self, users: &HashSet<&UserId>) { + for session in self.outbound_group_sessions.iter() { + if users.iter().any(|u| session.contains_recipient(u)) { + session.invalidate_session() + } + } + } + + /// Invalidate the sessions that were sent to the given user/device pair. + pub fn invalidate_sessions(&self, user_id: &UserId, device_id: &DeviceId) { + for session in self.outbound_group_sessions.iter() { + if session.is_shared_with(user_id, device_id) { + session.invalidate_session(); + + if !session.shared() { + for request_id in session.clear_requests() { + self.outbound_sessions_being_shared.remove(&request_id); + } + } + } + } + } + /// Get an outbound group session for a room, if one exists. /// /// # Arguments @@ -105,7 +131,7 @@ impl GroupSessionManager { let session = self.outbound_group_sessions.get(room_id); match session { - Some(s) => !s.shared() || s.expired(), + Some(s) => !s.shared() || s.expired() || s.invalidated(), None => true, } } @@ -158,6 +184,7 @@ impl GroupSessionManager { let mut devices: Vec = Vec::new(); for user_id in users { + session.add_recipient(user_id); let user_devices = self.store.get_user_devices(&user_id).await?; devices.extend(user_devices.devices().filter(|d| !d.is_blacklisted())); } diff --git a/matrix_sdk_crypto/src/identities/manager.rs b/matrix_sdk_crypto/src/identities/manager.rs index 509c3269..99a13a4a 100644 --- a/matrix_sdk_crypto/src/identities/manager.rs +++ b/matrix_sdk_crypto/src/identities/manager.rs @@ -27,6 +27,7 @@ use matrix_sdk_common::{ use crate::{ error::OlmResult, + group_manager::GroupSessionManager, identities::{ MasterPubkey, OwnUserIdentity, ReadOnlyDevice, SelfSigningPubkey, UserIdentities, UserIdentity, UserSigningPubkey, @@ -39,15 +40,22 @@ use crate::{ pub(crate) struct IdentityManager { user_id: Arc, device_id: Arc, + group_manager: GroupSessionManager, store: Store, } impl IdentityManager { - pub fn new(user_id: Arc, device_id: Arc, store: Store) -> Self { + pub fn new( + user_id: Arc, + device_id: Arc, + store: Store, + group_manager: GroupSessionManager, + ) -> Self { IdentityManager { user_id, device_id, store, + group_manager, } } @@ -104,6 +112,7 @@ impl IdentityManager { &self, device_keys_map: &BTreeMap>, ) -> StoreResult> { + let mut users_with_new_devices = HashSet::new(); let mut changed_devices = Vec::new(); for (user_id, device_map) in device_keys_map { @@ -149,6 +158,7 @@ impl IdentityManager { } }; info!("Adding a new device to the device store {:?}", device); + users_with_new_devices.insert(user_id); device }; @@ -164,12 +174,17 @@ impl IdentityManager { for device_id in deleted_devices { if let Some(device) = stored_devices.get(device_id) { + self.group_manager + .invalidate_sessions(device.user_id(), device.device_id()); device.mark_as_deleted(); self.store.delete_device(device).await?; } } } + self.group_manager + .invalidate_sessions_new_devices(&users_with_new_devices); + Ok(changed_devices) } @@ -362,9 +377,10 @@ pub(crate) mod test { use serde_json::json; use crate::{ + group_manager::GroupSessionManager, identities::IdentityManager, machine::test::response_from_file, - olm::ReadOnlyAccount, + olm::{Account, ReadOnlyAccount}, store::{CryptoStore, MemoryStore, Store}, verification::VerificationMachine, }; @@ -385,13 +401,18 @@ pub(crate) mod test { let user_id = Arc::new(user_id()); let account = ReadOnlyAccount::new(&user_id, &device_id()); let store: Arc> = Arc::new(Box::new(MemoryStore::new())); - let verification = VerificationMachine::new(account, store); + let verification = VerificationMachine::new(account.clone(), store); let store = Store::new( user_id.clone(), Arc::new(Box::new(MemoryStore::new())), verification, ); - IdentityManager::new(user_id, Arc::new(device_id()), store) + let account = Account { + inner: account, + store: store.clone(), + }; + let group = GroupSessionManager::new(account.clone(), store.clone()); + IdentityManager::new(user_id, Arc::new(device_id()), store, group) } pub(crate) fn other_key_query() -> KeyQueryResponse { diff --git a/matrix_sdk_crypto/src/machine.rs b/matrix_sdk_crypto/src/machine.rs index 931739ad..757f053f 100644 --- a/matrix_sdk_crypto/src/machine.rs +++ b/matrix_sdk_crypto/src/machine.rs @@ -136,14 +136,19 @@ impl OlmMachine { store.clone(), outbound_group_sessions, ); - let identity_manager = - IdentityManager::new(user_id.clone(), device_id.clone(), store.clone()); + let account = Account { inner: account, store: store.clone(), }; let group_session_manager = GroupSessionManager::new(account.clone(), store.clone()); + let identity_manager = IdentityManager::new( + user_id.clone(), + device_id.clone(), + store.clone(), + group_session_manager.clone(), + ); OlmMachine { user_id, diff --git a/matrix_sdk_crypto/src/olm/account.rs b/matrix_sdk_crypto/src/olm/account.rs index 86a9dfa1..0cf5e569 100644 --- a/matrix_sdk_crypto/src/olm/account.rs +++ b/matrix_sdk_crypto/src/olm/account.rs @@ -58,7 +58,7 @@ use crate::{ use super::{EncryptionSettings, InboundGroupSession, OutboundGroupSession, Session}; -#[derive(Clone)] +#[derive(Debug, Clone)] pub struct Account { pub(crate) inner: ReadOnlyAccount, pub(crate) store: Store, diff --git a/matrix_sdk_crypto/src/olm/group_sessions/outbound.rs b/matrix_sdk_crypto/src/olm/group_sessions/outbound.rs index 54e59c53..d2daab7e 100644 --- a/matrix_sdk_crypto/src/olm/group_sessions/outbound.rs +++ b/matrix_sdk_crypto/src/olm/group_sessions/outbound.rs @@ -104,6 +104,7 @@ pub struct OutboundGroupSession { pub(crate) creation_time: Arc, message_count: Arc, shared: Arc, + invalidated: Arc, settings: Arc, shared_with_set: Arc>>, to_share_with_set: Arc>>, @@ -143,6 +144,7 @@ impl OutboundGroupSession { creation_time: Arc::new(Instant::now()), message_count: Arc::new(AtomicU64::new(0)), shared: Arc::new(AtomicBool::new(false)), + invalidated: Arc::new(AtomicBool::new(false)), settings: Arc::new(settings), shared_with_set: Arc::new(DashMap::new()), to_share_with_set: Arc::new(DashMap::new()), @@ -153,6 +155,16 @@ impl OutboundGroupSession { self.to_share_with_set.insert(request_id, request); } + pub fn add_recipient(&self, user_id: &UserId) { + self.shared_with_set + .entry(user_id.to_owned()) + .or_insert_with(DashSet::new); + } + + pub fn contains_recipient(&self, user_id: &UserId) -> bool { + self.shared_with_set.contains_key(user_id) + } + /// Mark the request with the given request id as sent. /// /// This removes the request from the queue and marks the set of @@ -263,6 +275,11 @@ impl OutboundGroupSession { >= min(self.settings.rotation_period, Duration::from_secs(3600)) } + /// Has the session been invalidated. + pub fn invalidated(&self) -> bool { + self.invalidated.load(Ordering::Relaxed) + } + /// Mark the session as shared. /// /// Messages shouldn't be encrypted with the session before it has been @@ -315,12 +332,48 @@ impl OutboundGroupSession { }) } - /// The set of users this session is shared with. + /// Mark the session as invalid. + /// + /// This should be called if an user/device deletes a device that received + /// this session. + pub fn invalidate_session(&self) { + self.invalidated.store(true, Ordering::Relaxed) + } + + /// Clear out the requests returning the request ids. + pub fn clear_requests(&self) -> Vec { + let request_ids = self + .to_share_with_set + .iter() + .map(|item| *item.key()) + .collect(); + self.to_share_with_set.clear(); + request_ids + } + + /// Has or will the session be shared with the given user/device pair. pub(crate) fn is_shared_with(&self, user_id: &UserId, device_id: &DeviceId) -> bool { - self.shared_with_set + let shared_with = self + .shared_with_set .get(user_id) .map(|d| d.contains(device_id)) - .unwrap_or(false) + .unwrap_or(false); + + let should_be_shared_with = if self.shared() { + false + } else { + let device_id = DeviceIdOrAllDevices::DeviceId(device_id.into()); + + self.to_share_with_set.iter().any(|item| { + if let Some(e) = item.value().messages.get(user_id) { + e.contains_key(&device_id) + } else { + false + } + }) + }; + + shared_with || should_be_shared_with } /// Mark that the session was shared with the given user/device pair.