crypto: Refactor the outobund group session storing

This introduces a group session cache struct that can be shared between
components that need to access the currently active group session.
master
Damir Jelić 2021-04-15 15:19:21 +02:00
parent 9e817a623b
commit d4c56cc5b3
5 changed files with 112 additions and 67 deletions

View File

@ -40,8 +40,9 @@ use matrix_sdk_common::{
use crate::{ use crate::{
error::{OlmError, OlmResult}, error::{OlmError, OlmResult},
olm::{InboundGroupSession, OutboundGroupSession, Session, ShareState}, olm::{InboundGroupSession, Session, ShareState},
requests::{OutgoingRequest, ToDeviceRequest}, requests::{OutgoingRequest, ToDeviceRequest},
session_manager::GroupSessionCache,
store::{Changes, CryptoStoreError, Store}, store::{Changes, CryptoStoreError, Store},
Device, Device,
}; };
@ -128,7 +129,7 @@ pub(crate) struct KeyRequestMachine {
user_id: Arc<UserId>, user_id: Arc<UserId>,
device_id: Arc<DeviceIdBox>, device_id: Arc<DeviceIdBox>,
store: Store, store: Store,
outbound_group_sessions: Arc<DashMap<RoomId, OutboundGroupSession>>, outbound_group_sessions: GroupSessionCache,
outgoing_to_device_requests: Arc<DashMap<Uuid, OutgoingRequest>>, outgoing_to_device_requests: Arc<DashMap<Uuid, OutgoingRequest>>,
incoming_key_requests: Arc< incoming_key_requests: Arc<
DashMap<(UserId, DeviceIdBox, String), ToDeviceEvent<RoomKeyRequestToDeviceEventContent>>, DashMap<(UserId, DeviceIdBox, String), ToDeviceEvent<RoomKeyRequestToDeviceEventContent>>,
@ -188,7 +189,7 @@ impl KeyRequestMachine {
user_id: Arc<UserId>, user_id: Arc<UserId>,
device_id: Arc<DeviceIdBox>, device_id: Arc<DeviceIdBox>,
store: Store, store: Store,
outbound_group_sessions: Arc<DashMap<RoomId, OutboundGroupSession>>, outbound_group_sessions: GroupSessionCache,
users_for_key_claim: Arc<DashMap<UserId, DashSet<DeviceIdBox>>>, users_for_key_claim: Arc<DashMap<UserId, DashSet<DeviceIdBox>>>,
) -> Self { ) -> Self {
Self { Self {
@ -356,7 +357,7 @@ impl KeyRequestMachine {
.await?; .await?;
if let Some(device) = device { if let Some(device) = device {
match self.should_share_session(&device, &session) { match self.should_share_session(&device, &session).await {
Err(e) => { Err(e) => {
info!( info!(
"Received a key request from {} {} that we won't serve: {}", "Received a key request from {} {} that we won't serve: {}",
@ -458,14 +459,17 @@ impl KeyRequestMachine {
/// * `device` - The device that is requesting a session from us. /// * `device` - The device that is requesting a session from us.
/// ///
/// * `session` - The session that was requested to be shared. /// * `session` - The session that was requested to be shared.
fn should_share_session( async fn should_share_session(
&self, &self,
device: &Device, device: &Device,
session: &InboundGroupSession, session: &InboundGroupSession,
) -> Result<Option<u32>, KeyshareDecision> { ) -> Result<Option<u32>, KeyshareDecision> {
let outbound_session = self let outbound_session = self
.outbound_group_sessions .outbound_group_sessions
.get(session.room_id()) .get_or_load(session.room_id())
.await
.ok()
.flatten()
.filter(|o| session.session_id() == o.session_id()); .filter(|o| session.session_id() == o.session_id());
let own_device_check = || { let own_device_check = || {
@ -720,6 +724,7 @@ mod test {
use crate::{ use crate::{
identities::{LocalTrust, ReadOnlyDevice}, identities::{LocalTrust, ReadOnlyDevice},
olm::{Account, PrivateCrossSigningIdentity, ReadOnlyAccount}, olm::{Account, PrivateCrossSigningIdentity, ReadOnlyAccount},
session_manager::GroupSessionCache,
store::{Changes, CryptoStore, MemoryStore, Store}, store::{Changes, CryptoStore, MemoryStore, Store},
verification::VerificationMachine, verification::VerificationMachine,
}; };
@ -761,12 +766,13 @@ mod test {
let identity = Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty(bob_id()))); let identity = Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty(bob_id())));
let verification = VerificationMachine::new(account, identity.clone(), store.clone()); let verification = VerificationMachine::new(account, identity.clone(), store.clone());
let store = Store::new(user_id.clone(), identity, store, verification); let store = Store::new(user_id.clone(), identity, store, verification);
let session_cache = GroupSessionCache::new(store.clone());
KeyRequestMachine::new( KeyRequestMachine::new(
user_id, user_id,
Arc::new(bob_device_id()), Arc::new(bob_device_id()),
store, store,
Arc::new(DashMap::new()), session_cache,
Arc::new(DashMap::new()), Arc::new(DashMap::new()),
) )
} }
@ -780,12 +786,13 @@ mod test {
let verification = VerificationMachine::new(account, identity.clone(), store.clone()); let verification = VerificationMachine::new(account, identity.clone(), store.clone());
let store = Store::new(user_id.clone(), identity, store, verification); let store = Store::new(user_id.clone(), identity, store, verification);
store.save_devices(&[device]).await.unwrap(); store.save_devices(&[device]).await.unwrap();
let session_cache = GroupSessionCache::new(store.clone());
KeyRequestMachine::new( KeyRequestMachine::new(
user_id, user_id,
Arc::new(alice_device_id()), Arc::new(alice_device_id()),
store, store,
Arc::new(DashMap::new()), session_cache,
Arc::new(DashMap::new()), Arc::new(DashMap::new()),
) )
} }
@ -973,12 +980,16 @@ mod test {
assert_eq!( assert_eq!(
machine machine
.should_share_session(&own_device, &inbound) .should_share_session(&own_device, &inbound)
.await
.expect_err("Should not share with untrusted"), .expect_err("Should not share with untrusted"),
KeyshareDecision::UntrustedDevice KeyshareDecision::UntrustedDevice
); );
own_device.set_trust_state(LocalTrust::Verified); own_device.set_trust_state(LocalTrust::Verified);
// Now we do want to share the keys. // Now we do want to share the keys.
assert!(machine.should_share_session(&own_device, &inbound).is_ok()); assert!(machine
.should_share_session(&own_device, &inbound)
.await
.is_ok());
let bob_device = ReadOnlyDevice::from_account(&bob_account()).await; let bob_device = ReadOnlyDevice::from_account(&bob_account()).await;
machine.store.save_devices(&[bob_device]).await.unwrap(); machine.store.save_devices(&[bob_device]).await.unwrap();
@ -995,6 +1006,7 @@ mod test {
assert_eq!( assert_eq!(
machine machine
.should_share_session(&bob_device, &inbound) .should_share_session(&bob_device, &inbound)
.await
.expect_err("Should not share with other."), .expect_err("Should not share with other."),
KeyshareDecision::MissingOutboundSession KeyshareDecision::MissingOutboundSession
); );
@ -1004,15 +1016,14 @@ mod test {
changes.outbound_group_sessions.push(outbound.clone()); changes.outbound_group_sessions.push(outbound.clone());
changes.inbound_group_sessions.push(inbound.clone()); changes.inbound_group_sessions.push(inbound.clone());
machine.store.save_changes(changes).await.unwrap(); machine.store.save_changes(changes).await.unwrap();
machine machine.outbound_group_sessions.insert(outbound.clone());
.outbound_group_sessions
.insert(inbound.room_id().to_owned(), outbound.clone());
// We don't share sessions with other user's devices if the session // We don't share sessions with other user's devices if the session
// wasn't shared in the first place. // wasn't shared in the first place.
assert_eq!( assert_eq!(
machine machine
.should_share_session(&bob_device, &inbound) .should_share_session(&bob_device, &inbound)
.await
.expect_err("Should not share with other unless shared."), .expect_err("Should not share with other unless shared."),
KeyshareDecision::OutboundSessionNotShared KeyshareDecision::OutboundSessionNotShared
); );
@ -1024,13 +1035,17 @@ mod test {
assert_eq!( assert_eq!(
machine machine
.should_share_session(&bob_device, &inbound) .should_share_session(&bob_device, &inbound)
.await
.expect_err("Should not share with other unless shared."), .expect_err("Should not share with other unless shared."),
KeyshareDecision::OutboundSessionNotShared KeyshareDecision::OutboundSessionNotShared
); );
// We now share the session, since it was shared before. // We now share the session, since it was shared before.
outbound.mark_shared_with(bob_device.user_id(), bob_device.device_id()); outbound.mark_shared_with(bob_device.user_id(), bob_device.device_id());
assert!(machine.should_share_session(&bob_device, &inbound).is_ok()); assert!(machine
.should_share_session(&bob_device, &inbound)
.await
.is_ok());
// But we don't share some other session that doesn't match our outbound // But we don't share some other session that doesn't match our outbound
// session // session
@ -1042,6 +1057,7 @@ mod test {
assert_eq!( assert_eq!(
machine machine
.should_share_session(&bob_device, &other_inbound) .should_share_session(&bob_device, &other_inbound)
.await
.expect_err("Should not share with other unless shared."), .expect_err("Should not share with other unless shared."),
KeyshareDecision::MissingOutboundSession KeyshareDecision::MissingOutboundSession
); );
@ -1112,7 +1128,7 @@ mod test {
// Put the outbound session into bobs store. // Put the outbound session into bobs store.
bob_machine bob_machine
.outbound_group_sessions .outbound_group_sessions
.insert(room_id(), group_session.clone()); .insert(group_session.clone());
// Get the request and convert it into a event. // Get the request and convert it into a event.
let request = alice_machine let request = alice_machine
@ -1278,7 +1294,7 @@ mod test {
// Put the outbound session into bobs store. // Put the outbound session into bobs store.
bob_machine bob_machine
.outbound_group_sessions .outbound_group_sessions
.insert(room_id(), group_session.clone()); .insert(group_session.clone());
// Get the request and convert it into a event. // Get the request and convert it into a event.
let request = alice_machine let request = alice_machine

View File

@ -156,29 +156,29 @@ impl OlmMachine {
verification_machine.clone(), verification_machine.clone(),
); );
let device_id: Arc<DeviceIdBox> = Arc::new(device_id); let device_id: Arc<DeviceIdBox> = Arc::new(device_id);
let outbound_group_sessions = Arc::new(DashMap::new());
let users_for_key_claim = Arc::new(DashMap::new()); let users_for_key_claim = Arc::new(DashMap::new());
let key_request_machine = KeyRequestMachine::new(
user_id.clone(),
device_id.clone(),
store.clone(),
outbound_group_sessions,
users_for_key_claim.clone(),
);
let account = Account { let account = Account {
inner: account, inner: account,
store: store.clone(), store: store.clone(),
}; };
let group_session_manager = GroupSessionManager::new(account.clone(), store.clone());
let key_request_machine = KeyRequestMachine::new(
user_id.clone(),
device_id.clone(),
store.clone(),
group_session_manager.session_cache(),
users_for_key_claim.clone(),
);
let session_manager = SessionManager::new( let session_manager = SessionManager::new(
account.clone(), account.clone(),
users_for_key_claim, users_for_key_claim,
key_request_machine.clone(), key_request_machine.clone(),
store.clone(), store.clone(),
); );
let group_session_manager = GroupSessionManager::new(account.clone(), store.clone());
let identity_manager = let identity_manager =
IdentityManager::new(user_id.clone(), device_id.clone(), store.clone()); IdentityManager::new(user_id.clone(), device_id.clone(), store.clone());

View File

@ -40,6 +40,57 @@ use crate::{
Device, EncryptionSettings, OlmError, ToDeviceRequest, Device, EncryptionSettings, OlmError, ToDeviceRequest,
}; };
#[derive(Clone, Debug)]
pub(crate) struct GroupSessionCache {
store: Store,
sessions: Arc<DashMap<RoomId, OutboundGroupSession>>,
/// A map from the request id to the group session that the request belongs
/// to. Used to mark requests belonging to the session as shared.
sessions_being_shared: Arc<DashMap<Uuid, OutboundGroupSession>>,
}
impl GroupSessionCache {
pub(crate) fn new(store: Store) -> Self {
Self {
store,
sessions: DashMap::new().into(),
sessions_being_shared: Arc::new(DashMap::new()),
}
}
pub(crate) fn insert(&self, session: OutboundGroupSession) {
self.sessions.insert(session.room_id().to_owned(), session);
}
pub async fn get_or_load(&self, room_id: &RoomId) -> StoreResult<Option<OutboundGroupSession>> {
// Get the cached session, if there isn't one load one from the store
// and put it in the cache.
if let Some(s) = self.sessions.get(room_id) {
Ok(Some(s.clone()))
} else if let Some(s) = self.store.get_outbound_group_sessions(room_id).await? {
for request_id in s.pending_request_ids() {
self.sessions_being_shared.insert(request_id, s.clone());
}
self.sessions.insert(room_id.clone(), s.clone());
Ok(Some(s))
} else {
Ok(None)
}
}
/// Get an outbound group session for a room, if one exists.
///
/// # Arguments
///
/// * `room_id` - The id of the room for which we should get the outbound
/// group session.
fn get(&self, room_id: &RoomId) -> Option<OutboundGroupSession> {
self.sessions.get(room_id).map(|s| s.clone())
}
}
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct GroupSessionManager { pub struct GroupSessionManager {
account: Account, account: Account,
@ -48,10 +99,7 @@ pub struct GroupSessionManager {
/// without the need to create new keys. /// without the need to create new keys.
store: Store, store: Store,
/// The currently active outbound group sessions. /// The currently active outbound group sessions.
outbound_group_sessions: Arc<DashMap<RoomId, OutboundGroupSession>>, sessions: GroupSessionCache,
/// A map from the request id to the group session that the request belongs
/// to. Used to mark requests belonging to the session as shared.
outbound_sessions_being_shared: Arc<DashMap<Uuid, OutboundGroupSession>>,
} }
impl GroupSessionManager { impl GroupSessionManager {
@ -60,14 +108,13 @@ impl GroupSessionManager {
pub(crate) fn new(account: Account, store: Store) -> Self { pub(crate) fn new(account: Account, store: Store) -> Self {
Self { Self {
account, account,
store, store: store.clone(),
outbound_group_sessions: Arc::new(DashMap::new()), sessions: GroupSessionCache::new(store),
outbound_sessions_being_shared: Arc::new(DashMap::new()),
} }
} }
pub async fn invalidate_group_session(&self, room_id: &RoomId) -> StoreResult<bool> { pub async fn invalidate_group_session(&self, room_id: &RoomId) -> StoreResult<bool> {
if let Some(s) = self.outbound_group_sessions.get(room_id) { if let Some(s) = self.sessions.get(room_id) {
s.invalidate_session(); s.invalidate_session();
let mut changes = Changes::default(); let mut changes = Changes::default();
@ -81,7 +128,7 @@ impl GroupSessionManager {
} }
pub async fn mark_request_as_sent(&self, request_id: &Uuid) -> StoreResult<()> { pub async fn mark_request_as_sent(&self, request_id: &Uuid) -> StoreResult<()> {
if let Some((_, s)) = self.outbound_sessions_being_shared.remove(request_id) { if let Some((_, s)) = self.sessions.sessions_being_shared.remove(request_id) {
s.mark_request_as_sent(request_id); s.mark_request_as_sent(request_id);
let mut changes = Changes::default(); let mut changes = Changes::default();
@ -97,15 +144,9 @@ impl GroupSessionManager {
Ok(()) Ok(())
} }
/// Get an outbound group session for a room, if one exists. #[cfg(test)]
///
/// # Arguments
///
/// * `room_id` - The id of the room for which we should get the outbound
/// group session.
pub fn get_outbound_group_session(&self, room_id: &RoomId) -> Option<OutboundGroupSession> { pub fn get_outbound_group_session(&self, room_id: &RoomId) -> Option<OutboundGroupSession> {
#[allow(clippy::map_clone)] self.sessions.get(room_id)
self.outbound_group_sessions.get(room_id).map(|s| s.clone())
} }
pub async fn encrypt( pub async fn encrypt(
@ -113,7 +154,7 @@ impl GroupSessionManager {
room_id: &RoomId, room_id: &RoomId,
content: AnyMessageEventContent, content: AnyMessageEventContent,
) -> MegolmResult<EncryptedEventContent> { ) -> MegolmResult<EncryptedEventContent> {
let session = if let Some(s) = self.get_outbound_group_session(room_id) { let session = if let Some(s) = self.sessions.get(room_id) {
s s
} else { } else {
panic!("Session wasn't created nor shared"); panic!("Session wasn't created nor shared");
@ -147,9 +188,7 @@ impl GroupSessionManager {
.await .await
.map_err(|_| EventError::UnsupportedAlgorithm)?; .map_err(|_| EventError::UnsupportedAlgorithm)?;
let _ = self self.sessions.insert(outbound.clone());
.outbound_group_sessions
.insert(room_id.to_owned(), outbound.clone());
Ok((outbound, inbound)) Ok((outbound, inbound))
} }
@ -158,23 +197,7 @@ impl GroupSessionManager {
room_id: &RoomId, room_id: &RoomId,
settings: EncryptionSettings, settings: EncryptionSettings,
) -> OlmResult<(OutboundGroupSession, Option<InboundGroupSession>)> { ) -> OlmResult<(OutboundGroupSession, Option<InboundGroupSession>)> {
// Get the cached session, if there isn't one load one from the store let outbound_session = self.sessions.get_or_load(&room_id).await?;
// and put it in the cache.
let outbound_session = if let Some(s) = self.outbound_group_sessions.get(room_id) {
Some(s.clone())
} else if let Some(s) = self.store.get_outbound_group_sessions(room_id).await? {
for request_id in s.pending_request_ids() {
self.outbound_sessions_being_shared
.insert(request_id, s.clone());
}
self.outbound_group_sessions
.insert(room_id.clone(), s.clone());
Some(s)
} else {
None
};
// If there is no session or the session has expired or is invalid, // If there is no session or the session has expired or is invalid,
// create a new one. // create a new one.
@ -388,6 +411,10 @@ impl GroupSessionManager {
Ok(used_sessions) Ok(used_sessions)
} }
pub(crate) fn session_cache(&self) -> GroupSessionCache {
self.sessions.clone()
}
/// Get to-device requests to share a group session with users in a room. /// Get to-device requests to share a group session with users in a room.
/// ///
/// # Arguments /// # Arguments
@ -489,7 +516,7 @@ impl GroupSessionManager {
key_content.clone(), key_content.clone(),
outbound.clone(), outbound.clone(),
message_index, message_index,
self.outbound_sessions_being_shared.clone(), self.sessions.sessions_being_shared.clone(),
)) ))
}) })
.collect(); .collect();

View File

@ -15,5 +15,5 @@
mod group_sessions; mod group_sessions;
mod sessions; mod sessions;
pub(crate) use group_sessions::GroupSessionManager; pub(crate) use group_sessions::{GroupSessionCache, GroupSessionManager};
pub(crate) use sessions::SessionManager; pub(crate) use sessions::SessionManager;

View File

@ -322,6 +322,7 @@ mod test {
identities::ReadOnlyDevice, identities::ReadOnlyDevice,
key_request::KeyRequestMachine, key_request::KeyRequestMachine,
olm::{Account, PrivateCrossSigningIdentity, ReadOnlyAccount}, olm::{Account, PrivateCrossSigningIdentity, ReadOnlyAccount},
session_manager::GroupSessionCache,
store::{CryptoStore, MemoryStore, Store}, store::{CryptoStore, MemoryStore, Store},
verification::VerificationMachine, verification::VerificationMachine,
}; };
@ -342,7 +343,6 @@ mod test {
let user_id = user_id(); let user_id = user_id();
let device_id = device_id(); let device_id = device_id();
let outbound_sessions = Arc::new(DashMap::new());
let users_for_key_claim = Arc::new(DashMap::new()); let users_for_key_claim = Arc::new(DashMap::new());
let account = ReadOnlyAccount::new(&user_id, &device_id); let account = ReadOnlyAccount::new(&user_id, &device_id);
let store: Arc<Box<dyn CryptoStore>> = Arc::new(Box::new(MemoryStore::new())); let store: Arc<Box<dyn CryptoStore>> = Arc::new(Box::new(MemoryStore::new()));
@ -363,11 +363,13 @@ mod test {
store: store.clone(), store: store.clone(),
}; };
let session_cache = GroupSessionCache::new(store.clone());
let key_request = KeyRequestMachine::new( let key_request = KeyRequestMachine::new(
user_id, user_id,
device_id, device_id,
store.clone(), store.clone(),
outbound_sessions, session_cache,
users_for_key_claim.clone(), users_for_key_claim.clone(),
); );