crypto: Remember the users that received the outbound group session.

master
Damir Jelić 2020-09-18 18:55:17 +02:00
parent 562bb5aee3
commit 5b0457dad0
9 changed files with 44 additions and 20 deletions

View File

@ -300,7 +300,7 @@ mod test {
let room_id = room_id!("!test:localhost"); let room_id = room_id!("!test:localhost");
machine machine
.create_outnbound_group_session_with_defaults(&room_id) .create_outbound_group_session_with_defaults(&room_id)
.await .await
.unwrap(); .unwrap();
let export = machine let export = machine

View File

@ -374,7 +374,7 @@ mod test {
let account = account(); let account = account();
let (_, session) = account let (_, session) = account
.create_group_session_pair(&room_id(), Default::default()) .create_group_session_pair_with_defaults(&room_id())
.await .await
.unwrap(); .unwrap();
@ -415,7 +415,7 @@ mod test {
let account = account(); let account = account();
let (_, session) = account let (_, session) = account
.create_group_session_pair(&room_id(), Default::default()) .create_group_session_pair_with_defaults(&room_id())
.await .await
.unwrap(); .unwrap();
machine machine

View File

@ -1026,10 +1026,11 @@ impl OlmMachine {
&self, &self,
room_id: &RoomId, room_id: &RoomId,
settings: EncryptionSettings, settings: EncryptionSettings,
users_to_share_with: impl Iterator<Item = &UserId>,
) -> OlmResult<()> { ) -> OlmResult<()> {
let (outbound, inbound) = self let (outbound, inbound) = self
.account .account
.create_group_session_pair(room_id, settings) .create_group_session_pair(room_id, settings, users_to_share_with)
.await .await
.map_err(|_| EventError::UnsupportedAlgorithm)?; .map_err(|_| EventError::UnsupportedAlgorithm)?;
@ -1042,11 +1043,11 @@ impl OlmMachine {
} }
#[cfg(test)] #[cfg(test)]
pub(crate) async fn create_outnbound_group_session_with_defaults( pub(crate) async fn create_outbound_group_session_with_defaults(
&self, &self,
room_id: &RoomId, room_id: &RoomId,
) -> OlmResult<()> { ) -> OlmResult<()> {
self.create_outbound_group_session(room_id, EncryptionSettings::default()) self.create_outbound_group_session(room_id, EncryptionSettings::default(), [].iter())
.await .await
} }
@ -1143,7 +1144,7 @@ impl OlmMachine {
users: impl Iterator<Item = &UserId>, users: impl Iterator<Item = &UserId>,
encryption_settings: impl Into<EncryptionSettings>, encryption_settings: impl Into<EncryptionSettings>,
) -> OlmResult<Vec<ToDeviceRequest>> { ) -> OlmResult<Vec<ToDeviceRequest>> {
self.create_outbound_group_session(room_id, encryption_settings.into()) self.create_outbound_group_session(room_id, encryption_settings.into(), users)
.await?; .await?;
let session = self.outbound_group_sessions.get(room_id).unwrap(); let session = self.outbound_group_sessions.get(room_id).unwrap();
@ -1159,8 +1160,8 @@ impl OlmMachine {
let mut devices = Vec::new(); let mut devices = Vec::new();
for user_id in users { for user_id in session.users_to_share_with() {
for device in self.get_user_devices(user_id).await?.devices() { for device in self.get_user_devices(&user_id).await?.devices() {
if !device.is_blacklisted() { if !device.is_blacklisted() {
devices.push(device.clone()); devices.push(device.clone());
} }
@ -1928,7 +1929,7 @@ pub(crate) mod test {
let room_id = room_id!("!test:example.org"); let room_id = room_id!("!test:example.org");
machine machine
.create_outnbound_group_session_with_defaults(&room_id) .create_outbound_group_session_with_defaults(&room_id)
.await .await
.unwrap(); .unwrap();
assert!(machine.outbound_group_sessions.get(&room_id).is_some()); assert!(machine.outbound_group_sessions.get(&room_id).is_some());

View File

@ -580,6 +580,7 @@ impl Account {
&self, &self,
room_id: &RoomId, room_id: &RoomId,
settings: EncryptionSettings, settings: EncryptionSettings,
users_to_share_with: impl Iterator<Item = &UserId>,
) -> Result<(OutboundGroupSession, InboundGroupSession), ()> { ) -> Result<(OutboundGroupSession, InboundGroupSession), ()> {
if settings.algorithm != EventEncryptionAlgorithm::MegolmV1AesSha2 { if settings.algorithm != EventEncryptionAlgorithm::MegolmV1AesSha2 {
return Err(()); return Err(());
@ -590,6 +591,7 @@ impl Account {
self.identity_keys.clone(), self.identity_keys.clone(),
room_id, room_id,
settings, settings,
users_to_share_with,
); );
let identity_keys = self.identity_keys(); let identity_keys = self.identity_keys();
@ -606,6 +608,15 @@ impl Account {
Ok((outbound, inbound)) Ok((outbound, inbound))
} }
#[cfg(test)]
pub(crate) async fn create_group_session_pair_with_defaults(
&self,
room_id: &RoomId,
) -> Result<(OutboundGroupSession, InboundGroupSession), ()> {
self.create_group_session_pair(room_id, EncryptionSettings::default(), [].iter())
.await
}
} }
impl PartialEq for Account { impl PartialEq for Account {

View File

@ -147,7 +147,7 @@ mod test {
let account = Account::new(&user_id!("@alice:example.org"), "DEVICEID".into()); let account = Account::new(&user_id!("@alice:example.org"), "DEVICEID".into());
let (session, _) = account let (session, _) = account
.create_group_session_pair(&room_id!("!test_room:example.org"), settings) .create_group_session_pair(&room_id!("!test_room:example.org"), settings, [].iter())
.await .await
.unwrap(); .unwrap();
@ -165,7 +165,7 @@ mod test {
}; };
let (mut session, _) = account let (mut session, _) = account
.create_group_session_pair(&room_id!("!test_room:example.org"), settings) .create_group_session_pair(&room_id!("!test_room:example.org"), settings, [].iter())
.await .await
.unwrap(); .unwrap();

View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use dashmap::{setref::multiple::RefMulti, DashSet};
use std::{ use std::{
cmp::min, cmp::min,
fmt, fmt,
@ -27,7 +28,7 @@ use matrix_sdk_common::{
room::{encrypted::EncryptedEventContent, encryption::EncryptionEventContent}, room::{encrypted::EncryptedEventContent, encryption::EncryptionEventContent},
AnyMessageEventContent, EventContent, AnyMessageEventContent, EventContent,
}, },
identifiers::{DeviceId, EventEncryptionAlgorithm, RoomId}, identifiers::{DeviceIdBox, EventEncryptionAlgorithm, RoomId, UserId},
instant::Instant, instant::Instant,
locks::Mutex, locks::Mutex,
}; };
@ -92,7 +93,7 @@ impl From<&EncryptionEventContent> for EncryptionSettings {
#[derive(Clone)] #[derive(Clone)]
pub struct OutboundGroupSession { pub struct OutboundGroupSession {
inner: Arc<Mutex<OlmOutboundGroupSession>>, inner: Arc<Mutex<OlmOutboundGroupSession>>,
device_id: Arc<Box<DeviceId>>, device_id: Arc<DeviceIdBox>,
account_identity_keys: Arc<IdentityKeys>, account_identity_keys: Arc<IdentityKeys>,
session_id: Arc<String>, session_id: Arc<String>,
room_id: Arc<RoomId>, room_id: Arc<RoomId>,
@ -100,6 +101,8 @@ pub struct OutboundGroupSession {
message_count: Arc<AtomicU64>, message_count: Arc<AtomicU64>,
shared: Arc<AtomicBool>, shared: Arc<AtomicBool>,
settings: Arc<EncryptionSettings>, settings: Arc<EncryptionSettings>,
shared_with_set: Arc<DashSet<UserId>>,
to_share_with_set: Arc<DashSet<UserId>>,
} }
impl OutboundGroupSession { impl OutboundGroupSession {
@ -118,15 +121,18 @@ impl OutboundGroupSession {
/// ///
/// * `settings` - Settings determining the algorithm and rotation period of /// * `settings` - Settings determining the algorithm and rotation period of
/// the outbound group session. /// the outbound group session.
pub fn new( pub fn new<'a>(
device_id: Arc<Box<DeviceId>>, device_id: Arc<DeviceIdBox>,
identity_keys: Arc<IdentityKeys>, identity_keys: Arc<IdentityKeys>,
room_id: &RoomId, room_id: &RoomId,
settings: EncryptionSettings, settings: EncryptionSettings,
users_to_share_with: impl Iterator<Item = &'a UserId>,
) -> Self { ) -> Self {
let session = OlmOutboundGroupSession::new(); let session = OlmOutboundGroupSession::new();
let session_id = session.session_id(); let session_id = session.session_id();
let users_to_share_with = users_to_share_with.cloned().collect();
OutboundGroupSession { OutboundGroupSession {
inner: Arc::new(Mutex::new(session)), inner: Arc::new(Mutex::new(session)),
room_id: Arc::new(room_id.to_owned()), room_id: Arc::new(room_id.to_owned()),
@ -137,9 +143,15 @@ impl OutboundGroupSession {
message_count: Arc::new(AtomicU64::new(0)), message_count: Arc::new(AtomicU64::new(0)),
shared: Arc::new(AtomicBool::new(false)), shared: Arc::new(AtomicBool::new(false)),
settings: Arc::new(settings), settings: Arc::new(settings),
shared_with_set: Arc::new(DashSet::new()),
to_share_with_set: Arc::new(users_to_share_with),
} }
} }
pub(crate) fn users_to_share_with(&self) -> impl Iterator<Item = RefMulti<'_, UserId>> + '_ {
self.to_share_with_set.iter()
}
/// Encrypt the given plaintext using this session. /// Encrypt the given plaintext using this session.
/// ///
/// Returns the encrypted ciphertext. /// Returns the encrypted ciphertext.

View File

@ -194,7 +194,7 @@ pub(crate) mod test {
let room_id = room_id!("!test:localhost"); let room_id = room_id!("!test:localhost");
let (outbound, _) = alice let (outbound, _) = alice
.create_group_session_pair(&room_id, Default::default()) .create_group_session_pair_with_defaults(&room_id)
.await .await
.unwrap(); .unwrap();
@ -230,7 +230,7 @@ pub(crate) mod test {
let room_id = room_id!("!test:localhost"); let room_id = room_id!("!test:localhost");
let (_, inbound) = alice let (_, inbound) = alice
.create_group_session_pair(&room_id, Default::default()) .create_group_session_pair_with_defaults(&room_id)
.await .await
.unwrap(); .unwrap();

View File

@ -265,7 +265,7 @@ mod test {
let room_id = room_id!("!test:localhost"); let room_id = room_id!("!test:localhost");
let (outbound, _) = account let (outbound, _) = account
.create_group_session_pair(&room_id, Default::default()) .create_group_session_pair_with_defaults(&room_id)
.await .await
.unwrap(); .unwrap();

View File

@ -219,7 +219,7 @@ mod test {
let room_id = room_id!("!test:localhost"); let room_id = room_id!("!test:localhost");
let (outbound, _) = account let (outbound, _) = account
.create_group_session_pair(&room_id, Default::default()) .create_group_session_pair_with_defaults(&room_id)
.await .await
.unwrap(); .unwrap();
let inbound = InboundGroupSession::new( let inbound = InboundGroupSession::new(