From 975f9a0b41e6315d57c0ce17459f10257affd249 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Damir=20Jeli=C4=87?= Date: Wed, 14 Apr 2021 14:30:53 +0200 Subject: [PATCH 01/11] crypto: Improve the way we decide if we honor room key requests This improves two things, use the correct outbound session to check if the session should be shared. Check first if the session has been shared if there isn't a session or it hasn't been shared check if the request is comming from our own user. --- matrix_sdk_crypto/src/key_request.rs | 73 +++++++++++++++++++--------- 1 file changed, 49 insertions(+), 24 deletions(-) diff --git a/matrix_sdk_crypto/src/key_request.rs b/matrix_sdk_crypto/src/key_request.rs index 5fa560a5..d091568f 100644 --- a/matrix_sdk_crypto/src/key_request.rs +++ b/matrix_sdk_crypto/src/key_request.rs @@ -363,12 +363,7 @@ impl KeyRequestMachine { .await?; if let Some(device) = device { - match self.should_share_session( - &device, - self.outbound_group_sessions - .get(&key_info.room_id) - .as_deref(), - ) { + match self.should_share_session(&device, session.session_id(), session.room_id()) { Err(e) => { info!( "Received a key request from {} {} that we won't serve: {}", @@ -469,28 +464,47 @@ impl KeyRequestMachine { /// /// * `device` - The device that is requesting a session from us. /// - /// * `outbound_session` - If one still exists, the matching outbound - /// session that was used to create the inbound session that is being - /// requested. + /// * `session_id` - The unique ID of the session that should be shared + /// + /// * `room_id` - The unique ID of the room where the session is being used. fn should_share_session( &self, device: &Device, - outbound_session: Option<&OutboundGroupSession>, + session_id: &str, + room_id: &RoomId, ) -> Result, KeyshareDecision> { - if device.user_id() == self.user_id() { + let outbound_session = self + .outbound_group_sessions + .get(room_id) + .filter(|o| session_id == o.session_id()); + + let own_device_check = || { if device.trust_state() { Ok(None) } else { Err(KeyshareDecision::UntrustedDevice) } - } else if let Some(outbound) = outbound_session { + }; + + // If we have a matching outbound session we can check the list of + // users/devices that received the session, if it wasn't shared check if + // it's our own device and if it's trusted. + if let Some(outbound) = outbound_session { if let ShareState::Shared(message_index) = outbound.is_shared_with(device.user_id(), device.device_id()) { Ok(Some(message_index)) + } else if device.user_id() == self.user_id() { + own_device_check() } else { Err(KeyshareDecision::OutboundSessionNotShared) } + // Else just check if it's one of our own devices that requested the key and + // check if the device is trusted. + } else if device.user_id() == self.user_id() { + own_device_check() + // Otherwise, there's not enough info to decide if we can safely share + // the session. } else { Err(KeyshareDecision::MissingOutboundSession) } @@ -722,7 +736,7 @@ mod test { use crate::{ identities::{LocalTrust, ReadOnlyDevice}, olm::{Account, PrivateCrossSigningIdentity, ReadOnlyAccount}, - store::{CryptoStore, MemoryStore, Store}, + store::{Changes, CryptoStore, MemoryStore, Store}, verification::VerificationMachine, }; @@ -966,16 +980,23 @@ mod test { .unwrap() .unwrap(); + let (outbound, inbound) = account + .create_group_session_pair_with_defaults(&room_id()) + .await + .unwrap(); + // We don't share keys with untrusted devices. assert_eq!( machine - .should_share_session(&own_device, None) + .should_share_session(&own_device, inbound.session_id(), inbound.room_id()) .expect_err("Should not share with untrusted"), KeyshareDecision::UntrustedDevice ); own_device.set_trust_state(LocalTrust::Verified); // Now we do want to share the keys. - assert!(machine.should_share_session(&own_device, None).is_ok()); + assert!(machine + .should_share_session(&own_device, inbound.session_id(), inbound.room_id()) + .is_ok()); let bob_device = ReadOnlyDevice::from_account(&bob_account()).await; machine.store.save_devices(&[bob_device]).await.unwrap(); @@ -991,21 +1012,25 @@ mod test { // session was provided. assert_eq!( machine - .should_share_session(&bob_device, None) + .should_share_session(&bob_device, inbound.session_id(), inbound.room_id()) .expect_err("Should not share with other."), KeyshareDecision::MissingOutboundSession ); - let (session, _) = account - .create_group_session_pair_with_defaults(&room_id()) - .await - .unwrap(); + let mut changes = Changes::default(); + + changes.outbound_group_sessions.push(outbound.clone()); + changes.inbound_group_sessions.push(inbound.clone()); + machine.store.save_changes(changes).await.unwrap(); + machine + .outbound_group_sessions + .insert(inbound.room_id().to_owned(), outbound.clone()); // We don't share sessions with other user's devices if the session // wasn't shared in the first place. assert_eq!( machine - .should_share_session(&bob_device, Some(&session)) + .should_share_session(&bob_device, inbound.session_id(), inbound.room_id()) .expect_err("Should not share with other unless shared."), KeyshareDecision::OutboundSessionNotShared ); @@ -1016,14 +1041,14 @@ mod test { // wasn't shared in the first place even if the device is trusted. assert_eq!( machine - .should_share_session(&bob_device, Some(&session)) + .should_share_session(&bob_device, inbound.session_id(), inbound.room_id()) .expect_err("Should not share with other unless shared."), KeyshareDecision::OutboundSessionNotShared ); - session.mark_shared_with(bob_device.user_id(), bob_device.device_id()); + outbound.mark_shared_with(bob_device.user_id(), bob_device.device_id()); assert!(machine - .should_share_session(&bob_device, Some(&session)) + .should_share_session(&bob_device, inbound.session_id(), inbound.room_id()) .is_ok()); } From 5637ca30801d13b2e6e12c831329f6c85e2913f5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Damir=20Jeli=C4=87?= Date: Thu, 15 Apr 2021 11:19:14 +0200 Subject: [PATCH 02/11] crypto: Simplify the should_share_session method --- matrix_sdk_crypto/src/key_request.rs | 44 ++++++++++++++++------------ 1 file changed, 26 insertions(+), 18 deletions(-) diff --git a/matrix_sdk_crypto/src/key_request.rs b/matrix_sdk_crypto/src/key_request.rs index d091568f..aaabc80d 100644 --- a/matrix_sdk_crypto/src/key_request.rs +++ b/matrix_sdk_crypto/src/key_request.rs @@ -363,7 +363,7 @@ impl KeyRequestMachine { .await?; if let Some(device) = device { - match self.should_share_session(&device, session.session_id(), session.room_id()) { + match self.should_share_session(&device, &session) { Err(e) => { info!( "Received a key request from {} {} that we won't serve: {}", @@ -464,19 +464,16 @@ impl KeyRequestMachine { /// /// * `device` - The device that is requesting a session from us. /// - /// * `session_id` - The unique ID of the session that should be shared - /// - /// * `room_id` - The unique ID of the room where the session is being used. + /// * `session` - The session that was requested to be shared. fn should_share_session( &self, device: &Device, - session_id: &str, - room_id: &RoomId, + session: &InboundGroupSession, ) -> Result, KeyshareDecision> { let outbound_session = self .outbound_group_sessions - .get(room_id) - .filter(|o| session_id == o.session_id()); + .get(session.room_id()) + .filter(|o| session.session_id() == o.session_id()); let own_device_check = || { if device.trust_state() { @@ -988,15 +985,13 @@ mod test { // We don't share keys with untrusted devices. assert_eq!( machine - .should_share_session(&own_device, inbound.session_id(), inbound.room_id()) + .should_share_session(&own_device, &inbound) .expect_err("Should not share with untrusted"), KeyshareDecision::UntrustedDevice ); own_device.set_trust_state(LocalTrust::Verified); // Now we do want to share the keys. - assert!(machine - .should_share_session(&own_device, inbound.session_id(), inbound.room_id()) - .is_ok()); + assert!(machine.should_share_session(&own_device, &inbound).is_ok()); let bob_device = ReadOnlyDevice::from_account(&bob_account()).await; machine.store.save_devices(&[bob_device]).await.unwrap(); @@ -1012,7 +1007,7 @@ mod test { // session was provided. assert_eq!( machine - .should_share_session(&bob_device, inbound.session_id(), inbound.room_id()) + .should_share_session(&bob_device, &inbound) .expect_err("Should not share with other."), KeyshareDecision::MissingOutboundSession ); @@ -1030,7 +1025,7 @@ mod test { // wasn't shared in the first place. assert_eq!( machine - .should_share_session(&bob_device, inbound.session_id(), inbound.room_id()) + .should_share_session(&bob_device, &inbound) .expect_err("Should not share with other unless shared."), KeyshareDecision::OutboundSessionNotShared ); @@ -1041,15 +1036,28 @@ mod test { // wasn't shared in the first place even if the device is trusted. assert_eq!( machine - .should_share_session(&bob_device, inbound.session_id(), inbound.room_id()) + .should_share_session(&bob_device, &inbound) .expect_err("Should not share with other unless shared."), KeyshareDecision::OutboundSessionNotShared ); + // We now share the session, since it was shared before. outbound.mark_shared_with(bob_device.user_id(), bob_device.device_id()); - assert!(machine - .should_share_session(&bob_device, inbound.session_id(), inbound.room_id()) - .is_ok()); + assert!(machine.should_share_session(&bob_device, &inbound).is_ok()); + + // But we don't share some other session that doesn't match our outbound + // session + let (_, other_inbound) = account + .create_group_session_pair_with_defaults(&room_id()) + .await + .unwrap(); + + assert_eq!( + machine + .should_share_session(&bob_device, &other_inbound) + .expect_err("Should not share with other unless shared."), + KeyshareDecision::MissingOutboundSession + ); } #[async_test] From 02331fa3258a49e8964ad88d6feea039350ec3b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Damir=20Jeli=C4=87?= Date: Thu, 15 Apr 2021 12:45:51 +0200 Subject: [PATCH 03/11] crypto: Add specialized methods to store outgoing key requests --- matrix_sdk_crypto/src/key_request.rs | 121 +++++------ matrix_sdk_crypto/src/machine.rs | 2 +- matrix_sdk_crypto/src/store/memorystore.rs | 93 +++++--- matrix_sdk_crypto/src/store/mod.rs | 64 +++--- matrix_sdk_crypto/src/store/sled.rs | 234 +++++++++++++++------ 5 files changed, 327 insertions(+), 187 deletions(-) diff --git a/matrix_sdk_crypto/src/key_request.rs b/matrix_sdk_crypto/src/key_request.rs index aaabc80d..f8fab2ff 100644 --- a/matrix_sdk_crypto/src/key_request.rs +++ b/matrix_sdk_crypto/src/key_request.rs @@ -42,7 +42,7 @@ use crate::{ error::{OlmError, OlmResult}, olm::{InboundGroupSession, OutboundGroupSession, Session, ShareState}, requests::{OutgoingRequest, ToDeviceRequest}, - store::{CryptoStoreError, Store}, + store::{Changes, CryptoStoreError, Store}, Device, }; @@ -137,32 +137,24 @@ pub(crate) struct KeyRequestMachine { users_for_key_claim: Arc>>, } -#[derive(Debug, Serialize, Deserialize)] -struct OugoingKeyInfo { - request_id: Uuid, - info: RequestedKeyInfo, - sent_out: bool, +/// A struct describing an outgoing key request. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct OutgoingKeyRequest { + /// The unique id of the key request. + pub request_id: Uuid, + /// The info of the requested key. + pub info: RequestedKeyInfo, + /// Has the request been sent out. + pub sent_out: bool, } -trait Encode { - fn encode(&self) -> String; -} - -impl Encode for RequestedKeyInfo { - fn encode(&self) -> String { - format!( - "{}|{}|{}|{}", - self.sender_key, self.room_id, self.session_id, self.algorithm - ) - } -} - -impl Encode for ForwardedRoomKeyToDeviceEventContent { - fn encode(&self) -> String { - format!( - "{}|{}|{}|{}", - self.sender_key, self.room_id, self.session_id, self.algorithm - ) +impl PartialEq for OutgoingKeyRequest { + fn eq(&self, other: &Self) -> bool { + self.request_id == other.request_id + && self.info.algorithm == other.info.algorithm + && self.info.room_id == other.info.room_id + && self.info.session_id == other.info.session_id + && self.info.sender_key == other.info.sender_key } } @@ -246,6 +238,7 @@ impl KeyRequestMachine { /// key request queue. pub async fn collect_incoming_key_requests(&self) -> OlmResult> { let mut changed_sessions = Vec::new(); + for item in self.incoming_key_requests.iter() { let event = item.value(); if let Some(s) = self.handle_key_request(event).await? { @@ -534,9 +527,9 @@ impl KeyRequestMachine { session_id: session_id.to_owned(), }; - let id: Option = self.store.get_object(&key_info.encode()).await?; + let request = self.store.get_key_request_by_info(&key_info).await?; - if id.is_some() { + if request.is_some() { // We already sent out a request for this key, nothing to do. return Ok(()); } @@ -554,13 +547,13 @@ impl KeyRequestMachine { let request = wrap_key_request_content(self.user_id().clone(), id, &content)?; - let info = OugoingKeyInfo { + let info = OutgoingKeyRequest { request_id: id, info: content.body.unwrap(), sent_out: false, }; - self.save_outgoing_key_info(id, info).await?; + self.save_outgoing_key_info(info).await?; self.outgoing_to_device_requests.insert(id, request); Ok(()) @@ -569,16 +562,11 @@ impl KeyRequestMachine { /// Save an outgoing key info. async fn save_outgoing_key_info( &self, - id: Uuid, - info: OugoingKeyInfo, + info: OutgoingKeyRequest, ) -> Result<(), CryptoStoreError> { - // TODO we'll want to use a transaction to store those atomically. - // To allow this we'll need to rework our cryptostore trait to return - // a transaction trait and the transaction trait will have the save_X - // methods. - let id_string = id.to_string(); - self.store.save_object(&id_string, &info).await?; - self.store.save_object(&info.info.encode(), &id).await?; + let mut changes = Changes::default(); + changes.key_requests.push(info); + self.store.save_changes(changes).await?; Ok(()) } @@ -587,44 +575,43 @@ impl KeyRequestMachine { async fn get_key_info( &self, content: &ForwardedRoomKeyToDeviceEventContent, - ) -> Result, CryptoStoreError> { - let id: Option = self.store.get_object(&content.encode()).await?; + ) -> Result, CryptoStoreError> { + let info = RequestedKeyInfo { + algorithm: content.algorithm.clone(), + room_id: content.room_id.clone(), + sender_key: content.sender_key.clone(), + session_id: content.session_id.clone(), + }; - if let Some(id) = id { - self.store.get_object(&id.to_string()).await - } else { - Ok(None) - } + self.store.get_key_request_by_info(&info).await } /// Delete the given outgoing key info. - async fn delete_key_info(&self, info: &OugoingKeyInfo) -> Result<(), CryptoStoreError> { + async fn delete_key_info(&self, info: &OutgoingKeyRequest) -> Result<(), CryptoStoreError> { self.store - .delete_object(&info.request_id.to_string()) - .await?; - self.store.delete_object(&info.info.encode()).await?; - - Ok(()) + .delete_outgoing_key_request(info.request_id) + .await } /// Mark the outgoing request as sent. - pub async fn mark_outgoing_request_as_sent(&self, id: &Uuid) -> Result<(), CryptoStoreError> { - self.outgoing_to_device_requests.remove(id); - let info: Option = self.store.get_object(&id.to_string()).await?; + pub async fn mark_outgoing_request_as_sent(&self, id: Uuid) -> Result<(), CryptoStoreError> { + let info = self.store.get_outgoing_key_request(id).await?; if let Some(mut info) = info { trace!("Marking outgoing key request as sent {:#?}", info); info.sent_out = true; - self.save_outgoing_key_info(*id, info).await?; + self.save_outgoing_key_info(info).await?; } + self.outgoing_to_device_requests.remove(&id); + Ok(()) } /// Mark the given outgoing key info as done. /// /// This will queue up a request cancelation. - async fn mark_as_done(&self, key_info: OugoingKeyInfo) -> Result<(), CryptoStoreError> { + async fn mark_as_done(&self, key_info: OutgoingKeyRequest) -> Result<(), CryptoStoreError> { // TODO perhaps only remove the key info if the first known index is 0. trace!( "Successfully received a forwarded room key for {:#?}", @@ -847,7 +834,7 @@ mod test { let id = request.request_id; drop(request); - machine.mark_outgoing_request_as_sent(&id).await.unwrap(); + machine.mark_outgoing_request_as_sent(id).await.unwrap(); assert!(machine.outgoing_to_device_requests.is_empty()); } @@ -873,7 +860,7 @@ mod test { let id = request.request_id; drop(request); - machine.mark_outgoing_request_as_sent(&id).await.unwrap(); + machine.mark_outgoing_request_as_sent(id).await.unwrap(); let export = session.export_at_index(10).await; @@ -915,7 +902,7 @@ mod test { let request = machine.outgoing_to_device_requests.iter().next().unwrap(); let id = request.request_id; drop(request); - machine.mark_outgoing_request_as_sent(&id).await.unwrap(); + machine.mark_outgoing_request_as_sent(id).await.unwrap(); machine .create_outgoing_key_request( @@ -930,7 +917,7 @@ mod test { let id = request.request_id; drop(request); - machine.mark_outgoing_request_as_sent(&id).await.unwrap(); + machine.mark_outgoing_request_as_sent(id).await.unwrap(); let export = session.export_at_index(15).await; @@ -1148,7 +1135,7 @@ mod test { drop(request); alice_machine - .mark_outgoing_request_as_sent(&id) + .mark_outgoing_request_as_sent(id) .await .unwrap(); @@ -1186,10 +1173,7 @@ mod test { let content: EncryptedEventContent = serde_json::from_str(content.get()).unwrap(); drop(request); - bob_machine - .mark_outgoing_request_as_sent(&id) - .await - .unwrap(); + bob_machine.mark_outgoing_request_as_sent(id).await.unwrap(); let event = ToDeviceEvent { sender: bob_id(), @@ -1317,7 +1301,7 @@ mod test { drop(request); alice_machine - .mark_outgoing_request_as_sent(&id) + .mark_outgoing_request_as_sent(id) .await .unwrap(); @@ -1378,10 +1362,7 @@ mod test { let content: EncryptedEventContent = serde_json::from_str(content.get()).unwrap(); drop(request); - bob_machine - .mark_outgoing_request_as_sent(&id) - .await - .unwrap(); + bob_machine.mark_outgoing_request_as_sent(id).await.unwrap(); let event = ToDeviceEvent { sender: bob_id(), diff --git a/matrix_sdk_crypto/src/machine.rs b/matrix_sdk_crypto/src/machine.rs index 415b80f7..601be381 100644 --- a/matrix_sdk_crypto/src/machine.rs +++ b/matrix_sdk_crypto/src/machine.rs @@ -751,7 +751,7 @@ impl OlmMachine { async fn mark_to_device_request_as_sent(&self, request_id: &Uuid) -> StoreResult<()> { self.verification_machine.mark_request_as_sent(request_id); self.key_request_machine - .mark_outgoing_request_as_sent(request_id) + .mark_outgoing_request_as_sent(*request_id) .await?; self.group_session_manager .mark_request_as_sent(request_id) diff --git a/matrix_sdk_crypto/src/store/memorystore.rs b/matrix_sdk_crypto/src/store/memorystore.rs index 3d249c82..27993e25 100644 --- a/matrix_sdk_crypto/src/store/memorystore.rs +++ b/matrix_sdk_crypto/src/store/memorystore.rs @@ -20,8 +20,10 @@ use std::{ use dashmap::{DashMap, DashSet}; use matrix_sdk_common::{ async_trait, + events::room_key_request::RequestedKeyInfo, identifiers::{DeviceId, DeviceIdBox, RoomId, UserId}, locks::Mutex, + uuid::Uuid, }; use super::{ @@ -30,9 +32,17 @@ use super::{ }; use crate::{ identities::{ReadOnlyDevice, UserIdentities}, + key_request::OutgoingKeyRequest, olm::{OutboundGroupSession, PrivateCrossSigningIdentity}, }; +fn encode_key_info(info: &RequestedKeyInfo) -> String { + format!( + "{}{}{}{}", + &info.room_id, &info.sender_key, &info.algorithm, &info.session_id + ) +} + /// An in-memory only store that will forget all the E2EE key once it's dropped. #[derive(Debug, Clone)] pub struct MemoryStore { @@ -43,7 +53,8 @@ pub struct MemoryStore { olm_hashes: Arc>>, devices: DeviceStore, identities: Arc>, - values: Arc>, + outgoing_key_requests: Arc>, + key_requests_by_info: Arc>, } impl Default for MemoryStore { @@ -56,7 +67,8 @@ impl Default for MemoryStore { olm_hashes: Arc::new(DashMap::new()), devices: DeviceStore::new(), identities: Arc::new(DashMap::new()), - values: Arc::new(DashMap::new()), + outgoing_key_requests: Arc::new(DashMap::new()), + key_requests_by_info: Arc::new(DashMap::new()), } } } @@ -103,6 +115,10 @@ impl CryptoStore for MemoryStore { Ok(()) } + async fn load_identity(&self) -> Result> { + Ok(None) + } + async fn save_changes(&self, mut changes: Changes) -> Result<()> { self.save_sessions(changes.sessions).await; self.save_inbound_group_sessions(changes.inbound_group_sessions) @@ -130,6 +146,14 @@ impl CryptoStore for MemoryStore { .insert(hash.hash.clone()); } + for key_request in changes.key_requests { + let id = key_request.request_id; + let info_string = encode_key_info(&key_request.info); + + self.outgoing_key_requests.insert(id, key_request); + self.key_requests_by_info.insert(info_string, id); + } + Ok(()) } @@ -152,9 +176,11 @@ impl CryptoStore for MemoryStore { Ok(self.inbound_group_sessions.get_all()) } - fn users_for_key_query(&self) -> HashSet { - #[allow(clippy::map_clone)] - self.users_for_key_query.iter().map(|u| u.clone()).collect() + async fn get_outbound_group_sessions( + &self, + _: &RoomId, + ) -> Result> { + Ok(None) } fn is_user_tracked(&self, user_id: &UserId) -> bool { @@ -165,6 +191,11 @@ impl CryptoStore for MemoryStore { !self.users_for_key_query.is_empty() } + fn users_for_key_query(&self) -> HashSet { + #[allow(clippy::map_clone)] + self.users_for_key_query.iter().map(|u| u.clone()).collect() + } + async fn update_tracked_user(&self, user: &UserId, dirty: bool) -> Result { // TODO to prevent a race between the sync and a key query in flight we // need to have an additional state to mention that the user changed. @@ -207,24 +238,6 @@ impl CryptoStore for MemoryStore { Ok(self.identities.get(user_id).map(|i| i.clone())) } - async fn save_value(&self, key: String, value: String) -> Result<()> { - self.values.insert(key, value); - Ok(()) - } - - async fn remove_value(&self, key: &str) -> Result<()> { - self.values.remove(key); - Ok(()) - } - - async fn get_value(&self, key: &str) -> Result> { - Ok(self.values.get(key).map(|v| v.to_owned())) - } - - async fn load_identity(&self) -> Result> { - Ok(None) - } - async fn is_message_known(&self, message_hash: &crate::olm::OlmMessageHash) -> Result { Ok(self .olm_hashes @@ -233,11 +246,37 @@ impl CryptoStore for MemoryStore { .contains(&message_hash.hash)) } - async fn get_outbound_group_sessions( + async fn get_outgoing_key_request( &self, - _: &RoomId, - ) -> Result> { - Ok(None) + request_id: Uuid, + ) -> Result> { + Ok(self + .outgoing_key_requests + .get(&request_id) + .map(|r| r.clone())) + } + + async fn get_key_request_by_info( + &self, + key_info: &RequestedKeyInfo, + ) -> Result> { + let key_info_string = encode_key_info(key_info); + + Ok(self + .key_requests_by_info + .get(&key_info_string) + .and_then(|i| self.outgoing_key_requests.get(&i).map(|r| r.clone()))) + } + + async fn delete_outgoing_key_request(&self, request_id: Uuid) -> Result<()> { + self.outgoing_key_requests + .remove(&request_id) + .and_then(|(_, i)| { + let key_info_string = encode_key_info(&i.info); + self.key_requests_by_info.remove(&key_info_string) + }); + + Ok(()) } } diff --git a/matrix_sdk_crypto/src/store/mod.rs b/matrix_sdk_crypto/src/store/mod.rs index a837dfd0..387bd610 100644 --- a/matrix_sdk_crypto/src/store/mod.rs +++ b/matrix_sdk_crypto/src/store/mod.rs @@ -57,23 +57,25 @@ use std::{ }; use olm_rs::errors::{OlmAccountError, OlmGroupSessionError, OlmSessionError}; -use serde::{Deserialize, Serialize}; use serde_json::Error as SerdeError; use thiserror::Error; use matrix_sdk_common::{ async_trait, + events::room_key_request::RequestedKeyInfo, identifiers::{ DeviceId, DeviceIdBox, DeviceKeyAlgorithm, Error as IdentifierValidationError, RoomId, UserId, }, locks::Mutex, + uuid::Uuid, AsyncTraitDeps, }; use crate::{ error::SessionUnpicklingError, identities::{Device, ReadOnlyDevice, UserDevices, UserIdentities}, + key_request::OutgoingKeyRequest, olm::{ InboundGroupSession, OlmMessageHash, OutboundGroupSession, PrivateCrossSigningIdentity, ReadOnlyAccount, Session, @@ -108,6 +110,7 @@ pub struct Changes { pub inbound_group_sessions: Vec, pub outbound_group_sessions: Vec, pub identities: IdentityChanges, + pub key_requests: Vec, pub devices: DeviceChanges, } @@ -257,24 +260,6 @@ impl Store { device_owner_identity, })) } - - pub async fn get_object Deserialize<'b>>(&self, key: &str) -> Result> { - if let Some(value) = self.get_value(key).await? { - Ok(Some(serde_json::from_str(&value)?)) - } else { - Ok(None) - } - } - - pub async fn save_object(&self, key: &str, value: &impl Serialize) -> Result<()> { - let value = serde_json::to_string(value)?; - self.save_value(key.to_owned(), value).await - } - - pub async fn delete_object(&self, key: &str) -> Result<()> { - self.inner.remove_value(key).await?; - Ok(()) - } } impl Deref for Store { @@ -438,15 +423,38 @@ pub trait CryptoStore: AsyncTraitDeps { /// * `user_id` - The user for which we should get the identity. async fn get_user_identity(&self, user_id: &UserId) -> Result>; - /// Save a serializeable object in the store. - async fn save_value(&self, key: String, value: String) -> Result<()>; - - /// Remove a value from the store. - async fn remove_value(&self, key: &str) -> Result<()>; - - /// Load a serializeable object from the store. - async fn get_value(&self, key: &str) -> Result>; - /// Check if a hash for an Olm message stored in the database. async fn is_message_known(&self, message_hash: &OlmMessageHash) -> Result; + + /// Get an outoing key request that we created that matches the given + /// request id. + /// + /// # Arguments + /// + /// * `request_id` - The unique request id that identifies this outgoing key + /// request. + async fn get_outgoing_key_request( + &self, + request_id: Uuid, + ) -> Result>; + + /// Get an outoing key request that we created that matches the given + /// requested key info. + /// + /// # Arguments + /// + /// * `key_info` - The key info of an outgoing key request. + async fn get_key_request_by_info( + &self, + key_info: &RequestedKeyInfo, + ) -> Result>; + + /// Delete an outoing key request that we created that matches the given + /// request id. + /// + /// # Arguments + /// + /// * `request_id` - The unique request id that identifies this outgoing key + /// request. + async fn delete_outgoing_key_request(&self, request_id: Uuid) -> Result<()>; } diff --git a/matrix_sdk_crypto/src/store/sled.rs b/matrix_sdk_crypto/src/store/sled.rs index 00950d6f..f177f10d 100644 --- a/matrix_sdk_crypto/src/store/sled.rs +++ b/matrix_sdk_crypto/src/store/sled.rs @@ -29,9 +29,12 @@ use sled::{ use matrix_sdk_common::{ async_trait, + events::room_key_request::RequestedKeyInfo, identifiers::{DeviceId, DeviceIdBox, RoomId, UserId}, locks::Mutex, + uuid, }; +use uuid::Uuid; use super::{ caches::SessionStore, Changes, CryptoStore, CryptoStoreError, InboundGroupSession, PickleKey, @@ -39,6 +42,7 @@ use super::{ }; use crate::{ identities::{ReadOnlyDevice, UserIdentities}, + key_request::OutgoingKeyRequest, olm::{OutboundGroupSession, PickledInboundGroupSession, PrivateCrossSigningIdentity}, }; @@ -51,6 +55,28 @@ trait EncodeKey { fn encode(&self) -> Vec; } +impl EncodeKey for Uuid { + fn encode(&self) -> Vec { + self.as_u128().to_be_bytes().to_vec() + } +} + +impl EncodeKey for &RequestedKeyInfo { + fn encode(&self) -> Vec { + [ + self.room_id.as_bytes(), + &[Self::SEPARATOR], + self.sender_key.as_bytes(), + &[Self::SEPARATOR], + self.algorithm.as_ref().as_bytes(), + &[Self::SEPARATOR], + self.session_id.as_bytes(), + &[Self::SEPARATOR], + ] + .concat() + } +} + impl EncodeKey for &UserId { fn encode(&self) -> Vec { self.as_str().encode() @@ -122,12 +148,14 @@ pub struct SledStore { inbound_group_sessions: Tree, outbound_group_sessions: Tree, + outgoing_key_requests: Tree, + key_requests_by_info: Tree, + devices: Tree, identities: Tree, tracked_users: Tree, users_for_key_query: Tree, - values: Tree, } impl std::fmt::Debug for SledStore { @@ -178,13 +206,16 @@ impl SledStore { let sessions = db.open_tree("session")?; let inbound_group_sessions = db.open_tree("inbound_group_sessions")?; let outbound_group_sessions = db.open_tree("outbound_group_sessions")?; + let tracked_users = db.open_tree("tracked_users")?; let users_for_key_query = db.open_tree("users_for_key_query")?; let olm_hashes = db.open_tree("olm_hashes")?; let devices = db.open_tree("devices")?; let identities = db.open_tree("identities")?; - let values = db.open_tree("values")?; + + let outgoing_key_requests = db.open_tree("outgoing_key_requests")?; + let key_requests_by_info = db.open_tree("key_requests_by_info")?; let session_cache = SessionStore::new(); @@ -208,12 +239,13 @@ impl SledStore { users_for_key_query_cache: DashSet::new().into(), inbound_group_sessions, outbound_group_sessions, + outgoing_key_requests, + key_requests_by_info, devices, tracked_users, users_for_key_query, olm_hashes, identities, - values, }) } @@ -332,6 +364,7 @@ impl SledStore { let identity_changes = changes.identities; let olm_hashes = changes.message_hashes; + let key_requests = changes.key_requests; let ret: Result<(), TransactionError> = ( &self.account, @@ -342,6 +375,8 @@ impl SledStore { &self.inbound_group_sessions, &self.outbound_group_sessions, &self.olm_hashes, + &self.outgoing_key_requests, + &self.key_requests_by_info, ) .transaction( |( @@ -353,6 +388,8 @@ impl SledStore { inbound_sessions, outbound_sessions, hashes, + outgoing_key_requests, + key_requests_by_info, )| { if let Some(a) = &account_pickle { account.insert( @@ -420,6 +457,19 @@ impl SledStore { )?; } + for key_request in &key_requests { + key_requests_by_info.insert( + (&key_request.info).encode(), + key_request.request_id.encode(), + )?; + + outgoing_key_requests.insert( + key_request.request_id.encode(), + serde_json::to_vec(&key_request) + .map_err(ConflictableTransactionError::Abort)?, + )?; + } + Ok(()) }, ); @@ -472,6 +522,19 @@ impl CryptoStore for SledStore { self.save_changes(changes).await } + async fn load_identity(&self) -> Result> { + if let Some(i) = self.private_identity.get("identity".encode())? { + let pickle = serde_json::from_slice(&i)?; + Ok(Some( + PrivateCrossSigningIdentity::from_pickle(pickle, self.get_pickle_key()) + .await + .map_err(|_| CryptoStoreError::UnpicklingError)?, + )) + } else { + Ok(None) + } + } + async fn save_changes(&self, changes: Changes) -> Result<()> { self.save_changes(changes).await } @@ -539,12 +602,11 @@ impl CryptoStore for SledStore { .collect()) } - fn users_for_key_query(&self) -> HashSet { - #[allow(clippy::map_clone)] - self.users_for_key_query_cache - .iter() - .map(|u| u.clone()) - .collect() + async fn get_outbound_group_sessions( + &self, + room_id: &RoomId, + ) -> Result> { + self.load_outbound_group_session(room_id).await } fn is_user_tracked(&self, user_id: &UserId) -> bool { @@ -555,6 +617,14 @@ impl CryptoStore for SledStore { !self.users_for_key_query_cache.is_empty() } + fn users_for_key_query(&self) -> HashSet { + #[allow(clippy::map_clone)] + self.users_for_key_query_cache + .iter() + .map(|u| u.clone()) + .collect() + } + async fn update_tracked_user(&self, user: &UserId, dirty: bool) -> Result { let already_added = self.tracked_users_cache.insert(user.clone()); @@ -605,48 +675,62 @@ impl CryptoStore for SledStore { .transpose()?) } - async fn save_value(&self, key: String, value: String) -> Result<()> { - self.values.insert(key.as_str().encode(), value.as_str())?; - self.inner.flush_async().await?; - Ok(()) - } - - async fn remove_value(&self, key: &str) -> Result<()> { - self.values.remove(key.encode())?; - Ok(()) - } - - async fn get_value(&self, key: &str) -> Result> { - Ok(self - .values - .get(key.encode())? - .map(|v| String::from_utf8_lossy(&v).to_string())) - } - - async fn load_identity(&self) -> Result> { - if let Some(i) = self.private_identity.get("identity".encode())? { - let pickle = serde_json::from_slice(&i)?; - Ok(Some( - PrivateCrossSigningIdentity::from_pickle(pickle, self.get_pickle_key()) - .await - .map_err(|_| CryptoStoreError::UnpicklingError)?, - )) - } else { - Ok(None) - } - } - async fn is_message_known(&self, message_hash: &crate::olm::OlmMessageHash) -> Result { Ok(self .olm_hashes .contains_key(serde_json::to_vec(message_hash)?)?) } - async fn get_outbound_group_sessions( + async fn get_outgoing_key_request( &self, - room_id: &RoomId, - ) -> Result> { - self.load_outbound_group_session(room_id).await + request_id: Uuid, + ) -> Result> { + Ok(self + .outgoing_key_requests + .get(request_id.encode())? + .map(|r| serde_json::from_slice(&r)) + .transpose()?) + } + + async fn get_key_request_by_info( + &self, + key_info: &RequestedKeyInfo, + ) -> Result> { + let id = self.key_requests_by_info.get(key_info.encode())?; + + if let Some(id) = id { + Ok(self + .outgoing_key_requests + .get(id)? + .map(|r| serde_json::from_slice(&r)) + .transpose()?) + } else { + Ok(None) + } + } + + async fn delete_outgoing_key_request(&self, request_id: Uuid) -> Result<()> { + let ret: Result<(), TransactionError> = + (&self.outgoing_key_requests, &self.key_requests_by_info).transaction( + |(outgoing_key_requests, key_requests_by_info)| { + let request: Option = outgoing_key_requests + .remove(request_id.encode())? + .map(|r| serde_json::from_slice(&r)) + .transpose() + .map_err(ConflictableTransactionError::Abort)?; + + if let Some(request) = request { + key_requests_by_info.remove((&request.info).encode())?; + } + + Ok(()) + }, + ); + + ret?; + self.inner.flush_async().await?; + + Ok(()) } } @@ -665,14 +749,16 @@ mod test { }; use matrix_sdk_common::{ api::r0::keys::SignedKey, - identifiers::{room_id, user_id, DeviceId, UserId}, + events::room_key_request::RequestedKeyInfo, + identifiers::{room_id, user_id, DeviceId, EventEncryptionAlgorithm, UserId}, + uuid::Uuid, }; use matrix_sdk_test::async_test; use olm_rs::outbound_group_session::OlmOutboundGroupSession; use std::collections::BTreeMap; use tempfile::tempdir; - use super::{CryptoStore, SledStore}; + use super::{CryptoStore, OutgoingKeyRequest, SledStore}; fn alice_id() -> UserId { user_id!("@alice:example.org") @@ -1184,21 +1270,6 @@ mod test { assert_eq!(identity.user_id(), loaded_identity.user_id()); } - #[async_test] - async fn key_value_saving() { - let (_, store, _dir) = get_loaded_store().await; - let key = "test_key".to_string(); - let value = "secret value".to_string(); - - store.save_value(key.clone(), value.clone()).await.unwrap(); - let stored_value = store.get_value(&key).await.unwrap().unwrap(); - - assert_eq!(value, stored_value); - - store.remove_value(&key).await.unwrap(); - assert!(store.get_value(&key).await.unwrap().is_none()); - } - #[async_test] async fn olm_hash_saving() { let (_, store, _dir) = get_loaded_store().await; @@ -1215,4 +1286,45 @@ mod test { store.save_changes(changes).await.unwrap(); assert!(store.is_message_known(&hash).await.unwrap()); } + + #[async_test] + async fn key_request_saving() { + let (_, store, _dir) = get_loaded_store().await; + + let id = Uuid::new_v4(); + let info = RequestedKeyInfo { + algorithm: EventEncryptionAlgorithm::MegolmV1AesSha2, + room_id: room_id!("!test:localhost"), + sender_key: "test_sender_key".to_string(), + session_id: "test_session_id".to_string(), + }; + + let request = OutgoingKeyRequest { + request_id: id, + info: info.clone(), + sent_out: false, + }; + + assert!(store.get_outgoing_key_request(id).await.unwrap().is_none()); + + let mut changes = Changes::default(); + changes.key_requests.push(request.clone()); + store.save_changes(changes).await.unwrap(); + + let request = Some(request); + + let stored_request = store.get_outgoing_key_request(id).await.unwrap(); + assert_eq!(request, stored_request); + + let stored_request = store.get_key_request_by_info(&info).await.unwrap(); + assert_eq!(request, stored_request); + + store.delete_outgoing_key_request(id).await.unwrap(); + + let stored_request = store.get_outgoing_key_request(id).await.unwrap(); + assert_eq!(None, stored_request); + + let stored_request = store.get_key_request_by_info(&info).await.unwrap(); + assert_eq!(None, stored_request); + } } From 9e817a623b997c78cb34cb8d48cf775f75b9d2d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Damir=20Jeli=C4=87?= Date: Thu, 15 Apr 2021 15:01:56 +0200 Subject: [PATCH 04/11] crypto: Fix an invalid assert in the crypto bench --- matrix_sdk_crypto/benches/crypto_bench.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/matrix_sdk_crypto/benches/crypto_bench.rs b/matrix_sdk_crypto/benches/crypto_bench.rs index 271be9b7..6e656ae5 100644 --- a/matrix_sdk_crypto/benches/crypto_bench.rs +++ b/matrix_sdk_crypto/benches/crypto_bench.rs @@ -213,7 +213,7 @@ pub fn room_key_sharing(c: &mut Criterion) { .await .unwrap(); - assert!(requests.len() >= 8); + assert!(!requests.is_empty()); for request in requests { machine @@ -249,7 +249,7 @@ pub fn room_key_sharing(c: &mut Criterion) { .await .unwrap(); - assert!(requests.len() >= 8); + assert!(!requests.is_empty()); for request in requests { machine From d4c56cc5b30e47395652bf144c1d441e493a74c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Damir=20Jeli=C4=87?= Date: Thu, 15 Apr 2021 15:19:21 +0200 Subject: [PATCH 05/11] crypto: Refactor the outobund group session storing This introduces a group session cache struct that can be shared between components that need to access the currently active group session. --- matrix_sdk_crypto/src/key_request.rs | 46 +++++--- matrix_sdk_crypto/src/machine.rs | 20 ++-- .../src/session_manager/group_sessions.rs | 105 +++++++++++------- matrix_sdk_crypto/src/session_manager/mod.rs | 2 +- .../src/session_manager/sessions.rs | 6 +- 5 files changed, 112 insertions(+), 67 deletions(-) diff --git a/matrix_sdk_crypto/src/key_request.rs b/matrix_sdk_crypto/src/key_request.rs index f8fab2ff..34a3b8a3 100644 --- a/matrix_sdk_crypto/src/key_request.rs +++ b/matrix_sdk_crypto/src/key_request.rs @@ -40,8 +40,9 @@ use matrix_sdk_common::{ use crate::{ error::{OlmError, OlmResult}, - olm::{InboundGroupSession, OutboundGroupSession, Session, ShareState}, + olm::{InboundGroupSession, Session, ShareState}, requests::{OutgoingRequest, ToDeviceRequest}, + session_manager::GroupSessionCache, store::{Changes, CryptoStoreError, Store}, Device, }; @@ -128,7 +129,7 @@ pub(crate) struct KeyRequestMachine { user_id: Arc, device_id: Arc, store: Store, - outbound_group_sessions: Arc>, + outbound_group_sessions: GroupSessionCache, outgoing_to_device_requests: Arc>, incoming_key_requests: Arc< DashMap<(UserId, DeviceIdBox, String), ToDeviceEvent>, @@ -188,7 +189,7 @@ impl KeyRequestMachine { user_id: Arc, device_id: Arc, store: Store, - outbound_group_sessions: Arc>, + outbound_group_sessions: GroupSessionCache, users_for_key_claim: Arc>>, ) -> Self { Self { @@ -356,7 +357,7 @@ impl KeyRequestMachine { .await?; if let Some(device) = device { - match self.should_share_session(&device, &session) { + match self.should_share_session(&device, &session).await { Err(e) => { info!( "Received a key request from {} {} that we won't serve: {}", @@ -458,14 +459,17 @@ impl KeyRequestMachine { /// * `device` - The device that is requesting a session from us. /// /// * `session` - The session that was requested to be shared. - fn should_share_session( + async fn should_share_session( &self, device: &Device, session: &InboundGroupSession, ) -> Result, KeyshareDecision> { let outbound_session = self .outbound_group_sessions - .get(session.room_id()) + .get_or_load(session.room_id()) + .await + .ok() + .flatten() .filter(|o| session.session_id() == o.session_id()); let own_device_check = || { @@ -720,6 +724,7 @@ mod test { use crate::{ identities::{LocalTrust, ReadOnlyDevice}, olm::{Account, PrivateCrossSigningIdentity, ReadOnlyAccount}, + session_manager::GroupSessionCache, store::{Changes, CryptoStore, MemoryStore, Store}, verification::VerificationMachine, }; @@ -761,12 +766,13 @@ mod test { let identity = Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty(bob_id()))); let verification = VerificationMachine::new(account, identity.clone(), store.clone()); let store = Store::new(user_id.clone(), identity, store, verification); + let session_cache = GroupSessionCache::new(store.clone()); KeyRequestMachine::new( user_id, Arc::new(bob_device_id()), store, - Arc::new(DashMap::new()), + session_cache, Arc::new(DashMap::new()), ) } @@ -780,12 +786,13 @@ mod test { let verification = VerificationMachine::new(account, identity.clone(), store.clone()); let store = Store::new(user_id.clone(), identity, store, verification); store.save_devices(&[device]).await.unwrap(); + let session_cache = GroupSessionCache::new(store.clone()); KeyRequestMachine::new( user_id, Arc::new(alice_device_id()), store, - Arc::new(DashMap::new()), + session_cache, Arc::new(DashMap::new()), ) } @@ -973,12 +980,16 @@ mod test { assert_eq!( machine .should_share_session(&own_device, &inbound) + .await .expect_err("Should not share with untrusted"), KeyshareDecision::UntrustedDevice ); own_device.set_trust_state(LocalTrust::Verified); // Now we do want to share the keys. - assert!(machine.should_share_session(&own_device, &inbound).is_ok()); + assert!(machine + .should_share_session(&own_device, &inbound) + .await + .is_ok()); let bob_device = ReadOnlyDevice::from_account(&bob_account()).await; machine.store.save_devices(&[bob_device]).await.unwrap(); @@ -995,6 +1006,7 @@ mod test { assert_eq!( machine .should_share_session(&bob_device, &inbound) + .await .expect_err("Should not share with other."), KeyshareDecision::MissingOutboundSession ); @@ -1004,15 +1016,14 @@ mod test { changes.outbound_group_sessions.push(outbound.clone()); changes.inbound_group_sessions.push(inbound.clone()); machine.store.save_changes(changes).await.unwrap(); - machine - .outbound_group_sessions - .insert(inbound.room_id().to_owned(), outbound.clone()); + machine.outbound_group_sessions.insert(outbound.clone()); // We don't share sessions with other user's devices if the session // wasn't shared in the first place. assert_eq!( machine .should_share_session(&bob_device, &inbound) + .await .expect_err("Should not share with other unless shared."), KeyshareDecision::OutboundSessionNotShared ); @@ -1024,13 +1035,17 @@ mod test { assert_eq!( machine .should_share_session(&bob_device, &inbound) + .await .expect_err("Should not share with other unless shared."), KeyshareDecision::OutboundSessionNotShared ); // We now share the session, since it was shared before. outbound.mark_shared_with(bob_device.user_id(), bob_device.device_id()); - assert!(machine.should_share_session(&bob_device, &inbound).is_ok()); + assert!(machine + .should_share_session(&bob_device, &inbound) + .await + .is_ok()); // But we don't share some other session that doesn't match our outbound // session @@ -1042,6 +1057,7 @@ mod test { assert_eq!( machine .should_share_session(&bob_device, &other_inbound) + .await .expect_err("Should not share with other unless shared."), KeyshareDecision::MissingOutboundSession ); @@ -1112,7 +1128,7 @@ mod test { // Put the outbound session into bobs store. bob_machine .outbound_group_sessions - .insert(room_id(), group_session.clone()); + .insert(group_session.clone()); // Get the request and convert it into a event. let request = alice_machine @@ -1278,7 +1294,7 @@ mod test { // Put the outbound session into bobs store. bob_machine .outbound_group_sessions - .insert(room_id(), group_session.clone()); + .insert(group_session.clone()); // Get the request and convert it into a event. let request = alice_machine diff --git a/matrix_sdk_crypto/src/machine.rs b/matrix_sdk_crypto/src/machine.rs index 601be381..59a99457 100644 --- a/matrix_sdk_crypto/src/machine.rs +++ b/matrix_sdk_crypto/src/machine.rs @@ -156,29 +156,29 @@ impl OlmMachine { verification_machine.clone(), ); let device_id: Arc = Arc::new(device_id); - let outbound_group_sessions = Arc::new(DashMap::new()); let users_for_key_claim = Arc::new(DashMap::new()); - let key_request_machine = KeyRequestMachine::new( - user_id.clone(), - device_id.clone(), - store.clone(), - outbound_group_sessions, - users_for_key_claim.clone(), - ); - let account = Account { inner: account, store: store.clone(), }; + let group_session_manager = GroupSessionManager::new(account.clone(), store.clone()); + + let key_request_machine = KeyRequestMachine::new( + user_id.clone(), + device_id.clone(), + store.clone(), + group_session_manager.session_cache(), + users_for_key_claim.clone(), + ); + let session_manager = SessionManager::new( account.clone(), users_for_key_claim, key_request_machine.clone(), store.clone(), ); - let group_session_manager = GroupSessionManager::new(account.clone(), store.clone()); let identity_manager = IdentityManager::new(user_id.clone(), device_id.clone(), store.clone()); diff --git a/matrix_sdk_crypto/src/session_manager/group_sessions.rs b/matrix_sdk_crypto/src/session_manager/group_sessions.rs index 649664a8..a656d00e 100644 --- a/matrix_sdk_crypto/src/session_manager/group_sessions.rs +++ b/matrix_sdk_crypto/src/session_manager/group_sessions.rs @@ -40,6 +40,57 @@ use crate::{ Device, EncryptionSettings, OlmError, ToDeviceRequest, }; +#[derive(Clone, Debug)] +pub(crate) struct GroupSessionCache { + store: Store, + sessions: Arc>, + /// A map from the request id to the group session that the request belongs + /// to. Used to mark requests belonging to the session as shared. + sessions_being_shared: Arc>, +} + +impl GroupSessionCache { + pub(crate) fn new(store: Store) -> Self { + Self { + store, + sessions: DashMap::new().into(), + sessions_being_shared: Arc::new(DashMap::new()), + } + } + + pub(crate) fn insert(&self, session: OutboundGroupSession) { + self.sessions.insert(session.room_id().to_owned(), session); + } + + pub async fn get_or_load(&self, room_id: &RoomId) -> StoreResult> { + // Get the cached session, if there isn't one load one from the store + // and put it in the cache. + if let Some(s) = self.sessions.get(room_id) { + Ok(Some(s.clone())) + } else if let Some(s) = self.store.get_outbound_group_sessions(room_id).await? { + for request_id in s.pending_request_ids() { + self.sessions_being_shared.insert(request_id, s.clone()); + } + + self.sessions.insert(room_id.clone(), s.clone()); + + Ok(Some(s)) + } else { + Ok(None) + } + } + + /// Get an outbound group session for a room, if one exists. + /// + /// # Arguments + /// + /// * `room_id` - The id of the room for which we should get the outbound + /// group session. + fn get(&self, room_id: &RoomId) -> Option { + self.sessions.get(room_id).map(|s| s.clone()) + } +} + #[derive(Debug, Clone)] pub struct GroupSessionManager { account: Account, @@ -48,10 +99,7 @@ pub struct GroupSessionManager { /// without the need to create new keys. store: Store, /// The currently active outbound group sessions. - outbound_group_sessions: Arc>, - /// A map from the request id to the group session that the request belongs - /// to. Used to mark requests belonging to the session as shared. - outbound_sessions_being_shared: Arc>, + sessions: GroupSessionCache, } impl GroupSessionManager { @@ -60,14 +108,13 @@ impl GroupSessionManager { pub(crate) fn new(account: Account, store: Store) -> Self { Self { account, - store, - outbound_group_sessions: Arc::new(DashMap::new()), - outbound_sessions_being_shared: Arc::new(DashMap::new()), + store: store.clone(), + sessions: GroupSessionCache::new(store), } } pub async fn invalidate_group_session(&self, room_id: &RoomId) -> StoreResult { - if let Some(s) = self.outbound_group_sessions.get(room_id) { + if let Some(s) = self.sessions.get(room_id) { s.invalidate_session(); let mut changes = Changes::default(); @@ -81,7 +128,7 @@ impl GroupSessionManager { } pub async fn mark_request_as_sent(&self, request_id: &Uuid) -> StoreResult<()> { - if let Some((_, s)) = self.outbound_sessions_being_shared.remove(request_id) { + if let Some((_, s)) = self.sessions.sessions_being_shared.remove(request_id) { s.mark_request_as_sent(request_id); let mut changes = Changes::default(); @@ -97,15 +144,9 @@ impl GroupSessionManager { Ok(()) } - /// Get an outbound group session for a room, if one exists. - /// - /// # Arguments - /// - /// * `room_id` - The id of the room for which we should get the outbound - /// group session. + #[cfg(test)] pub fn get_outbound_group_session(&self, room_id: &RoomId) -> Option { - #[allow(clippy::map_clone)] - self.outbound_group_sessions.get(room_id).map(|s| s.clone()) + self.sessions.get(room_id) } pub async fn encrypt( @@ -113,7 +154,7 @@ impl GroupSessionManager { room_id: &RoomId, content: AnyMessageEventContent, ) -> MegolmResult { - let session = if let Some(s) = self.get_outbound_group_session(room_id) { + let session = if let Some(s) = self.sessions.get(room_id) { s } else { panic!("Session wasn't created nor shared"); @@ -147,9 +188,7 @@ impl GroupSessionManager { .await .map_err(|_| EventError::UnsupportedAlgorithm)?; - let _ = self - .outbound_group_sessions - .insert(room_id.to_owned(), outbound.clone()); + self.sessions.insert(outbound.clone()); Ok((outbound, inbound)) } @@ -158,23 +197,7 @@ impl GroupSessionManager { room_id: &RoomId, settings: EncryptionSettings, ) -> OlmResult<(OutboundGroupSession, Option)> { - // Get the cached session, if there isn't one load one from the store - // and put it in the cache. - let outbound_session = if let Some(s) = self.outbound_group_sessions.get(room_id) { - Some(s.clone()) - } else if let Some(s) = self.store.get_outbound_group_sessions(room_id).await? { - for request_id in s.pending_request_ids() { - self.outbound_sessions_being_shared - .insert(request_id, s.clone()); - } - - self.outbound_group_sessions - .insert(room_id.clone(), s.clone()); - - Some(s) - } else { - None - }; + let outbound_session = self.sessions.get_or_load(&room_id).await?; // If there is no session or the session has expired or is invalid, // create a new one. @@ -388,6 +411,10 @@ impl GroupSessionManager { Ok(used_sessions) } + pub(crate) fn session_cache(&self) -> GroupSessionCache { + self.sessions.clone() + } + /// Get to-device requests to share a group session with users in a room. /// /// # Arguments @@ -489,7 +516,7 @@ impl GroupSessionManager { key_content.clone(), outbound.clone(), message_index, - self.outbound_sessions_being_shared.clone(), + self.sessions.sessions_being_shared.clone(), )) }) .collect(); diff --git a/matrix_sdk_crypto/src/session_manager/mod.rs b/matrix_sdk_crypto/src/session_manager/mod.rs index 7750262e..1af686ef 100644 --- a/matrix_sdk_crypto/src/session_manager/mod.rs +++ b/matrix_sdk_crypto/src/session_manager/mod.rs @@ -15,5 +15,5 @@ mod group_sessions; mod sessions; -pub(crate) use group_sessions::GroupSessionManager; +pub(crate) use group_sessions::{GroupSessionCache, GroupSessionManager}; pub(crate) use sessions::SessionManager; diff --git a/matrix_sdk_crypto/src/session_manager/sessions.rs b/matrix_sdk_crypto/src/session_manager/sessions.rs index 9417c3ea..274314a8 100644 --- a/matrix_sdk_crypto/src/session_manager/sessions.rs +++ b/matrix_sdk_crypto/src/session_manager/sessions.rs @@ -322,6 +322,7 @@ mod test { identities::ReadOnlyDevice, key_request::KeyRequestMachine, olm::{Account, PrivateCrossSigningIdentity, ReadOnlyAccount}, + session_manager::GroupSessionCache, store::{CryptoStore, MemoryStore, Store}, verification::VerificationMachine, }; @@ -342,7 +343,6 @@ mod test { let user_id = user_id(); let device_id = device_id(); - let outbound_sessions = Arc::new(DashMap::new()); let users_for_key_claim = Arc::new(DashMap::new()); let account = ReadOnlyAccount::new(&user_id, &device_id); let store: Arc> = Arc::new(Box::new(MemoryStore::new())); @@ -363,11 +363,13 @@ mod test { store: store.clone(), }; + let session_cache = GroupSessionCache::new(store.clone()); + let key_request = KeyRequestMachine::new( user_id, device_id, store.clone(), - outbound_sessions, + session_cache, users_for_key_claim.clone(), ); From f9d290746c8d8e5710a1af232d5f041d48f3fa2a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Damir=20Jeli=C4=87?= Date: Thu, 15 Apr 2021 17:48:37 +0200 Subject: [PATCH 06/11] crypto: Load unsent outgoing key requests when we open a store --- matrix_sdk_crypto/src/key_request.rs | 49 +++++++++++++++++----- matrix_sdk_crypto/src/machine.rs | 7 ++-- matrix_sdk_crypto/src/store/memorystore.rs | 8 ++++ matrix_sdk_crypto/src/store/mod.rs | 3 ++ matrix_sdk_crypto/src/store/sled.rs | 12 ++++++ 5 files changed, 66 insertions(+), 13 deletions(-) diff --git a/matrix_sdk_crypto/src/key_request.rs b/matrix_sdk_crypto/src/key_request.rs index 34a3b8a3..b00e8fbb 100644 --- a/matrix_sdk_crypto/src/key_request.rs +++ b/matrix_sdk_crypto/src/key_request.rs @@ -149,6 +149,23 @@ pub struct OutgoingKeyRequest { pub sent_out: bool, } +impl OutgoingKeyRequest { + fn into_request( + &self, + recipient: &UserId, + own_device_id: &DeviceId, + ) -> Result { + let content = RoomKeyRequestToDeviceEventContent { + action: Action::Request, + request_id: self.request_id.to_string(), + requesting_device_id: own_device_id.to_owned(), + body: Some(self.info.clone()), + }; + + wrap_key_request_content(recipient.to_owned(), self.request_id, &content) + } +} + impl PartialEq for OutgoingKeyRequest { fn eq(&self, other: &Self) -> bool { self.request_id == other.request_id @@ -204,6 +221,25 @@ impl KeyRequestMachine { } } + /// Load stored non-sent out outgoing requests + pub async fn load_outgoing_requests(&mut self) -> Result<(), CryptoStoreError> { + let infos: Vec = vec![]; + let requests: DashMap = infos + .iter() + .filter(|i| !i.sent_out) + .filter_map(|info| { + Some(( + info.request_id, + info.into_request(self.user_id(), self.device_id()).ok()?, + )) + }) + .collect(); + + self.outgoing_to_device_requests = requests.into(); + + Ok(()) + } + /// Our own user id. pub fn user_id(&self) -> &UserId { &self.user_id @@ -542,21 +578,14 @@ impl KeyRequestMachine { let id = Uuid::new_v4(); - let content = RoomKeyRequestToDeviceEventContent { - action: Action::Request, - request_id: id.to_string(), - requesting_device_id: (&*self.device_id).clone(), - body: Some(key_info), - }; - - let request = wrap_key_request_content(self.user_id().clone(), id, &content)?; - let info = OutgoingKeyRequest { request_id: id, - info: content.body.unwrap(), + info: key_info, sent_out: false, }; + let request = info.into_request(self.user_id(), self.device_id())?; + self.save_outgoing_key_info(info).await?; self.outgoing_to_device_requests.insert(id, request); diff --git a/matrix_sdk_crypto/src/machine.rs b/matrix_sdk_crypto/src/machine.rs index 59a99457..746523f3 100644 --- a/matrix_sdk_crypto/src/machine.rs +++ b/matrix_sdk_crypto/src/machine.rs @@ -245,9 +245,10 @@ impl OlmMachine { } }; - Ok(OlmMachine::new_helper( - &user_id, device_id, store, account, identity, - )) + let mut machine = OlmMachine::new_helper(&user_id, device_id, store, account, identity); + machine.key_request_machine.load_outgoing_requests().await?; + + Ok(machine) } /// Create a new machine with the default crypto store. diff --git a/matrix_sdk_crypto/src/store/memorystore.rs b/matrix_sdk_crypto/src/store/memorystore.rs index 27993e25..150354ca 100644 --- a/matrix_sdk_crypto/src/store/memorystore.rs +++ b/matrix_sdk_crypto/src/store/memorystore.rs @@ -268,6 +268,14 @@ impl CryptoStore for MemoryStore { .and_then(|i| self.outgoing_key_requests.get(&i).map(|r| r.clone()))) } + async fn get_outgoing_key_requests(&self) -> Result> { + Ok(self + .outgoing_key_requests + .iter() + .map(|i| i.value().clone()) + .collect()) + } + async fn delete_outgoing_key_request(&self, request_id: Uuid) -> Result<()> { self.outgoing_key_requests .remove(&request_id) diff --git a/matrix_sdk_crypto/src/store/mod.rs b/matrix_sdk_crypto/src/store/mod.rs index 387bd610..6f5b1338 100644 --- a/matrix_sdk_crypto/src/store/mod.rs +++ b/matrix_sdk_crypto/src/store/mod.rs @@ -449,6 +449,9 @@ pub trait CryptoStore: AsyncTraitDeps { key_info: &RequestedKeyInfo, ) -> Result>; + /// Get all outgoing key requests that we have in the store. + async fn get_outgoing_key_requests(&self) -> Result>; + /// Delete an outoing key request that we created that matches the given /// request id. /// diff --git a/matrix_sdk_crypto/src/store/sled.rs b/matrix_sdk_crypto/src/store/sled.rs index f177f10d..0bda2de0 100644 --- a/matrix_sdk_crypto/src/store/sled.rs +++ b/matrix_sdk_crypto/src/store/sled.rs @@ -709,6 +709,16 @@ impl CryptoStore for SledStore { } } + async fn get_outgoing_key_requests(&self) -> Result> { + let requests: Result> = self + .outgoing_key_requests + .iter() + .map(|i| serde_json::from_slice(&i?.1).map_err(CryptoStoreError::from)) + .collect(); + + requests + } + async fn delete_outgoing_key_request(&self, request_id: Uuid) -> Result<()> { let ret: Result<(), TransactionError> = (&self.outgoing_key_requests, &self.key_requests_by_info).transaction( @@ -1318,6 +1328,7 @@ mod test { let stored_request = store.get_key_request_by_info(&info).await.unwrap(); assert_eq!(request, stored_request); + assert!(!store.get_outgoing_key_requests().await.unwrap().is_empty()); store.delete_outgoing_key_request(id).await.unwrap(); @@ -1326,5 +1337,6 @@ mod test { let stored_request = store.get_key_request_by_info(&info).await.unwrap(); assert_eq!(None, stored_request); + assert!(store.get_outgoing_key_requests().await.unwrap().is_empty()); } } From 8c007510cd14f5a935807b674d974396a018722f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Damir=20Jeli=C4=87?= Date: Thu, 15 Apr 2021 19:40:24 +0200 Subject: [PATCH 07/11] crypto: Only load the outgoing key requests when we want to send them out --- matrix_sdk/src/client.rs | 11 +- matrix_sdk_base/src/client.rs | 4 +- matrix_sdk_crypto/src/key_request.rs | 156 ++++++++++++--------- matrix_sdk_crypto/src/machine.rs | 18 ++- matrix_sdk_crypto/src/store/memorystore.rs | 3 +- matrix_sdk_crypto/src/store/mod.rs | 2 +- matrix_sdk_crypto/src/store/sled.rs | 109 ++++++++++---- 7 files changed, 198 insertions(+), 105 deletions(-) diff --git a/matrix_sdk/src/client.rs b/matrix_sdk/src/client.rs index a6b6164b..7453f365 100644 --- a/matrix_sdk/src/client.rs +++ b/matrix_sdk/src/client.rs @@ -1839,7 +1839,16 @@ impl Client { warn!("Error while claiming one-time keys {:?}", e); } - for r in self.base_client.outgoing_requests().await { + // TODO we should probably abort if we get an cryptostore error here + let outgoing_requests = match self.base_client.outgoing_requests().await { + Ok(r) => r, + Err(e) => { + warn!("Could not fetch the outgoing requests {:?}", e); + vec![] + } + }; + + for r in outgoing_requests { match r.request() { OutgoingRequests::KeysQuery(request) => { if let Err(e) = self diff --git a/matrix_sdk_base/src/client.rs b/matrix_sdk_base/src/client.rs index 9e98489c..eedd48b0 100644 --- a/matrix_sdk_base/src/client.rs +++ b/matrix_sdk_base/src/client.rs @@ -1069,12 +1069,12 @@ impl BaseClient { /// [`mark_request_as_sent`]: #method.mark_request_as_sent #[cfg(feature = "encryption")] #[cfg_attr(feature = "docs", doc(cfg(encryption)))] - pub async fn outgoing_requests(&self) -> Vec { + pub async fn outgoing_requests(&self) -> Result, CryptoStoreError> { let olm = self.olm.lock().await; match &*olm { Some(o) => o.outgoing_requests().await, - None => vec![], + None => Ok(vec![]), } } diff --git a/matrix_sdk_crypto/src/key_request.rs b/matrix_sdk_crypto/src/key_request.rs index b00e8fbb..52ee3ade 100644 --- a/matrix_sdk_crypto/src/key_request.rs +++ b/matrix_sdk_crypto/src/key_request.rs @@ -150,7 +150,7 @@ pub struct OutgoingKeyRequest { } impl OutgoingKeyRequest { - fn into_request( + fn to_request( &self, recipient: &UserId, own_device_id: &DeviceId, @@ -214,30 +214,25 @@ impl KeyRequestMachine { device_id, store, outbound_group_sessions, - outgoing_to_device_requests: Arc::new(DashMap::new()), - incoming_key_requests: Arc::new(DashMap::new()), + outgoing_to_device_requests: DashMap::new().into(), + incoming_key_requests: DashMap::new().into(), wait_queue: WaitQueue::new(), users_for_key_claim, } } - /// Load stored non-sent out outgoing requests - pub async fn load_outgoing_requests(&mut self) -> Result<(), CryptoStoreError> { - let infos: Vec = vec![]; - let requests: DashMap = infos - .iter() + /// Load stored outgoing requests that were not yet sent out. + async fn load_outgoing_requests(&self) -> Result, CryptoStoreError> { + self.store + .get_unsent_key_requests() + .await? + .into_iter() .filter(|i| !i.sent_out) - .filter_map(|info| { - Some(( - info.request_id, - info.into_request(self.user_id(), self.device_id()).ok()?, - )) + .map(|info| { + info.to_request(self.user_id(), self.device_id()) + .map_err(CryptoStoreError::from) }) - .collect(); - - self.outgoing_to_device_requests = requests.into(); - - Ok(()) + .collect() } /// Our own user id. @@ -250,12 +245,18 @@ impl KeyRequestMachine { &self.device_id } - pub fn outgoing_to_device_requests(&self) -> Vec { - #[allow(clippy::map_clone)] - self.outgoing_to_device_requests + pub async fn outgoing_to_device_requests( + &self, + ) -> Result, CryptoStoreError> { + let mut key_requests = self.load_outgoing_requests().await?; + let key_forwards: Vec = self + .outgoing_to_device_requests .iter() - .map(|r| (*r).clone()) - .collect() + .map(|i| i.value().clone()) + .collect(); + key_requests.extend(key_forwards); + + Ok(key_requests) } /// Receive a room key request event. @@ -584,10 +585,7 @@ impl KeyRequestMachine { sent_out: false, }; - let request = info.into_request(self.user_id(), self.device_id())?; - self.save_outgoing_key_info(info).await?; - self.outgoing_to_device_requests.insert(id, request); Ok(()) } @@ -830,7 +828,11 @@ mod test { async fn create_machine() { let machine = get_machine().await; - assert!(machine.outgoing_to_device_requests().is_empty()); + assert!(machine + .outgoing_to_device_requests() + .await + .unwrap() + .is_empty()); } #[async_test] @@ -843,7 +845,11 @@ mod test { .await .unwrap(); - assert!(machine.outgoing_to_device_requests().is_empty()); + assert!(machine + .outgoing_to_device_requests() + .await + .unwrap() + .is_empty()); machine .create_outgoing_key_request( session.room_id(), @@ -852,8 +858,15 @@ mod test { ) .await .unwrap(); - assert!(!machine.outgoing_to_device_requests().is_empty()); - assert_eq!(machine.outgoing_to_device_requests().len(), 1); + assert!(!machine + .outgoing_to_device_requests() + .await + .unwrap() + .is_empty()); + assert_eq!( + machine.outgoing_to_device_requests().await.unwrap().len(), + 1 + ); machine .create_outgoing_key_request( @@ -863,15 +876,21 @@ mod test { ) .await .unwrap(); - assert_eq!(machine.outgoing_to_device_requests.len(), 1); - let request = machine.outgoing_to_device_requests.iter().next().unwrap(); + let requests = machine.outgoing_to_device_requests().await.unwrap(); + assert_eq!(requests.len(), 1); - let id = request.request_id; - drop(request); + let request = requests.get(0).unwrap(); - machine.mark_outgoing_request_as_sent(id).await.unwrap(); - assert!(machine.outgoing_to_device_requests.is_empty()); + machine + .mark_outgoing_request_as_sent(request.request_id) + .await + .unwrap(); + assert!(machine + .outgoing_to_device_requests() + .await + .unwrap() + .is_empty()); } #[async_test] @@ -892,9 +911,9 @@ mod test { .await .unwrap(); - let request = machine.outgoing_to_device_requests.iter().next().unwrap(); + let requests = machine.outgoing_to_device_requests().await.unwrap(); + let request = requests.get(0).unwrap(); let id = request.request_id; - drop(request); machine.mark_outgoing_request_as_sent(id).await.unwrap(); @@ -949,11 +968,13 @@ mod test { .await .unwrap(); - let request = machine.outgoing_to_device_requests.iter().next().unwrap(); - let id = request.request_id; - drop(request); + let requests = machine.outgoing_to_device_requests().await.unwrap(); + let request = &requests[0]; - machine.mark_outgoing_request_as_sent(id).await.unwrap(); + machine + .mark_outgoing_request_as_sent(request.request_id) + .await + .unwrap(); let export = session.export_at_index(15).await; @@ -1160,11 +1181,8 @@ mod test { .insert(group_session.clone()); // Get the request and convert it into a event. - let request = alice_machine - .outgoing_to_device_requests - .iter() - .next() - .unwrap(); + let requests = alice_machine.outgoing_to_device_requests().await.unwrap(); + let request = &requests[0]; let id = request.request_id; let content = request .request @@ -1178,7 +1196,6 @@ mod test { let content: RoomKeyRequestToDeviceEventContent = serde_json::from_str(content.get()).unwrap(); - drop(request); alice_machine .mark_outgoing_request_as_sent(id) .await @@ -1199,11 +1216,8 @@ mod test { assert!(!bob_machine.outgoing_to_device_requests.is_empty()); // Get the request and convert it to a encrypted to-device event. - let request = bob_machine - .outgoing_to_device_requests - .iter() - .next() - .unwrap(); + let requests = bob_machine.outgoing_to_device_requests().await.unwrap(); + let request = &requests[0]; let id = request.request_id; let content = request @@ -1217,7 +1231,6 @@ mod test { .unwrap(); let content: EncryptedEventContent = serde_json::from_str(content.get()).unwrap(); - drop(request); bob_machine.mark_outgoing_request_as_sent(id).await.unwrap(); let event = ToDeviceEvent { @@ -1326,11 +1339,8 @@ mod test { .insert(group_session.clone()); // Get the request and convert it into a event. - let request = alice_machine - .outgoing_to_device_requests - .iter() - .next() - .unwrap(); + let requests = alice_machine.outgoing_to_device_requests().await.unwrap(); + let request = &requests[0]; let id = request.request_id; let content = request .request @@ -1344,7 +1354,6 @@ mod test { let content: RoomKeyRequestToDeviceEventContent = serde_json::from_str(content.get()).unwrap(); - drop(request); alice_machine .mark_outgoing_request_as_sent(id) .await @@ -1356,7 +1365,11 @@ mod test { }; // Bob doesn't have any outgoing requests. - assert!(bob_machine.outgoing_to_device_requests.is_empty()); + assert!(bob_machine + .outgoing_to_device_requests() + .await + .unwrap() + .is_empty()); assert!(bob_machine.users_for_key_claim.is_empty()); assert!(bob_machine.wait_queue.is_empty()); @@ -1364,7 +1377,11 @@ mod test { bob_machine.receive_incoming_key_request(&event); bob_machine.collect_incoming_key_requests().await.unwrap(); // Bob doens't have an outgoing requests since we're lacking a session. - assert!(bob_machine.outgoing_to_device_requests.is_empty()); + assert!(bob_machine + .outgoing_to_device_requests() + .await + .unwrap() + .is_empty()); assert!(!bob_machine.users_for_key_claim.is_empty()); assert!(!bob_machine.wait_queue.is_empty()); @@ -1384,15 +1401,17 @@ mod test { assert!(bob_machine.users_for_key_claim.is_empty()); bob_machine.collect_incoming_key_requests().await.unwrap(); // Bob now has an outgoing requests. - assert!(!bob_machine.outgoing_to_device_requests.is_empty()); + assert!(!bob_machine + .outgoing_to_device_requests() + .await + .unwrap() + .is_empty()); assert!(bob_machine.wait_queue.is_empty()); // Get the request and convert it to a encrypted to-device event. - let request = bob_machine - .outgoing_to_device_requests - .iter() - .next() - .unwrap(); + let requests = bob_machine.outgoing_to_device_requests().await.unwrap(); + + let request = &requests[0]; let id = request.request_id; let content = request @@ -1406,7 +1425,6 @@ mod test { .unwrap(); let content: EncryptedEventContent = serde_json::from_str(content.get()).unwrap(); - drop(request); bob_machine.mark_outgoing_request_as_sent(id).await.unwrap(); let event = ToDeviceEvent { diff --git a/matrix_sdk_crypto/src/machine.rs b/matrix_sdk_crypto/src/machine.rs index 746523f3..e7203043 100644 --- a/matrix_sdk_crypto/src/machine.rs +++ b/matrix_sdk_crypto/src/machine.rs @@ -245,10 +245,9 @@ impl OlmMachine { } }; - let mut machine = OlmMachine::new_helper(&user_id, device_id, store, account, identity); - machine.key_request_machine.load_outgoing_requests().await?; - - Ok(machine) + Ok(OlmMachine::new_helper( + &user_id, device_id, store, account, identity, + )) } /// Create a new machine with the default crypto store. @@ -295,7 +294,7 @@ impl OlmMachine { /// machine using [`mark_request_as_sent`]. /// /// [`mark_request_as_sent`]: #method.mark_request_as_sent - pub async fn outgoing_requests(&self) -> Vec { + pub async fn outgoing_requests(&self) -> StoreResult> { let mut requests = Vec::new(); if let Some(r) = self.keys_for_upload().await.map(|r| OutgoingRequest { @@ -320,9 +319,14 @@ impl OlmMachine { requests.append(&mut self.outgoing_to_device_requests()); requests.append(&mut self.verification_machine.outgoing_room_message_requests()); - requests.append(&mut self.key_request_machine.outgoing_to_device_requests()); + requests.append( + &mut self + .key_request_machine + .outgoing_to_device_requests() + .await?, + ); - requests + Ok(requests) } /// Mark the request with the given request id as sent. diff --git a/matrix_sdk_crypto/src/store/memorystore.rs b/matrix_sdk_crypto/src/store/memorystore.rs index 150354ca..610262a5 100644 --- a/matrix_sdk_crypto/src/store/memorystore.rs +++ b/matrix_sdk_crypto/src/store/memorystore.rs @@ -268,10 +268,11 @@ impl CryptoStore for MemoryStore { .and_then(|i| self.outgoing_key_requests.get(&i).map(|r| r.clone()))) } - async fn get_outgoing_key_requests(&self) -> Result> { + async fn get_unsent_key_requests(&self) -> Result> { Ok(self .outgoing_key_requests .iter() + .filter(|i| !i.value().sent_out) .map(|i| i.value().clone()) .collect()) } diff --git a/matrix_sdk_crypto/src/store/mod.rs b/matrix_sdk_crypto/src/store/mod.rs index 6f5b1338..1982e1cc 100644 --- a/matrix_sdk_crypto/src/store/mod.rs +++ b/matrix_sdk_crypto/src/store/mod.rs @@ -450,7 +450,7 @@ pub trait CryptoStore: AsyncTraitDeps { ) -> Result>; /// Get all outgoing key requests that we have in the store. - async fn get_outgoing_key_requests(&self) -> Result>; + async fn get_unsent_key_requests(&self) -> Result>; /// Delete an outoing key request that we created that matches the given /// request id. diff --git a/matrix_sdk_crypto/src/store/sled.rs b/matrix_sdk_crypto/src/store/sled.rs index 0bda2de0..d82f62fc 100644 --- a/matrix_sdk_crypto/src/store/sled.rs +++ b/matrix_sdk_crypto/src/store/sled.rs @@ -149,6 +149,7 @@ pub struct SledStore { outbound_group_sessions: Tree, outgoing_key_requests: Tree, + unsent_key_requests: Tree, key_requests_by_info: Tree, devices: Tree, @@ -215,6 +216,7 @@ impl SledStore { let identities = db.open_tree("identities")?; let outgoing_key_requests = db.open_tree("outgoing_key_requests")?; + let unsent_key_requests = db.open_tree("unsent_key_requests")?; let key_requests_by_info = db.open_tree("key_requests_by_info")?; let session_cache = SessionStore::new(); @@ -240,6 +242,7 @@ impl SledStore { inbound_group_sessions, outbound_group_sessions, outgoing_key_requests, + unsent_key_requests, key_requests_by_info, devices, tracked_users, @@ -376,6 +379,7 @@ impl SledStore { &self.outbound_group_sessions, &self.olm_hashes, &self.outgoing_key_requests, + &self.unsent_key_requests, &self.key_requests_by_info, ) .transaction( @@ -389,6 +393,7 @@ impl SledStore { outbound_sessions, hashes, outgoing_key_requests, + unsent_key_requests, key_requests_by_info, )| { if let Some(a) = &account_pickle { @@ -463,11 +468,23 @@ impl SledStore { key_request.request_id.encode(), )?; - outgoing_key_requests.insert( - key_request.request_id.encode(), - serde_json::to_vec(&key_request) - .map_err(ConflictableTransactionError::Abort)?, - )?; + let key_request_id = key_request.request_id.encode(); + + if key_request.sent_out { + unsent_key_requests.remove(key_request_id.clone())?; + outgoing_key_requests.insert( + key_request_id, + serde_json::to_vec(&key_request) + .map_err(ConflictableTransactionError::Abort)?, + )?; + } else { + outgoing_key_requests.remove(key_request_id.clone())?; + unsent_key_requests.insert( + key_request_id, + serde_json::to_vec(&key_request) + .map_err(ConflictableTransactionError::Abort)?, + )?; + } } Ok(()) @@ -479,6 +496,28 @@ impl SledStore { Ok(()) } + + async fn get_outgoing_key_request_helper( + &self, + id: &[u8], + ) -> Result> { + let request = self + .outgoing_key_requests + .get(id)? + .map(|r| serde_json::from_slice(&r)) + .transpose()?; + + let request = if request.is_none() { + self.unsent_key_requests + .get(id)? + .map(|r| serde_json::from_slice(&r)) + .transpose()? + } else { + request + }; + + Ok(request) + } } #[async_trait] @@ -685,11 +724,9 @@ impl CryptoStore for SledStore { &self, request_id: Uuid, ) -> Result> { - Ok(self - .outgoing_key_requests - .get(request_id.encode())? - .map(|r| serde_json::from_slice(&r)) - .transpose()?) + let request_id = request_id.encode(); + + self.get_outgoing_key_request_helper(&request_id).await } async fn get_key_request_by_info( @@ -699,19 +736,15 @@ impl CryptoStore for SledStore { let id = self.key_requests_by_info.get(key_info.encode())?; if let Some(id) = id { - Ok(self - .outgoing_key_requests - .get(id)? - .map(|r| serde_json::from_slice(&r)) - .transpose()?) + self.get_outgoing_key_request_helper(&id).await } else { Ok(None) } } - async fn get_outgoing_key_requests(&self) -> Result> { + async fn get_unsent_key_requests(&self) -> Result> { let requests: Result> = self - .outgoing_key_requests + .unsent_key_requests .iter() .map(|i| serde_json::from_slice(&i?.1).map_err(CryptoStoreError::from)) .collect(); @@ -720,16 +753,30 @@ impl CryptoStore for SledStore { } async fn delete_outgoing_key_request(&self, request_id: Uuid) -> Result<()> { - let ret: Result<(), TransactionError> = - (&self.outgoing_key_requests, &self.key_requests_by_info).transaction( - |(outgoing_key_requests, key_requests_by_info)| { - let request: Option = outgoing_key_requests + let ret: Result<(), TransactionError> = ( + &self.outgoing_key_requests, + &self.unsent_key_requests, + &self.key_requests_by_info, + ) + .transaction( + |(outgoing_key_requests, unsent_key_requests, key_requests_by_info)| { + let sent_request: Option = outgoing_key_requests .remove(request_id.encode())? .map(|r| serde_json::from_slice(&r)) .transpose() .map_err(ConflictableTransactionError::Abort)?; - if let Some(request) = request { + let unsent_request: Option = unsent_key_requests + .remove(request_id.encode())? + .map(|r| serde_json::from_slice(&r)) + .transpose() + .map_err(ConflictableTransactionError::Abort)?; + + if let Some(request) = sent_request { + key_requests_by_info.remove((&request.info).encode())?; + } + + if let Some(request) = unsent_request { key_requests_by_info.remove((&request.info).encode())?; } @@ -1328,7 +1375,21 @@ mod test { let stored_request = store.get_key_request_by_info(&info).await.unwrap(); assert_eq!(request, stored_request); - assert!(!store.get_outgoing_key_requests().await.unwrap().is_empty()); + assert!(!store.get_unsent_key_requests().await.unwrap().is_empty()); + + let request = OutgoingKeyRequest { + request_id: id, + info: info.clone(), + sent_out: true, + }; + + let mut changes = Changes::default(); + changes.key_requests.push(request.clone()); + store.save_changes(changes).await.unwrap(); + + assert!(store.get_unsent_key_requests().await.unwrap().is_empty()); + let stored_request = store.get_outgoing_key_request(id).await.unwrap(); + assert_eq!(Some(request), stored_request); store.delete_outgoing_key_request(id).await.unwrap(); @@ -1337,6 +1398,6 @@ mod test { let stored_request = store.get_key_request_by_info(&info).await.unwrap(); assert_eq!(None, stored_request); - assert!(store.get_outgoing_key_requests().await.unwrap().is_empty()); + assert!(store.get_unsent_key_requests().await.unwrap().is_empty()); } } From 78b7dcac615434b13d7c9f1c1bb07e313bac2af8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Damir=20Jeli=C4=87?= Date: Mon, 19 Apr 2021 15:00:21 +0200 Subject: [PATCH 08/11] crypto: Add a public method to request and re-request keys. --- matrix_sdk_crypto/src/key_request.rs | 157 +++++++++++++++++++++------ matrix_sdk_crypto/src/machine.rs | 32 ++++++ matrix_sdk_crypto/src/store/sled.rs | 4 +- 3 files changed, 158 insertions(+), 35 deletions(-) diff --git a/matrix_sdk_crypto/src/key_request.rs b/matrix_sdk_crypto/src/key_request.rs index 52ee3ade..2ff5363b 100644 --- a/matrix_sdk_crypto/src/key_request.rs +++ b/matrix_sdk_crypto/src/key_request.rs @@ -141,6 +141,8 @@ pub(crate) struct KeyRequestMachine { /// A struct describing an outgoing key request. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct OutgoingKeyRequest { + /// The user we requested the key from + pub request_recipient: UserId, /// The unique id of the key request. pub request_id: Uuid, /// The info of the requested key. @@ -150,11 +152,7 @@ pub struct OutgoingKeyRequest { } impl OutgoingKeyRequest { - fn to_request( - &self, - recipient: &UserId, - own_device_id: &DeviceId, - ) -> Result { + fn to_request(&self, own_device_id: &DeviceId) -> Result { let content = RoomKeyRequestToDeviceEventContent { action: Action::Request, request_id: self.request_id.to_string(), @@ -162,7 +160,22 @@ impl OutgoingKeyRequest { body: Some(self.info.clone()), }; - wrap_key_request_content(recipient.to_owned(), self.request_id, &content) + wrap_key_request_content(self.request_recipient.clone(), self.request_id, &content) + } + + fn to_cancelation( + &self, + own_device_id: &DeviceId, + ) -> Result { + let content = RoomKeyRequestToDeviceEventContent { + action: Action::CancelRequest, + request_id: self.request_id.to_string(), + requesting_device_id: own_device_id.to_owned(), + body: None, + }; + + let id = Uuid::new_v4(); + wrap_key_request_content(self.request_recipient.clone(), id, &content) } } @@ -229,7 +242,7 @@ impl KeyRequestMachine { .into_iter() .filter(|i| !i.sent_out) .map(|info| { - info.to_request(self.user_id(), self.device_id()) + info.to_request(self.device_id()) .map_err(CryptoStoreError::from) }) .collect() @@ -541,6 +554,69 @@ impl KeyRequestMachine { } } + /// Create a new outgoing key request for the key with the given session id. + /// + /// This will queue up a new to-device request and store the key info so + /// once we receive a forwarded room key we can check that it matches the + /// key we requested. + /// + /// This method will return a cancel request and a new key request if the + /// key was already requested, otherwise it will return just the key + /// request. + /// + /// # Arguments + /// + /// * `room_id` - The id of the room where the key is used in. + /// + /// * `sender_key` - The curve25519 key of the sender that owns the key. + /// + /// * `session_id` - The id that uniquely identifies the session. + pub async fn request_key( + &self, + room_id: &RoomId, + sender_key: &str, + session_id: &str, + ) -> Result<(Option, OutgoingRequest), CryptoStoreError> { + let key_info = RequestedKeyInfo { + algorithm: EventEncryptionAlgorithm::MegolmV1AesSha2, + room_id: room_id.to_owned(), + sender_key: sender_key.to_owned(), + session_id: session_id.to_owned(), + }; + + let request = self.store.get_key_request_by_info(&key_info).await?; + + if let Some(request) = request { + let cancel = request.to_cancelation(self.device_id())?; + let request = request.to_request(self.device_id())?; + + Ok((Some(cancel), request)) + } else { + let request = self.request_key_helper(key_info).await?; + + Ok((None, request)) + } + } + + async fn request_key_helper( + &self, + key_info: RequestedKeyInfo, + ) -> Result { + info!("Creating new outgoing room key request {:#?}", key_info); + + let request = OutgoingKeyRequest { + request_recipient: self.user_id().to_owned(), + request_id: Uuid::new_v4(), + info: key_info, + sent_out: false, + }; + + let outgoing_request = request.to_request(self.device_id())?; + self.save_outgoing_key_info(request).await?; + + Ok(outgoing_request) + } + /// Create a new outgoing key request for the key with the given session id. /// /// This will queue up a new to-device request and store the key info so @@ -570,23 +646,10 @@ impl KeyRequestMachine { let request = self.store.get_key_request_by_info(&key_info).await?; - if request.is_some() { - // We already sent out a request for this key, nothing to do. - return Ok(()); + if request.is_none() { + self.request_key_helper(key_info).await?; } - info!("Creating new outgoing room key request {:#?}", key_info); - - let id = Uuid::new_v4(); - - let info = OutgoingKeyRequest { - request_id: id, - info: key_info, - sent_out: false, - }; - - self.save_outgoing_key_info(info).await?; - Ok(()) } @@ -655,18 +718,9 @@ impl KeyRequestMachine { // can delete it in one transaction. self.delete_key_info(&key_info).await?; - let content = RoomKeyRequestToDeviceEventContent { - action: Action::CancelRequest, - request_id: key_info.request_id.to_string(), - requesting_device_id: (&*self.device_id).clone(), - body: None, - }; - - let id = Uuid::new_v4(); - - let request = wrap_key_request_content(self.user_id().clone(), id, &content)?; - - self.outgoing_to_device_requests.insert(id, request); + let request = key_info.to_cancelation(self.device_id())?; + self.outgoing_to_device_requests + .insert(request.request_id, request); Ok(()) } @@ -840,6 +894,41 @@ mod test { let machine = get_machine().await; let account = account(); + let (_, session) = account + .create_group_session_pair_with_defaults(&room_id()) + .await + .unwrap(); + + assert!(machine + .outgoing_to_device_requests() + .await + .unwrap() + .is_empty()); + let (cancel, request) = machine + .request_key(session.room_id(), &session.sender_key, session.session_id()) + .await + .unwrap(); + + assert!(cancel.is_none()); + + machine + .mark_outgoing_request_as_sent(request.request_id) + .await + .unwrap(); + + let (cancel, _) = machine + .request_key(session.room_id(), &session.sender_key, session.session_id()) + .await + .unwrap(); + + assert!(cancel.is_some()); + } + + #[async_test] + async fn re_request_keys() { + let machine = get_machine().await; + let account = account(); + let (_, session) = account .create_group_session_pair_with_defaults(&room_id()) .await diff --git a/matrix_sdk_crypto/src/machine.rs b/matrix_sdk_crypto/src/machine.rs index e7203043..0eb13fb1 100644 --- a/matrix_sdk_crypto/src/machine.rs +++ b/matrix_sdk_crypto/src/machine.rs @@ -918,6 +918,38 @@ impl OlmMachine { Ok(ToDevice { events }) } + /// Request a room key from our devices. + /// + /// This method will return a request cancelation and a new key request if + /// the key was already requested, otherwise it will return just the key + /// request. + /// + /// The request cancelation *must* be sent out before the request is sent + /// out, otherwise devices will ignore the key request. + /// + /// # Arguments + /// + /// * `room_id` - The id of the room where the key is used in. + /// + /// * `sender_key` - The curve25519 key of the sender that owns the key. + /// + /// * `session_id` - The id that uniquely identifies the session. + pub async fn request_room_key( + &self, + event: &SyncMessageEvent, + room_id: &RoomId, + ) -> MegolmResult<(Option, OutgoingRequest)> { + let content = match &event.content { + EncryptedEventContent::MegolmV1AesSha2(c) => c, + _ => return Err(EventError::UnsupportedAlgorithm.into()), + }; + + Ok(self + .key_request_machine + .request_key(room_id, &content.sender_key, &content.session_id) + .await?) + } + /// Decrypt an event from a room timeline. /// /// # Arguments diff --git a/matrix_sdk_crypto/src/store/sled.rs b/matrix_sdk_crypto/src/store/sled.rs index d82f62fc..b6928e78 100644 --- a/matrix_sdk_crypto/src/store/sled.rs +++ b/matrix_sdk_crypto/src/store/sled.rs @@ -1346,7 +1346,7 @@ mod test { #[async_test] async fn key_request_saving() { - let (_, store, _dir) = get_loaded_store().await; + let (account, store, _dir) = get_loaded_store().await; let id = Uuid::new_v4(); let info = RequestedKeyInfo { @@ -1357,6 +1357,7 @@ mod test { }; let request = OutgoingKeyRequest { + request_recipient: account.user_id().to_owned(), request_id: id, info: info.clone(), sent_out: false, @@ -1378,6 +1379,7 @@ mod test { assert!(!store.get_unsent_key_requests().await.unwrap().is_empty()); let request = OutgoingKeyRequest { + request_recipient: account.user_id().to_owned(), request_id: id, info: info.clone(), sent_out: true, From 4a7be139618c14e5be8bfab903a3ff570cb0d62c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Damir=20Jeli=C4=87?= Date: Tue, 20 Apr 2021 11:47:11 +0200 Subject: [PATCH 09/11] crypto: Only send out automatic key requests if we have a verified device Sending out automatic key requests is a bit spammy for new logins, they'll likely have many undecryptable events upon an initial sync. It's unlikely that anyone will respond to such a key request since keys are shared only with verified devices between devices of the same user or if the key owner knows that the device should have received the key. Upon initial sync it's unlikely that we have been verified and the key owner likely did not intend to send us the key since we just created the new device. --- matrix_sdk_crypto/src/identities/device.rs | 8 ++ matrix_sdk_crypto/src/key_request.rs | 101 ++++++++++++++++++--- 2 files changed, 95 insertions(+), 14 deletions(-) diff --git a/matrix_sdk_crypto/src/identities/device.rs b/matrix_sdk_crypto/src/identities/device.rs index 76e2498d..a3b8c3f8 100644 --- a/matrix_sdk_crypto/src/identities/device.rs +++ b/matrix_sdk_crypto/src/identities/device.rs @@ -258,6 +258,14 @@ impl UserDevices { }) } + /// Returns true if there is at least one devices of this user that is + /// considered to be verified, false otherwise. + pub fn is_any_verified(&self) -> bool { + self.inner + .values() + .any(|d| d.trust_state(&self.own_identity, &self.device_owner_identity)) + } + /// Iterator over all the device ids of the user devices. pub fn keys(&self) -> impl Iterator { self.inner.keys() diff --git a/matrix_sdk_crypto/src/key_request.rs b/matrix_sdk_crypto/src/key_request.rs index 2ff5363b..bc4df8c1 100644 --- a/matrix_sdk_crypto/src/key_request.rs +++ b/matrix_sdk_crypto/src/key_request.rs @@ -407,7 +407,7 @@ impl KeyRequestMachine { .await?; if let Some(device) = device { - match self.should_share_session(&device, &session).await { + match self.should_share_key(&device, &session).await { Err(e) => { info!( "Received a key request from {} {} that we won't serve: {}", @@ -509,7 +509,7 @@ impl KeyRequestMachine { /// * `device` - The device that is requesting a session from us. /// /// * `session` - The session that was requested to be shared. - async fn should_share_session( + async fn should_share_key( &self, device: &Device, session: &InboundGroupSession, @@ -554,6 +554,38 @@ impl KeyRequestMachine { } } + /// Check if it's ok, or rather if it makes sense to automatically request + /// a key from our other devices. + /// + /// # Arguments + /// + /// * `key_info` - The info of our key request containing information about + /// the key we wish to request. + async fn should_request_key( + &self, + key_info: &RequestedKeyInfo, + ) -> Result { + let request = self.store.get_key_request_by_info(&key_info).await?; + + // Don't send out duplicate requests, users can re-request them if they + // think a second request might succeed. + if request.is_none() { + let devices = self.store.get_user_devices(self.user_id()).await?; + + // Devices will only respond to key requests if the devices are + // verified, if the device isn't verified by us it's unlikely that + // we're verified by them either. Don't request keys if there isn't + // at least one verified device. + if devices.is_any_verified() { + Ok(true) + } else { + Ok(false) + } + } else { + Ok(false) + } + } + /// Create a new outgoing key request for the key with the given session id. /// /// This will queue up a new to-device request and store the key info so @@ -644,9 +676,7 @@ impl KeyRequestMachine { session_id: session_id.to_owned(), }; - let request = self.store.get_key_request_by_info(&key_info).await?; - - if request.is_none() { + if self.should_request_key(&key_info).await? { self.request_key_helper(key_info).await?; } @@ -828,6 +858,10 @@ mod test { "ILMLKASTES".into() } + fn alice2_device_id() -> DeviceIdBox { + "ILMLKASTES".into() + } + fn room_id() -> RoomId { room_id!("!test:example.org") } @@ -840,6 +874,10 @@ mod test { ReadOnlyAccount::new(&bob_id(), &bob_device_id()) } + fn alice_2_account() -> ReadOnlyAccount { + ReadOnlyAccount::new(&alice_id(), &alice2_device_id()) + } + fn bob_machine() -> KeyRequestMachine { let user_id = Arc::new(bob_id()); let account = ReadOnlyAccount::new(&user_id, &alice_device_id()); @@ -890,7 +928,7 @@ mod test { } #[async_test] - async fn create_key_request() { + async fn re_request_keys() { let machine = get_machine().await; let account = account(); @@ -925,9 +963,15 @@ mod test { } #[async_test] - async fn re_request_keys() { + async fn create_key_request() { let machine = get_machine().await; let account = account(); + let second_account = alice_2_account(); + let alice_device = ReadOnlyDevice::from_account(&second_account).await; + + // We need a trusted device, otherwise we won't request keys + alice_device.set_trust_state(LocalTrust::Verified); + machine.store.save_devices(&[alice_device]).await.unwrap(); let (_, session) = account .create_group_session_pair_with_defaults(&room_id()) @@ -987,6 +1031,13 @@ mod test { let machine = get_machine().await; let account = account(); + let second_account = alice_2_account(); + let alice_device = ReadOnlyDevice::from_account(&second_account).await; + + // We need a trusted device, otherwise we won't request keys + alice_device.set_trust_state(LocalTrust::Verified); + machine.store.save_devices(&[alice_device]).await.unwrap(); + let (_, session) = account .create_group_session_pair_with_defaults(&room_id()) .await @@ -1118,7 +1169,7 @@ mod test { // We don't share keys with untrusted devices. assert_eq!( machine - .should_share_session(&own_device, &inbound) + .should_share_key(&own_device, &inbound) .await .expect_err("Should not share with untrusted"), KeyshareDecision::UntrustedDevice @@ -1126,7 +1177,7 @@ mod test { own_device.set_trust_state(LocalTrust::Verified); // Now we do want to share the keys. assert!(machine - .should_share_session(&own_device, &inbound) + .should_share_key(&own_device, &inbound) .await .is_ok()); @@ -1144,7 +1195,7 @@ mod test { // session was provided. assert_eq!( machine - .should_share_session(&bob_device, &inbound) + .should_share_key(&bob_device, &inbound) .await .expect_err("Should not share with other."), KeyshareDecision::MissingOutboundSession @@ -1161,7 +1212,7 @@ mod test { // wasn't shared in the first place. assert_eq!( machine - .should_share_session(&bob_device, &inbound) + .should_share_key(&bob_device, &inbound) .await .expect_err("Should not share with other unless shared."), KeyshareDecision::OutboundSessionNotShared @@ -1173,7 +1224,7 @@ mod test { // wasn't shared in the first place even if the device is trusted. assert_eq!( machine - .should_share_session(&bob_device, &inbound) + .should_share_key(&bob_device, &inbound) .await .expect_err("Should not share with other unless shared."), KeyshareDecision::OutboundSessionNotShared @@ -1182,7 +1233,7 @@ mod test { // We now share the session, since it was shared before. outbound.mark_shared_with(bob_device.user_id(), bob_device.device_id()); assert!(machine - .should_share_session(&bob_device, &inbound) + .should_share_key(&bob_device, &inbound) .await .is_ok()); @@ -1195,7 +1246,7 @@ mod test { assert_eq!( machine - .should_share_session(&bob_device, &other_inbound) + .should_share_key(&bob_device, &other_inbound) .await .expect_err("Should not share with other unless shared."), KeyshareDecision::MissingOutboundSession @@ -1213,6 +1264,17 @@ mod test { let bob_machine = bob_machine(); let bob_account = bob_account(); + let second_account = alice_2_account(); + let alice_device = ReadOnlyDevice::from_account(&second_account).await; + + // We need a trusted device, otherwise we won't request keys + alice_device.set_trust_state(LocalTrust::Verified); + alice_machine + .store + .save_devices(&[alice_device]) + .await + .unwrap(); + // Create Olm sessions for our two accounts. let (alice_session, bob_session) = alice_account.create_session_for(&bob_account).await; @@ -1381,6 +1443,17 @@ mod test { let bob_machine = bob_machine(); let bob_account = bob_account(); + let second_account = alice_2_account(); + let alice_device = ReadOnlyDevice::from_account(&second_account).await; + + // We need a trusted device, otherwise we won't request keys + alice_device.set_trust_state(LocalTrust::Verified); + alice_machine + .store + .save_devices(&[alice_device]) + .await + .unwrap(); + // Create Olm sessions for our two accounts. let (alice_session, bob_session) = alice_account.create_session_for(&bob_account).await; From e15f7264dc363381e6e02a6eddaa79c5bec0f25e Mon Sep 17 00:00:00 2001 From: poljar Date: Tue, 20 Apr 2021 12:27:56 +0200 Subject: [PATCH 10/11] crypto: Don't borrow inside a format unnecessarily Co-authored-by: Jonas Platte --- matrix_sdk_crypto/src/store/memorystore.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/matrix_sdk_crypto/src/store/memorystore.rs b/matrix_sdk_crypto/src/store/memorystore.rs index 610262a5..e9ca5307 100644 --- a/matrix_sdk_crypto/src/store/memorystore.rs +++ b/matrix_sdk_crypto/src/store/memorystore.rs @@ -39,7 +39,7 @@ use crate::{ fn encode_key_info(info: &RequestedKeyInfo) -> String { format!( "{}{}{}{}", - &info.room_id, &info.sender_key, &info.algorithm, &info.session_id + info.room_id, info.sender_key, info.algorithm, info.session_id ) } From bfc7434f7e1ea24728dd2c236b5316cd49c59542 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Damir=20Jeli=C4=87?= Date: Tue, 20 Apr 2021 13:35:47 +0200 Subject: [PATCH 11/11] crypto: Move the outbound session filter logic into the group session cache --- matrix_sdk_crypto/src/key_request.rs | 5 ++--- .../src/session_manager/group_sessions.rs | 21 +++++++++++++++++++ 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/matrix_sdk_crypto/src/key_request.rs b/matrix_sdk_crypto/src/key_request.rs index bc4df8c1..1eec2981 100644 --- a/matrix_sdk_crypto/src/key_request.rs +++ b/matrix_sdk_crypto/src/key_request.rs @@ -516,11 +516,10 @@ impl KeyRequestMachine { ) -> Result, KeyshareDecision> { let outbound_session = self .outbound_group_sessions - .get_or_load(session.room_id()) + .get_with_id(session.room_id(), session.session_id()) .await .ok() - .flatten() - .filter(|o| session.session_id() == o.session_id()); + .flatten(); let own_device_check = || { if device.trust_state() { diff --git a/matrix_sdk_crypto/src/session_manager/group_sessions.rs b/matrix_sdk_crypto/src/session_manager/group_sessions.rs index a656d00e..71645714 100644 --- a/matrix_sdk_crypto/src/session_manager/group_sessions.rs +++ b/matrix_sdk_crypto/src/session_manager/group_sessions.rs @@ -62,6 +62,12 @@ impl GroupSessionCache { self.sessions.insert(session.room_id().to_owned(), session); } + /// Either get a session for the given room from the cache or load it from + /// the store. + /// + /// # Arguments + /// + /// * `room_id` - The id of the room this session is used for. pub async fn get_or_load(&self, room_id: &RoomId) -> StoreResult> { // Get the cached session, if there isn't one load one from the store // and put it in the cache. @@ -89,6 +95,21 @@ impl GroupSessionCache { fn get(&self, room_id: &RoomId) -> Option { self.sessions.get(room_id).map(|s| s.clone()) } + + /// Get or load the session for the given room with the given session id. + /// + /// This is the same as [get_or_load()](#method.get_or_load) but it will + /// filter out the session if it doesn't match the given session id. + pub async fn get_with_id( + &self, + room_id: &RoomId, + session_id: &str, + ) -> StoreResult> { + Ok(self + .get_or_load(room_id) + .await? + .filter(|o| session_id == o.session_id())) + } } #[derive(Debug, Clone)]