From 02c765f903a81444fc980c33056671db1a9249d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Damir=20Jeli=C4=87?= Date: Thu, 1 Oct 2020 16:31:24 +0200 Subject: [PATCH] crypto: Don't mark outbound group sessions automatically as shared. --- matrix_sdk/src/client.rs | 18 ++-- matrix_sdk_base/src/client.rs | 15 ++-- matrix_sdk_crypto/src/group_manager.rs | 38 ++++++--- matrix_sdk_crypto/src/key_request.rs | 5 +- matrix_sdk_crypto/src/machine.rs | 8 +- matrix_sdk_crypto/src/olm/account.rs | 4 +- .../src/olm/group_sessions/mod.rs | 4 +- .../src/olm/group_sessions/outbound.rs | 82 +++++++++++++++---- 8 files changed, 122 insertions(+), 52 deletions(-) diff --git a/matrix_sdk/src/client.rs b/matrix_sdk/src/client.rs index 3a33a60a..f9f02172 100644 --- a/matrix_sdk/src/client.rs +++ b/matrix_sdk/src/client.rs @@ -1303,10 +1303,13 @@ impl Client { } #[cfg(feature = "encryption")] - async fn send_to_device(&self, request: ToDeviceRequest) -> Result { + async fn send_to_device(&self, request: &ToDeviceRequest) -> Result { let txn_id_string = request.txn_id_string(); - let request = - RumaToDeviceRequest::new(request.event_type, &txn_id_string, request.messages); + let request = RumaToDeviceRequest::new( + request.event_type.clone(), + &txn_id_string, + request.messages.clone(), + ); self.send(request).await } @@ -1468,7 +1471,8 @@ impl Client { } } OutgoingRequests::ToDeviceRequest(request) => { - if let Ok(resp) = self.send_to_device(request.clone()).await { + // TODO remove this unwrap + if let Ok(resp) = self.send_to_device(&request).await { self.base_client .mark_request_as_sent(&r.request_id(), &resp) .await @@ -1551,7 +1555,11 @@ impl Client { .expect("Keys don't need to be uploaded"); for request in requests.drain(..) { - self.send_to_device(request).await?; + let response = self.send_to_device(&request).await?; + + self.base_client + .mark_request_as_sent(&request.txn_id, &response) + .await?; } Ok(()) diff --git a/matrix_sdk_base/src/client.rs b/matrix_sdk_base/src/client.rs index 97965999..9161e376 100644 --- a/matrix_sdk_base/src/client.rs +++ b/matrix_sdk_base/src/client.rs @@ -1304,7 +1304,7 @@ impl BaseClient { /// Get a to-device request that will share a group session for a room. #[cfg(feature = "encryption")] #[cfg_attr(feature = "docs", doc(cfg(encryption)))] - pub async fn share_group_session(&self, room_id: &RoomId) -> Result> { + pub async fn share_group_session(&self, room_id: &RoomId) -> Result>> { let room = self.get_joined_room(room_id).await.expect("No room found"); let olm = self.olm.lock().await; @@ -1881,18 +1881,19 @@ impl BaseClient { #[cfg(test)] mod test { + use serde_json::json; + use std::convert::TryFrom; + #[cfg(feature = "messages")] use crate::{ events::AnySyncRoomEvent, identifiers::event_id, BaseClientConfig, JsonStore, Raw, }; - use crate::{ - identifiers::{room_id, user_id, RoomId}, - BaseClient, Session, - }; + use crate::{BaseClient, Session}; + + use matrix_sdk_common::identifiers::{room_id, user_id, RoomId}; + use matrix_sdk_common_macros::async_trait; use matrix_sdk_test::{async_test, test_json, EventBuilder, EventsJson}; - use serde_json::json; - use std::convert::TryFrom; #[cfg(not(target_arch = "wasm32"))] use tempfile::tempdir; diff --git a/matrix_sdk_crypto/src/group_manager.rs b/matrix_sdk_crypto/src/group_manager.rs index 015df4da..1e6b4c55 100644 --- a/matrix_sdk_crypto/src/group_manager.rs +++ b/matrix_sdk_crypto/src/group_manager.rs @@ -21,6 +21,7 @@ use matrix_sdk_common::{ identifiers::{RoomId, UserId}, uuid::Uuid, }; +use tracing::debug; use crate::{ error::{EventError, MegolmResult, OlmResult}, @@ -57,6 +58,12 @@ impl GroupSessionManager { self.outbound_group_sessions.remove(room_id).is_some() } + pub fn mark_request_as_sent(&self, request_id: &Uuid) { + self.outbound_sessions_being_shared + .remove(request_id) + .map(|(_, s)| s.mark_request_as_sent(request_id)); + } + /// Get an outbound group session for a room, if one exists. /// /// # Arguments @@ -111,11 +118,10 @@ impl GroupSessionManager { &self, room_id: &RoomId, settings: EncryptionSettings, - users_to_share_with: impl Iterator, ) -> OlmResult<()> { let (outbound, inbound) = self .account - .create_group_session_pair(room_id, settings, users_to_share_with) + .create_group_session_pair(room_id, settings) .await .map_err(|_| EventError::UnsupportedAlgorithm)?; @@ -140,8 +146,8 @@ impl GroupSessionManager { room_id: &RoomId, users: impl Iterator, encryption_settings: impl Into, - ) -> OlmResult> { - self.create_outbound_group_session(room_id, encryption_settings.into(), users) + ) -> OlmResult>> { + self.create_outbound_group_session(room_id, encryption_settings.into()) .await?; let session = self.outbound_group_sessions.get(room_id).unwrap(); @@ -149,15 +155,9 @@ impl GroupSessionManager { panic!("Session is already shared"); } - // TODO don't mark the session as shared automatically, only when all - // the requests are done, failure to send these requests will likely end - // up in wedged sessions. We'll need to store the requests and let the - // caller mark them as sent using an UUID. - session.mark_as_shared(); - let mut devices: Vec = Vec::new(); - for user_id in session.users_to_share_with() { + for user_id in users { let user_devices = self.store.get_user_devices(&user_id).await?; devices.extend(user_devices.devices().filter(|d| !d.is_blacklisted())); } @@ -193,11 +193,25 @@ impl GroupSessionManager { let id = Uuid::new_v4(); - requests.push(ToDeviceRequest { + let request = Arc::new(ToDeviceRequest { event_type: EventType::RoomEncrypted, txn_id: id, messages, }); + + session.add_request(id, request.clone()); + self.outbound_sessions_being_shared + .insert(id, session.clone()); + requests.push(request); + } + + if requests.is_empty() { + debug!( + "Session {} for room {} doesn't need to be shared with anyone, marking as shared", + session.session_id(), + session.room_id() + ); + session.mark_as_shared(); } Ok(requests) diff --git a/matrix_sdk_crypto/src/key_request.rs b/matrix_sdk_crypto/src/key_request.rs index 831ce0e0..67618038 100644 --- a/matrix_sdk_crypto/src/key_request.rs +++ b/matrix_sdk_crypto/src/key_request.rs @@ -323,10 +323,7 @@ impl KeyRequestMachine { Err(KeyshareDecision::UntrustedDevice) } } else if let Some(outbound) = outbound_session { - if outbound - .shared_with() - .contains(&(device.user_id().to_owned(), device.device_id().to_owned())) - { + if outbound.is_shared_with(device.user_id(), device.device_id()) { Ok(()) } else { Err(KeyshareDecision::OutboundSessionNotShared) diff --git a/matrix_sdk_crypto/src/machine.rs b/matrix_sdk_crypto/src/machine.rs index b173da6a..d2b5bf49 100644 --- a/matrix_sdk_crypto/src/machine.rs +++ b/matrix_sdk_crypto/src/machine.rs @@ -574,7 +574,7 @@ impl OlmMachine { room_id: &RoomId, ) -> OlmResult<()> { self.group_session_manager - .create_outbound_group_session(room_id, EncryptionSettings::default(), [].iter()) + .create_outbound_group_session(room_id, EncryptionSettings::default()) .await } @@ -645,7 +645,7 @@ impl OlmMachine { room_id: &RoomId, users: impl Iterator, encryption_settings: impl Into, - ) -> OlmResult> { + ) -> OlmResult>> { self.group_session_manager .share_group_session(room_id, users, encryption_settings) .await @@ -705,6 +705,7 @@ impl OlmMachine { self.key_request_machine .mark_outgoing_request_as_sent(request_id) .await?; + self.group_session_manager.mark_request_as_sent(request_id); Ok(()) } @@ -1037,6 +1038,7 @@ pub(crate) mod test { use std::{ collections::BTreeMap, convert::{TryFrom, TryInto}, + sync::Arc, time::SystemTime, }; @@ -1103,7 +1105,7 @@ pub(crate) mod test { get_keys::Response::try_from(data).expect("Can't parse the keys upload response") } - fn to_device_requests_to_content(requests: Vec) -> EncryptedEventContent { + fn to_device_requests_to_content(requests: Vec>) -> EncryptedEventContent { let to_device_request = &requests[0]; let content: Raw = serde_json::from_str( diff --git a/matrix_sdk_crypto/src/olm/account.rs b/matrix_sdk_crypto/src/olm/account.rs index de20784b..86a9dfa1 100644 --- a/matrix_sdk_crypto/src/olm/account.rs +++ b/matrix_sdk_crypto/src/olm/account.rs @@ -878,7 +878,6 @@ impl ReadOnlyAccount { &self, room_id: &RoomId, settings: EncryptionSettings, - users_to_share_with: impl Iterator, ) -> Result<(OutboundGroupSession, InboundGroupSession), ()> { if settings.algorithm != EventEncryptionAlgorithm::MegolmV1AesSha2 { return Err(()); @@ -889,7 +888,6 @@ impl ReadOnlyAccount { self.identity_keys.clone(), room_id, settings, - users_to_share_with, ); let identity_keys = self.identity_keys(); @@ -912,7 +910,7 @@ impl ReadOnlyAccount { &self, room_id: &RoomId, ) -> Result<(OutboundGroupSession, InboundGroupSession), ()> { - self.create_group_session_pair(room_id, EncryptionSettings::default(), [].iter()) + self.create_group_session_pair(room_id, EncryptionSettings::default()) .await } diff --git a/matrix_sdk_crypto/src/olm/group_sessions/mod.rs b/matrix_sdk_crypto/src/olm/group_sessions/mod.rs index 74acfbbf..a78bde53 100644 --- a/matrix_sdk_crypto/src/olm/group_sessions/mod.rs +++ b/matrix_sdk_crypto/src/olm/group_sessions/mod.rs @@ -147,7 +147,7 @@ mod test { let account = ReadOnlyAccount::new(&user_id!("@alice:example.org"), "DEVICEID".into()); let (session, _) = account - .create_group_session_pair(&room_id!("!test_room:example.org"), settings, [].iter()) + .create_group_session_pair(&room_id!("!test_room:example.org"), settings) .await .unwrap(); @@ -165,7 +165,7 @@ mod test { }; let (mut session, _) = account - .create_group_session_pair(&room_id!("!test_room:example.org"), settings, [].iter()) + .create_group_session_pair(&room_id!("!test_room:example.org"), settings) .await .unwrap(); diff --git a/matrix_sdk_crypto/src/olm/group_sessions/outbound.rs b/matrix_sdk_crypto/src/olm/group_sessions/outbound.rs index 64b1229b..532881a9 100644 --- a/matrix_sdk_crypto/src/olm/group_sessions/outbound.rs +++ b/matrix_sdk_crypto/src/olm/group_sessions/outbound.rs @@ -12,7 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -use dashmap::{setref::multiple::RefMulti, DashSet}; +use dashmap::{DashMap, DashSet}; +use matrix_sdk_common::{api::r0::to_device::DeviceIdOrAllDevices, uuid::Uuid}; use std::{ cmp::min, fmt, @@ -22,6 +23,7 @@ use std::{ }, time::Duration, }; +use tracing::debug; use matrix_sdk_common::{ events::{ @@ -32,15 +34,17 @@ use matrix_sdk_common::{ instant::Instant, locks::Mutex, }; -use olm_rs::outbound_group_session::OlmOutboundGroupSession; use serde_json::{json, Value}; +use olm_rs::outbound_group_session::OlmOutboundGroupSession; pub use olm_rs::{ account::IdentityKeys, session::{OlmMessage, PreKeyMessage}, utility::OlmUtility, }; +use crate::ToDeviceRequest; + use super::GroupSessionKey; const ROTATION_PERIOD: Duration = Duration::from_millis(604800000); @@ -101,8 +105,8 @@ pub struct OutboundGroupSession { message_count: Arc, shared: Arc, settings: Arc, - shared_with_set: Arc>, - to_share_with_set: Arc>, + shared_with_set: Arc>>, + to_share_with_set: Arc>>, } impl OutboundGroupSession { @@ -121,18 +125,15 @@ impl OutboundGroupSession { /// /// * `settings` - Settings determining the algorithm and rotation period of /// the outbound group session. - pub fn new<'a>( + pub fn new( device_id: Arc, identity_keys: Arc, room_id: &RoomId, settings: EncryptionSettings, - users_to_share_with: impl Iterator, ) -> Self { let session = OlmOutboundGroupSession::new(); let session_id = session.session_id(); - let users_to_share_with = users_to_share_with.cloned().collect(); - OutboundGroupSession { inner: Arc::new(Mutex::new(session)), room_id: Arc::new(room_id.to_owned()), @@ -143,13 +144,52 @@ impl OutboundGroupSession { message_count: Arc::new(AtomicU64::new(0)), shared: Arc::new(AtomicBool::new(false)), settings: Arc::new(settings), - shared_with_set: Arc::new(DashSet::new()), - to_share_with_set: Arc::new(users_to_share_with), + shared_with_set: Arc::new(DashMap::new()), + to_share_with_set: Arc::new(DashMap::new()), } } - pub(crate) fn users_to_share_with(&self) -> impl Iterator> + '_ { - self.to_share_with_set.iter() + pub fn add_request(&self, request_id: Uuid, request: Arc) { + self.to_share_with_set.insert(request_id, request); + } + + /// Mark the request with the given request id as sent. + /// + /// This removes the request from the queue and marks the set of + /// users/devices that received the session. + pub fn mark_request_as_sent(&self, request_id: &Uuid) { + let request = self.to_share_with_set.remove(request_id); + + request.map(|(_, r)| { + let user_pairs = r.messages.iter().map(|(u, v)| { + ( + u.clone(), + v.keys().filter_map(|d| { + if let DeviceIdOrAllDevices::DeviceId(d) = d { + Some(d.clone()) + } else { + None + } + }), + ) + }); + + user_pairs.for_each(|(u, d)| { + self.shared_with_set + .entry(u) + .or_insert_with(DashSet::new) + .extend(d); + }) + }); + + if self.to_share_with_set.is_empty() { + debug!( + "Marking session {} for room {} as shared.", + self.session_id(), + self.room_id + ); + self.mark_as_shared(); + } } /// Encrypt the given plaintext using this session. @@ -246,6 +286,11 @@ impl OutboundGroupSession { GroupSessionKey(session.session_key()) } + /// Get the room id of the room this session belongs to. + pub fn room_id(&self) -> &RoomId { + &self.room_id + } + /// Returns the unique identifier for this session. pub fn session_id(&self) -> &str { &self.session_id @@ -273,15 +318,20 @@ impl OutboundGroupSession { } /// The set of users this session is shared with. - pub(crate) fn shared_with(&self) -> &DashSet<(UserId, DeviceIdBox)> { - &self.shared_with_set + pub(crate) fn is_shared_with(&self, user_id: &UserId, device_id: &DeviceId) -> bool { + self.shared_with_set + .get(user_id) + .map(|d| d.contains(device_id)) + .unwrap_or(false) } /// Mark that the session was shared with the given user/device pair. - #[allow(dead_code)] + #[cfg(test)] pub fn mark_shared_with(&self, user_id: &UserId, device_id: &DeviceId) { self.shared_with_set - .insert((user_id.to_owned(), device_id.to_owned())); + .entry(user_id.to_owned()) + .or_insert_with(DashSet::new) + .insert(device_id.to_owned()); } }