diff --git a/matrix_sdk/src/client.rs b/matrix_sdk/src/client.rs index e24ba4c6..a554f0ba 100644 --- a/matrix_sdk/src/client.rs +++ b/matrix_sdk/src/client.rs @@ -1839,7 +1839,16 @@ impl Client { warn!("Error while claiming one-time keys {:?}", e); } - for r in self.base_client.outgoing_requests().await { + // TODO we should probably abort if we get an cryptostore error here + let outgoing_requests = match self.base_client.outgoing_requests().await { + Ok(r) => r, + Err(e) => { + warn!("Could not fetch the outgoing requests {:?}", e); + vec![] + } + }; + + for r in outgoing_requests { match r.request() { OutgoingRequests::KeysQuery(request) => { if let Err(e) = self diff --git a/matrix_sdk_base/src/client.rs b/matrix_sdk_base/src/client.rs index dad4ea36..e53c77d1 100644 --- a/matrix_sdk_base/src/client.rs +++ b/matrix_sdk_base/src/client.rs @@ -1085,12 +1085,12 @@ impl BaseClient { /// [`mark_request_as_sent`]: #method.mark_request_as_sent #[cfg(feature = "encryption")] #[cfg_attr(feature = "docs", doc(cfg(encryption)))] - pub async fn outgoing_requests(&self) -> Vec { + pub async fn outgoing_requests(&self) -> Result, CryptoStoreError> { let olm = self.olm.lock().await; match &*olm { Some(o) => o.outgoing_requests().await, - None => vec![], + None => Ok(vec![]), } } diff --git a/matrix_sdk_crypto/benches/crypto_bench.rs b/matrix_sdk_crypto/benches/crypto_bench.rs index f2b23110..9fdfab95 100644 --- a/matrix_sdk_crypto/benches/crypto_bench.rs +++ b/matrix_sdk_crypto/benches/crypto_bench.rs @@ -215,7 +215,7 @@ pub fn room_key_sharing(c: &mut Criterion) { .await .unwrap(); - assert!(requests.len() >= 8); + assert!(!requests.is_empty()); for request in requests { machine @@ -251,7 +251,7 @@ pub fn room_key_sharing(c: &mut Criterion) { .await .unwrap(); - assert!(requests.len() >= 8); + assert!(!requests.is_empty()); for request in requests { machine diff --git a/matrix_sdk_crypto/src/identities/device.rs b/matrix_sdk_crypto/src/identities/device.rs index 76e2498d..a3b8c3f8 100644 --- a/matrix_sdk_crypto/src/identities/device.rs +++ b/matrix_sdk_crypto/src/identities/device.rs @@ -258,6 +258,14 @@ impl UserDevices { }) } + /// Returns true if there is at least one devices of this user that is + /// considered to be verified, false otherwise. + pub fn is_any_verified(&self) -> bool { + self.inner + .values() + .any(|d| d.trust_state(&self.own_identity, &self.device_owner_identity)) + } + /// Iterator over all the device ids of the user devices. pub fn keys(&self) -> impl Iterator { self.inner.keys() diff --git a/matrix_sdk_crypto/src/key_request.rs b/matrix_sdk_crypto/src/key_request.rs index 5fa560a5..1eec2981 100644 --- a/matrix_sdk_crypto/src/key_request.rs +++ b/matrix_sdk_crypto/src/key_request.rs @@ -40,9 +40,10 @@ use matrix_sdk_common::{ use crate::{ error::{OlmError, OlmResult}, - olm::{InboundGroupSession, OutboundGroupSession, Session, ShareState}, + olm::{InboundGroupSession, Session, ShareState}, requests::{OutgoingRequest, ToDeviceRequest}, - store::{CryptoStoreError, Store}, + session_manager::GroupSessionCache, + store::{Changes, CryptoStoreError, Store}, Device, }; @@ -128,7 +129,7 @@ pub(crate) struct KeyRequestMachine { user_id: Arc, device_id: Arc, store: Store, - outbound_group_sessions: Arc>, + outbound_group_sessions: GroupSessionCache, outgoing_to_device_requests: Arc>, incoming_key_requests: Arc< DashMap<(UserId, DeviceIdBox, String), ToDeviceEvent>, @@ -137,32 +138,54 @@ pub(crate) struct KeyRequestMachine { users_for_key_claim: Arc>>, } -#[derive(Debug, Serialize, Deserialize)] -struct OugoingKeyInfo { - request_id: Uuid, - info: RequestedKeyInfo, - sent_out: bool, +/// A struct describing an outgoing key request. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct OutgoingKeyRequest { + /// The user we requested the key from + pub request_recipient: UserId, + /// The unique id of the key request. + pub request_id: Uuid, + /// The info of the requested key. + pub info: RequestedKeyInfo, + /// Has the request been sent out. + pub sent_out: bool, } -trait Encode { - fn encode(&self) -> String; -} +impl OutgoingKeyRequest { + fn to_request(&self, own_device_id: &DeviceId) -> Result { + let content = RoomKeyRequestToDeviceEventContent { + action: Action::Request, + request_id: self.request_id.to_string(), + requesting_device_id: own_device_id.to_owned(), + body: Some(self.info.clone()), + }; -impl Encode for RequestedKeyInfo { - fn encode(&self) -> String { - format!( - "{}|{}|{}|{}", - self.sender_key, self.room_id, self.session_id, self.algorithm - ) + wrap_key_request_content(self.request_recipient.clone(), self.request_id, &content) + } + + fn to_cancelation( + &self, + own_device_id: &DeviceId, + ) -> Result { + let content = RoomKeyRequestToDeviceEventContent { + action: Action::CancelRequest, + request_id: self.request_id.to_string(), + requesting_device_id: own_device_id.to_owned(), + body: None, + }; + + let id = Uuid::new_v4(); + wrap_key_request_content(self.request_recipient.clone(), id, &content) } } -impl Encode for ForwardedRoomKeyToDeviceEventContent { - fn encode(&self) -> String { - format!( - "{}|{}|{}|{}", - self.sender_key, self.room_id, self.session_id, self.algorithm - ) +impl PartialEq for OutgoingKeyRequest { + fn eq(&self, other: &Self) -> bool { + self.request_id == other.request_id + && self.info.algorithm == other.info.algorithm + && self.info.room_id == other.info.room_id + && self.info.session_id == other.info.session_id + && self.info.sender_key == other.info.sender_key } } @@ -196,7 +219,7 @@ impl KeyRequestMachine { user_id: Arc, device_id: Arc, store: Store, - outbound_group_sessions: Arc>, + outbound_group_sessions: GroupSessionCache, users_for_key_claim: Arc>>, ) -> Self { Self { @@ -204,13 +227,27 @@ impl KeyRequestMachine { device_id, store, outbound_group_sessions, - outgoing_to_device_requests: Arc::new(DashMap::new()), - incoming_key_requests: Arc::new(DashMap::new()), + outgoing_to_device_requests: DashMap::new().into(), + incoming_key_requests: DashMap::new().into(), wait_queue: WaitQueue::new(), users_for_key_claim, } } + /// Load stored outgoing requests that were not yet sent out. + async fn load_outgoing_requests(&self) -> Result, CryptoStoreError> { + self.store + .get_unsent_key_requests() + .await? + .into_iter() + .filter(|i| !i.sent_out) + .map(|info| { + info.to_request(self.device_id()) + .map_err(CryptoStoreError::from) + }) + .collect() + } + /// Our own user id. pub fn user_id(&self) -> &UserId { &self.user_id @@ -221,12 +258,18 @@ impl KeyRequestMachine { &self.device_id } - pub fn outgoing_to_device_requests(&self) -> Vec { - #[allow(clippy::map_clone)] - self.outgoing_to_device_requests + pub async fn outgoing_to_device_requests( + &self, + ) -> Result, CryptoStoreError> { + let mut key_requests = self.load_outgoing_requests().await?; + let key_forwards: Vec = self + .outgoing_to_device_requests .iter() - .map(|r| (*r).clone()) - .collect() + .map(|i| i.value().clone()) + .collect(); + key_requests.extend(key_forwards); + + Ok(key_requests) } /// Receive a room key request event. @@ -246,6 +289,7 @@ impl KeyRequestMachine { /// key request queue. pub async fn collect_incoming_key_requests(&self) -> OlmResult> { let mut changed_sessions = Vec::new(); + for item in self.incoming_key_requests.iter() { let event = item.value(); if let Some(s) = self.handle_key_request(event).await? { @@ -363,12 +407,7 @@ impl KeyRequestMachine { .await?; if let Some(device) = device { - match self.should_share_session( - &device, - self.outbound_group_sessions - .get(&key_info.room_id) - .as_deref(), - ) { + match self.should_share_key(&device, &session).await { Err(e) => { info!( "Received a key request from {} {} that we won't serve: {}", @@ -469,33 +508,146 @@ impl KeyRequestMachine { /// /// * `device` - The device that is requesting a session from us. /// - /// * `outbound_session` - If one still exists, the matching outbound - /// session that was used to create the inbound session that is being - /// requested. - fn should_share_session( + /// * `session` - The session that was requested to be shared. + async fn should_share_key( &self, device: &Device, - outbound_session: Option<&OutboundGroupSession>, + session: &InboundGroupSession, ) -> Result, KeyshareDecision> { - if device.user_id() == self.user_id() { + let outbound_session = self + .outbound_group_sessions + .get_with_id(session.room_id(), session.session_id()) + .await + .ok() + .flatten(); + + let own_device_check = || { if device.trust_state() { Ok(None) } else { Err(KeyshareDecision::UntrustedDevice) } - } else if let Some(outbound) = outbound_session { + }; + + // If we have a matching outbound session we can check the list of + // users/devices that received the session, if it wasn't shared check if + // it's our own device and if it's trusted. + if let Some(outbound) = outbound_session { if let ShareState::Shared(message_index) = outbound.is_shared_with(device.user_id(), device.device_id()) { Ok(Some(message_index)) + } else if device.user_id() == self.user_id() { + own_device_check() } else { Err(KeyshareDecision::OutboundSessionNotShared) } + // Else just check if it's one of our own devices that requested the key and + // check if the device is trusted. + } else if device.user_id() == self.user_id() { + own_device_check() + // Otherwise, there's not enough info to decide if we can safely share + // the session. } else { Err(KeyshareDecision::MissingOutboundSession) } } + /// Check if it's ok, or rather if it makes sense to automatically request + /// a key from our other devices. + /// + /// # Arguments + /// + /// * `key_info` - The info of our key request containing information about + /// the key we wish to request. + async fn should_request_key( + &self, + key_info: &RequestedKeyInfo, + ) -> Result { + let request = self.store.get_key_request_by_info(&key_info).await?; + + // Don't send out duplicate requests, users can re-request them if they + // think a second request might succeed. + if request.is_none() { + let devices = self.store.get_user_devices(self.user_id()).await?; + + // Devices will only respond to key requests if the devices are + // verified, if the device isn't verified by us it's unlikely that + // we're verified by them either. Don't request keys if there isn't + // at least one verified device. + if devices.is_any_verified() { + Ok(true) + } else { + Ok(false) + } + } else { + Ok(false) + } + } + + /// Create a new outgoing key request for the key with the given session id. + /// + /// This will queue up a new to-device request and store the key info so + /// once we receive a forwarded room key we can check that it matches the + /// key we requested. + /// + /// This method will return a cancel request and a new key request if the + /// key was already requested, otherwise it will return just the key + /// request. + /// + /// # Arguments + /// + /// * `room_id` - The id of the room where the key is used in. + /// + /// * `sender_key` - The curve25519 key of the sender that owns the key. + /// + /// * `session_id` - The id that uniquely identifies the session. + pub async fn request_key( + &self, + room_id: &RoomId, + sender_key: &str, + session_id: &str, + ) -> Result<(Option, OutgoingRequest), CryptoStoreError> { + let key_info = RequestedKeyInfo { + algorithm: EventEncryptionAlgorithm::MegolmV1AesSha2, + room_id: room_id.to_owned(), + sender_key: sender_key.to_owned(), + session_id: session_id.to_owned(), + }; + + let request = self.store.get_key_request_by_info(&key_info).await?; + + if let Some(request) = request { + let cancel = request.to_cancelation(self.device_id())?; + let request = request.to_request(self.device_id())?; + + Ok((Some(cancel), request)) + } else { + let request = self.request_key_helper(key_info).await?; + + Ok((None, request)) + } + } + + async fn request_key_helper( + &self, + key_info: RequestedKeyInfo, + ) -> Result { + info!("Creating new outgoing room key request {:#?}", key_info); + + let request = OutgoingKeyRequest { + request_recipient: self.user_id().to_owned(), + request_id: Uuid::new_v4(), + info: key_info, + sent_out: false, + }; + + let outgoing_request = request.to_request(self.device_id())?; + self.save_outgoing_key_info(request).await?; + + Ok(outgoing_request) + } + /// Create a new outgoing key request for the key with the given session id. /// /// This will queue up a new to-device request and store the key info so @@ -523,51 +675,21 @@ impl KeyRequestMachine { session_id: session_id.to_owned(), }; - let id: Option = self.store.get_object(&key_info.encode()).await?; - - if id.is_some() { - // We already sent out a request for this key, nothing to do. - return Ok(()); + if self.should_request_key(&key_info).await? { + self.request_key_helper(key_info).await?; } - info!("Creating new outgoing room key request {:#?}", key_info); - - let id = Uuid::new_v4(); - - let content = RoomKeyRequestToDeviceEventContent { - action: Action::Request, - request_id: id.to_string(), - requesting_device_id: (&*self.device_id).clone(), - body: Some(key_info), - }; - - let request = wrap_key_request_content(self.user_id().clone(), id, &content)?; - - let info = OugoingKeyInfo { - request_id: id, - info: content.body.unwrap(), - sent_out: false, - }; - - self.save_outgoing_key_info(id, info).await?; - self.outgoing_to_device_requests.insert(id, request); - Ok(()) } /// Save an outgoing key info. async fn save_outgoing_key_info( &self, - id: Uuid, - info: OugoingKeyInfo, + info: OutgoingKeyRequest, ) -> Result<(), CryptoStoreError> { - // TODO we'll want to use a transaction to store those atomically. - // To allow this we'll need to rework our cryptostore trait to return - // a transaction trait and the transaction trait will have the save_X - // methods. - let id_string = id.to_string(); - self.store.save_object(&id_string, &info).await?; - self.store.save_object(&info.info.encode(), &id).await?; + let mut changes = Changes::default(); + changes.key_requests.push(info); + self.store.save_changes(changes).await?; Ok(()) } @@ -576,44 +698,43 @@ impl KeyRequestMachine { async fn get_key_info( &self, content: &ForwardedRoomKeyToDeviceEventContent, - ) -> Result, CryptoStoreError> { - let id: Option = self.store.get_object(&content.encode()).await?; + ) -> Result, CryptoStoreError> { + let info = RequestedKeyInfo { + algorithm: content.algorithm.clone(), + room_id: content.room_id.clone(), + sender_key: content.sender_key.clone(), + session_id: content.session_id.clone(), + }; - if let Some(id) = id { - self.store.get_object(&id.to_string()).await - } else { - Ok(None) - } + self.store.get_key_request_by_info(&info).await } /// Delete the given outgoing key info. - async fn delete_key_info(&self, info: &OugoingKeyInfo) -> Result<(), CryptoStoreError> { + async fn delete_key_info(&self, info: &OutgoingKeyRequest) -> Result<(), CryptoStoreError> { self.store - .delete_object(&info.request_id.to_string()) - .await?; - self.store.delete_object(&info.info.encode()).await?; - - Ok(()) + .delete_outgoing_key_request(info.request_id) + .await } /// Mark the outgoing request as sent. - pub async fn mark_outgoing_request_as_sent(&self, id: &Uuid) -> Result<(), CryptoStoreError> { - self.outgoing_to_device_requests.remove(id); - let info: Option = self.store.get_object(&id.to_string()).await?; + pub async fn mark_outgoing_request_as_sent(&self, id: Uuid) -> Result<(), CryptoStoreError> { + let info = self.store.get_outgoing_key_request(id).await?; if let Some(mut info) = info { trace!("Marking outgoing key request as sent {:#?}", info); info.sent_out = true; - self.save_outgoing_key_info(*id, info).await?; + self.save_outgoing_key_info(info).await?; } + self.outgoing_to_device_requests.remove(&id); + Ok(()) } /// Mark the given outgoing key info as done. /// /// This will queue up a request cancelation. - async fn mark_as_done(&self, key_info: OugoingKeyInfo) -> Result<(), CryptoStoreError> { + async fn mark_as_done(&self, key_info: OutgoingKeyRequest) -> Result<(), CryptoStoreError> { // TODO perhaps only remove the key info if the first known index is 0. trace!( "Successfully received a forwarded room key for {:#?}", @@ -626,18 +747,9 @@ impl KeyRequestMachine { // can delete it in one transaction. self.delete_key_info(&key_info).await?; - let content = RoomKeyRequestToDeviceEventContent { - action: Action::CancelRequest, - request_id: key_info.request_id.to_string(), - requesting_device_id: (&*self.device_id).clone(), - body: None, - }; - - let id = Uuid::new_v4(); - - let request = wrap_key_request_content(self.user_id().clone(), id, &content)?; - - self.outgoing_to_device_requests.insert(id, request); + let request = key_info.to_cancelation(self.device_id())?; + self.outgoing_to_device_requests + .insert(request.request_id, request); Ok(()) } @@ -722,7 +834,8 @@ mod test { use crate::{ identities::{LocalTrust, ReadOnlyDevice}, olm::{Account, PrivateCrossSigningIdentity, ReadOnlyAccount}, - store::{CryptoStore, MemoryStore, Store}, + session_manager::GroupSessionCache, + store::{Changes, CryptoStore, MemoryStore, Store}, verification::VerificationMachine, }; @@ -744,6 +857,10 @@ mod test { "ILMLKASTES".into() } + fn alice2_device_id() -> DeviceIdBox { + "ILMLKASTES".into() + } + fn room_id() -> RoomId { room_id!("!test:example.org") } @@ -756,6 +873,10 @@ mod test { ReadOnlyAccount::new(&bob_id(), &bob_device_id()) } + fn alice_2_account() -> ReadOnlyAccount { + ReadOnlyAccount::new(&alice_id(), &alice2_device_id()) + } + fn bob_machine() -> KeyRequestMachine { let user_id = Arc::new(bob_id()); let account = ReadOnlyAccount::new(&user_id, &alice_device_id()); @@ -763,12 +884,13 @@ mod test { let identity = Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty(bob_id()))); let verification = VerificationMachine::new(account, identity.clone(), store.clone()); let store = Store::new(user_id.clone(), identity, store, verification); + let session_cache = GroupSessionCache::new(store.clone()); KeyRequestMachine::new( user_id, Arc::new(bob_device_id()), store, - Arc::new(DashMap::new()), + session_cache, Arc::new(DashMap::new()), ) } @@ -782,12 +904,13 @@ mod test { let verification = VerificationMachine::new(account, identity.clone(), store.clone()); let store = Store::new(user_id.clone(), identity, store, verification); store.save_devices(&[device]).await.unwrap(); + let session_cache = GroupSessionCache::new(store.clone()); KeyRequestMachine::new( user_id, Arc::new(alice_device_id()), store, - Arc::new(DashMap::new()), + session_cache, Arc::new(DashMap::new()), ) } @@ -796,11 +919,15 @@ mod test { async fn create_machine() { let machine = get_machine().await; - assert!(machine.outgoing_to_device_requests().is_empty()); + assert!(machine + .outgoing_to_device_requests() + .await + .unwrap() + .is_empty()); } #[async_test] - async fn create_key_request() { + async fn re_request_keys() { let machine = get_machine().await; let account = account(); @@ -809,7 +936,52 @@ mod test { .await .unwrap(); - assert!(machine.outgoing_to_device_requests().is_empty()); + assert!(machine + .outgoing_to_device_requests() + .await + .unwrap() + .is_empty()); + let (cancel, request) = machine + .request_key(session.room_id(), &session.sender_key, session.session_id()) + .await + .unwrap(); + + assert!(cancel.is_none()); + + machine + .mark_outgoing_request_as_sent(request.request_id) + .await + .unwrap(); + + let (cancel, _) = machine + .request_key(session.room_id(), &session.sender_key, session.session_id()) + .await + .unwrap(); + + assert!(cancel.is_some()); + } + + #[async_test] + async fn create_key_request() { + let machine = get_machine().await; + let account = account(); + let second_account = alice_2_account(); + let alice_device = ReadOnlyDevice::from_account(&second_account).await; + + // We need a trusted device, otherwise we won't request keys + alice_device.set_trust_state(LocalTrust::Verified); + machine.store.save_devices(&[alice_device]).await.unwrap(); + + let (_, session) = account + .create_group_session_pair_with_defaults(&room_id()) + .await + .unwrap(); + + assert!(machine + .outgoing_to_device_requests() + .await + .unwrap() + .is_empty()); machine .create_outgoing_key_request( session.room_id(), @@ -818,8 +990,15 @@ mod test { ) .await .unwrap(); - assert!(!machine.outgoing_to_device_requests().is_empty()); - assert_eq!(machine.outgoing_to_device_requests().len(), 1); + assert!(!machine + .outgoing_to_device_requests() + .await + .unwrap() + .is_empty()); + assert_eq!( + machine.outgoing_to_device_requests().await.unwrap().len(), + 1 + ); machine .create_outgoing_key_request( @@ -829,15 +1008,21 @@ mod test { ) .await .unwrap(); - assert_eq!(machine.outgoing_to_device_requests.len(), 1); - let request = machine.outgoing_to_device_requests.iter().next().unwrap(); + let requests = machine.outgoing_to_device_requests().await.unwrap(); + assert_eq!(requests.len(), 1); - let id = request.request_id; - drop(request); + let request = requests.get(0).unwrap(); - machine.mark_outgoing_request_as_sent(&id).await.unwrap(); - assert!(machine.outgoing_to_device_requests.is_empty()); + machine + .mark_outgoing_request_as_sent(request.request_id) + .await + .unwrap(); + assert!(machine + .outgoing_to_device_requests() + .await + .unwrap() + .is_empty()); } #[async_test] @@ -845,6 +1030,13 @@ mod test { let machine = get_machine().await; let account = account(); + let second_account = alice_2_account(); + let alice_device = ReadOnlyDevice::from_account(&second_account).await; + + // We need a trusted device, otherwise we won't request keys + alice_device.set_trust_state(LocalTrust::Verified); + machine.store.save_devices(&[alice_device]).await.unwrap(); + let (_, session) = account .create_group_session_pair_with_defaults(&room_id()) .await @@ -858,11 +1050,11 @@ mod test { .await .unwrap(); - let request = machine.outgoing_to_device_requests.iter().next().unwrap(); + let requests = machine.outgoing_to_device_requests().await.unwrap(); + let request = requests.get(0).unwrap(); let id = request.request_id; - drop(request); - machine.mark_outgoing_request_as_sent(&id).await.unwrap(); + machine.mark_outgoing_request_as_sent(id).await.unwrap(); let export = session.export_at_index(10).await; @@ -904,7 +1096,7 @@ mod test { let request = machine.outgoing_to_device_requests.iter().next().unwrap(); let id = request.request_id; drop(request); - machine.mark_outgoing_request_as_sent(&id).await.unwrap(); + machine.mark_outgoing_request_as_sent(id).await.unwrap(); machine .create_outgoing_key_request( @@ -915,11 +1107,13 @@ mod test { .await .unwrap(); - let request = machine.outgoing_to_device_requests.iter().next().unwrap(); - let id = request.request_id; - drop(request); + let requests = machine.outgoing_to_device_requests().await.unwrap(); + let request = &requests[0]; - machine.mark_outgoing_request_as_sent(&id).await.unwrap(); + machine + .mark_outgoing_request_as_sent(request.request_id) + .await + .unwrap(); let export = session.export_at_index(15).await; @@ -966,16 +1160,25 @@ mod test { .unwrap() .unwrap(); + let (outbound, inbound) = account + .create_group_session_pair_with_defaults(&room_id()) + .await + .unwrap(); + // We don't share keys with untrusted devices. assert_eq!( machine - .should_share_session(&own_device, None) + .should_share_key(&own_device, &inbound) + .await .expect_err("Should not share with untrusted"), KeyshareDecision::UntrustedDevice ); own_device.set_trust_state(LocalTrust::Verified); // Now we do want to share the keys. - assert!(machine.should_share_session(&own_device, None).is_ok()); + assert!(machine + .should_share_key(&own_device, &inbound) + .await + .is_ok()); let bob_device = ReadOnlyDevice::from_account(&bob_account()).await; machine.store.save_devices(&[bob_device]).await.unwrap(); @@ -991,21 +1194,25 @@ mod test { // session was provided. assert_eq!( machine - .should_share_session(&bob_device, None) + .should_share_key(&bob_device, &inbound) + .await .expect_err("Should not share with other."), KeyshareDecision::MissingOutboundSession ); - let (session, _) = account - .create_group_session_pair_with_defaults(&room_id()) - .await - .unwrap(); + let mut changes = Changes::default(); + + changes.outbound_group_sessions.push(outbound.clone()); + changes.inbound_group_sessions.push(inbound.clone()); + machine.store.save_changes(changes).await.unwrap(); + machine.outbound_group_sessions.insert(outbound.clone()); // We don't share sessions with other user's devices if the session // wasn't shared in the first place. assert_eq!( machine - .should_share_session(&bob_device, Some(&session)) + .should_share_key(&bob_device, &inbound) + .await .expect_err("Should not share with other unless shared."), KeyshareDecision::OutboundSessionNotShared ); @@ -1016,15 +1223,33 @@ mod test { // wasn't shared in the first place even if the device is trusted. assert_eq!( machine - .should_share_session(&bob_device, Some(&session)) + .should_share_key(&bob_device, &inbound) + .await .expect_err("Should not share with other unless shared."), KeyshareDecision::OutboundSessionNotShared ); - session.mark_shared_with(bob_device.user_id(), bob_device.device_id()); + // We now share the session, since it was shared before. + outbound.mark_shared_with(bob_device.user_id(), bob_device.device_id()); assert!(machine - .should_share_session(&bob_device, Some(&session)) + .should_share_key(&bob_device, &inbound) + .await .is_ok()); + + // But we don't share some other session that doesn't match our outbound + // session + let (_, other_inbound) = account + .create_group_session_pair_with_defaults(&room_id()) + .await + .unwrap(); + + assert_eq!( + machine + .should_share_key(&bob_device, &other_inbound) + .await + .expect_err("Should not share with other unless shared."), + KeyshareDecision::MissingOutboundSession + ); } #[async_test] @@ -1038,6 +1263,17 @@ mod test { let bob_machine = bob_machine(); let bob_account = bob_account(); + let second_account = alice_2_account(); + let alice_device = ReadOnlyDevice::from_account(&second_account).await; + + // We need a trusted device, otherwise we won't request keys + alice_device.set_trust_state(LocalTrust::Verified); + alice_machine + .store + .save_devices(&[alice_device]) + .await + .unwrap(); + // Create Olm sessions for our two accounts. let (alice_session, bob_session) = alice_account.create_session_for(&bob_account).await; @@ -1092,14 +1328,11 @@ mod test { // Put the outbound session into bobs store. bob_machine .outbound_group_sessions - .insert(room_id(), group_session.clone()); + .insert(group_session.clone()); // Get the request and convert it into a event. - let request = alice_machine - .outgoing_to_device_requests - .iter() - .next() - .unwrap(); + let requests = alice_machine.outgoing_to_device_requests().await.unwrap(); + let request = &requests[0]; let id = request.request_id; let content = request .request @@ -1113,9 +1346,8 @@ mod test { let content: RoomKeyRequestToDeviceEventContent = serde_json::from_str(content.get()).unwrap(); - drop(request); alice_machine - .mark_outgoing_request_as_sent(&id) + .mark_outgoing_request_as_sent(id) .await .unwrap(); @@ -1134,11 +1366,8 @@ mod test { assert!(!bob_machine.outgoing_to_device_requests.is_empty()); // Get the request and convert it to a encrypted to-device event. - let request = bob_machine - .outgoing_to_device_requests - .iter() - .next() - .unwrap(); + let requests = bob_machine.outgoing_to_device_requests().await.unwrap(); + let request = &requests[0]; let id = request.request_id; let content = request @@ -1152,11 +1381,7 @@ mod test { .unwrap(); let content: EncryptedEventContent = serde_json::from_str(content.get()).unwrap(); - drop(request); - bob_machine - .mark_outgoing_request_as_sent(&id) - .await - .unwrap(); + bob_machine.mark_outgoing_request_as_sent(id).await.unwrap(); let event = ToDeviceEvent { sender: bob_id(), @@ -1217,6 +1442,17 @@ mod test { let bob_machine = bob_machine(); let bob_account = bob_account(); + let second_account = alice_2_account(); + let alice_device = ReadOnlyDevice::from_account(&second_account).await; + + // We need a trusted device, otherwise we won't request keys + alice_device.set_trust_state(LocalTrust::Verified); + alice_machine + .store + .save_devices(&[alice_device]) + .await + .unwrap(); + // Create Olm sessions for our two accounts. let (alice_session, bob_session) = alice_account.create_session_for(&bob_account).await; @@ -1261,14 +1497,11 @@ mod test { // Put the outbound session into bobs store. bob_machine .outbound_group_sessions - .insert(room_id(), group_session.clone()); + .insert(group_session.clone()); // Get the request and convert it into a event. - let request = alice_machine - .outgoing_to_device_requests - .iter() - .next() - .unwrap(); + let requests = alice_machine.outgoing_to_device_requests().await.unwrap(); + let request = &requests[0]; let id = request.request_id; let content = request .request @@ -1282,9 +1515,8 @@ mod test { let content: RoomKeyRequestToDeviceEventContent = serde_json::from_str(content.get()).unwrap(); - drop(request); alice_machine - .mark_outgoing_request_as_sent(&id) + .mark_outgoing_request_as_sent(id) .await .unwrap(); @@ -1294,7 +1526,11 @@ mod test { }; // Bob doesn't have any outgoing requests. - assert!(bob_machine.outgoing_to_device_requests.is_empty()); + assert!(bob_machine + .outgoing_to_device_requests() + .await + .unwrap() + .is_empty()); assert!(bob_machine.users_for_key_claim.is_empty()); assert!(bob_machine.wait_queue.is_empty()); @@ -1302,7 +1538,11 @@ mod test { bob_machine.receive_incoming_key_request(&event); bob_machine.collect_incoming_key_requests().await.unwrap(); // Bob doens't have an outgoing requests since we're lacking a session. - assert!(bob_machine.outgoing_to_device_requests.is_empty()); + assert!(bob_machine + .outgoing_to_device_requests() + .await + .unwrap() + .is_empty()); assert!(!bob_machine.users_for_key_claim.is_empty()); assert!(!bob_machine.wait_queue.is_empty()); @@ -1322,15 +1562,17 @@ mod test { assert!(bob_machine.users_for_key_claim.is_empty()); bob_machine.collect_incoming_key_requests().await.unwrap(); // Bob now has an outgoing requests. - assert!(!bob_machine.outgoing_to_device_requests.is_empty()); + assert!(!bob_machine + .outgoing_to_device_requests() + .await + .unwrap() + .is_empty()); assert!(bob_machine.wait_queue.is_empty()); // Get the request and convert it to a encrypted to-device event. - let request = bob_machine - .outgoing_to_device_requests - .iter() - .next() - .unwrap(); + let requests = bob_machine.outgoing_to_device_requests().await.unwrap(); + + let request = &requests[0]; let id = request.request_id; let content = request @@ -1344,11 +1586,7 @@ mod test { .unwrap(); let content: EncryptedEventContent = serde_json::from_str(content.get()).unwrap(); - drop(request); - bob_machine - .mark_outgoing_request_as_sent(&id) - .await - .unwrap(); + bob_machine.mark_outgoing_request_as_sent(id).await.unwrap(); let event = ToDeviceEvent { sender: bob_id(), diff --git a/matrix_sdk_crypto/src/machine.rs b/matrix_sdk_crypto/src/machine.rs index b4d40f8d..447d7a12 100644 --- a/matrix_sdk_crypto/src/machine.rs +++ b/matrix_sdk_crypto/src/machine.rs @@ -156,29 +156,29 @@ impl OlmMachine { verification_machine.clone(), ); let device_id: Arc = Arc::new(device_id); - let outbound_group_sessions = Arc::new(DashMap::new()); let users_for_key_claim = Arc::new(DashMap::new()); - let key_request_machine = KeyRequestMachine::new( - user_id.clone(), - device_id.clone(), - store.clone(), - outbound_group_sessions, - users_for_key_claim.clone(), - ); - let account = Account { inner: account, store: store.clone(), }; + let group_session_manager = GroupSessionManager::new(account.clone(), store.clone()); + + let key_request_machine = KeyRequestMachine::new( + user_id.clone(), + device_id.clone(), + store.clone(), + group_session_manager.session_cache(), + users_for_key_claim.clone(), + ); + let session_manager = SessionManager::new( account.clone(), users_for_key_claim, key_request_machine.clone(), store.clone(), ); - let group_session_manager = GroupSessionManager::new(account.clone(), store.clone()); let identity_manager = IdentityManager::new(user_id.clone(), device_id.clone(), store.clone()); @@ -294,7 +294,7 @@ impl OlmMachine { /// machine using [`mark_request_as_sent`]. /// /// [`mark_request_as_sent`]: #method.mark_request_as_sent - pub async fn outgoing_requests(&self) -> Vec { + pub async fn outgoing_requests(&self) -> StoreResult> { let mut requests = Vec::new(); if let Some(r) = self.keys_for_upload().await.map(|r| OutgoingRequest { @@ -319,9 +319,14 @@ impl OlmMachine { requests.append(&mut self.outgoing_to_device_requests()); requests.append(&mut self.verification_machine.outgoing_room_message_requests()); - requests.append(&mut self.key_request_machine.outgoing_to_device_requests()); + requests.append( + &mut self + .key_request_machine + .outgoing_to_device_requests() + .await?, + ); - requests + Ok(requests) } /// Mark the request with the given request id as sent. @@ -751,7 +756,7 @@ impl OlmMachine { async fn mark_to_device_request_as_sent(&self, request_id: &Uuid) -> StoreResult<()> { self.verification_machine.mark_request_as_sent(request_id); self.key_request_machine - .mark_outgoing_request_as_sent(request_id) + .mark_outgoing_request_as_sent(*request_id) .await?; self.group_session_manager .mark_request_as_sent(request_id) @@ -913,6 +918,38 @@ impl OlmMachine { Ok(ToDevice { events }) } + /// Request a room key from our devices. + /// + /// This method will return a request cancelation and a new key request if + /// the key was already requested, otherwise it will return just the key + /// request. + /// + /// The request cancelation *must* be sent out before the request is sent + /// out, otherwise devices will ignore the key request. + /// + /// # Arguments + /// + /// * `room_id` - The id of the room where the key is used in. + /// + /// * `sender_key` - The curve25519 key of the sender that owns the key. + /// + /// * `session_id` - The id that uniquely identifies the session. + pub async fn request_room_key( + &self, + event: &SyncMessageEvent, + room_id: &RoomId, + ) -> MegolmResult<(Option, OutgoingRequest)> { + let content = match &event.content { + EncryptedEventContent::MegolmV1AesSha2(c) => c, + _ => return Err(EventError::UnsupportedAlgorithm.into()), + }; + + Ok(self + .key_request_machine + .request_key(room_id, &content.sender_key, &content.session_id) + .await?) + } + /// Decrypt an event from a room timeline. /// /// # Arguments diff --git a/matrix_sdk_crypto/src/session_manager/group_sessions.rs b/matrix_sdk_crypto/src/session_manager/group_sessions.rs index df73f84b..c36aac37 100644 --- a/matrix_sdk_crypto/src/session_manager/group_sessions.rs +++ b/matrix_sdk_crypto/src/session_manager/group_sessions.rs @@ -40,6 +40,78 @@ use crate::{ Device, EncryptionSettings, OlmError, ToDeviceRequest, }; +#[derive(Clone, Debug)] +pub(crate) struct GroupSessionCache { + store: Store, + sessions: Arc>, + /// A map from the request id to the group session that the request belongs + /// to. Used to mark requests belonging to the session as shared. + sessions_being_shared: Arc>, +} + +impl GroupSessionCache { + pub(crate) fn new(store: Store) -> Self { + Self { + store, + sessions: DashMap::new().into(), + sessions_being_shared: Arc::new(DashMap::new()), + } + } + + pub(crate) fn insert(&self, session: OutboundGroupSession) { + self.sessions.insert(session.room_id().to_owned(), session); + } + + /// Either get a session for the given room from the cache or load it from + /// the store. + /// + /// # Arguments + /// + /// * `room_id` - The id of the room this session is used for. + pub async fn get_or_load(&self, room_id: &RoomId) -> StoreResult> { + // Get the cached session, if there isn't one load one from the store + // and put it in the cache. + if let Some(s) = self.sessions.get(room_id) { + Ok(Some(s.clone())) + } else if let Some(s) = self.store.get_outbound_group_sessions(room_id).await? { + for request_id in s.pending_request_ids() { + self.sessions_being_shared.insert(request_id, s.clone()); + } + + self.sessions.insert(room_id.clone(), s.clone()); + + Ok(Some(s)) + } else { + Ok(None) + } + } + + /// Get an outbound group session for a room, if one exists. + /// + /// # Arguments + /// + /// * `room_id` - The id of the room for which we should get the outbound + /// group session. + fn get(&self, room_id: &RoomId) -> Option { + self.sessions.get(room_id).map(|s| s.clone()) + } + + /// Get or load the session for the given room with the given session id. + /// + /// This is the same as [get_or_load()](#method.get_or_load) but it will + /// filter out the session if it doesn't match the given session id. + pub async fn get_with_id( + &self, + room_id: &RoomId, + session_id: &str, + ) -> StoreResult> { + Ok(self + .get_or_load(room_id) + .await? + .filter(|o| session_id == o.session_id())) + } +} + #[derive(Debug, Clone)] pub struct GroupSessionManager { account: Account, @@ -48,10 +120,7 @@ pub struct GroupSessionManager { /// without the need to create new keys. store: Store, /// The currently active outbound group sessions. - outbound_group_sessions: Arc>, - /// A map from the request id to the group session that the request belongs - /// to. Used to mark requests belonging to the session as shared. - outbound_sessions_being_shared: Arc>, + sessions: GroupSessionCache, } impl GroupSessionManager { @@ -60,14 +129,13 @@ impl GroupSessionManager { pub(crate) fn new(account: Account, store: Store) -> Self { Self { account, - store, - outbound_group_sessions: Arc::new(DashMap::new()), - outbound_sessions_being_shared: Arc::new(DashMap::new()), + store: store.clone(), + sessions: GroupSessionCache::new(store), } } pub async fn invalidate_group_session(&self, room_id: &RoomId) -> StoreResult { - if let Some(s) = self.outbound_group_sessions.get(room_id) { + if let Some(s) = self.sessions.get(room_id) { s.invalidate_session(); let mut changes = Changes::default(); @@ -81,7 +149,7 @@ impl GroupSessionManager { } pub async fn mark_request_as_sent(&self, request_id: &Uuid) -> StoreResult<()> { - if let Some((_, s)) = self.outbound_sessions_being_shared.remove(request_id) { + if let Some((_, s)) = self.sessions.sessions_being_shared.remove(request_id) { s.mark_request_as_sent(request_id); let mut changes = Changes::default(); @@ -97,15 +165,9 @@ impl GroupSessionManager { Ok(()) } - /// Get an outbound group session for a room, if one exists. - /// - /// # Arguments - /// - /// * `room_id` - The id of the room for which we should get the outbound - /// group session. + #[cfg(test)] pub fn get_outbound_group_session(&self, room_id: &RoomId) -> Option { - #[allow(clippy::map_clone)] - self.outbound_group_sessions.get(room_id).map(|s| s.clone()) + self.sessions.get(room_id) } pub async fn encrypt( @@ -113,7 +175,7 @@ impl GroupSessionManager { room_id: &RoomId, content: AnyMessageEventContent, ) -> MegolmResult { - let session = if let Some(s) = self.get_outbound_group_session(room_id) { + let session = if let Some(s) = self.sessions.get(room_id) { s } else { panic!("Session wasn't created nor shared"); @@ -147,9 +209,7 @@ impl GroupSessionManager { .await .map_err(|_| EventError::UnsupportedAlgorithm)?; - let _ = self - .outbound_group_sessions - .insert(room_id.to_owned(), outbound.clone()); + self.sessions.insert(outbound.clone()); Ok((outbound, inbound)) } @@ -158,23 +218,7 @@ impl GroupSessionManager { room_id: &RoomId, settings: EncryptionSettings, ) -> OlmResult<(OutboundGroupSession, Option)> { - // Get the cached session, if there isn't one load one from the store - // and put it in the cache. - let outbound_session = if let Some(s) = self.outbound_group_sessions.get(room_id) { - Some(s.clone()) - } else if let Some(s) = self.store.get_outbound_group_sessions(room_id).await? { - for request_id in s.pending_request_ids() { - self.outbound_sessions_being_shared - .insert(request_id, s.clone()); - } - - self.outbound_group_sessions - .insert(room_id.clone(), s.clone()); - - Some(s) - } else { - None - }; + let outbound_session = self.sessions.get_or_load(&room_id).await?; // If there is no session or the session has expired or is invalid, // create a new one. @@ -388,6 +432,10 @@ impl GroupSessionManager { Ok(used_sessions) } + pub(crate) fn session_cache(&self) -> GroupSessionCache { + self.sessions.clone() + } + /// Get to-device requests to share a group session with users in a room. /// /// # Arguments @@ -489,7 +537,7 @@ impl GroupSessionManager { key_content.clone(), outbound.clone(), message_index, - self.outbound_sessions_being_shared.clone(), + self.sessions.sessions_being_shared.clone(), )) }) .collect(); diff --git a/matrix_sdk_crypto/src/session_manager/mod.rs b/matrix_sdk_crypto/src/session_manager/mod.rs index 7750262e..1af686ef 100644 --- a/matrix_sdk_crypto/src/session_manager/mod.rs +++ b/matrix_sdk_crypto/src/session_manager/mod.rs @@ -15,5 +15,5 @@ mod group_sessions; mod sessions; -pub(crate) use group_sessions::GroupSessionManager; +pub(crate) use group_sessions::{GroupSessionCache, GroupSessionManager}; pub(crate) use sessions::SessionManager; diff --git a/matrix_sdk_crypto/src/session_manager/sessions.rs b/matrix_sdk_crypto/src/session_manager/sessions.rs index 9417c3ea..274314a8 100644 --- a/matrix_sdk_crypto/src/session_manager/sessions.rs +++ b/matrix_sdk_crypto/src/session_manager/sessions.rs @@ -322,6 +322,7 @@ mod test { identities::ReadOnlyDevice, key_request::KeyRequestMachine, olm::{Account, PrivateCrossSigningIdentity, ReadOnlyAccount}, + session_manager::GroupSessionCache, store::{CryptoStore, MemoryStore, Store}, verification::VerificationMachine, }; @@ -342,7 +343,6 @@ mod test { let user_id = user_id(); let device_id = device_id(); - let outbound_sessions = Arc::new(DashMap::new()); let users_for_key_claim = Arc::new(DashMap::new()); let account = ReadOnlyAccount::new(&user_id, &device_id); let store: Arc> = Arc::new(Box::new(MemoryStore::new())); @@ -363,11 +363,13 @@ mod test { store: store.clone(), }; + let session_cache = GroupSessionCache::new(store.clone()); + let key_request = KeyRequestMachine::new( user_id, device_id, store.clone(), - outbound_sessions, + session_cache, users_for_key_claim.clone(), ); diff --git a/matrix_sdk_crypto/src/store/memorystore.rs b/matrix_sdk_crypto/src/store/memorystore.rs index 3d249c82..e9ca5307 100644 --- a/matrix_sdk_crypto/src/store/memorystore.rs +++ b/matrix_sdk_crypto/src/store/memorystore.rs @@ -20,8 +20,10 @@ use std::{ use dashmap::{DashMap, DashSet}; use matrix_sdk_common::{ async_trait, + events::room_key_request::RequestedKeyInfo, identifiers::{DeviceId, DeviceIdBox, RoomId, UserId}, locks::Mutex, + uuid::Uuid, }; use super::{ @@ -30,9 +32,17 @@ use super::{ }; use crate::{ identities::{ReadOnlyDevice, UserIdentities}, + key_request::OutgoingKeyRequest, olm::{OutboundGroupSession, PrivateCrossSigningIdentity}, }; +fn encode_key_info(info: &RequestedKeyInfo) -> String { + format!( + "{}{}{}{}", + info.room_id, info.sender_key, info.algorithm, info.session_id + ) +} + /// An in-memory only store that will forget all the E2EE key once it's dropped. #[derive(Debug, Clone)] pub struct MemoryStore { @@ -43,7 +53,8 @@ pub struct MemoryStore { olm_hashes: Arc>>, devices: DeviceStore, identities: Arc>, - values: Arc>, + outgoing_key_requests: Arc>, + key_requests_by_info: Arc>, } impl Default for MemoryStore { @@ -56,7 +67,8 @@ impl Default for MemoryStore { olm_hashes: Arc::new(DashMap::new()), devices: DeviceStore::new(), identities: Arc::new(DashMap::new()), - values: Arc::new(DashMap::new()), + outgoing_key_requests: Arc::new(DashMap::new()), + key_requests_by_info: Arc::new(DashMap::new()), } } } @@ -103,6 +115,10 @@ impl CryptoStore for MemoryStore { Ok(()) } + async fn load_identity(&self) -> Result> { + Ok(None) + } + async fn save_changes(&self, mut changes: Changes) -> Result<()> { self.save_sessions(changes.sessions).await; self.save_inbound_group_sessions(changes.inbound_group_sessions) @@ -130,6 +146,14 @@ impl CryptoStore for MemoryStore { .insert(hash.hash.clone()); } + for key_request in changes.key_requests { + let id = key_request.request_id; + let info_string = encode_key_info(&key_request.info); + + self.outgoing_key_requests.insert(id, key_request); + self.key_requests_by_info.insert(info_string, id); + } + Ok(()) } @@ -152,9 +176,11 @@ impl CryptoStore for MemoryStore { Ok(self.inbound_group_sessions.get_all()) } - fn users_for_key_query(&self) -> HashSet { - #[allow(clippy::map_clone)] - self.users_for_key_query.iter().map(|u| u.clone()).collect() + async fn get_outbound_group_sessions( + &self, + _: &RoomId, + ) -> Result> { + Ok(None) } fn is_user_tracked(&self, user_id: &UserId) -> bool { @@ -165,6 +191,11 @@ impl CryptoStore for MemoryStore { !self.users_for_key_query.is_empty() } + fn users_for_key_query(&self) -> HashSet { + #[allow(clippy::map_clone)] + self.users_for_key_query.iter().map(|u| u.clone()).collect() + } + async fn update_tracked_user(&self, user: &UserId, dirty: bool) -> Result { // TODO to prevent a race between the sync and a key query in flight we // need to have an additional state to mention that the user changed. @@ -207,24 +238,6 @@ impl CryptoStore for MemoryStore { Ok(self.identities.get(user_id).map(|i| i.clone())) } - async fn save_value(&self, key: String, value: String) -> Result<()> { - self.values.insert(key, value); - Ok(()) - } - - async fn remove_value(&self, key: &str) -> Result<()> { - self.values.remove(key); - Ok(()) - } - - async fn get_value(&self, key: &str) -> Result> { - Ok(self.values.get(key).map(|v| v.to_owned())) - } - - async fn load_identity(&self) -> Result> { - Ok(None) - } - async fn is_message_known(&self, message_hash: &crate::olm::OlmMessageHash) -> Result { Ok(self .olm_hashes @@ -233,11 +246,46 @@ impl CryptoStore for MemoryStore { .contains(&message_hash.hash)) } - async fn get_outbound_group_sessions( + async fn get_outgoing_key_request( &self, - _: &RoomId, - ) -> Result> { - Ok(None) + request_id: Uuid, + ) -> Result> { + Ok(self + .outgoing_key_requests + .get(&request_id) + .map(|r| r.clone())) + } + + async fn get_key_request_by_info( + &self, + key_info: &RequestedKeyInfo, + ) -> Result> { + let key_info_string = encode_key_info(key_info); + + Ok(self + .key_requests_by_info + .get(&key_info_string) + .and_then(|i| self.outgoing_key_requests.get(&i).map(|r| r.clone()))) + } + + async fn get_unsent_key_requests(&self) -> Result> { + Ok(self + .outgoing_key_requests + .iter() + .filter(|i| !i.value().sent_out) + .map(|i| i.value().clone()) + .collect()) + } + + async fn delete_outgoing_key_request(&self, request_id: Uuid) -> Result<()> { + self.outgoing_key_requests + .remove(&request_id) + .and_then(|(_, i)| { + let key_info_string = encode_key_info(&i.info); + self.key_requests_by_info.remove(&key_info_string) + }); + + Ok(()) } } diff --git a/matrix_sdk_crypto/src/store/mod.rs b/matrix_sdk_crypto/src/store/mod.rs index a837dfd0..1982e1cc 100644 --- a/matrix_sdk_crypto/src/store/mod.rs +++ b/matrix_sdk_crypto/src/store/mod.rs @@ -57,23 +57,25 @@ use std::{ }; use olm_rs::errors::{OlmAccountError, OlmGroupSessionError, OlmSessionError}; -use serde::{Deserialize, Serialize}; use serde_json::Error as SerdeError; use thiserror::Error; use matrix_sdk_common::{ async_trait, + events::room_key_request::RequestedKeyInfo, identifiers::{ DeviceId, DeviceIdBox, DeviceKeyAlgorithm, Error as IdentifierValidationError, RoomId, UserId, }, locks::Mutex, + uuid::Uuid, AsyncTraitDeps, }; use crate::{ error::SessionUnpicklingError, identities::{Device, ReadOnlyDevice, UserDevices, UserIdentities}, + key_request::OutgoingKeyRequest, olm::{ InboundGroupSession, OlmMessageHash, OutboundGroupSession, PrivateCrossSigningIdentity, ReadOnlyAccount, Session, @@ -108,6 +110,7 @@ pub struct Changes { pub inbound_group_sessions: Vec, pub outbound_group_sessions: Vec, pub identities: IdentityChanges, + pub key_requests: Vec, pub devices: DeviceChanges, } @@ -257,24 +260,6 @@ impl Store { device_owner_identity, })) } - - pub async fn get_object Deserialize<'b>>(&self, key: &str) -> Result> { - if let Some(value) = self.get_value(key).await? { - Ok(Some(serde_json::from_str(&value)?)) - } else { - Ok(None) - } - } - - pub async fn save_object(&self, key: &str, value: &impl Serialize) -> Result<()> { - let value = serde_json::to_string(value)?; - self.save_value(key.to_owned(), value).await - } - - pub async fn delete_object(&self, key: &str) -> Result<()> { - self.inner.remove_value(key).await?; - Ok(()) - } } impl Deref for Store { @@ -438,15 +423,41 @@ pub trait CryptoStore: AsyncTraitDeps { /// * `user_id` - The user for which we should get the identity. async fn get_user_identity(&self, user_id: &UserId) -> Result>; - /// Save a serializeable object in the store. - async fn save_value(&self, key: String, value: String) -> Result<()>; - - /// Remove a value from the store. - async fn remove_value(&self, key: &str) -> Result<()>; - - /// Load a serializeable object from the store. - async fn get_value(&self, key: &str) -> Result>; - /// Check if a hash for an Olm message stored in the database. async fn is_message_known(&self, message_hash: &OlmMessageHash) -> Result; + + /// Get an outoing key request that we created that matches the given + /// request id. + /// + /// # Arguments + /// + /// * `request_id` - The unique request id that identifies this outgoing key + /// request. + async fn get_outgoing_key_request( + &self, + request_id: Uuid, + ) -> Result>; + + /// Get an outoing key request that we created that matches the given + /// requested key info. + /// + /// # Arguments + /// + /// * `key_info` - The key info of an outgoing key request. + async fn get_key_request_by_info( + &self, + key_info: &RequestedKeyInfo, + ) -> Result>; + + /// Get all outgoing key requests that we have in the store. + async fn get_unsent_key_requests(&self) -> Result>; + + /// Delete an outoing key request that we created that matches the given + /// request id. + /// + /// # Arguments + /// + /// * `request_id` - The unique request id that identifies this outgoing key + /// request. + async fn delete_outgoing_key_request(&self, request_id: Uuid) -> Result<()>; } diff --git a/matrix_sdk_crypto/src/store/sled.rs b/matrix_sdk_crypto/src/store/sled.rs index 00950d6f..b6928e78 100644 --- a/matrix_sdk_crypto/src/store/sled.rs +++ b/matrix_sdk_crypto/src/store/sled.rs @@ -29,9 +29,12 @@ use sled::{ use matrix_sdk_common::{ async_trait, + events::room_key_request::RequestedKeyInfo, identifiers::{DeviceId, DeviceIdBox, RoomId, UserId}, locks::Mutex, + uuid, }; +use uuid::Uuid; use super::{ caches::SessionStore, Changes, CryptoStore, CryptoStoreError, InboundGroupSession, PickleKey, @@ -39,6 +42,7 @@ use super::{ }; use crate::{ identities::{ReadOnlyDevice, UserIdentities}, + key_request::OutgoingKeyRequest, olm::{OutboundGroupSession, PickledInboundGroupSession, PrivateCrossSigningIdentity}, }; @@ -51,6 +55,28 @@ trait EncodeKey { fn encode(&self) -> Vec; } +impl EncodeKey for Uuid { + fn encode(&self) -> Vec { + self.as_u128().to_be_bytes().to_vec() + } +} + +impl EncodeKey for &RequestedKeyInfo { + fn encode(&self) -> Vec { + [ + self.room_id.as_bytes(), + &[Self::SEPARATOR], + self.sender_key.as_bytes(), + &[Self::SEPARATOR], + self.algorithm.as_ref().as_bytes(), + &[Self::SEPARATOR], + self.session_id.as_bytes(), + &[Self::SEPARATOR], + ] + .concat() + } +} + impl EncodeKey for &UserId { fn encode(&self) -> Vec { self.as_str().encode() @@ -122,12 +148,15 @@ pub struct SledStore { inbound_group_sessions: Tree, outbound_group_sessions: Tree, + outgoing_key_requests: Tree, + unsent_key_requests: Tree, + key_requests_by_info: Tree, + devices: Tree, identities: Tree, tracked_users: Tree, users_for_key_query: Tree, - values: Tree, } impl std::fmt::Debug for SledStore { @@ -178,13 +207,17 @@ impl SledStore { let sessions = db.open_tree("session")?; let inbound_group_sessions = db.open_tree("inbound_group_sessions")?; let outbound_group_sessions = db.open_tree("outbound_group_sessions")?; + let tracked_users = db.open_tree("tracked_users")?; let users_for_key_query = db.open_tree("users_for_key_query")?; let olm_hashes = db.open_tree("olm_hashes")?; let devices = db.open_tree("devices")?; let identities = db.open_tree("identities")?; - let values = db.open_tree("values")?; + + let outgoing_key_requests = db.open_tree("outgoing_key_requests")?; + let unsent_key_requests = db.open_tree("unsent_key_requests")?; + let key_requests_by_info = db.open_tree("key_requests_by_info")?; let session_cache = SessionStore::new(); @@ -208,12 +241,14 @@ impl SledStore { users_for_key_query_cache: DashSet::new().into(), inbound_group_sessions, outbound_group_sessions, + outgoing_key_requests, + unsent_key_requests, + key_requests_by_info, devices, tracked_users, users_for_key_query, olm_hashes, identities, - values, }) } @@ -332,6 +367,7 @@ impl SledStore { let identity_changes = changes.identities; let olm_hashes = changes.message_hashes; + let key_requests = changes.key_requests; let ret: Result<(), TransactionError> = ( &self.account, @@ -342,6 +378,9 @@ impl SledStore { &self.inbound_group_sessions, &self.outbound_group_sessions, &self.olm_hashes, + &self.outgoing_key_requests, + &self.unsent_key_requests, + &self.key_requests_by_info, ) .transaction( |( @@ -353,6 +392,9 @@ impl SledStore { inbound_sessions, outbound_sessions, hashes, + outgoing_key_requests, + unsent_key_requests, + key_requests_by_info, )| { if let Some(a) = &account_pickle { account.insert( @@ -420,6 +462,31 @@ impl SledStore { )?; } + for key_request in &key_requests { + key_requests_by_info.insert( + (&key_request.info).encode(), + key_request.request_id.encode(), + )?; + + let key_request_id = key_request.request_id.encode(); + + if key_request.sent_out { + unsent_key_requests.remove(key_request_id.clone())?; + outgoing_key_requests.insert( + key_request_id, + serde_json::to_vec(&key_request) + .map_err(ConflictableTransactionError::Abort)?, + )?; + } else { + outgoing_key_requests.remove(key_request_id.clone())?; + unsent_key_requests.insert( + key_request_id, + serde_json::to_vec(&key_request) + .map_err(ConflictableTransactionError::Abort)?, + )?; + } + } + Ok(()) }, ); @@ -429,6 +496,28 @@ impl SledStore { Ok(()) } + + async fn get_outgoing_key_request_helper( + &self, + id: &[u8], + ) -> Result> { + let request = self + .outgoing_key_requests + .get(id)? + .map(|r| serde_json::from_slice(&r)) + .transpose()?; + + let request = if request.is_none() { + self.unsent_key_requests + .get(id)? + .map(|r| serde_json::from_slice(&r)) + .transpose()? + } else { + request + }; + + Ok(request) + } } #[async_trait] @@ -472,6 +561,19 @@ impl CryptoStore for SledStore { self.save_changes(changes).await } + async fn load_identity(&self) -> Result> { + if let Some(i) = self.private_identity.get("identity".encode())? { + let pickle = serde_json::from_slice(&i)?; + Ok(Some( + PrivateCrossSigningIdentity::from_pickle(pickle, self.get_pickle_key()) + .await + .map_err(|_| CryptoStoreError::UnpicklingError)?, + )) + } else { + Ok(None) + } + } + async fn save_changes(&self, changes: Changes) -> Result<()> { self.save_changes(changes).await } @@ -539,12 +641,11 @@ impl CryptoStore for SledStore { .collect()) } - fn users_for_key_query(&self) -> HashSet { - #[allow(clippy::map_clone)] - self.users_for_key_query_cache - .iter() - .map(|u| u.clone()) - .collect() + async fn get_outbound_group_sessions( + &self, + room_id: &RoomId, + ) -> Result> { + self.load_outbound_group_session(room_id).await } fn is_user_tracked(&self, user_id: &UserId) -> bool { @@ -555,6 +656,14 @@ impl CryptoStore for SledStore { !self.users_for_key_query_cache.is_empty() } + fn users_for_key_query(&self) -> HashSet { + #[allow(clippy::map_clone)] + self.users_for_key_query_cache + .iter() + .map(|u| u.clone()) + .collect() + } + async fn update_tracked_user(&self, user: &UserId, dirty: bool) -> Result { let already_added = self.tracked_users_cache.insert(user.clone()); @@ -605,48 +714,80 @@ impl CryptoStore for SledStore { .transpose()?) } - async fn save_value(&self, key: String, value: String) -> Result<()> { - self.values.insert(key.as_str().encode(), value.as_str())?; - self.inner.flush_async().await?; - Ok(()) - } - - async fn remove_value(&self, key: &str) -> Result<()> { - self.values.remove(key.encode())?; - Ok(()) - } - - async fn get_value(&self, key: &str) -> Result> { - Ok(self - .values - .get(key.encode())? - .map(|v| String::from_utf8_lossy(&v).to_string())) - } - - async fn load_identity(&self) -> Result> { - if let Some(i) = self.private_identity.get("identity".encode())? { - let pickle = serde_json::from_slice(&i)?; - Ok(Some( - PrivateCrossSigningIdentity::from_pickle(pickle, self.get_pickle_key()) - .await - .map_err(|_| CryptoStoreError::UnpicklingError)?, - )) - } else { - Ok(None) - } - } - async fn is_message_known(&self, message_hash: &crate::olm::OlmMessageHash) -> Result { Ok(self .olm_hashes .contains_key(serde_json::to_vec(message_hash)?)?) } - async fn get_outbound_group_sessions( + async fn get_outgoing_key_request( &self, - room_id: &RoomId, - ) -> Result> { - self.load_outbound_group_session(room_id).await + request_id: Uuid, + ) -> Result> { + let request_id = request_id.encode(); + + self.get_outgoing_key_request_helper(&request_id).await + } + + async fn get_key_request_by_info( + &self, + key_info: &RequestedKeyInfo, + ) -> Result> { + let id = self.key_requests_by_info.get(key_info.encode())?; + + if let Some(id) = id { + self.get_outgoing_key_request_helper(&id).await + } else { + Ok(None) + } + } + + async fn get_unsent_key_requests(&self) -> Result> { + let requests: Result> = self + .unsent_key_requests + .iter() + .map(|i| serde_json::from_slice(&i?.1).map_err(CryptoStoreError::from)) + .collect(); + + requests + } + + async fn delete_outgoing_key_request(&self, request_id: Uuid) -> Result<()> { + let ret: Result<(), TransactionError> = ( + &self.outgoing_key_requests, + &self.unsent_key_requests, + &self.key_requests_by_info, + ) + .transaction( + |(outgoing_key_requests, unsent_key_requests, key_requests_by_info)| { + let sent_request: Option = outgoing_key_requests + .remove(request_id.encode())? + .map(|r| serde_json::from_slice(&r)) + .transpose() + .map_err(ConflictableTransactionError::Abort)?; + + let unsent_request: Option = unsent_key_requests + .remove(request_id.encode())? + .map(|r| serde_json::from_slice(&r)) + .transpose() + .map_err(ConflictableTransactionError::Abort)?; + + if let Some(request) = sent_request { + key_requests_by_info.remove((&request.info).encode())?; + } + + if let Some(request) = unsent_request { + key_requests_by_info.remove((&request.info).encode())?; + } + + Ok(()) + }, + ); + + ret?; + self.inner.flush_async().await?; + + Ok(()) } } @@ -665,14 +806,16 @@ mod test { }; use matrix_sdk_common::{ api::r0::keys::SignedKey, - identifiers::{room_id, user_id, DeviceId, UserId}, + events::room_key_request::RequestedKeyInfo, + identifiers::{room_id, user_id, DeviceId, EventEncryptionAlgorithm, UserId}, + uuid::Uuid, }; use matrix_sdk_test::async_test; use olm_rs::outbound_group_session::OlmOutboundGroupSession; use std::collections::BTreeMap; use tempfile::tempdir; - use super::{CryptoStore, SledStore}; + use super::{CryptoStore, OutgoingKeyRequest, SledStore}; fn alice_id() -> UserId { user_id!("@alice:example.org") @@ -1184,21 +1327,6 @@ mod test { assert_eq!(identity.user_id(), loaded_identity.user_id()); } - #[async_test] - async fn key_value_saving() { - let (_, store, _dir) = get_loaded_store().await; - let key = "test_key".to_string(); - let value = "secret value".to_string(); - - store.save_value(key.clone(), value.clone()).await.unwrap(); - let stored_value = store.get_value(&key).await.unwrap().unwrap(); - - assert_eq!(value, stored_value); - - store.remove_value(&key).await.unwrap(); - assert!(store.get_value(&key).await.unwrap().is_none()); - } - #[async_test] async fn olm_hash_saving() { let (_, store, _dir) = get_loaded_store().await; @@ -1215,4 +1343,63 @@ mod test { store.save_changes(changes).await.unwrap(); assert!(store.is_message_known(&hash).await.unwrap()); } + + #[async_test] + async fn key_request_saving() { + let (account, store, _dir) = get_loaded_store().await; + + let id = Uuid::new_v4(); + let info = RequestedKeyInfo { + algorithm: EventEncryptionAlgorithm::MegolmV1AesSha2, + room_id: room_id!("!test:localhost"), + sender_key: "test_sender_key".to_string(), + session_id: "test_session_id".to_string(), + }; + + let request = OutgoingKeyRequest { + request_recipient: account.user_id().to_owned(), + request_id: id, + info: info.clone(), + sent_out: false, + }; + + assert!(store.get_outgoing_key_request(id).await.unwrap().is_none()); + + let mut changes = Changes::default(); + changes.key_requests.push(request.clone()); + store.save_changes(changes).await.unwrap(); + + let request = Some(request); + + let stored_request = store.get_outgoing_key_request(id).await.unwrap(); + assert_eq!(request, stored_request); + + let stored_request = store.get_key_request_by_info(&info).await.unwrap(); + assert_eq!(request, stored_request); + assert!(!store.get_unsent_key_requests().await.unwrap().is_empty()); + + let request = OutgoingKeyRequest { + request_recipient: account.user_id().to_owned(), + request_id: id, + info: info.clone(), + sent_out: true, + }; + + let mut changes = Changes::default(); + changes.key_requests.push(request.clone()); + store.save_changes(changes).await.unwrap(); + + assert!(store.get_unsent_key_requests().await.unwrap().is_empty()); + let stored_request = store.get_outgoing_key_request(id).await.unwrap(); + assert_eq!(Some(request), stored_request); + + store.delete_outgoing_key_request(id).await.unwrap(); + + let stored_request = store.get_outgoing_key_request(id).await.unwrap(); + assert_eq!(None, stored_request); + + let stored_request = store.get_key_request_by_info(&info).await.unwrap(); + assert_eq!(None, stored_request); + assert!(store.get_unsent_key_requests().await.unwrap().is_empty()); + } }