From 10da61c567e856a0360140b1d654ddcfe4c964cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Damir=20Jeli=C4=87?= Date: Thu, 28 Jan 2021 14:07:51 +0100 Subject: [PATCH] crypto: Answer key reshare requests only at the originally shared message index --- matrix_sdk_crypto/src/identities/device.rs | 7 +- matrix_sdk_crypto/src/key_request.rs | 79 ++++++++-------- .../src/olm/group_sessions/inbound.rs | 21 +++-- .../src/olm/group_sessions/mod.rs | 4 +- .../src/olm/group_sessions/outbound.rs | 90 ++++++++++++------- matrix_sdk_crypto/src/olm/mod.rs | 2 +- .../src/session_manager/group_sessions.rs | 22 +++-- 7 files changed, 139 insertions(+), 86 deletions(-) diff --git a/matrix_sdk_crypto/src/identities/device.rs b/matrix_sdk_crypto/src/identities/device.rs index 5348f497..3e79d3f4 100644 --- a/matrix_sdk_crypto/src/identities/device.rs +++ b/matrix_sdk_crypto/src/identities/device.rs @@ -189,8 +189,13 @@ impl Device { pub async fn encrypt_session( &self, session: InboundGroupSession, + message_index: Option, ) -> OlmResult<(Session, EncryptedEventContent)> { - let export = session.export().await; + let export = if let Some(index) = message_index { + session.export_at_index(index).await + } else { + session.export().await + }; let content: ForwardedRoomKeyToDeviceEventContent = if let Ok(c) = export.try_into() { c diff --git a/matrix_sdk_crypto/src/key_request.rs b/matrix_sdk_crypto/src/key_request.rs index a706068e..c2be3528 100644 --- a/matrix_sdk_crypto/src/key_request.rs +++ b/matrix_sdk_crypto/src/key_request.rs @@ -40,7 +40,7 @@ use matrix_sdk_common::{ use crate::{ error::{OlmError, OlmResult}, - olm::{InboundGroupSession, OutboundGroupSession, Session}, + olm::{InboundGroupSession, OutboundGroupSession, Session, ShareState}, requests::{OutgoingRequest, ToDeviceRequest}, store::{CryptoStoreError, Store}, Device, @@ -347,42 +347,46 @@ impl KeyRequestMachine { .await?; if let Some(device) = device { - if let Err(e) = self.should_share_session( + match self.should_share_session( &device, self.outbound_group_sessions .get(&key_info.room_id) .as_deref(), ) { - info!( - "Received a key request from {} {} that we won't serve: {}", - device.user_id(), - device.device_id(), - e - ); + Err(e) => { + info!( + "Received a key request from {} {} that we won't serve: {}", + device.user_id(), + device.device_id(), + e + ); - Ok(None) - } else { - info!( - "Serving a key request for {} from {} {}.", - key_info.session_id, - device.user_id(), - device.device_id() - ); + Ok(None) + } + Ok(message_index) => { + info!( + "Serving a key request for {} from {} {} with message_index {:?}.", + key_info.session_id, + device.user_id(), + device.device_id(), + message_index, + ); - match self.share_session(&session, &device).await { - Ok(s) => Ok(Some(s)), - Err(OlmError::MissingSession) => { - info!( - "Key request from {} {} is missing an Olm session, \ + match self.share_session(&session, &device, message_index).await { + Ok(s) => Ok(Some(s)), + Err(OlmError::MissingSession) => { + info!( + "Key request from {} {} is missing an Olm session, \ putting the request in the wait queue", - device.user_id(), - device.device_id() - ); - self.handle_key_share_without_session(device, event); + device.user_id(), + device.device_id() + ); + self.handle_key_share_without_session(device, event); - Ok(None) + Ok(None) + } + Err(e) => Err(e), } - Err(e) => Err(e), } } } else { @@ -400,8 +404,11 @@ impl KeyRequestMachine { &self, session: &InboundGroupSession, device: &Device, + message_index: Option, ) -> OlmResult { - let (used_session, content) = device.encrypt_session(session.clone()).await?; + let (used_session, content) = device + .encrypt_session(session.clone(), message_index) + .await?; let id = Uuid::new_v4(); let mut messages = BTreeMap::new(); @@ -453,16 +460,18 @@ impl KeyRequestMachine { &self, device: &Device, outbound_session: Option<&OutboundGroupSession>, - ) -> Result<(), KeyshareDecision> { + ) -> Result, KeyshareDecision> { if device.user_id() == self.user_id() { if device.trust_state() { - Ok(()) + Ok(None) } else { Err(KeyshareDecision::UntrustedDevice) } } else if let Some(outbound) = outbound_session { - if outbound.is_shared_with(device.user_id(), device.device_id()) { - Ok(()) + if let ShareState::Shared(message_index) = + outbound.is_shared_with(device.user_id(), device.device_id()) + { + Ok(Some(message_index)) } else { Err(KeyshareDecision::OutboundSessionNotShared) } @@ -830,7 +839,7 @@ mod test { machine.mark_outgoing_request_as_sent(&id).await.unwrap(); - let export = session.export_at_index(10).await.unwrap(); + let export = session.export_at_index(10).await; let content: ForwardedRoomKeyToDeviceEventContent = export.try_into().unwrap(); @@ -887,7 +896,7 @@ mod test { machine.mark_outgoing_request_as_sent(&id).await.unwrap(); - let export = session.export_at_index(15).await.unwrap(); + let export = session.export_at_index(15).await; let content: ForwardedRoomKeyToDeviceEventContent = export.try_into().unwrap(); @@ -903,7 +912,7 @@ mod test { assert!(second_session.is_none()); - let export = session.export_at_index(0).await.unwrap(); + let export = session.export_at_index(0).await; let content: ForwardedRoomKeyToDeviceEventContent = export.try_into().unwrap(); diff --git a/matrix_sdk_crypto/src/olm/group_sessions/inbound.rs b/matrix_sdk_crypto/src/olm/group_sessions/inbound.rs index 78aa9878..acfb2e35 100644 --- a/matrix_sdk_crypto/src/olm/group_sessions/inbound.rs +++ b/matrix_sdk_crypto/src/olm/group_sessions/inbound.rs @@ -183,9 +183,7 @@ impl InboundGroupSession { /// If only a limited part of this session should be exported use /// [`export_at_index()`](#method.export_at_index). pub async fn export(&self) -> ExportedRoomKey { - self.export_at_index(self.first_known_index()) - .await - .expect("Can't export at the first known index") + self.export_at_index(self.first_known_index()).await } /// Get the sender key that this session was received from. @@ -194,11 +192,18 @@ impl InboundGroupSession { } /// Export this session at the given message index. - pub async fn export_at_index(&self, message_index: u32) -> Option { - let session_key = - ExportedGroupSessionKey(self.inner.lock().await.export(message_index).ok()?); + pub async fn export_at_index(&self, message_index: u32) -> ExportedRoomKey { + let message_index = std::cmp::max(self.first_known_index(), message_index); - Some(ExportedRoomKey { + let session_key = ExportedGroupSessionKey( + self.inner + .lock() + .await + .export(message_index) + .expect("Can't export session"), + ); + + ExportedRoomKey { algorithm: EventEncryptionAlgorithm::MegolmV1AesSha2, room_id: (&*self.room_id).clone(), sender_key: (&*self.sender_key).to_owned(), @@ -212,7 +217,7 @@ impl InboundGroupSession { .unwrap_or_default(), sender_claimed_keys: (&*self.signing_key).clone(), session_key, - }) + } } /// Restore a Session from a previously pickled string. diff --git a/matrix_sdk_crypto/src/olm/group_sessions/mod.rs b/matrix_sdk_crypto/src/olm/group_sessions/mod.rs index 98625b50..9a7dc242 100644 --- a/matrix_sdk_crypto/src/olm/group_sessions/mod.rs +++ b/matrix_sdk_crypto/src/olm/group_sessions/mod.rs @@ -24,7 +24,9 @@ mod inbound; mod outbound; pub use inbound::{InboundGroupSession, InboundGroupSessionPickle, PickledInboundGroupSession}; -pub use outbound::{EncryptionSettings, OutboundGroupSession, PickledOutboundGroupSession}; +pub use outbound::{ + EncryptionSettings, OutboundGroupSession, PickledOutboundGroupSession, ShareState, +}; /// The private session key of a group session. /// Can be used to create a new inbound group session. diff --git a/matrix_sdk_crypto/src/olm/group_sessions/outbound.rs b/matrix_sdk_crypto/src/olm/group_sessions/outbound.rs index 26da5e0d..868cc41f 100644 --- a/matrix_sdk_crypto/src/olm/group_sessions/outbound.rs +++ b/matrix_sdk_crypto/src/olm/group_sessions/outbound.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use dashmap::{DashMap, DashSet}; +use dashmap::DashMap; use matrix_sdk_common::{ api::r0::to_device::DeviceIdOrAllDevices, events::room::{ @@ -23,7 +23,7 @@ use matrix_sdk_common::{ }; use std::{ cmp::max, - collections::{BTreeMap, BTreeSet}, + collections::BTreeMap, fmt, sync::{ atomic::{AtomicBool, AtomicU64, Ordering}, @@ -64,6 +64,11 @@ use super::{ const ROTATION_PERIOD: Duration = Duration::from_millis(604800000); const ROTATION_MESSAGES: u64 = 100; +pub enum ShareState { + NotShared, + Shared(u32), +} + /// Settings for an encrypted room. /// /// This determines the algorithm and rotation periods of a group session. @@ -120,8 +125,8 @@ pub struct OutboundGroupSession { shared: Arc, invalidated: Arc, settings: Arc, - pub(crate) shared_with_set: Arc>>, - to_share_with_set: Arc>>, + pub(crate) shared_with_set: Arc>>, + to_share_with_set: Arc, u32)>>, } impl OutboundGroupSession { @@ -165,8 +170,14 @@ impl OutboundGroupSession { } } - pub(crate) fn add_request(&self, request_id: Uuid, request: Arc) { - self.to_share_with_set.insert(request_id, request); + pub(crate) fn add_request( + &self, + request_id: Uuid, + request: Arc, + message_index: u32, + ) { + self.to_share_with_set + .insert(request_id, (request, message_index)); } /// This should be called if an the user wishes to rotate this session. @@ -180,12 +191,12 @@ impl OutboundGroupSession { /// users/devices that received the session. pub fn mark_request_as_sent(&self, request_id: &Uuid) { if let Some((_, r)) = self.to_share_with_set.remove(request_id) { - let user_pairs = r.messages.iter().map(|(u, v)| { + let user_pairs = r.0.messages.iter().map(|(u, v)| { ( u.clone(), - v.keys().filter_map(|d| { - if let DeviceIdOrAllDevices::DeviceId(d) = d { - Some(d.clone()) + v.iter().filter_map(|d| { + if let DeviceIdOrAllDevices::DeviceId(d) = d.0 { + Some((d.clone(), r.1)) } else { None } @@ -196,7 +207,7 @@ impl OutboundGroupSession { user_pairs.for_each(|(u, d)| { self.shared_with_set .entry(u) - .or_insert_with(DashSet::new) + .or_insert_with(DashMap::new) .extend(d); }); @@ -349,28 +360,40 @@ impl OutboundGroupSession { } /// Has or will the session be shared with the given user/device pair. - pub(crate) fn is_shared_with(&self, user_id: &UserId, device_id: &DeviceId) -> bool { - let shared_with = self + pub(crate) fn is_shared_with(&self, user_id: &UserId, device_id: &DeviceId) -> ShareState { + // Check if we shared the session. + let shared_state = self .shared_with_set .get(user_id) - .map(|d| d.contains(device_id)) - .unwrap_or(false); + .and_then(|d| d.get(device_id).map(|m| ShareState::Shared(*m.value()))); - let should_be_shared_with = if self.shared() { - false + if let Some(state) = shared_state { + state } else { + // If we haven't shared the session, check if we're going to share + // the session. let device_id = DeviceIdOrAllDevices::DeviceId(device_id.into()); - self.to_share_with_set.iter().any(|item| { - if let Some(e) = item.value().messages.get(user_id) { - e.contains_key(&device_id) - } else { - false - } - }) - }; + // Find the first request that contains the given user id and + // device id. + let shared = self.to_share_with_set.iter().find_map(|item| { + let request = &item.value().0; + let message_index = item.value().1; - shared_with || should_be_shared_with + if request + .messages + .get(user_id) + .map(|e| e.contains_key(&device_id)) + .unwrap_or(false) + { + Some(ShareState::Shared(message_index)) + } else { + None + } + }); + + shared.unwrap_or(ShareState::NotShared) + } } /// Mark that the session was shared with the given user/device pair. @@ -378,8 +401,8 @@ impl OutboundGroupSession { pub fn mark_shared_with(&self, user_id: &UserId, device_id: &DeviceId) { self.shared_with_set .entry(user_id.to_owned()) - .or_insert_with(DashSet::new) - .insert(device_id.to_owned()); + .or_insert_with(DashMap::new) + .insert(device_id.to_owned(), 0); } /// Get the list of requests that need to be sent out for this session to be @@ -387,7 +410,7 @@ impl OutboundGroupSession { pub(crate) fn pending_requests(&self) -> Vec> { self.to_share_with_set .iter() - .map(|i| i.value().clone()) + .map(|i| i.value().0.clone()) .collect() } @@ -465,7 +488,10 @@ impl OutboundGroupSession { ( u.key().clone(), #[allow(clippy::map_clone)] - u.value().iter().map(|d| d.clone()).collect(), + u.value() + .iter() + .map(|d| (d.key().clone(), *d.value())) + .collect(), ) }) .collect(), @@ -524,9 +550,9 @@ pub struct PickledOutboundGroupSession { /// Has the session been invalidated. pub invalidated: bool, /// The set of users the session has been already shared with. - pub shared_with_set: BTreeMap>, + pub shared_with_set: BTreeMap>, /// Requests that need to be sent out to share the session. - pub requests: BTreeMap>, + pub requests: BTreeMap, u32)>, } #[cfg(test)] diff --git a/matrix_sdk_crypto/src/olm/mod.rs b/matrix_sdk_crypto/src/olm/mod.rs index 19424988..0d20746a 100644 --- a/matrix_sdk_crypto/src/olm/mod.rs +++ b/matrix_sdk_crypto/src/olm/mod.rs @@ -25,11 +25,11 @@ mod utility; pub(crate) use account::{Account, OlmDecryptionInfo, SessionType}; pub use account::{AccountPickle, OlmMessageHash, PickledAccount, ReadOnlyAccount}; -pub(crate) use group_sessions::GroupSessionKey; pub use group_sessions::{ EncryptionSettings, ExportedRoomKey, InboundGroupSession, InboundGroupSessionPickle, OutboundGroupSession, PickledInboundGroupSession, PickledOutboundGroupSession, }; +pub(crate) use group_sessions::{GroupSessionKey, ShareState}; pub use olm_rs::{account::IdentityKeys, PicklingMode}; pub use session::{PickledSession, Session, SessionPickle}; pub use signing::{PickledCrossSigningIdentity, PrivateCrossSigningIdentity}; diff --git a/matrix_sdk_crypto/src/session_manager/group_sessions.rs b/matrix_sdk_crypto/src/session_manager/group_sessions.rs index 8391f55d..e1e3d4c2 100644 --- a/matrix_sdk_crypto/src/session_manager/group_sessions.rs +++ b/matrix_sdk_crypto/src/session_manager/group_sessions.rs @@ -29,7 +29,7 @@ use tracing::{debug, info}; use crate::{ error::{EventError, MegolmResult, OlmResult}, - olm::{Account, InboundGroupSession, OutboundGroupSession, Session}, + olm::{Account, InboundGroupSession, OutboundGroupSession, Session, ShareState}, store::{Changes, Store}, Device, EncryptionSettings, OlmError, ToDeviceRequest, }; @@ -263,7 +263,8 @@ impl GroupSessionManager { { #[allow(clippy::map_clone)] // Devices that received this session - let shared: HashSet = shared.iter().map(|d| d.clone()).collect(); + let shared: HashSet = + shared.iter().map(|d| d.key().clone()).collect(); let shared: HashSet<&DeviceId> = shared.iter().map(|d| d.as_ref()).collect(); // The difference between the devices that received the @@ -341,16 +342,23 @@ impl GroupSessionManager { let devices: Vec = devices .into_iter() .map(|(_, d)| { - d.into_iter() - .filter(|d| !outbound.is_shared_with(d.user_id(), d.device_id())) + d.into_iter().filter(|d| { + matches!( + outbound.is_shared_with(d.user_id(), d.device_id()), + ShareState::NotShared + ) + }) }) .flatten() .collect(); + let key_content = outbound.as_json().await; + let message_index = outbound.message_index().await; + if !devices.is_empty() { info!( "Sharing outbound session at index {} with {:?}", - outbound.message_index().await, + message_index, devices.iter().fold(BTreeMap::new(), |mut acc, d| { acc.entry(d.user_id()) .or_insert_with(BTreeSet::new) @@ -360,15 +368,13 @@ impl GroupSessionManager { ); } - let key_content = outbound.as_json().await; - for device_map_chunk in devices.chunks(Self::MAX_TO_DEVICE_MESSAGES) { let (id, request, used_sessions) = self .encrypt_session_for(key_content.clone(), device_map_chunk) .await?; if !request.messages.is_empty() { - outbound.add_request(id, request.into()); + outbound.add_request(id, request.into(), message_index); self.outbound_sessions_being_shared .insert(id, outbound.clone()); }