diff --git a/matrix_sdk_crypto/src/session_manager/group_sessions.rs b/matrix_sdk_crypto/src/session_manager/group_sessions.rs index 6b3dde99..3c68e562 100644 --- a/matrix_sdk_crypto/src/session_manager/group_sessions.rs +++ b/matrix_sdk_crypto/src/session_manager/group_sessions.rs @@ -235,7 +235,12 @@ impl GroupSessionManager { changed_sessions.push(session); } - messages.extend(message); + for (user, device_messages) in message.into_iter() { + messages + .entry(user) + .or_insert_with(BTreeMap::new) + .extend(device_messages); + } } let id = Uuid::new_v4(); @@ -515,3 +520,86 @@ impl GroupSessionManager { Ok(requests) } } + +#[cfg(test)] +mod test { + use std::convert::TryFrom; + + use matrix_sdk_common::{ + api::r0::keys::{claim_keys, get_keys}, + identifiers::{room_id, user_id, DeviceIdBox, UserId}, + uuid::Uuid, + }; + use matrix_sdk_test::response_from_file; + use serde_json::Value; + + use crate::{EncryptionSettings, OlmMachine}; + + fn alice_id() -> UserId { + user_id!("@alice:example.org") + } + + fn alice_device_id() -> DeviceIdBox { + "JLAFKJWSCS".into() + } + + fn keys_query_response() -> get_keys::Response { + let data = include_bytes!("../../benches/keys_query.json"); + let data: Value = serde_json::from_slice(data).unwrap(); + let data = response_from_file(&data); + get_keys::Response::try_from(data).expect("Can't parse the keys upload response") + } + + fn keys_claim_response() -> claim_keys::Response { + let data = include_bytes!("../../benches/keys_claim.json"); + let data: Value = serde_json::from_slice(data).unwrap(); + let data = response_from_file(&data); + claim_keys::Response::try_from(data).expect("Can't parse the keys upload response") + } + + async fn machine() -> OlmMachine { + let keys_query = keys_query_response(); + let keys_claim = keys_claim_response(); + let uuid = Uuid::new_v4(); + + let machine = OlmMachine::new(&alice_id(), &alice_device_id()); + + machine + .mark_request_as_sent(&uuid, &keys_query) + .await + .unwrap(); + machine + .mark_request_as_sent(&uuid, &keys_claim) + .await + .unwrap(); + + machine + } + + #[tokio::test] + async fn test_sharing() { + let machine = machine().await; + let room_id = room_id!("!test:localhost"); + let keys_claim = keys_claim_response(); + + let users: Vec<_> = keys_claim.one_time_keys.keys().collect(); + + let requests = machine + .share_group_session( + &room_id, + users.clone().into_iter(), + EncryptionSettings::default(), + ) + .await + .unwrap(); + + let event_count = requests.iter().fold(0, |acc, r| { + acc + r.messages.values().fold(0, |acc, v| acc + v.len()) + }); + + // The keys claim response has a couple of one-time keys with invalid + // signatures, thus only 148 sessions are actually created, we check + // that all 148 valid sessions get an room key. + assert_eq!(event_count, 148); + } +}