From d4c56cc5b30e47395652bf144c1d441e493a74c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Damir=20Jeli=C4=87?= Date: Thu, 15 Apr 2021 15:19:21 +0200 Subject: [PATCH] crypto: Refactor the outobund group session storing This introduces a group session cache struct that can be shared between components that need to access the currently active group session. --- matrix_sdk_crypto/src/key_request.rs | 46 +++++--- matrix_sdk_crypto/src/machine.rs | 20 ++-- .../src/session_manager/group_sessions.rs | 105 +++++++++++------- matrix_sdk_crypto/src/session_manager/mod.rs | 2 +- .../src/session_manager/sessions.rs | 6 +- 5 files changed, 112 insertions(+), 67 deletions(-) diff --git a/matrix_sdk_crypto/src/key_request.rs b/matrix_sdk_crypto/src/key_request.rs index f8fab2ff..34a3b8a3 100644 --- a/matrix_sdk_crypto/src/key_request.rs +++ b/matrix_sdk_crypto/src/key_request.rs @@ -40,8 +40,9 @@ use matrix_sdk_common::{ use crate::{ error::{OlmError, OlmResult}, - olm::{InboundGroupSession, OutboundGroupSession, Session, ShareState}, + olm::{InboundGroupSession, Session, ShareState}, requests::{OutgoingRequest, ToDeviceRequest}, + session_manager::GroupSessionCache, store::{Changes, CryptoStoreError, Store}, Device, }; @@ -128,7 +129,7 @@ pub(crate) struct KeyRequestMachine { user_id: Arc, device_id: Arc, store: Store, - outbound_group_sessions: Arc>, + outbound_group_sessions: GroupSessionCache, outgoing_to_device_requests: Arc>, incoming_key_requests: Arc< DashMap<(UserId, DeviceIdBox, String), ToDeviceEvent>, @@ -188,7 +189,7 @@ impl KeyRequestMachine { user_id: Arc, device_id: Arc, store: Store, - outbound_group_sessions: Arc>, + outbound_group_sessions: GroupSessionCache, users_for_key_claim: Arc>>, ) -> Self { Self { @@ -356,7 +357,7 @@ impl KeyRequestMachine { .await?; if let Some(device) = device { - match self.should_share_session(&device, &session) { + match self.should_share_session(&device, &session).await { Err(e) => { info!( "Received a key request from {} {} that we won't serve: {}", @@ -458,14 +459,17 @@ impl KeyRequestMachine { /// * `device` - The device that is requesting a session from us. /// /// * `session` - The session that was requested to be shared. - fn should_share_session( + async fn should_share_session( &self, device: &Device, session: &InboundGroupSession, ) -> Result, KeyshareDecision> { let outbound_session = self .outbound_group_sessions - .get(session.room_id()) + .get_or_load(session.room_id()) + .await + .ok() + .flatten() .filter(|o| session.session_id() == o.session_id()); let own_device_check = || { @@ -720,6 +724,7 @@ mod test { use crate::{ identities::{LocalTrust, ReadOnlyDevice}, olm::{Account, PrivateCrossSigningIdentity, ReadOnlyAccount}, + session_manager::GroupSessionCache, store::{Changes, CryptoStore, MemoryStore, Store}, verification::VerificationMachine, }; @@ -761,12 +766,13 @@ mod test { let identity = Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty(bob_id()))); let verification = VerificationMachine::new(account, identity.clone(), store.clone()); let store = Store::new(user_id.clone(), identity, store, verification); + let session_cache = GroupSessionCache::new(store.clone()); KeyRequestMachine::new( user_id, Arc::new(bob_device_id()), store, - Arc::new(DashMap::new()), + session_cache, Arc::new(DashMap::new()), ) } @@ -780,12 +786,13 @@ mod test { let verification = VerificationMachine::new(account, identity.clone(), store.clone()); let store = Store::new(user_id.clone(), identity, store, verification); store.save_devices(&[device]).await.unwrap(); + let session_cache = GroupSessionCache::new(store.clone()); KeyRequestMachine::new( user_id, Arc::new(alice_device_id()), store, - Arc::new(DashMap::new()), + session_cache, Arc::new(DashMap::new()), ) } @@ -973,12 +980,16 @@ mod test { assert_eq!( machine .should_share_session(&own_device, &inbound) + .await .expect_err("Should not share with untrusted"), KeyshareDecision::UntrustedDevice ); own_device.set_trust_state(LocalTrust::Verified); // Now we do want to share the keys. - assert!(machine.should_share_session(&own_device, &inbound).is_ok()); + assert!(machine + .should_share_session(&own_device, &inbound) + .await + .is_ok()); let bob_device = ReadOnlyDevice::from_account(&bob_account()).await; machine.store.save_devices(&[bob_device]).await.unwrap(); @@ -995,6 +1006,7 @@ mod test { assert_eq!( machine .should_share_session(&bob_device, &inbound) + .await .expect_err("Should not share with other."), KeyshareDecision::MissingOutboundSession ); @@ -1004,15 +1016,14 @@ mod test { changes.outbound_group_sessions.push(outbound.clone()); changes.inbound_group_sessions.push(inbound.clone()); machine.store.save_changes(changes).await.unwrap(); - machine - .outbound_group_sessions - .insert(inbound.room_id().to_owned(), outbound.clone()); + machine.outbound_group_sessions.insert(outbound.clone()); // We don't share sessions with other user's devices if the session // wasn't shared in the first place. assert_eq!( machine .should_share_session(&bob_device, &inbound) + .await .expect_err("Should not share with other unless shared."), KeyshareDecision::OutboundSessionNotShared ); @@ -1024,13 +1035,17 @@ mod test { assert_eq!( machine .should_share_session(&bob_device, &inbound) + .await .expect_err("Should not share with other unless shared."), KeyshareDecision::OutboundSessionNotShared ); // We now share the session, since it was shared before. outbound.mark_shared_with(bob_device.user_id(), bob_device.device_id()); - assert!(machine.should_share_session(&bob_device, &inbound).is_ok()); + assert!(machine + .should_share_session(&bob_device, &inbound) + .await + .is_ok()); // But we don't share some other session that doesn't match our outbound // session @@ -1042,6 +1057,7 @@ mod test { assert_eq!( machine .should_share_session(&bob_device, &other_inbound) + .await .expect_err("Should not share with other unless shared."), KeyshareDecision::MissingOutboundSession ); @@ -1112,7 +1128,7 @@ mod test { // Put the outbound session into bobs store. bob_machine .outbound_group_sessions - .insert(room_id(), group_session.clone()); + .insert(group_session.clone()); // Get the request and convert it into a event. let request = alice_machine @@ -1278,7 +1294,7 @@ mod test { // Put the outbound session into bobs store. bob_machine .outbound_group_sessions - .insert(room_id(), group_session.clone()); + .insert(group_session.clone()); // Get the request and convert it into a event. let request = alice_machine diff --git a/matrix_sdk_crypto/src/machine.rs b/matrix_sdk_crypto/src/machine.rs index 601be381..59a99457 100644 --- a/matrix_sdk_crypto/src/machine.rs +++ b/matrix_sdk_crypto/src/machine.rs @@ -156,29 +156,29 @@ impl OlmMachine { verification_machine.clone(), ); let device_id: Arc = Arc::new(device_id); - let outbound_group_sessions = Arc::new(DashMap::new()); let users_for_key_claim = Arc::new(DashMap::new()); - let key_request_machine = KeyRequestMachine::new( - user_id.clone(), - device_id.clone(), - store.clone(), - outbound_group_sessions, - users_for_key_claim.clone(), - ); - let account = Account { inner: account, store: store.clone(), }; + let group_session_manager = GroupSessionManager::new(account.clone(), store.clone()); + + let key_request_machine = KeyRequestMachine::new( + user_id.clone(), + device_id.clone(), + store.clone(), + group_session_manager.session_cache(), + users_for_key_claim.clone(), + ); + let session_manager = SessionManager::new( account.clone(), users_for_key_claim, key_request_machine.clone(), store.clone(), ); - let group_session_manager = GroupSessionManager::new(account.clone(), store.clone()); let identity_manager = IdentityManager::new(user_id.clone(), device_id.clone(), store.clone()); diff --git a/matrix_sdk_crypto/src/session_manager/group_sessions.rs b/matrix_sdk_crypto/src/session_manager/group_sessions.rs index 649664a8..a656d00e 100644 --- a/matrix_sdk_crypto/src/session_manager/group_sessions.rs +++ b/matrix_sdk_crypto/src/session_manager/group_sessions.rs @@ -40,6 +40,57 @@ use crate::{ Device, EncryptionSettings, OlmError, ToDeviceRequest, }; +#[derive(Clone, Debug)] +pub(crate) struct GroupSessionCache { + store: Store, + sessions: Arc>, + /// A map from the request id to the group session that the request belongs + /// to. Used to mark requests belonging to the session as shared. + sessions_being_shared: Arc>, +} + +impl GroupSessionCache { + pub(crate) fn new(store: Store) -> Self { + Self { + store, + sessions: DashMap::new().into(), + sessions_being_shared: Arc::new(DashMap::new()), + } + } + + pub(crate) fn insert(&self, session: OutboundGroupSession) { + self.sessions.insert(session.room_id().to_owned(), session); + } + + pub async fn get_or_load(&self, room_id: &RoomId) -> StoreResult> { + // Get the cached session, if there isn't one load one from the store + // and put it in the cache. + if let Some(s) = self.sessions.get(room_id) { + Ok(Some(s.clone())) + } else if let Some(s) = self.store.get_outbound_group_sessions(room_id).await? { + for request_id in s.pending_request_ids() { + self.sessions_being_shared.insert(request_id, s.clone()); + } + + self.sessions.insert(room_id.clone(), s.clone()); + + Ok(Some(s)) + } else { + Ok(None) + } + } + + /// Get an outbound group session for a room, if one exists. + /// + /// # Arguments + /// + /// * `room_id` - The id of the room for which we should get the outbound + /// group session. + fn get(&self, room_id: &RoomId) -> Option { + self.sessions.get(room_id).map(|s| s.clone()) + } +} + #[derive(Debug, Clone)] pub struct GroupSessionManager { account: Account, @@ -48,10 +99,7 @@ pub struct GroupSessionManager { /// without the need to create new keys. store: Store, /// The currently active outbound group sessions. - outbound_group_sessions: Arc>, - /// A map from the request id to the group session that the request belongs - /// to. Used to mark requests belonging to the session as shared. - outbound_sessions_being_shared: Arc>, + sessions: GroupSessionCache, } impl GroupSessionManager { @@ -60,14 +108,13 @@ impl GroupSessionManager { pub(crate) fn new(account: Account, store: Store) -> Self { Self { account, - store, - outbound_group_sessions: Arc::new(DashMap::new()), - outbound_sessions_being_shared: Arc::new(DashMap::new()), + store: store.clone(), + sessions: GroupSessionCache::new(store), } } pub async fn invalidate_group_session(&self, room_id: &RoomId) -> StoreResult { - if let Some(s) = self.outbound_group_sessions.get(room_id) { + if let Some(s) = self.sessions.get(room_id) { s.invalidate_session(); let mut changes = Changes::default(); @@ -81,7 +128,7 @@ impl GroupSessionManager { } pub async fn mark_request_as_sent(&self, request_id: &Uuid) -> StoreResult<()> { - if let Some((_, s)) = self.outbound_sessions_being_shared.remove(request_id) { + if let Some((_, s)) = self.sessions.sessions_being_shared.remove(request_id) { s.mark_request_as_sent(request_id); let mut changes = Changes::default(); @@ -97,15 +144,9 @@ impl GroupSessionManager { Ok(()) } - /// Get an outbound group session for a room, if one exists. - /// - /// # Arguments - /// - /// * `room_id` - The id of the room for which we should get the outbound - /// group session. + #[cfg(test)] pub fn get_outbound_group_session(&self, room_id: &RoomId) -> Option { - #[allow(clippy::map_clone)] - self.outbound_group_sessions.get(room_id).map(|s| s.clone()) + self.sessions.get(room_id) } pub async fn encrypt( @@ -113,7 +154,7 @@ impl GroupSessionManager { room_id: &RoomId, content: AnyMessageEventContent, ) -> MegolmResult { - let session = if let Some(s) = self.get_outbound_group_session(room_id) { + let session = if let Some(s) = self.sessions.get(room_id) { s } else { panic!("Session wasn't created nor shared"); @@ -147,9 +188,7 @@ impl GroupSessionManager { .await .map_err(|_| EventError::UnsupportedAlgorithm)?; - let _ = self - .outbound_group_sessions - .insert(room_id.to_owned(), outbound.clone()); + self.sessions.insert(outbound.clone()); Ok((outbound, inbound)) } @@ -158,23 +197,7 @@ impl GroupSessionManager { room_id: &RoomId, settings: EncryptionSettings, ) -> OlmResult<(OutboundGroupSession, Option)> { - // Get the cached session, if there isn't one load one from the store - // and put it in the cache. - let outbound_session = if let Some(s) = self.outbound_group_sessions.get(room_id) { - Some(s.clone()) - } else if let Some(s) = self.store.get_outbound_group_sessions(room_id).await? { - for request_id in s.pending_request_ids() { - self.outbound_sessions_being_shared - .insert(request_id, s.clone()); - } - - self.outbound_group_sessions - .insert(room_id.clone(), s.clone()); - - Some(s) - } else { - None - }; + let outbound_session = self.sessions.get_or_load(&room_id).await?; // If there is no session or the session has expired or is invalid, // create a new one. @@ -388,6 +411,10 @@ impl GroupSessionManager { Ok(used_sessions) } + pub(crate) fn session_cache(&self) -> GroupSessionCache { + self.sessions.clone() + } + /// Get to-device requests to share a group session with users in a room. /// /// # Arguments @@ -489,7 +516,7 @@ impl GroupSessionManager { key_content.clone(), outbound.clone(), message_index, - self.outbound_sessions_being_shared.clone(), + self.sessions.sessions_being_shared.clone(), )) }) .collect(); diff --git a/matrix_sdk_crypto/src/session_manager/mod.rs b/matrix_sdk_crypto/src/session_manager/mod.rs index 7750262e..1af686ef 100644 --- a/matrix_sdk_crypto/src/session_manager/mod.rs +++ b/matrix_sdk_crypto/src/session_manager/mod.rs @@ -15,5 +15,5 @@ mod group_sessions; mod sessions; -pub(crate) use group_sessions::GroupSessionManager; +pub(crate) use group_sessions::{GroupSessionCache, GroupSessionManager}; pub(crate) use sessions::SessionManager; diff --git a/matrix_sdk_crypto/src/session_manager/sessions.rs b/matrix_sdk_crypto/src/session_manager/sessions.rs index 9417c3ea..274314a8 100644 --- a/matrix_sdk_crypto/src/session_manager/sessions.rs +++ b/matrix_sdk_crypto/src/session_manager/sessions.rs @@ -322,6 +322,7 @@ mod test { identities::ReadOnlyDevice, key_request::KeyRequestMachine, olm::{Account, PrivateCrossSigningIdentity, ReadOnlyAccount}, + session_manager::GroupSessionCache, store::{CryptoStore, MemoryStore, Store}, verification::VerificationMachine, }; @@ -342,7 +343,6 @@ mod test { let user_id = user_id(); let device_id = device_id(); - let outbound_sessions = Arc::new(DashMap::new()); let users_for_key_claim = Arc::new(DashMap::new()); let account = ReadOnlyAccount::new(&user_id, &device_id); let store: Arc> = Arc::new(Box::new(MemoryStore::new())); @@ -363,11 +363,13 @@ mod test { store: store.clone(), }; + let session_cache = GroupSessionCache::new(store.clone()); + let key_request = KeyRequestMachine::new( user_id, device_id, store.clone(), - outbound_sessions, + session_cache, users_for_key_claim.clone(), );