diff --git a/matrix_sdk_crypto/src/file_encryption/key_export.rs b/matrix_sdk_crypto/src/file_encryption/key_export.rs index e1cde309..dcbdd440 100644 --- a/matrix_sdk_crypto/src/file_encryption/key_export.rs +++ b/matrix_sdk_crypto/src/file_encryption/key_export.rs @@ -300,7 +300,7 @@ mod test { let room_id = room_id!("!test:localhost"); machine - .create_outnbound_group_session_with_defaults(&room_id) + .create_outbound_group_session_with_defaults(&room_id) .await .unwrap(); let export = machine diff --git a/matrix_sdk_crypto/src/key_request.rs b/matrix_sdk_crypto/src/key_request.rs index 2bcc5214..5b900ce4 100644 --- a/matrix_sdk_crypto/src/key_request.rs +++ b/matrix_sdk_crypto/src/key_request.rs @@ -374,7 +374,7 @@ mod test { let account = account(); let (_, session) = account - .create_group_session_pair(&room_id(), Default::default()) + .create_group_session_pair_with_defaults(&room_id()) .await .unwrap(); @@ -415,7 +415,7 @@ mod test { let account = account(); let (_, session) = account - .create_group_session_pair(&room_id(), Default::default()) + .create_group_session_pair_with_defaults(&room_id()) .await .unwrap(); machine diff --git a/matrix_sdk_crypto/src/machine.rs b/matrix_sdk_crypto/src/machine.rs index 69c00dc8..ebe524b3 100644 --- a/matrix_sdk_crypto/src/machine.rs +++ b/matrix_sdk_crypto/src/machine.rs @@ -1026,10 +1026,11 @@ impl OlmMachine { &self, room_id: &RoomId, settings: EncryptionSettings, + users_to_share_with: impl Iterator, ) -> OlmResult<()> { let (outbound, inbound) = self .account - .create_group_session_pair(room_id, settings) + .create_group_session_pair(room_id, settings, users_to_share_with) .await .map_err(|_| EventError::UnsupportedAlgorithm)?; @@ -1042,11 +1043,11 @@ impl OlmMachine { } #[cfg(test)] - pub(crate) async fn create_outnbound_group_session_with_defaults( + pub(crate) async fn create_outbound_group_session_with_defaults( &self, room_id: &RoomId, ) -> OlmResult<()> { - self.create_outbound_group_session(room_id, EncryptionSettings::default()) + self.create_outbound_group_session(room_id, EncryptionSettings::default(), [].iter()) .await } @@ -1143,7 +1144,7 @@ impl OlmMachine { users: impl Iterator, encryption_settings: impl Into, ) -> OlmResult> { - self.create_outbound_group_session(room_id, encryption_settings.into()) + self.create_outbound_group_session(room_id, encryption_settings.into(), users) .await?; let session = self.outbound_group_sessions.get(room_id).unwrap(); @@ -1159,8 +1160,8 @@ impl OlmMachine { let mut devices = Vec::new(); - for user_id in users { - for device in self.get_user_devices(user_id).await?.devices() { + for user_id in session.users_to_share_with() { + for device in self.get_user_devices(&user_id).await?.devices() { if !device.is_blacklisted() { devices.push(device.clone()); } @@ -1928,7 +1929,7 @@ pub(crate) mod test { let room_id = room_id!("!test:example.org"); machine - .create_outnbound_group_session_with_defaults(&room_id) + .create_outbound_group_session_with_defaults(&room_id) .await .unwrap(); assert!(machine.outbound_group_sessions.get(&room_id).is_some()); diff --git a/matrix_sdk_crypto/src/olm/account.rs b/matrix_sdk_crypto/src/olm/account.rs index e46c7142..8be10151 100644 --- a/matrix_sdk_crypto/src/olm/account.rs +++ b/matrix_sdk_crypto/src/olm/account.rs @@ -580,6 +580,7 @@ impl Account { &self, room_id: &RoomId, settings: EncryptionSettings, + users_to_share_with: impl Iterator, ) -> Result<(OutboundGroupSession, InboundGroupSession), ()> { if settings.algorithm != EventEncryptionAlgorithm::MegolmV1AesSha2 { return Err(()); @@ -590,6 +591,7 @@ impl Account { self.identity_keys.clone(), room_id, settings, + users_to_share_with, ); let identity_keys = self.identity_keys(); @@ -606,6 +608,15 @@ impl Account { Ok((outbound, inbound)) } + + #[cfg(test)] + pub(crate) async fn create_group_session_pair_with_defaults( + &self, + room_id: &RoomId, + ) -> Result<(OutboundGroupSession, InboundGroupSession), ()> { + self.create_group_session_pair(room_id, EncryptionSettings::default(), [].iter()) + .await + } } impl PartialEq for Account { diff --git a/matrix_sdk_crypto/src/olm/group_sessions/mod.rs b/matrix_sdk_crypto/src/olm/group_sessions/mod.rs index b0e69599..a30125d2 100644 --- a/matrix_sdk_crypto/src/olm/group_sessions/mod.rs +++ b/matrix_sdk_crypto/src/olm/group_sessions/mod.rs @@ -147,7 +147,7 @@ mod test { let account = Account::new(&user_id!("@alice:example.org"), "DEVICEID".into()); let (session, _) = account - .create_group_session_pair(&room_id!("!test_room:example.org"), settings) + .create_group_session_pair(&room_id!("!test_room:example.org"), settings, [].iter()) .await .unwrap(); @@ -165,7 +165,7 @@ mod test { }; let (mut session, _) = account - .create_group_session_pair(&room_id!("!test_room:example.org"), settings) + .create_group_session_pair(&room_id!("!test_room:example.org"), settings, [].iter()) .await .unwrap(); diff --git a/matrix_sdk_crypto/src/olm/group_sessions/outbound.rs b/matrix_sdk_crypto/src/olm/group_sessions/outbound.rs index ac521272..11334eb2 100644 --- a/matrix_sdk_crypto/src/olm/group_sessions/outbound.rs +++ b/matrix_sdk_crypto/src/olm/group_sessions/outbound.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use dashmap::{setref::multiple::RefMulti, DashSet}; use std::{ cmp::min, fmt, @@ -27,7 +28,7 @@ use matrix_sdk_common::{ room::{encrypted::EncryptedEventContent, encryption::EncryptionEventContent}, AnyMessageEventContent, EventContent, }, - identifiers::{DeviceId, EventEncryptionAlgorithm, RoomId}, + identifiers::{DeviceIdBox, EventEncryptionAlgorithm, RoomId, UserId}, instant::Instant, locks::Mutex, }; @@ -92,7 +93,7 @@ impl From<&EncryptionEventContent> for EncryptionSettings { #[derive(Clone)] pub struct OutboundGroupSession { inner: Arc>, - device_id: Arc>, + device_id: Arc, account_identity_keys: Arc, session_id: Arc, room_id: Arc, @@ -100,6 +101,8 @@ pub struct OutboundGroupSession { message_count: Arc, shared: Arc, settings: Arc, + shared_with_set: Arc>, + to_share_with_set: Arc>, } impl OutboundGroupSession { @@ -118,15 +121,18 @@ impl OutboundGroupSession { /// /// * `settings` - Settings determining the algorithm and rotation period of /// the outbound group session. - pub fn new( - device_id: Arc>, + pub fn new<'a>( + device_id: Arc, identity_keys: Arc, room_id: &RoomId, settings: EncryptionSettings, + users_to_share_with: impl Iterator, ) -> Self { let session = OlmOutboundGroupSession::new(); let session_id = session.session_id(); + let users_to_share_with = users_to_share_with.cloned().collect(); + OutboundGroupSession { inner: Arc::new(Mutex::new(session)), room_id: Arc::new(room_id.to_owned()), @@ -137,9 +143,15 @@ impl OutboundGroupSession { message_count: Arc::new(AtomicU64::new(0)), shared: Arc::new(AtomicBool::new(false)), settings: Arc::new(settings), + shared_with_set: Arc::new(DashSet::new()), + to_share_with_set: Arc::new(users_to_share_with), } } + pub(crate) fn users_to_share_with(&self) -> impl Iterator> + '_ { + self.to_share_with_set.iter() + } + /// Encrypt the given plaintext using this session. /// /// Returns the encrypted ciphertext. diff --git a/matrix_sdk_crypto/src/olm/mod.rs b/matrix_sdk_crypto/src/olm/mod.rs index e88b4959..7a48b9fd 100644 --- a/matrix_sdk_crypto/src/olm/mod.rs +++ b/matrix_sdk_crypto/src/olm/mod.rs @@ -194,7 +194,7 @@ pub(crate) mod test { let room_id = room_id!("!test:localhost"); let (outbound, _) = alice - .create_group_session_pair(&room_id, Default::default()) + .create_group_session_pair_with_defaults(&room_id) .await .unwrap(); @@ -230,7 +230,7 @@ pub(crate) mod test { let room_id = room_id!("!test:localhost"); let (_, inbound) = alice - .create_group_session_pair(&room_id, Default::default()) + .create_group_session_pair_with_defaults(&room_id) .await .unwrap(); diff --git a/matrix_sdk_crypto/src/store/caches.rs b/matrix_sdk_crypto/src/store/caches.rs index cc56c100..eb66198d 100644 --- a/matrix_sdk_crypto/src/store/caches.rs +++ b/matrix_sdk_crypto/src/store/caches.rs @@ -265,7 +265,7 @@ mod test { let room_id = room_id!("!test:localhost"); let (outbound, _) = account - .create_group_session_pair(&room_id, Default::default()) + .create_group_session_pair_with_defaults(&room_id) .await .unwrap(); diff --git a/matrix_sdk_crypto/src/store/memorystore.rs b/matrix_sdk_crypto/src/store/memorystore.rs index 05746102..85e4a696 100644 --- a/matrix_sdk_crypto/src/store/memorystore.rs +++ b/matrix_sdk_crypto/src/store/memorystore.rs @@ -219,7 +219,7 @@ mod test { let room_id = room_id!("!test:localhost"); let (outbound, _) = account - .create_group_session_pair(&room_id, Default::default()) + .create_group_session_pair_with_defaults(&room_id) .await .unwrap(); let inbound = InboundGroupSession::new(