crypto: Refactor the outobund group session storing

This introduces a group session cache struct that can be shared between
components that need to access the currently active group session.
master
Damir Jelić 2021-04-15 15:19:21 +02:00
parent 9e817a623b
commit d4c56cc5b3
5 changed files with 112 additions and 67 deletions

View File

@ -40,8 +40,9 @@ use matrix_sdk_common::{
use crate::{
error::{OlmError, OlmResult},
olm::{InboundGroupSession, OutboundGroupSession, Session, ShareState},
olm::{InboundGroupSession, Session, ShareState},
requests::{OutgoingRequest, ToDeviceRequest},
session_manager::GroupSessionCache,
store::{Changes, CryptoStoreError, Store},
Device,
};
@ -128,7 +129,7 @@ pub(crate) struct KeyRequestMachine {
user_id: Arc<UserId>,
device_id: Arc<DeviceIdBox>,
store: Store,
outbound_group_sessions: Arc<DashMap<RoomId, OutboundGroupSession>>,
outbound_group_sessions: GroupSessionCache,
outgoing_to_device_requests: Arc<DashMap<Uuid, OutgoingRequest>>,
incoming_key_requests: Arc<
DashMap<(UserId, DeviceIdBox, String), ToDeviceEvent<RoomKeyRequestToDeviceEventContent>>,
@ -188,7 +189,7 @@ impl KeyRequestMachine {
user_id: Arc<UserId>,
device_id: Arc<DeviceIdBox>,
store: Store,
outbound_group_sessions: Arc<DashMap<RoomId, OutboundGroupSession>>,
outbound_group_sessions: GroupSessionCache,
users_for_key_claim: Arc<DashMap<UserId, DashSet<DeviceIdBox>>>,
) -> Self {
Self {
@ -356,7 +357,7 @@ impl KeyRequestMachine {
.await?;
if let Some(device) = device {
match self.should_share_session(&device, &session) {
match self.should_share_session(&device, &session).await {
Err(e) => {
info!(
"Received a key request from {} {} that we won't serve: {}",
@ -458,14 +459,17 @@ impl KeyRequestMachine {
/// * `device` - The device that is requesting a session from us.
///
/// * `session` - The session that was requested to be shared.
fn should_share_session(
async fn should_share_session(
&self,
device: &Device,
session: &InboundGroupSession,
) -> Result<Option<u32>, KeyshareDecision> {
let outbound_session = self
.outbound_group_sessions
.get(session.room_id())
.get_or_load(session.room_id())
.await
.ok()
.flatten()
.filter(|o| session.session_id() == o.session_id());
let own_device_check = || {
@ -720,6 +724,7 @@ mod test {
use crate::{
identities::{LocalTrust, ReadOnlyDevice},
olm::{Account, PrivateCrossSigningIdentity, ReadOnlyAccount},
session_manager::GroupSessionCache,
store::{Changes, CryptoStore, MemoryStore, Store},
verification::VerificationMachine,
};
@ -761,12 +766,13 @@ mod test {
let identity = Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty(bob_id())));
let verification = VerificationMachine::new(account, identity.clone(), store.clone());
let store = Store::new(user_id.clone(), identity, store, verification);
let session_cache = GroupSessionCache::new(store.clone());
KeyRequestMachine::new(
user_id,
Arc::new(bob_device_id()),
store,
Arc::new(DashMap::new()),
session_cache,
Arc::new(DashMap::new()),
)
}
@ -780,12 +786,13 @@ mod test {
let verification = VerificationMachine::new(account, identity.clone(), store.clone());
let store = Store::new(user_id.clone(), identity, store, verification);
store.save_devices(&[device]).await.unwrap();
let session_cache = GroupSessionCache::new(store.clone());
KeyRequestMachine::new(
user_id,
Arc::new(alice_device_id()),
store,
Arc::new(DashMap::new()),
session_cache,
Arc::new(DashMap::new()),
)
}
@ -973,12 +980,16 @@ mod test {
assert_eq!(
machine
.should_share_session(&own_device, &inbound)
.await
.expect_err("Should not share with untrusted"),
KeyshareDecision::UntrustedDevice
);
own_device.set_trust_state(LocalTrust::Verified);
// Now we do want to share the keys.
assert!(machine.should_share_session(&own_device, &inbound).is_ok());
assert!(machine
.should_share_session(&own_device, &inbound)
.await
.is_ok());
let bob_device = ReadOnlyDevice::from_account(&bob_account()).await;
machine.store.save_devices(&[bob_device]).await.unwrap();
@ -995,6 +1006,7 @@ mod test {
assert_eq!(
machine
.should_share_session(&bob_device, &inbound)
.await
.expect_err("Should not share with other."),
KeyshareDecision::MissingOutboundSession
);
@ -1004,15 +1016,14 @@ mod test {
changes.outbound_group_sessions.push(outbound.clone());
changes.inbound_group_sessions.push(inbound.clone());
machine.store.save_changes(changes).await.unwrap();
machine
.outbound_group_sessions
.insert(inbound.room_id().to_owned(), outbound.clone());
machine.outbound_group_sessions.insert(outbound.clone());
// We don't share sessions with other user's devices if the session
// wasn't shared in the first place.
assert_eq!(
machine
.should_share_session(&bob_device, &inbound)
.await
.expect_err("Should not share with other unless shared."),
KeyshareDecision::OutboundSessionNotShared
);
@ -1024,13 +1035,17 @@ mod test {
assert_eq!(
machine
.should_share_session(&bob_device, &inbound)
.await
.expect_err("Should not share with other unless shared."),
KeyshareDecision::OutboundSessionNotShared
);
// We now share the session, since it was shared before.
outbound.mark_shared_with(bob_device.user_id(), bob_device.device_id());
assert!(machine.should_share_session(&bob_device, &inbound).is_ok());
assert!(machine
.should_share_session(&bob_device, &inbound)
.await
.is_ok());
// But we don't share some other session that doesn't match our outbound
// session
@ -1042,6 +1057,7 @@ mod test {
assert_eq!(
machine
.should_share_session(&bob_device, &other_inbound)
.await
.expect_err("Should not share with other unless shared."),
KeyshareDecision::MissingOutboundSession
);
@ -1112,7 +1128,7 @@ mod test {
// Put the outbound session into bobs store.
bob_machine
.outbound_group_sessions
.insert(room_id(), group_session.clone());
.insert(group_session.clone());
// Get the request and convert it into a event.
let request = alice_machine
@ -1278,7 +1294,7 @@ mod test {
// Put the outbound session into bobs store.
bob_machine
.outbound_group_sessions
.insert(room_id(), group_session.clone());
.insert(group_session.clone());
// Get the request and convert it into a event.
let request = alice_machine

View File

@ -156,29 +156,29 @@ impl OlmMachine {
verification_machine.clone(),
);
let device_id: Arc<DeviceIdBox> = Arc::new(device_id);
let outbound_group_sessions = Arc::new(DashMap::new());
let users_for_key_claim = Arc::new(DashMap::new());
let key_request_machine = KeyRequestMachine::new(
user_id.clone(),
device_id.clone(),
store.clone(),
outbound_group_sessions,
users_for_key_claim.clone(),
);
let account = Account {
inner: account,
store: store.clone(),
};
let group_session_manager = GroupSessionManager::new(account.clone(), store.clone());
let key_request_machine = KeyRequestMachine::new(
user_id.clone(),
device_id.clone(),
store.clone(),
group_session_manager.session_cache(),
users_for_key_claim.clone(),
);
let session_manager = SessionManager::new(
account.clone(),
users_for_key_claim,
key_request_machine.clone(),
store.clone(),
);
let group_session_manager = GroupSessionManager::new(account.clone(), store.clone());
let identity_manager =
IdentityManager::new(user_id.clone(), device_id.clone(), store.clone());

View File

@ -40,6 +40,57 @@ use crate::{
Device, EncryptionSettings, OlmError, ToDeviceRequest,
};
#[derive(Clone, Debug)]
pub(crate) struct GroupSessionCache {
store: Store,
sessions: Arc<DashMap<RoomId, OutboundGroupSession>>,
/// A map from the request id to the group session that the request belongs
/// to. Used to mark requests belonging to the session as shared.
sessions_being_shared: Arc<DashMap<Uuid, OutboundGroupSession>>,
}
impl GroupSessionCache {
pub(crate) fn new(store: Store) -> Self {
Self {
store,
sessions: DashMap::new().into(),
sessions_being_shared: Arc::new(DashMap::new()),
}
}
pub(crate) fn insert(&self, session: OutboundGroupSession) {
self.sessions.insert(session.room_id().to_owned(), session);
}
pub async fn get_or_load(&self, room_id: &RoomId) -> StoreResult<Option<OutboundGroupSession>> {
// Get the cached session, if there isn't one load one from the store
// and put it in the cache.
if let Some(s) = self.sessions.get(room_id) {
Ok(Some(s.clone()))
} else if let Some(s) = self.store.get_outbound_group_sessions(room_id).await? {
for request_id in s.pending_request_ids() {
self.sessions_being_shared.insert(request_id, s.clone());
}
self.sessions.insert(room_id.clone(), s.clone());
Ok(Some(s))
} else {
Ok(None)
}
}
/// Get an outbound group session for a room, if one exists.
///
/// # Arguments
///
/// * `room_id` - The id of the room for which we should get the outbound
/// group session.
fn get(&self, room_id: &RoomId) -> Option<OutboundGroupSession> {
self.sessions.get(room_id).map(|s| s.clone())
}
}
#[derive(Debug, Clone)]
pub struct GroupSessionManager {
account: Account,
@ -48,10 +99,7 @@ pub struct GroupSessionManager {
/// without the need to create new keys.
store: Store,
/// The currently active outbound group sessions.
outbound_group_sessions: Arc<DashMap<RoomId, OutboundGroupSession>>,
/// A map from the request id to the group session that the request belongs
/// to. Used to mark requests belonging to the session as shared.
outbound_sessions_being_shared: Arc<DashMap<Uuid, OutboundGroupSession>>,
sessions: GroupSessionCache,
}
impl GroupSessionManager {
@ -60,14 +108,13 @@ impl GroupSessionManager {
pub(crate) fn new(account: Account, store: Store) -> Self {
Self {
account,
store,
outbound_group_sessions: Arc::new(DashMap::new()),
outbound_sessions_being_shared: Arc::new(DashMap::new()),
store: store.clone(),
sessions: GroupSessionCache::new(store),
}
}
pub async fn invalidate_group_session(&self, room_id: &RoomId) -> StoreResult<bool> {
if let Some(s) = self.outbound_group_sessions.get(room_id) {
if let Some(s) = self.sessions.get(room_id) {
s.invalidate_session();
let mut changes = Changes::default();
@ -81,7 +128,7 @@ impl GroupSessionManager {
}
pub async fn mark_request_as_sent(&self, request_id: &Uuid) -> StoreResult<()> {
if let Some((_, s)) = self.outbound_sessions_being_shared.remove(request_id) {
if let Some((_, s)) = self.sessions.sessions_being_shared.remove(request_id) {
s.mark_request_as_sent(request_id);
let mut changes = Changes::default();
@ -97,15 +144,9 @@ impl GroupSessionManager {
Ok(())
}
/// Get an outbound group session for a room, if one exists.
///
/// # Arguments
///
/// * `room_id` - The id of the room for which we should get the outbound
/// group session.
#[cfg(test)]
pub fn get_outbound_group_session(&self, room_id: &RoomId) -> Option<OutboundGroupSession> {
#[allow(clippy::map_clone)]
self.outbound_group_sessions.get(room_id).map(|s| s.clone())
self.sessions.get(room_id)
}
pub async fn encrypt(
@ -113,7 +154,7 @@ impl GroupSessionManager {
room_id: &RoomId,
content: AnyMessageEventContent,
) -> MegolmResult<EncryptedEventContent> {
let session = if let Some(s) = self.get_outbound_group_session(room_id) {
let session = if let Some(s) = self.sessions.get(room_id) {
s
} else {
panic!("Session wasn't created nor shared");
@ -147,9 +188,7 @@ impl GroupSessionManager {
.await
.map_err(|_| EventError::UnsupportedAlgorithm)?;
let _ = self
.outbound_group_sessions
.insert(room_id.to_owned(), outbound.clone());
self.sessions.insert(outbound.clone());
Ok((outbound, inbound))
}
@ -158,23 +197,7 @@ impl GroupSessionManager {
room_id: &RoomId,
settings: EncryptionSettings,
) -> OlmResult<(OutboundGroupSession, Option<InboundGroupSession>)> {
// Get the cached session, if there isn't one load one from the store
// and put it in the cache.
let outbound_session = if let Some(s) = self.outbound_group_sessions.get(room_id) {
Some(s.clone())
} else if let Some(s) = self.store.get_outbound_group_sessions(room_id).await? {
for request_id in s.pending_request_ids() {
self.outbound_sessions_being_shared
.insert(request_id, s.clone());
}
self.outbound_group_sessions
.insert(room_id.clone(), s.clone());
Some(s)
} else {
None
};
let outbound_session = self.sessions.get_or_load(&room_id).await?;
// If there is no session or the session has expired or is invalid,
// create a new one.
@ -388,6 +411,10 @@ impl GroupSessionManager {
Ok(used_sessions)
}
pub(crate) fn session_cache(&self) -> GroupSessionCache {
self.sessions.clone()
}
/// Get to-device requests to share a group session with users in a room.
///
/// # Arguments
@ -489,7 +516,7 @@ impl GroupSessionManager {
key_content.clone(),
outbound.clone(),
message_index,
self.outbound_sessions_being_shared.clone(),
self.sessions.sessions_being_shared.clone(),
))
})
.collect();

View File

@ -15,5 +15,5 @@
mod group_sessions;
mod sessions;
pub(crate) use group_sessions::GroupSessionManager;
pub(crate) use group_sessions::{GroupSessionCache, GroupSessionManager};
pub(crate) use sessions::SessionManager;

View File

@ -322,6 +322,7 @@ mod test {
identities::ReadOnlyDevice,
key_request::KeyRequestMachine,
olm::{Account, PrivateCrossSigningIdentity, ReadOnlyAccount},
session_manager::GroupSessionCache,
store::{CryptoStore, MemoryStore, Store},
verification::VerificationMachine,
};
@ -342,7 +343,6 @@ mod test {
let user_id = user_id();
let device_id = device_id();
let outbound_sessions = Arc::new(DashMap::new());
let users_for_key_claim = Arc::new(DashMap::new());
let account = ReadOnlyAccount::new(&user_id, &device_id);
let store: Arc<Box<dyn CryptoStore>> = Arc::new(Box::new(MemoryStore::new()));
@ -363,11 +363,13 @@ mod test {
store: store.clone(),
};
let session_cache = GroupSessionCache::new(store.clone());
let key_request = KeyRequestMachine::new(
user_id,
device_id,
store.clone(),
outbound_sessions,
session_cache,
users_for_key_claim.clone(),
);