crypto: Refactor the outobund group session storing
This introduces a group session cache struct that can be shared between components that need to access the currently active group session.master
parent
9e817a623b
commit
d4c56cc5b3
|
@ -40,8 +40,9 @@ use matrix_sdk_common::{
|
|||
|
||||
use crate::{
|
||||
error::{OlmError, OlmResult},
|
||||
olm::{InboundGroupSession, OutboundGroupSession, Session, ShareState},
|
||||
olm::{InboundGroupSession, Session, ShareState},
|
||||
requests::{OutgoingRequest, ToDeviceRequest},
|
||||
session_manager::GroupSessionCache,
|
||||
store::{Changes, CryptoStoreError, Store},
|
||||
Device,
|
||||
};
|
||||
|
@ -128,7 +129,7 @@ pub(crate) struct KeyRequestMachine {
|
|||
user_id: Arc<UserId>,
|
||||
device_id: Arc<DeviceIdBox>,
|
||||
store: Store,
|
||||
outbound_group_sessions: Arc<DashMap<RoomId, OutboundGroupSession>>,
|
||||
outbound_group_sessions: GroupSessionCache,
|
||||
outgoing_to_device_requests: Arc<DashMap<Uuid, OutgoingRequest>>,
|
||||
incoming_key_requests: Arc<
|
||||
DashMap<(UserId, DeviceIdBox, String), ToDeviceEvent<RoomKeyRequestToDeviceEventContent>>,
|
||||
|
@ -188,7 +189,7 @@ impl KeyRequestMachine {
|
|||
user_id: Arc<UserId>,
|
||||
device_id: Arc<DeviceIdBox>,
|
||||
store: Store,
|
||||
outbound_group_sessions: Arc<DashMap<RoomId, OutboundGroupSession>>,
|
||||
outbound_group_sessions: GroupSessionCache,
|
||||
users_for_key_claim: Arc<DashMap<UserId, DashSet<DeviceIdBox>>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
|
@ -356,7 +357,7 @@ impl KeyRequestMachine {
|
|||
.await?;
|
||||
|
||||
if let Some(device) = device {
|
||||
match self.should_share_session(&device, &session) {
|
||||
match self.should_share_session(&device, &session).await {
|
||||
Err(e) => {
|
||||
info!(
|
||||
"Received a key request from {} {} that we won't serve: {}",
|
||||
|
@ -458,14 +459,17 @@ impl KeyRequestMachine {
|
|||
/// * `device` - The device that is requesting a session from us.
|
||||
///
|
||||
/// * `session` - The session that was requested to be shared.
|
||||
fn should_share_session(
|
||||
async fn should_share_session(
|
||||
&self,
|
||||
device: &Device,
|
||||
session: &InboundGroupSession,
|
||||
) -> Result<Option<u32>, KeyshareDecision> {
|
||||
let outbound_session = self
|
||||
.outbound_group_sessions
|
||||
.get(session.room_id())
|
||||
.get_or_load(session.room_id())
|
||||
.await
|
||||
.ok()
|
||||
.flatten()
|
||||
.filter(|o| session.session_id() == o.session_id());
|
||||
|
||||
let own_device_check = || {
|
||||
|
@ -720,6 +724,7 @@ mod test {
|
|||
use crate::{
|
||||
identities::{LocalTrust, ReadOnlyDevice},
|
||||
olm::{Account, PrivateCrossSigningIdentity, ReadOnlyAccount},
|
||||
session_manager::GroupSessionCache,
|
||||
store::{Changes, CryptoStore, MemoryStore, Store},
|
||||
verification::VerificationMachine,
|
||||
};
|
||||
|
@ -761,12 +766,13 @@ mod test {
|
|||
let identity = Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty(bob_id())));
|
||||
let verification = VerificationMachine::new(account, identity.clone(), store.clone());
|
||||
let store = Store::new(user_id.clone(), identity, store, verification);
|
||||
let session_cache = GroupSessionCache::new(store.clone());
|
||||
|
||||
KeyRequestMachine::new(
|
||||
user_id,
|
||||
Arc::new(bob_device_id()),
|
||||
store,
|
||||
Arc::new(DashMap::new()),
|
||||
session_cache,
|
||||
Arc::new(DashMap::new()),
|
||||
)
|
||||
}
|
||||
|
@ -780,12 +786,13 @@ mod test {
|
|||
let verification = VerificationMachine::new(account, identity.clone(), store.clone());
|
||||
let store = Store::new(user_id.clone(), identity, store, verification);
|
||||
store.save_devices(&[device]).await.unwrap();
|
||||
let session_cache = GroupSessionCache::new(store.clone());
|
||||
|
||||
KeyRequestMachine::new(
|
||||
user_id,
|
||||
Arc::new(alice_device_id()),
|
||||
store,
|
||||
Arc::new(DashMap::new()),
|
||||
session_cache,
|
||||
Arc::new(DashMap::new()),
|
||||
)
|
||||
}
|
||||
|
@ -973,12 +980,16 @@ mod test {
|
|||
assert_eq!(
|
||||
machine
|
||||
.should_share_session(&own_device, &inbound)
|
||||
.await
|
||||
.expect_err("Should not share with untrusted"),
|
||||
KeyshareDecision::UntrustedDevice
|
||||
);
|
||||
own_device.set_trust_state(LocalTrust::Verified);
|
||||
// Now we do want to share the keys.
|
||||
assert!(machine.should_share_session(&own_device, &inbound).is_ok());
|
||||
assert!(machine
|
||||
.should_share_session(&own_device, &inbound)
|
||||
.await
|
||||
.is_ok());
|
||||
|
||||
let bob_device = ReadOnlyDevice::from_account(&bob_account()).await;
|
||||
machine.store.save_devices(&[bob_device]).await.unwrap();
|
||||
|
@ -995,6 +1006,7 @@ mod test {
|
|||
assert_eq!(
|
||||
machine
|
||||
.should_share_session(&bob_device, &inbound)
|
||||
.await
|
||||
.expect_err("Should not share with other."),
|
||||
KeyshareDecision::MissingOutboundSession
|
||||
);
|
||||
|
@ -1004,15 +1016,14 @@ mod test {
|
|||
changes.outbound_group_sessions.push(outbound.clone());
|
||||
changes.inbound_group_sessions.push(inbound.clone());
|
||||
machine.store.save_changes(changes).await.unwrap();
|
||||
machine
|
||||
.outbound_group_sessions
|
||||
.insert(inbound.room_id().to_owned(), outbound.clone());
|
||||
machine.outbound_group_sessions.insert(outbound.clone());
|
||||
|
||||
// We don't share sessions with other user's devices if the session
|
||||
// wasn't shared in the first place.
|
||||
assert_eq!(
|
||||
machine
|
||||
.should_share_session(&bob_device, &inbound)
|
||||
.await
|
||||
.expect_err("Should not share with other unless shared."),
|
||||
KeyshareDecision::OutboundSessionNotShared
|
||||
);
|
||||
|
@ -1024,13 +1035,17 @@ mod test {
|
|||
assert_eq!(
|
||||
machine
|
||||
.should_share_session(&bob_device, &inbound)
|
||||
.await
|
||||
.expect_err("Should not share with other unless shared."),
|
||||
KeyshareDecision::OutboundSessionNotShared
|
||||
);
|
||||
|
||||
// We now share the session, since it was shared before.
|
||||
outbound.mark_shared_with(bob_device.user_id(), bob_device.device_id());
|
||||
assert!(machine.should_share_session(&bob_device, &inbound).is_ok());
|
||||
assert!(machine
|
||||
.should_share_session(&bob_device, &inbound)
|
||||
.await
|
||||
.is_ok());
|
||||
|
||||
// But we don't share some other session that doesn't match our outbound
|
||||
// session
|
||||
|
@ -1042,6 +1057,7 @@ mod test {
|
|||
assert_eq!(
|
||||
machine
|
||||
.should_share_session(&bob_device, &other_inbound)
|
||||
.await
|
||||
.expect_err("Should not share with other unless shared."),
|
||||
KeyshareDecision::MissingOutboundSession
|
||||
);
|
||||
|
@ -1112,7 +1128,7 @@ mod test {
|
|||
// Put the outbound session into bobs store.
|
||||
bob_machine
|
||||
.outbound_group_sessions
|
||||
.insert(room_id(), group_session.clone());
|
||||
.insert(group_session.clone());
|
||||
|
||||
// Get the request and convert it into a event.
|
||||
let request = alice_machine
|
||||
|
@ -1278,7 +1294,7 @@ mod test {
|
|||
// Put the outbound session into bobs store.
|
||||
bob_machine
|
||||
.outbound_group_sessions
|
||||
.insert(room_id(), group_session.clone());
|
||||
.insert(group_session.clone());
|
||||
|
||||
// Get the request and convert it into a event.
|
||||
let request = alice_machine
|
||||
|
|
|
@ -156,29 +156,29 @@ impl OlmMachine {
|
|||
verification_machine.clone(),
|
||||
);
|
||||
let device_id: Arc<DeviceIdBox> = Arc::new(device_id);
|
||||
let outbound_group_sessions = Arc::new(DashMap::new());
|
||||
let users_for_key_claim = Arc::new(DashMap::new());
|
||||
|
||||
let key_request_machine = KeyRequestMachine::new(
|
||||
user_id.clone(),
|
||||
device_id.clone(),
|
||||
store.clone(),
|
||||
outbound_group_sessions,
|
||||
users_for_key_claim.clone(),
|
||||
);
|
||||
|
||||
let account = Account {
|
||||
inner: account,
|
||||
store: store.clone(),
|
||||
};
|
||||
|
||||
let group_session_manager = GroupSessionManager::new(account.clone(), store.clone());
|
||||
|
||||
let key_request_machine = KeyRequestMachine::new(
|
||||
user_id.clone(),
|
||||
device_id.clone(),
|
||||
store.clone(),
|
||||
group_session_manager.session_cache(),
|
||||
users_for_key_claim.clone(),
|
||||
);
|
||||
|
||||
let session_manager = SessionManager::new(
|
||||
account.clone(),
|
||||
users_for_key_claim,
|
||||
key_request_machine.clone(),
|
||||
store.clone(),
|
||||
);
|
||||
let group_session_manager = GroupSessionManager::new(account.clone(), store.clone());
|
||||
let identity_manager =
|
||||
IdentityManager::new(user_id.clone(), device_id.clone(), store.clone());
|
||||
|
||||
|
|
|
@ -40,6 +40,57 @@ use crate::{
|
|||
Device, EncryptionSettings, OlmError, ToDeviceRequest,
|
||||
};
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct GroupSessionCache {
|
||||
store: Store,
|
||||
sessions: Arc<DashMap<RoomId, OutboundGroupSession>>,
|
||||
/// A map from the request id to the group session that the request belongs
|
||||
/// to. Used to mark requests belonging to the session as shared.
|
||||
sessions_being_shared: Arc<DashMap<Uuid, OutboundGroupSession>>,
|
||||
}
|
||||
|
||||
impl GroupSessionCache {
|
||||
pub(crate) fn new(store: Store) -> Self {
|
||||
Self {
|
||||
store,
|
||||
sessions: DashMap::new().into(),
|
||||
sessions_being_shared: Arc::new(DashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn insert(&self, session: OutboundGroupSession) {
|
||||
self.sessions.insert(session.room_id().to_owned(), session);
|
||||
}
|
||||
|
||||
pub async fn get_or_load(&self, room_id: &RoomId) -> StoreResult<Option<OutboundGroupSession>> {
|
||||
// Get the cached session, if there isn't one load one from the store
|
||||
// and put it in the cache.
|
||||
if let Some(s) = self.sessions.get(room_id) {
|
||||
Ok(Some(s.clone()))
|
||||
} else if let Some(s) = self.store.get_outbound_group_sessions(room_id).await? {
|
||||
for request_id in s.pending_request_ids() {
|
||||
self.sessions_being_shared.insert(request_id, s.clone());
|
||||
}
|
||||
|
||||
self.sessions.insert(room_id.clone(), s.clone());
|
||||
|
||||
Ok(Some(s))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
/// Get an outbound group session for a room, if one exists.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `room_id` - The id of the room for which we should get the outbound
|
||||
/// group session.
|
||||
fn get(&self, room_id: &RoomId) -> Option<OutboundGroupSession> {
|
||||
self.sessions.get(room_id).map(|s| s.clone())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GroupSessionManager {
|
||||
account: Account,
|
||||
|
@ -48,10 +99,7 @@ pub struct GroupSessionManager {
|
|||
/// without the need to create new keys.
|
||||
store: Store,
|
||||
/// The currently active outbound group sessions.
|
||||
outbound_group_sessions: Arc<DashMap<RoomId, OutboundGroupSession>>,
|
||||
/// A map from the request id to the group session that the request belongs
|
||||
/// to. Used to mark requests belonging to the session as shared.
|
||||
outbound_sessions_being_shared: Arc<DashMap<Uuid, OutboundGroupSession>>,
|
||||
sessions: GroupSessionCache,
|
||||
}
|
||||
|
||||
impl GroupSessionManager {
|
||||
|
@ -60,14 +108,13 @@ impl GroupSessionManager {
|
|||
pub(crate) fn new(account: Account, store: Store) -> Self {
|
||||
Self {
|
||||
account,
|
||||
store,
|
||||
outbound_group_sessions: Arc::new(DashMap::new()),
|
||||
outbound_sessions_being_shared: Arc::new(DashMap::new()),
|
||||
store: store.clone(),
|
||||
sessions: GroupSessionCache::new(store),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn invalidate_group_session(&self, room_id: &RoomId) -> StoreResult<bool> {
|
||||
if let Some(s) = self.outbound_group_sessions.get(room_id) {
|
||||
if let Some(s) = self.sessions.get(room_id) {
|
||||
s.invalidate_session();
|
||||
|
||||
let mut changes = Changes::default();
|
||||
|
@ -81,7 +128,7 @@ impl GroupSessionManager {
|
|||
}
|
||||
|
||||
pub async fn mark_request_as_sent(&self, request_id: &Uuid) -> StoreResult<()> {
|
||||
if let Some((_, s)) = self.outbound_sessions_being_shared.remove(request_id) {
|
||||
if let Some((_, s)) = self.sessions.sessions_being_shared.remove(request_id) {
|
||||
s.mark_request_as_sent(request_id);
|
||||
|
||||
let mut changes = Changes::default();
|
||||
|
@ -97,15 +144,9 @@ impl GroupSessionManager {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
/// Get an outbound group session for a room, if one exists.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `room_id` - The id of the room for which we should get the outbound
|
||||
/// group session.
|
||||
#[cfg(test)]
|
||||
pub fn get_outbound_group_session(&self, room_id: &RoomId) -> Option<OutboundGroupSession> {
|
||||
#[allow(clippy::map_clone)]
|
||||
self.outbound_group_sessions.get(room_id).map(|s| s.clone())
|
||||
self.sessions.get(room_id)
|
||||
}
|
||||
|
||||
pub async fn encrypt(
|
||||
|
@ -113,7 +154,7 @@ impl GroupSessionManager {
|
|||
room_id: &RoomId,
|
||||
content: AnyMessageEventContent,
|
||||
) -> MegolmResult<EncryptedEventContent> {
|
||||
let session = if let Some(s) = self.get_outbound_group_session(room_id) {
|
||||
let session = if let Some(s) = self.sessions.get(room_id) {
|
||||
s
|
||||
} else {
|
||||
panic!("Session wasn't created nor shared");
|
||||
|
@ -147,9 +188,7 @@ impl GroupSessionManager {
|
|||
.await
|
||||
.map_err(|_| EventError::UnsupportedAlgorithm)?;
|
||||
|
||||
let _ = self
|
||||
.outbound_group_sessions
|
||||
.insert(room_id.to_owned(), outbound.clone());
|
||||
self.sessions.insert(outbound.clone());
|
||||
Ok((outbound, inbound))
|
||||
}
|
||||
|
||||
|
@ -158,23 +197,7 @@ impl GroupSessionManager {
|
|||
room_id: &RoomId,
|
||||
settings: EncryptionSettings,
|
||||
) -> OlmResult<(OutboundGroupSession, Option<InboundGroupSession>)> {
|
||||
// Get the cached session, if there isn't one load one from the store
|
||||
// and put it in the cache.
|
||||
let outbound_session = if let Some(s) = self.outbound_group_sessions.get(room_id) {
|
||||
Some(s.clone())
|
||||
} else if let Some(s) = self.store.get_outbound_group_sessions(room_id).await? {
|
||||
for request_id in s.pending_request_ids() {
|
||||
self.outbound_sessions_being_shared
|
||||
.insert(request_id, s.clone());
|
||||
}
|
||||
|
||||
self.outbound_group_sessions
|
||||
.insert(room_id.clone(), s.clone());
|
||||
|
||||
Some(s)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let outbound_session = self.sessions.get_or_load(&room_id).await?;
|
||||
|
||||
// If there is no session or the session has expired or is invalid,
|
||||
// create a new one.
|
||||
|
@ -388,6 +411,10 @@ impl GroupSessionManager {
|
|||
Ok(used_sessions)
|
||||
}
|
||||
|
||||
pub(crate) fn session_cache(&self) -> GroupSessionCache {
|
||||
self.sessions.clone()
|
||||
}
|
||||
|
||||
/// Get to-device requests to share a group session with users in a room.
|
||||
///
|
||||
/// # Arguments
|
||||
|
@ -489,7 +516,7 @@ impl GroupSessionManager {
|
|||
key_content.clone(),
|
||||
outbound.clone(),
|
||||
message_index,
|
||||
self.outbound_sessions_being_shared.clone(),
|
||||
self.sessions.sessions_being_shared.clone(),
|
||||
))
|
||||
})
|
||||
.collect();
|
||||
|
|
|
@ -15,5 +15,5 @@
|
|||
mod group_sessions;
|
||||
mod sessions;
|
||||
|
||||
pub(crate) use group_sessions::GroupSessionManager;
|
||||
pub(crate) use group_sessions::{GroupSessionCache, GroupSessionManager};
|
||||
pub(crate) use sessions::SessionManager;
|
||||
|
|
|
@ -322,6 +322,7 @@ mod test {
|
|||
identities::ReadOnlyDevice,
|
||||
key_request::KeyRequestMachine,
|
||||
olm::{Account, PrivateCrossSigningIdentity, ReadOnlyAccount},
|
||||
session_manager::GroupSessionCache,
|
||||
store::{CryptoStore, MemoryStore, Store},
|
||||
verification::VerificationMachine,
|
||||
};
|
||||
|
@ -342,7 +343,6 @@ mod test {
|
|||
let user_id = user_id();
|
||||
let device_id = device_id();
|
||||
|
||||
let outbound_sessions = Arc::new(DashMap::new());
|
||||
let users_for_key_claim = Arc::new(DashMap::new());
|
||||
let account = ReadOnlyAccount::new(&user_id, &device_id);
|
||||
let store: Arc<Box<dyn CryptoStore>> = Arc::new(Box::new(MemoryStore::new()));
|
||||
|
@ -363,11 +363,13 @@ mod test {
|
|||
store: store.clone(),
|
||||
};
|
||||
|
||||
let session_cache = GroupSessionCache::new(store.clone());
|
||||
|
||||
let key_request = KeyRequestMachine::new(
|
||||
user_id,
|
||||
device_id,
|
||||
store.clone(),
|
||||
outbound_sessions,
|
||||
session_cache,
|
||||
users_for_key_claim.clone(),
|
||||
);
|
||||
|
||||
|
|
Loading…
Reference in New Issue