crypto: Remember the users that received the outbound group session.
parent
562bb5aee3
commit
5b0457dad0
|
@ -300,7 +300,7 @@ mod test {
|
||||||
let room_id = room_id!("!test:localhost");
|
let room_id = room_id!("!test:localhost");
|
||||||
|
|
||||||
machine
|
machine
|
||||||
.create_outnbound_group_session_with_defaults(&room_id)
|
.create_outbound_group_session_with_defaults(&room_id)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let export = machine
|
let export = machine
|
||||||
|
|
|
@ -374,7 +374,7 @@ mod test {
|
||||||
let account = account();
|
let account = account();
|
||||||
|
|
||||||
let (_, session) = account
|
let (_, session) = account
|
||||||
.create_group_session_pair(&room_id(), Default::default())
|
.create_group_session_pair_with_defaults(&room_id())
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
@ -415,7 +415,7 @@ mod test {
|
||||||
let account = account();
|
let account = account();
|
||||||
|
|
||||||
let (_, session) = account
|
let (_, session) = account
|
||||||
.create_group_session_pair(&room_id(), Default::default())
|
.create_group_session_pair_with_defaults(&room_id())
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
machine
|
machine
|
||||||
|
|
|
@ -1026,10 +1026,11 @@ impl OlmMachine {
|
||||||
&self,
|
&self,
|
||||||
room_id: &RoomId,
|
room_id: &RoomId,
|
||||||
settings: EncryptionSettings,
|
settings: EncryptionSettings,
|
||||||
|
users_to_share_with: impl Iterator<Item = &UserId>,
|
||||||
) -> OlmResult<()> {
|
) -> OlmResult<()> {
|
||||||
let (outbound, inbound) = self
|
let (outbound, inbound) = self
|
||||||
.account
|
.account
|
||||||
.create_group_session_pair(room_id, settings)
|
.create_group_session_pair(room_id, settings, users_to_share_with)
|
||||||
.await
|
.await
|
||||||
.map_err(|_| EventError::UnsupportedAlgorithm)?;
|
.map_err(|_| EventError::UnsupportedAlgorithm)?;
|
||||||
|
|
||||||
|
@ -1042,11 +1043,11 @@ impl OlmMachine {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
pub(crate) async fn create_outnbound_group_session_with_defaults(
|
pub(crate) async fn create_outbound_group_session_with_defaults(
|
||||||
&self,
|
&self,
|
||||||
room_id: &RoomId,
|
room_id: &RoomId,
|
||||||
) -> OlmResult<()> {
|
) -> OlmResult<()> {
|
||||||
self.create_outbound_group_session(room_id, EncryptionSettings::default())
|
self.create_outbound_group_session(room_id, EncryptionSettings::default(), [].iter())
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1143,7 +1144,7 @@ impl OlmMachine {
|
||||||
users: impl Iterator<Item = &UserId>,
|
users: impl Iterator<Item = &UserId>,
|
||||||
encryption_settings: impl Into<EncryptionSettings>,
|
encryption_settings: impl Into<EncryptionSettings>,
|
||||||
) -> OlmResult<Vec<ToDeviceRequest>> {
|
) -> OlmResult<Vec<ToDeviceRequest>> {
|
||||||
self.create_outbound_group_session(room_id, encryption_settings.into())
|
self.create_outbound_group_session(room_id, encryption_settings.into(), users)
|
||||||
.await?;
|
.await?;
|
||||||
let session = self.outbound_group_sessions.get(room_id).unwrap();
|
let session = self.outbound_group_sessions.get(room_id).unwrap();
|
||||||
|
|
||||||
|
@ -1159,8 +1160,8 @@ impl OlmMachine {
|
||||||
|
|
||||||
let mut devices = Vec::new();
|
let mut devices = Vec::new();
|
||||||
|
|
||||||
for user_id in users {
|
for user_id in session.users_to_share_with() {
|
||||||
for device in self.get_user_devices(user_id).await?.devices() {
|
for device in self.get_user_devices(&user_id).await?.devices() {
|
||||||
if !device.is_blacklisted() {
|
if !device.is_blacklisted() {
|
||||||
devices.push(device.clone());
|
devices.push(device.clone());
|
||||||
}
|
}
|
||||||
|
@ -1928,7 +1929,7 @@ pub(crate) mod test {
|
||||||
let room_id = room_id!("!test:example.org");
|
let room_id = room_id!("!test:example.org");
|
||||||
|
|
||||||
machine
|
machine
|
||||||
.create_outnbound_group_session_with_defaults(&room_id)
|
.create_outbound_group_session_with_defaults(&room_id)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert!(machine.outbound_group_sessions.get(&room_id).is_some());
|
assert!(machine.outbound_group_sessions.get(&room_id).is_some());
|
||||||
|
|
|
@ -580,6 +580,7 @@ impl Account {
|
||||||
&self,
|
&self,
|
||||||
room_id: &RoomId,
|
room_id: &RoomId,
|
||||||
settings: EncryptionSettings,
|
settings: EncryptionSettings,
|
||||||
|
users_to_share_with: impl Iterator<Item = &UserId>,
|
||||||
) -> Result<(OutboundGroupSession, InboundGroupSession), ()> {
|
) -> Result<(OutboundGroupSession, InboundGroupSession), ()> {
|
||||||
if settings.algorithm != EventEncryptionAlgorithm::MegolmV1AesSha2 {
|
if settings.algorithm != EventEncryptionAlgorithm::MegolmV1AesSha2 {
|
||||||
return Err(());
|
return Err(());
|
||||||
|
@ -590,6 +591,7 @@ impl Account {
|
||||||
self.identity_keys.clone(),
|
self.identity_keys.clone(),
|
||||||
room_id,
|
room_id,
|
||||||
settings,
|
settings,
|
||||||
|
users_to_share_with,
|
||||||
);
|
);
|
||||||
let identity_keys = self.identity_keys();
|
let identity_keys = self.identity_keys();
|
||||||
|
|
||||||
|
@ -606,6 +608,15 @@ impl Account {
|
||||||
|
|
||||||
Ok((outbound, inbound))
|
Ok((outbound, inbound))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
pub(crate) async fn create_group_session_pair_with_defaults(
|
||||||
|
&self,
|
||||||
|
room_id: &RoomId,
|
||||||
|
) -> Result<(OutboundGroupSession, InboundGroupSession), ()> {
|
||||||
|
self.create_group_session_pair(room_id, EncryptionSettings::default(), [].iter())
|
||||||
|
.await
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl PartialEq for Account {
|
impl PartialEq for Account {
|
||||||
|
|
|
@ -147,7 +147,7 @@ mod test {
|
||||||
|
|
||||||
let account = Account::new(&user_id!("@alice:example.org"), "DEVICEID".into());
|
let account = Account::new(&user_id!("@alice:example.org"), "DEVICEID".into());
|
||||||
let (session, _) = account
|
let (session, _) = account
|
||||||
.create_group_session_pair(&room_id!("!test_room:example.org"), settings)
|
.create_group_session_pair(&room_id!("!test_room:example.org"), settings, [].iter())
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
@ -165,7 +165,7 @@ mod test {
|
||||||
};
|
};
|
||||||
|
|
||||||
let (mut session, _) = account
|
let (mut session, _) = account
|
||||||
.create_group_session_pair(&room_id!("!test_room:example.org"), settings)
|
.create_group_session_pair(&room_id!("!test_room:example.org"), settings, [].iter())
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
use dashmap::{setref::multiple::RefMulti, DashSet};
|
||||||
use std::{
|
use std::{
|
||||||
cmp::min,
|
cmp::min,
|
||||||
fmt,
|
fmt,
|
||||||
|
@ -27,7 +28,7 @@ use matrix_sdk_common::{
|
||||||
room::{encrypted::EncryptedEventContent, encryption::EncryptionEventContent},
|
room::{encrypted::EncryptedEventContent, encryption::EncryptionEventContent},
|
||||||
AnyMessageEventContent, EventContent,
|
AnyMessageEventContent, EventContent,
|
||||||
},
|
},
|
||||||
identifiers::{DeviceId, EventEncryptionAlgorithm, RoomId},
|
identifiers::{DeviceIdBox, EventEncryptionAlgorithm, RoomId, UserId},
|
||||||
instant::Instant,
|
instant::Instant,
|
||||||
locks::Mutex,
|
locks::Mutex,
|
||||||
};
|
};
|
||||||
|
@ -92,7 +93,7 @@ impl From<&EncryptionEventContent> for EncryptionSettings {
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct OutboundGroupSession {
|
pub struct OutboundGroupSession {
|
||||||
inner: Arc<Mutex<OlmOutboundGroupSession>>,
|
inner: Arc<Mutex<OlmOutboundGroupSession>>,
|
||||||
device_id: Arc<Box<DeviceId>>,
|
device_id: Arc<DeviceIdBox>,
|
||||||
account_identity_keys: Arc<IdentityKeys>,
|
account_identity_keys: Arc<IdentityKeys>,
|
||||||
session_id: Arc<String>,
|
session_id: Arc<String>,
|
||||||
room_id: Arc<RoomId>,
|
room_id: Arc<RoomId>,
|
||||||
|
@ -100,6 +101,8 @@ pub struct OutboundGroupSession {
|
||||||
message_count: Arc<AtomicU64>,
|
message_count: Arc<AtomicU64>,
|
||||||
shared: Arc<AtomicBool>,
|
shared: Arc<AtomicBool>,
|
||||||
settings: Arc<EncryptionSettings>,
|
settings: Arc<EncryptionSettings>,
|
||||||
|
shared_with_set: Arc<DashSet<UserId>>,
|
||||||
|
to_share_with_set: Arc<DashSet<UserId>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl OutboundGroupSession {
|
impl OutboundGroupSession {
|
||||||
|
@ -118,15 +121,18 @@ impl OutboundGroupSession {
|
||||||
///
|
///
|
||||||
/// * `settings` - Settings determining the algorithm and rotation period of
|
/// * `settings` - Settings determining the algorithm and rotation period of
|
||||||
/// the outbound group session.
|
/// the outbound group session.
|
||||||
pub fn new(
|
pub fn new<'a>(
|
||||||
device_id: Arc<Box<DeviceId>>,
|
device_id: Arc<DeviceIdBox>,
|
||||||
identity_keys: Arc<IdentityKeys>,
|
identity_keys: Arc<IdentityKeys>,
|
||||||
room_id: &RoomId,
|
room_id: &RoomId,
|
||||||
settings: EncryptionSettings,
|
settings: EncryptionSettings,
|
||||||
|
users_to_share_with: impl Iterator<Item = &'a UserId>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let session = OlmOutboundGroupSession::new();
|
let session = OlmOutboundGroupSession::new();
|
||||||
let session_id = session.session_id();
|
let session_id = session.session_id();
|
||||||
|
|
||||||
|
let users_to_share_with = users_to_share_with.cloned().collect();
|
||||||
|
|
||||||
OutboundGroupSession {
|
OutboundGroupSession {
|
||||||
inner: Arc::new(Mutex::new(session)),
|
inner: Arc::new(Mutex::new(session)),
|
||||||
room_id: Arc::new(room_id.to_owned()),
|
room_id: Arc::new(room_id.to_owned()),
|
||||||
|
@ -137,9 +143,15 @@ impl OutboundGroupSession {
|
||||||
message_count: Arc::new(AtomicU64::new(0)),
|
message_count: Arc::new(AtomicU64::new(0)),
|
||||||
shared: Arc::new(AtomicBool::new(false)),
|
shared: Arc::new(AtomicBool::new(false)),
|
||||||
settings: Arc::new(settings),
|
settings: Arc::new(settings),
|
||||||
|
shared_with_set: Arc::new(DashSet::new()),
|
||||||
|
to_share_with_set: Arc::new(users_to_share_with),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn users_to_share_with(&self) -> impl Iterator<Item = RefMulti<'_, UserId>> + '_ {
|
||||||
|
self.to_share_with_set.iter()
|
||||||
|
}
|
||||||
|
|
||||||
/// Encrypt the given plaintext using this session.
|
/// Encrypt the given plaintext using this session.
|
||||||
///
|
///
|
||||||
/// Returns the encrypted ciphertext.
|
/// Returns the encrypted ciphertext.
|
||||||
|
|
|
@ -194,7 +194,7 @@ pub(crate) mod test {
|
||||||
let room_id = room_id!("!test:localhost");
|
let room_id = room_id!("!test:localhost");
|
||||||
|
|
||||||
let (outbound, _) = alice
|
let (outbound, _) = alice
|
||||||
.create_group_session_pair(&room_id, Default::default())
|
.create_group_session_pair_with_defaults(&room_id)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
@ -230,7 +230,7 @@ pub(crate) mod test {
|
||||||
let room_id = room_id!("!test:localhost");
|
let room_id = room_id!("!test:localhost");
|
||||||
|
|
||||||
let (_, inbound) = alice
|
let (_, inbound) = alice
|
||||||
.create_group_session_pair(&room_id, Default::default())
|
.create_group_session_pair_with_defaults(&room_id)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
|
|
@ -265,7 +265,7 @@ mod test {
|
||||||
let room_id = room_id!("!test:localhost");
|
let room_id = room_id!("!test:localhost");
|
||||||
|
|
||||||
let (outbound, _) = account
|
let (outbound, _) = account
|
||||||
.create_group_session_pair(&room_id, Default::default())
|
.create_group_session_pair_with_defaults(&room_id)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
|
|
@ -219,7 +219,7 @@ mod test {
|
||||||
let room_id = room_id!("!test:localhost");
|
let room_id = room_id!("!test:localhost");
|
||||||
|
|
||||||
let (outbound, _) = account
|
let (outbound, _) = account
|
||||||
.create_group_session_pair(&room_id, Default::default())
|
.create_group_session_pair_with_defaults(&room_id)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let inbound = InboundGroupSession::new(
|
let inbound = InboundGroupSession::new(
|
||||||
|
|
Loading…
Reference in New Issue