diff --git a/matrix_sdk_crypto/src/file_encryption/key_export.rs b/matrix_sdk_crypto/src/file_encryption/key_export.rs index e1cde309..dcbdd440 100644 --- a/matrix_sdk_crypto/src/file_encryption/key_export.rs +++ b/matrix_sdk_crypto/src/file_encryption/key_export.rs @@ -300,7 +300,7 @@ mod test { let room_id = room_id!("!test:localhost"); machine - .create_outnbound_group_session_with_defaults(&room_id) + .create_outbound_group_session_with_defaults(&room_id) .await .unwrap(); let export = machine diff --git a/matrix_sdk_crypto/src/identities/manager.rs b/matrix_sdk_crypto/src/identities/manager.rs new file mode 100644 index 00000000..6779fd1f --- /dev/null +++ b/matrix_sdk_crypto/src/identities/manager.rs @@ -0,0 +1,349 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{ + collections::{BTreeMap, HashSet}, + convert::TryFrom, + sync::Arc, +}; +use tracing::{info, trace, warn}; + +use matrix_sdk_common::{ + api::r0::keys::get_keys::Response as KeysQueryResponse, + encryption::DeviceKeys, + identifiers::{DeviceId, DeviceIdBox, UserId}, +}; + +use crate::{ + error::OlmResult, + identities::{ + MasterPubkey, OwnUserIdentity, ReadOnlyDevice, SelfSigningPubkey, UserIdentities, + UserIdentity, UserSigningPubkey, + }, + requests::KeysQueryRequest, + store::{Result as StoreResult, Store}, +}; + +#[derive(Debug, Clone)] +pub(crate) struct IdentityManager { + user_id: Arc, + device_id: Arc, + store: Store, +} + +impl IdentityManager { + pub fn new(user_id: Arc, device_id: Arc, store: Store) -> Self { + IdentityManager { + user_id, + device_id, + store, + } + } + + fn user_id(&self) -> &UserId { + &self.user_id + } + + fn device_id(&self) -> &DeviceId { + &self.device_id + } + + /// Receive a successful keys query response. + /// + /// Returns a list of devices newly discovered devices and devices that + /// changed. + /// + /// # Arguments + /// + /// * `response` - The keys query response of the request that the client + /// performed. + pub async fn receive_keys_query_response( + &self, + response: &KeysQueryResponse, + ) -> OlmResult<(Vec, Vec)> { + // TODO create a enum that tells us how the device/identity changed, + // e.g. new/deleted/display name change. + // + // TODO create a struct that will hold the device/identity and the + // change enum and return the struct. + // + // TODO once outbound group sessions hold on to the set of users that + // received the session, invalidate the session if a user device + // got added/deleted. + let changed_devices = self + .handle_devices_from_key_query(&response.device_keys) + .await?; + self.store.save_devices(&changed_devices).await?; + let changed_identities = self.handle_cross_singing_keys(response).await?; + self.store.save_user_identities(&changed_identities).await?; + + Ok((changed_devices, changed_identities)) + } + + /// Handle the device keys part of a key query response. + /// + /// # Arguments + /// + /// * `device_keys_map` - A map holding the device keys of the users for + /// which the key query was done. + /// + /// Returns a list of devices that changed. Changed here means either + /// they are new, one of their properties has changed or they got deleted. + async fn handle_devices_from_key_query( + &self, + device_keys_map: &BTreeMap>, + ) -> StoreResult> { + let mut changed_devices = Vec::new(); + + for (user_id, device_map) in device_keys_map { + // TODO move this out into the handle keys query response method + // since we might fail handle the new device at any point here or + // when updating the user identities. + self.store.update_tracked_user(user_id, false).await?; + + for (device_id, device_keys) in device_map.iter() { + // We don't need our own device in the device store. + if user_id == self.user_id() && &**device_id == self.device_id() { + continue; + } + + if user_id != &device_keys.user_id || device_id != &device_keys.device_id { + warn!( + "Mismatch in device keys payload of device {} from user {}", + device_keys.device_id, device_keys.user_id + ); + continue; + } + + let device = self.store.get_device(&user_id, device_id).await?; + + let device = if let Some(mut device) = device { + if let Err(e) = device.update_device(device_keys) { + warn!( + "Failed to update the device keys for {} {}: {:?}", + user_id, device_id, e + ); + continue; + } + device + } else { + let device = match ReadOnlyDevice::try_from(device_keys) { + Ok(d) => d, + Err(e) => { + warn!( + "Failed to create a new device for {} {}: {:?}", + user_id, device_id, e + ); + continue; + } + }; + info!("Adding a new device to the device store {:?}", device); + device + }; + + changed_devices.push(device); + } + + let current_devices: HashSet<&DeviceId> = + device_map.keys().map(|id| id.as_ref()).collect(); + let stored_devices = self.store.get_user_devices(&user_id).await.unwrap(); + let stored_devices_set: HashSet<&DeviceId> = stored_devices.keys().collect(); + + let deleted_devices = stored_devices_set.difference(¤t_devices); + + for device_id in deleted_devices { + if let Some(device) = stored_devices.get(device_id) { + device.mark_as_deleted(); + self.store.delete_device(device).await?; + } + } + } + + Ok(changed_devices) + } + + /// Handle the device keys part of a key query response. + /// + /// # Arguments + /// + /// * `response` - The keys query response. + /// + /// Returns a list of identities that changed. Changed here means either + /// they are new, one of their properties has changed or they got deleted. + async fn handle_cross_singing_keys( + &self, + response: &KeysQueryResponse, + ) -> StoreResult> { + let mut changed = Vec::new(); + + for (user_id, master_key) in &response.master_keys { + let master_key = MasterPubkey::from(master_key); + + let self_signing = if let Some(s) = response.self_signing_keys.get(user_id) { + SelfSigningPubkey::from(s) + } else { + warn!( + "User identity for user {} didn't contain a self signing pubkey", + user_id + ); + continue; + }; + + let identity = if let Some(mut i) = self.store.get_user_identity(user_id).await? { + match &mut i { + UserIdentities::Own(ref mut identity) => { + let user_signing = if let Some(s) = response.user_signing_keys.get(user_id) + { + UserSigningPubkey::from(s) + } else { + warn!( + "User identity for our own user {} didn't \ + contain a user signing pubkey", + user_id + ); + continue; + }; + + identity + .update(master_key, self_signing, user_signing) + .map(|_| i) + } + UserIdentities::Other(ref mut identity) => { + identity.update(master_key, self_signing).map(|_| i) + } + } + } else if user_id == self.user_id() { + if let Some(s) = response.user_signing_keys.get(user_id) { + let user_signing = UserSigningPubkey::from(s); + + if master_key.user_id() != user_id + || self_signing.user_id() != user_id + || user_signing.user_id() != user_id + { + warn!( + "User id mismatch in one of the cross signing keys for user {}", + user_id + ); + continue; + } + + OwnUserIdentity::new(master_key, self_signing, user_signing) + .map(UserIdentities::Own) + } else { + warn!( + "User identity for our own user {} didn't contain a \ + user signing pubkey", + user_id + ); + continue; + } + } else if master_key.user_id() != user_id || self_signing.user_id() != user_id { + warn!( + "User id mismatch in one of the cross signing keys for user {}", + user_id + ); + continue; + } else { + UserIdentity::new(master_key, self_signing).map(UserIdentities::Other) + }; + + match identity { + Ok(i) => { + trace!( + "Updated or created new user identity for {}: {:?}", + user_id, + i + ); + changed.push(i); + } + Err(e) => { + warn!( + "Couldn't update or create new user identity for {}: {:?}", + user_id, e + ); + continue; + } + } + } + + Ok(changed) + } + + /// Get a key query request if one is needed. + /// + /// Returns a key query reqeust if the client should query E2E keys, + /// otherwise None. + /// + /// The response of a successful key query requests needs to be passed to + /// the [`OlmMachine`] with the [`receive_keys_query_response`]. + /// + /// [`OlmMachine`]: struct.OlmMachine.html + /// [`receive_keys_query_response`]: #method.receive_keys_query_response + pub async fn users_for_key_query(&self) -> Option { + let mut users = self.store.users_for_key_query(); + + if users.is_empty() { + None + } else { + let mut device_keys: BTreeMap>> = BTreeMap::new(); + + for user in users.drain() { + device_keys.insert(user, Vec::new()); + } + + Some(KeysQueryRequest::new(device_keys)) + } + } + + /// Mark that the given user has changed his devices. + /// + /// This will queue up the given user for a key query. + /// + /// Note: The user already needs to be tracked for it to be queued up for a + /// key query. + /// + /// Returns true if the user was queued up for a key query, false otherwise. + pub async fn mark_user_as_changed(&self, user_id: &UserId) -> StoreResult { + if self.store.is_user_tracked(user_id) { + self.store.update_tracked_user(user_id, true).await?; + Ok(true) + } else { + Ok(false) + } + } + + /// Update the tracked users. + /// + /// # Arguments + /// + /// * `users` - An iterator over user ids that should be marked for + /// tracking. + /// + /// This will mark users that weren't seen before for a key query and + /// tracking. + /// + /// If the user is already known to the Olm machine it will not be + /// considered for a key query. + pub async fn update_tracked_users(&self, users: impl IntoIterator) { + for user in users { + if self.store.is_user_tracked(user) { + continue; + } + + if let Err(e) = self.store.update_tracked_user(user, true).await { + warn!("Error storing users for tracking {}", e); + } + } + } +} diff --git a/matrix_sdk_crypto/src/identities/mod.rs b/matrix_sdk_crypto/src/identities/mod.rs index c1941b50..c1392eba 100644 --- a/matrix_sdk_crypto/src/identities/mod.rs +++ b/matrix_sdk_crypto/src/identities/mod.rs @@ -41,9 +41,11 @@ //! Both identity sets need to reqularly fetched from the server using the //! `/keys/query` API call. pub(crate) mod device; +mod manager; pub(crate) mod user; pub use device::{Device, LocalTrust, ReadOnlyDevice, UserDevices}; +pub(crate) use manager::IdentityManager; pub use user::{ MasterPubkey, OwnUserIdentity, SelfSigningPubkey, UserIdentities, UserIdentity, UserSigningPubkey, diff --git a/matrix_sdk_crypto/src/key_request.rs b/matrix_sdk_crypto/src/key_request.rs index 2bcc5214..5b900ce4 100644 --- a/matrix_sdk_crypto/src/key_request.rs +++ b/matrix_sdk_crypto/src/key_request.rs @@ -374,7 +374,7 @@ mod test { let account = account(); let (_, session) = account - .create_group_session_pair(&room_id(), Default::default()) + .create_group_session_pair_with_defaults(&room_id()) .await .unwrap(); @@ -415,7 +415,7 @@ mod test { let account = account(); let (_, session) = account - .create_group_session_pair(&room_id(), Default::default()) + .create_group_session_pair_with_defaults(&room_id()) .await .unwrap(); machine diff --git a/matrix_sdk_crypto/src/machine.rs b/matrix_sdk_crypto/src/machine.rs index 69c00dc8..6c0bffcc 100644 --- a/matrix_sdk_crypto/src/machine.rs +++ b/matrix_sdk_crypto/src/machine.rs @@ -14,13 +14,7 @@ #[cfg(feature = "sqlite_cryptostore")] use std::path::Path; -use std::{ - collections::{BTreeMap, HashSet}, - convert::{TryFrom, TryInto}, - mem, - sync::Arc, - time::Duration, -}; +use std::{collections::BTreeMap, convert::TryInto, mem, sync::Arc, time::Duration}; use dashmap::DashMap; use serde_json::Value; @@ -37,7 +31,6 @@ use matrix_sdk_common::{ to_device::DeviceIdOrAllDevices, }, assign, - encryption::DeviceKeys, events::{ room::encrypted::EncryptedEventContent, room_key::RoomKeyEventContent, room_key_request::RoomKeyRequestEventContent, AnyMessageEventContent, AnySyncRoomEvent, @@ -54,16 +47,13 @@ use matrix_sdk_common::{ use super::store::sqlite::SqliteStore; use super::{ error::{EventError, MegolmError, MegolmResult, OlmError, OlmResult}, - identities::{ - Device, MasterPubkey, OwnUserIdentity, ReadOnlyDevice, SelfSigningPubkey, UserDevices, - UserIdentities, UserIdentity, UserSigningPubkey, - }, + identities::{Device, IdentityManager, ReadOnlyDevice, UserDevices, UserIdentities}, key_request::KeyRequestMachine, olm::{ Account, EncryptionSettings, ExportedRoomKey, GroupSessionKey, IdentityKeys, InboundGroupSession, OlmMessage, OutboundGroupSession, }, - requests::{IncomingResponse, KeysQueryRequest, OutgoingRequest, ToDeviceRequest}, + requests::{IncomingResponse, OutgoingRequest, ToDeviceRequest}, store::{CryptoStore, MemoryStore, Result as StoreResult, Store}, verification::{Sas, VerificationMachine}, }; @@ -90,6 +80,9 @@ pub struct OlmMachine { /// The state machine that is responsible to handle outgoing and incoming /// key requests. key_request_machine: KeyRequestMachine, + /// State machine handling public user identities and devices, keeping track + /// of when a key query needs to be done and handling one. + identity_manager: IdentityManager, } #[cfg(not(tarpaulin_include))] @@ -118,19 +111,36 @@ impl OlmMachine { /// * `device_id` - The unique id of the device that owns this machine. pub fn new(user_id: &UserId, device_id: &DeviceId) -> Self { let store: Box = Box::new(MemoryStore::new()); + let device_id: DeviceIdBox = device_id.into(); + let account = Account::new(&user_id, &device_id); + + OlmMachine::new_helper(user_id, device_id, store, account) + } + + fn new_helper( + user_id: &UserId, + device_id: DeviceIdBox, + store: Box, + account: Account, + ) -> Self { let store = Store::new(store); - let account = Account::new(user_id, device_id); + let verification_machine = VerificationMachine::new(account.clone(), store.clone()); let user_id = Arc::new(user_id.clone()); - let device_id: Arc = Arc::new(device_id.into()); + let device_id: Arc = Arc::new(device_id); + let key_request_machine = + KeyRequestMachine::new(user_id.clone(), device_id.clone(), store.clone()); + let identity_manager = + IdentityManager::new(user_id.clone(), device_id.clone(), store.clone()); OlmMachine { - user_id: user_id.clone(), - device_id: device_id.clone(), - account: account.clone(), - store: store.clone(), + user_id, + device_id, + account, + store, outbound_group_sessions: Arc::new(DashMap::new()), - verification_machine: VerificationMachine::new(account, store.clone()), - key_request_machine: KeyRequestMachine::new(user_id, device_id, store), + verification_machine, + key_request_machine, + identity_manager, } } @@ -155,7 +165,7 @@ impl OlmMachine { /// [`Cryptostore`]: trait.CryptoStore.html pub async fn new_with_store( user_id: UserId, - device_id: Box, + device_id: DeviceIdBox, store: Box, ) -> StoreResult { let account = match store.load_account().await? { @@ -169,22 +179,7 @@ impl OlmMachine { } }; - let store = Store::new(store); - let verification_machine = VerificationMachine::new(account.clone(), store.clone()); - let user_id = Arc::new(user_id.clone()); - let device_id: Arc = Arc::new(device_id); - let key_request_machine = - KeyRequestMachine::new(user_id.clone(), device_id.clone(), store.clone()); - - Ok(OlmMachine { - user_id, - device_id, - account, - store, - outbound_group_sessions: Arc::new(DashMap::new()), - verification_machine, - key_request_machine, - }) + Ok(OlmMachine::new_helper(&user_id, device_id, store, account)) } /// Create a new machine with the default crypto store. @@ -243,10 +238,15 @@ impl OlmMachine { requests.push(r); } - if let Some(r) = self.users_for_key_query().await.map(|r| OutgoingRequest { - request_id: Uuid::new_v4(), - request: Arc::new(r.into()), - }) { + if let Some(r) = + self.identity_manager + .users_for_key_query() + .await + .map(|r| OutgoingRequest { + request_id: Uuid::new_v4(), + request: Arc::new(r.into()), + }) + { requests.push(r); } @@ -491,195 +491,6 @@ impl OlmMachine { Ok(()) } - /// Handle the device keys part of a key query response. - /// - /// # Arguments - /// - /// * `device_keys_map` - A map holding the device keys of the users for - /// which the key query was done. - /// - /// Returns a list of devices that changed. Changed here means either - /// they are new, one of their properties has changed or they got deleted. - async fn handle_devices_from_key_query( - &self, - device_keys_map: &BTreeMap, DeviceKeys>>, - ) -> StoreResult> { - let mut changed_devices = Vec::new(); - - for (user_id, device_map) in device_keys_map { - // TODO move this out into the handle keys query response method - // since we might fail handle the new device at any point here or - // when updating the user identities. - self.store.update_tracked_user(user_id, false).await?; - - for (device_id, device_keys) in device_map.iter() { - // We don't need our own device in the device store. - if user_id == self.user_id() && &**device_id == self.device_id() { - continue; - } - - if user_id != &device_keys.user_id || device_id != &device_keys.device_id { - warn!( - "Mismatch in device keys payload of device {} from user {}", - device_keys.device_id, device_keys.user_id - ); - continue; - } - - let device = self.store.get_device(&user_id, device_id).await?; - - let device = if let Some(mut device) = device { - if let Err(e) = device.update_device(device_keys) { - warn!( - "Failed to update the device keys for {} {}: {:?}", - user_id, device_id, e - ); - continue; - } - device - } else { - let device = match ReadOnlyDevice::try_from(device_keys) { - Ok(d) => d, - Err(e) => { - warn!( - "Failed to create a new device for {} {}: {:?}", - user_id, device_id, e - ); - continue; - } - }; - info!("Adding a new device to the device store {:?}", device); - device - }; - - changed_devices.push(device); - } - - let current_devices: HashSet<&DeviceId> = - device_map.keys().map(|id| id.as_ref()).collect(); - let stored_devices = self.store.get_user_devices(&user_id).await.unwrap(); - let stored_devices_set: HashSet<&DeviceId> = stored_devices.keys().collect(); - - let deleted_devices = stored_devices_set.difference(¤t_devices); - - for device_id in deleted_devices { - if let Some(device) = stored_devices.get(device_id) { - device.mark_as_deleted(); - self.store.delete_device(device).await?; - } - } - } - - Ok(changed_devices) - } - - /// Handle the device keys part of a key query response. - /// - /// # Arguments - /// - /// * `response` - The keys query response. - /// - /// Returns a list of identities that changed. Changed here means either - /// they are new, one of their properties has changed or they got deleted. - async fn handle_cross_singing_keys( - &self, - response: &KeysQueryResponse, - ) -> StoreResult> { - let mut changed = Vec::new(); - - for (user_id, master_key) in &response.master_keys { - let master_key = MasterPubkey::from(master_key); - - let self_signing = if let Some(s) = response.self_signing_keys.get(user_id) { - SelfSigningPubkey::from(s) - } else { - warn!( - "User identity for user {} didn't contain a self signing pubkey", - user_id - ); - continue; - }; - - let identity = if let Some(mut i) = self.store.get_user_identity(user_id).await? { - match &mut i { - UserIdentities::Own(ref mut identity) => { - let user_signing = if let Some(s) = response.user_signing_keys.get(user_id) - { - UserSigningPubkey::from(s) - } else { - warn!( - "User identity for our own user {} didn't \ - contain a user signing pubkey", - user_id - ); - continue; - }; - - identity - .update(master_key, self_signing, user_signing) - .map(|_| i) - } - UserIdentities::Other(ref mut identity) => { - identity.update(master_key, self_signing).map(|_| i) - } - } - } else if user_id == self.user_id() { - if let Some(s) = response.user_signing_keys.get(user_id) { - let user_signing = UserSigningPubkey::from(s); - - if master_key.user_id() != user_id - || self_signing.user_id() != user_id - || user_signing.user_id() != user_id - { - warn!( - "User id mismatch in one of the cross signing keys for user {}", - user_id - ); - continue; - } - - OwnUserIdentity::new(master_key, self_signing, user_signing) - .map(UserIdentities::Own) - } else { - warn!( - "User identity for our own user {} didn't contain a \ - user signing pubkey", - user_id - ); - continue; - } - } else if master_key.user_id() != user_id || self_signing.user_id() != user_id { - warn!( - "User id mismatch in one of the cross signing keys for user {}", - user_id - ); - continue; - } else { - UserIdentity::new(master_key, self_signing).map(UserIdentities::Other) - }; - - match identity { - Ok(i) => { - trace!( - "Updated or created new user identity for {}: {:?}", - user_id, - i - ); - changed.push(i); - } - Err(e) => { - warn!( - "Couldn't update or create new user identity for {}: {:?}", - user_id, e - ); - continue; - } - } - } - - Ok(changed) - } - /// Receive a successful keys query response. /// /// Returns a list of devices newly discovered devices and devices that @@ -693,23 +504,9 @@ impl OlmMachine { &self, response: &KeysQueryResponse, ) -> OlmResult<(Vec, Vec)> { - // TODO create a enum that tells us how the device/identity changed, - // e.g. new/deleted/display name change. - // - // TODO create a struct that will hold the device/identity and the - // change enum and return the struct. - // - // TODO once outbound group sessions hold on to the set of users that - // received the session, invalidate the session if a user device - // got added/deleted. - let changed_devices = self - .handle_devices_from_key_query(&response.device_keys) - .await?; - self.store.save_devices(&changed_devices).await?; - let changed_identities = self.handle_cross_singing_keys(response).await?; - self.store.save_user_identities(&changed_identities).await?; - - Ok((changed_devices, changed_identities)) + self.identity_manager + .receive_keys_query_response(response) + .await } /// Get a request to upload E2EE keys to the server. @@ -1026,10 +823,11 @@ impl OlmMachine { &self, room_id: &RoomId, settings: EncryptionSettings, + users_to_share_with: impl Iterator, ) -> OlmResult<()> { let (outbound, inbound) = self .account - .create_group_session_pair(room_id, settings) + .create_group_session_pair(room_id, settings, users_to_share_with) .await .map_err(|_| EventError::UnsupportedAlgorithm)?; @@ -1042,11 +840,11 @@ impl OlmMachine { } #[cfg(test)] - pub(crate) async fn create_outnbound_group_session_with_defaults( + pub(crate) async fn create_outbound_group_session_with_defaults( &self, room_id: &RoomId, ) -> OlmResult<()> { - self.create_outbound_group_session(room_id, EncryptionSettings::default()) + self.create_outbound_group_session(room_id, EncryptionSettings::default(), [].iter()) .await } @@ -1143,7 +941,7 @@ impl OlmMachine { users: impl Iterator, encryption_settings: impl Into, ) -> OlmResult> { - self.create_outbound_group_session(room_id, encryption_settings.into()) + self.create_outbound_group_session(room_id, encryption_settings.into(), users) .await?; let session = self.outbound_group_sessions.get(room_id).unwrap(); @@ -1159,8 +957,8 @@ impl OlmMachine { let mut devices = Vec::new(); - for user_id in users { - for device in self.get_user_devices(user_id).await?.devices() { + for user_id in session.users_to_share_with() { + for device in self.get_user_devices(&user_id).await?.devices() { if !device.is_blacklisted() { devices.push(device.clone()); } @@ -1299,8 +1097,12 @@ impl OlmMachine { let count: u64 = one_time_key_count.map_or(0, |c| (*c).into()); self.update_key_count(count); + if let Err(e) = self.store.save_account(self.account.clone()).await { + error!("Error updating the one-time key count {:?}", e); + } + for user_id in &response.device_lists.changed { - if let Err(e) = self.mark_user_as_changed(&user_id).await { + if let Err(e) = self.identity_manager.mark_user_as_changed(&user_id).await { error!("Error marking a tracked user as changed {:?}", e); } } @@ -1389,23 +1191,6 @@ impl OlmMachine { Ok(decrypted_event) } - /// Mark that the given user has changed his devices. - /// - /// This will queue up the given user for a key query. - /// - /// Note: The user already needs to be tracked for it to be queued up for a - /// key query. - /// - /// Returns true if the user was queued up for a key query, false otherwise. - async fn mark_user_as_changed(&self, user_id: &UserId) -> StoreResult { - if self.store.is_user_tracked(user_id) { - self.store.update_tracked_user(user_id, true).await?; - Ok(true) - } else { - Ok(false) - } - } - /// Update the tracked users. /// /// # Arguments @@ -1419,41 +1204,7 @@ impl OlmMachine { /// If the user is already known to the Olm machine it will not be /// considered for a key query. pub async fn update_tracked_users(&self, users: impl IntoIterator) { - for user in users { - if self.store.is_user_tracked(user) { - continue; - } - - if let Err(e) = self.store.update_tracked_user(user, true).await { - warn!("Error storing users for tracking {}", e); - } - } - } - - /// Get a key query request if one is needed. - /// - /// Returns a key query reqeust if the client should query E2E keys, - /// otherwise None. - /// - /// The response of a successful key query requests needs to be passed to - /// the [`OlmMachine`] with the [`receive_keys_query_response`]. - /// - /// [`OlmMachine`]: struct.OlmMachine.html - /// [`receive_keys_query_response`]: #method.receive_keys_query_response - async fn users_for_key_query(&self) -> Option { - let mut users = self.store.users_for_key_query(); - - if users.is_empty() { - None - } else { - let mut device_keys: BTreeMap>> = BTreeMap::new(); - - for user in users.drain() { - device_keys.insert(user, Vec::new()); - } - - Some(KeysQueryRequest::new(device_keys)) - } + self.identity_manager.update_tracked_users(users).await } /// Get a specific device of a user. @@ -1928,7 +1679,7 @@ pub(crate) mod test { let room_id = room_id!("!test:example.org"); machine - .create_outnbound_group_session_with_defaults(&room_id) + .create_outbound_group_session_with_defaults(&room_id) .await .unwrap(); assert!(machine.outbound_group_sessions.get(&room_id).is_some()); diff --git a/matrix_sdk_crypto/src/olm/account.rs b/matrix_sdk_crypto/src/olm/account.rs index e46c7142..8be10151 100644 --- a/matrix_sdk_crypto/src/olm/account.rs +++ b/matrix_sdk_crypto/src/olm/account.rs @@ -580,6 +580,7 @@ impl Account { &self, room_id: &RoomId, settings: EncryptionSettings, + users_to_share_with: impl Iterator, ) -> Result<(OutboundGroupSession, InboundGroupSession), ()> { if settings.algorithm != EventEncryptionAlgorithm::MegolmV1AesSha2 { return Err(()); @@ -590,6 +591,7 @@ impl Account { self.identity_keys.clone(), room_id, settings, + users_to_share_with, ); let identity_keys = self.identity_keys(); @@ -606,6 +608,15 @@ impl Account { Ok((outbound, inbound)) } + + #[cfg(test)] + pub(crate) async fn create_group_session_pair_with_defaults( + &self, + room_id: &RoomId, + ) -> Result<(OutboundGroupSession, InboundGroupSession), ()> { + self.create_group_session_pair(room_id, EncryptionSettings::default(), [].iter()) + .await + } } impl PartialEq for Account { diff --git a/matrix_sdk_crypto/src/olm/group_sessions/mod.rs b/matrix_sdk_crypto/src/olm/group_sessions/mod.rs index b0e69599..a30125d2 100644 --- a/matrix_sdk_crypto/src/olm/group_sessions/mod.rs +++ b/matrix_sdk_crypto/src/olm/group_sessions/mod.rs @@ -147,7 +147,7 @@ mod test { let account = Account::new(&user_id!("@alice:example.org"), "DEVICEID".into()); let (session, _) = account - .create_group_session_pair(&room_id!("!test_room:example.org"), settings) + .create_group_session_pair(&room_id!("!test_room:example.org"), settings, [].iter()) .await .unwrap(); @@ -165,7 +165,7 @@ mod test { }; let (mut session, _) = account - .create_group_session_pair(&room_id!("!test_room:example.org"), settings) + .create_group_session_pair(&room_id!("!test_room:example.org"), settings, [].iter()) .await .unwrap(); diff --git a/matrix_sdk_crypto/src/olm/group_sessions/outbound.rs b/matrix_sdk_crypto/src/olm/group_sessions/outbound.rs index ac521272..11334eb2 100644 --- a/matrix_sdk_crypto/src/olm/group_sessions/outbound.rs +++ b/matrix_sdk_crypto/src/olm/group_sessions/outbound.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use dashmap::{setref::multiple::RefMulti, DashSet}; use std::{ cmp::min, fmt, @@ -27,7 +28,7 @@ use matrix_sdk_common::{ room::{encrypted::EncryptedEventContent, encryption::EncryptionEventContent}, AnyMessageEventContent, EventContent, }, - identifiers::{DeviceId, EventEncryptionAlgorithm, RoomId}, + identifiers::{DeviceIdBox, EventEncryptionAlgorithm, RoomId, UserId}, instant::Instant, locks::Mutex, }; @@ -92,7 +93,7 @@ impl From<&EncryptionEventContent> for EncryptionSettings { #[derive(Clone)] pub struct OutboundGroupSession { inner: Arc>, - device_id: Arc>, + device_id: Arc, account_identity_keys: Arc, session_id: Arc, room_id: Arc, @@ -100,6 +101,8 @@ pub struct OutboundGroupSession { message_count: Arc, shared: Arc, settings: Arc, + shared_with_set: Arc>, + to_share_with_set: Arc>, } impl OutboundGroupSession { @@ -118,15 +121,18 @@ impl OutboundGroupSession { /// /// * `settings` - Settings determining the algorithm and rotation period of /// the outbound group session. - pub fn new( - device_id: Arc>, + pub fn new<'a>( + device_id: Arc, identity_keys: Arc, room_id: &RoomId, settings: EncryptionSettings, + users_to_share_with: impl Iterator, ) -> Self { let session = OlmOutboundGroupSession::new(); let session_id = session.session_id(); + let users_to_share_with = users_to_share_with.cloned().collect(); + OutboundGroupSession { inner: Arc::new(Mutex::new(session)), room_id: Arc::new(room_id.to_owned()), @@ -137,9 +143,15 @@ impl OutboundGroupSession { message_count: Arc::new(AtomicU64::new(0)), shared: Arc::new(AtomicBool::new(false)), settings: Arc::new(settings), + shared_with_set: Arc::new(DashSet::new()), + to_share_with_set: Arc::new(users_to_share_with), } } + pub(crate) fn users_to_share_with(&self) -> impl Iterator> + '_ { + self.to_share_with_set.iter() + } + /// Encrypt the given plaintext using this session. /// /// Returns the encrypted ciphertext. diff --git a/matrix_sdk_crypto/src/olm/mod.rs b/matrix_sdk_crypto/src/olm/mod.rs index e88b4959..7a48b9fd 100644 --- a/matrix_sdk_crypto/src/olm/mod.rs +++ b/matrix_sdk_crypto/src/olm/mod.rs @@ -194,7 +194,7 @@ pub(crate) mod test { let room_id = room_id!("!test:localhost"); let (outbound, _) = alice - .create_group_session_pair(&room_id, Default::default()) + .create_group_session_pair_with_defaults(&room_id) .await .unwrap(); @@ -230,7 +230,7 @@ pub(crate) mod test { let room_id = room_id!("!test:localhost"); let (_, inbound) = alice - .create_group_session_pair(&room_id, Default::default()) + .create_group_session_pair_with_defaults(&room_id) .await .unwrap(); diff --git a/matrix_sdk_crypto/src/store/caches.rs b/matrix_sdk_crypto/src/store/caches.rs index cc56c100..eb66198d 100644 --- a/matrix_sdk_crypto/src/store/caches.rs +++ b/matrix_sdk_crypto/src/store/caches.rs @@ -265,7 +265,7 @@ mod test { let room_id = room_id!("!test:localhost"); let (outbound, _) = account - .create_group_session_pair(&room_id, Default::default()) + .create_group_session_pair_with_defaults(&room_id) .await .unwrap(); diff --git a/matrix_sdk_crypto/src/store/memorystore.rs b/matrix_sdk_crypto/src/store/memorystore.rs index 05746102..85e4a696 100644 --- a/matrix_sdk_crypto/src/store/memorystore.rs +++ b/matrix_sdk_crypto/src/store/memorystore.rs @@ -219,7 +219,7 @@ mod test { let room_id = room_id!("!test:localhost"); let (outbound, _) = account - .create_group_session_pair(&room_id, Default::default()) + .create_group_session_pair_with_defaults(&room_id) .await .unwrap(); let inbound = InboundGroupSession::new(