diff --git a/matrix_sdk_crypto/src/key_request.rs b/matrix_sdk_crypto/src/key_request.rs index ea74ef1e..a222f424 100644 --- a/matrix_sdk_crypto/src/key_request.rs +++ b/matrix_sdk_crypto/src/key_request.rs @@ -107,6 +107,7 @@ pub(crate) struct KeyRequestMachine { user_id: Arc, device_id: Arc, store: Store, + outbound_group_sessions: Arc>, outgoing_to_device_requests: Arc>, incoming_key_requests: Arc>>, @@ -167,11 +168,17 @@ fn wrap_key_request_content( } impl KeyRequestMachine { - pub fn new(user_id: Arc, device_id: Arc, store: Store) -> Self { + pub fn new( + user_id: Arc, + device_id: Arc, + store: Store, + outbound_group_sessions: Arc>, + ) -> Self { Self { user_id, device_id, store, + outbound_group_sessions, outgoing_to_device_requests: Arc::new(DashMap::new()), incoming_key_requests: Arc::new(DashMap::new()), } @@ -267,7 +274,12 @@ impl KeyRequestMachine { if let Some(device) = device { // TODO get the matching outbound session. - if let Err(e) = self.should_share_session(&device, None) { + if let Err(e) = self.should_share_session( + &device, + self.outbound_group_sessions + .get(&key_info.room_id) + .as_deref(), + ) { info!( "Received a key request from {} {} that we won't serve: {}", device.user_id(), @@ -572,6 +584,7 @@ impl KeyRequestMachine { #[cfg(test)] mod test { + use dashmap::DashMap; use matrix_sdk_common::{ events::{forwarded_room_key::ForwardedRoomKeyEventContent, ToDeviceEvent}, identifiers::{room_id, user_id, DeviceIdBox, RoomId, UserId}, @@ -619,7 +632,12 @@ mod test { let user_id = Arc::new(alice_id()); let store = Store::new(user_id.clone(), Box::new(MemoryStore::new())); - KeyRequestMachine::new(user_id, Arc::new(alice_device_id()), store) + KeyRequestMachine::new( + user_id, + Arc::new(alice_device_id()), + store, + Arc::new(DashMap::new()), + ) } #[test] diff --git a/matrix_sdk_crypto/src/machine.rs b/matrix_sdk_crypto/src/machine.rs index d55708a6..47741239 100644 --- a/matrix_sdk_crypto/src/machine.rs +++ b/matrix_sdk_crypto/src/machine.rs @@ -129,8 +129,13 @@ impl OlmMachine { let store = Store::new(user_id.clone(), store); let verification_machine = VerificationMachine::new(account.clone(), store.clone()); let device_id: Arc = Arc::new(device_id); - let key_request_machine = - KeyRequestMachine::new(user_id.clone(), device_id.clone(), store.clone()); + let outbound_group_sessions = Arc::new(DashMap::new()); + let key_request_machine = KeyRequestMachine::new( + user_id.clone(), + device_id.clone(), + store.clone(), + outbound_group_sessions.clone(), + ); let identity_manager = IdentityManager::new(user_id.clone(), device_id.clone(), store.clone()); @@ -139,7 +144,7 @@ impl OlmMachine { device_id, account, store, - outbound_group_sessions: Arc::new(DashMap::new()), + outbound_group_sessions, verification_machine, key_request_machine, identity_manager,