From 02331fa3258a49e8964ad88d6feea039350ec3b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Damir=20Jeli=C4=87?= Date: Thu, 15 Apr 2021 12:45:51 +0200 Subject: [PATCH] crypto: Add specialized methods to store outgoing key requests --- matrix_sdk_crypto/src/key_request.rs | 121 +++++------ matrix_sdk_crypto/src/machine.rs | 2 +- matrix_sdk_crypto/src/store/memorystore.rs | 93 +++++--- matrix_sdk_crypto/src/store/mod.rs | 64 +++--- matrix_sdk_crypto/src/store/sled.rs | 234 +++++++++++++++------ 5 files changed, 327 insertions(+), 187 deletions(-) diff --git a/matrix_sdk_crypto/src/key_request.rs b/matrix_sdk_crypto/src/key_request.rs index aaabc80d..f8fab2ff 100644 --- a/matrix_sdk_crypto/src/key_request.rs +++ b/matrix_sdk_crypto/src/key_request.rs @@ -42,7 +42,7 @@ use crate::{ error::{OlmError, OlmResult}, olm::{InboundGroupSession, OutboundGroupSession, Session, ShareState}, requests::{OutgoingRequest, ToDeviceRequest}, - store::{CryptoStoreError, Store}, + store::{Changes, CryptoStoreError, Store}, Device, }; @@ -137,32 +137,24 @@ pub(crate) struct KeyRequestMachine { users_for_key_claim: Arc>>, } -#[derive(Debug, Serialize, Deserialize)] -struct OugoingKeyInfo { - request_id: Uuid, - info: RequestedKeyInfo, - sent_out: bool, +/// A struct describing an outgoing key request. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct OutgoingKeyRequest { + /// The unique id of the key request. + pub request_id: Uuid, + /// The info of the requested key. + pub info: RequestedKeyInfo, + /// Has the request been sent out. + pub sent_out: bool, } -trait Encode { - fn encode(&self) -> String; -} - -impl Encode for RequestedKeyInfo { - fn encode(&self) -> String { - format!( - "{}|{}|{}|{}", - self.sender_key, self.room_id, self.session_id, self.algorithm - ) - } -} - -impl Encode for ForwardedRoomKeyToDeviceEventContent { - fn encode(&self) -> String { - format!( - "{}|{}|{}|{}", - self.sender_key, self.room_id, self.session_id, self.algorithm - ) +impl PartialEq for OutgoingKeyRequest { + fn eq(&self, other: &Self) -> bool { + self.request_id == other.request_id + && self.info.algorithm == other.info.algorithm + && self.info.room_id == other.info.room_id + && self.info.session_id == other.info.session_id + && self.info.sender_key == other.info.sender_key } } @@ -246,6 +238,7 @@ impl KeyRequestMachine { /// key request queue. pub async fn collect_incoming_key_requests(&self) -> OlmResult> { let mut changed_sessions = Vec::new(); + for item in self.incoming_key_requests.iter() { let event = item.value(); if let Some(s) = self.handle_key_request(event).await? { @@ -534,9 +527,9 @@ impl KeyRequestMachine { session_id: session_id.to_owned(), }; - let id: Option = self.store.get_object(&key_info.encode()).await?; + let request = self.store.get_key_request_by_info(&key_info).await?; - if id.is_some() { + if request.is_some() { // We already sent out a request for this key, nothing to do. return Ok(()); } @@ -554,13 +547,13 @@ impl KeyRequestMachine { let request = wrap_key_request_content(self.user_id().clone(), id, &content)?; - let info = OugoingKeyInfo { + let info = OutgoingKeyRequest { request_id: id, info: content.body.unwrap(), sent_out: false, }; - self.save_outgoing_key_info(id, info).await?; + self.save_outgoing_key_info(info).await?; self.outgoing_to_device_requests.insert(id, request); Ok(()) @@ -569,16 +562,11 @@ impl KeyRequestMachine { /// Save an outgoing key info. async fn save_outgoing_key_info( &self, - id: Uuid, - info: OugoingKeyInfo, + info: OutgoingKeyRequest, ) -> Result<(), CryptoStoreError> { - // TODO we'll want to use a transaction to store those atomically. - // To allow this we'll need to rework our cryptostore trait to return - // a transaction trait and the transaction trait will have the save_X - // methods. - let id_string = id.to_string(); - self.store.save_object(&id_string, &info).await?; - self.store.save_object(&info.info.encode(), &id).await?; + let mut changes = Changes::default(); + changes.key_requests.push(info); + self.store.save_changes(changes).await?; Ok(()) } @@ -587,44 +575,43 @@ impl KeyRequestMachine { async fn get_key_info( &self, content: &ForwardedRoomKeyToDeviceEventContent, - ) -> Result, CryptoStoreError> { - let id: Option = self.store.get_object(&content.encode()).await?; + ) -> Result, CryptoStoreError> { + let info = RequestedKeyInfo { + algorithm: content.algorithm.clone(), + room_id: content.room_id.clone(), + sender_key: content.sender_key.clone(), + session_id: content.session_id.clone(), + }; - if let Some(id) = id { - self.store.get_object(&id.to_string()).await - } else { - Ok(None) - } + self.store.get_key_request_by_info(&info).await } /// Delete the given outgoing key info. - async fn delete_key_info(&self, info: &OugoingKeyInfo) -> Result<(), CryptoStoreError> { + async fn delete_key_info(&self, info: &OutgoingKeyRequest) -> Result<(), CryptoStoreError> { self.store - .delete_object(&info.request_id.to_string()) - .await?; - self.store.delete_object(&info.info.encode()).await?; - - Ok(()) + .delete_outgoing_key_request(info.request_id) + .await } /// Mark the outgoing request as sent. - pub async fn mark_outgoing_request_as_sent(&self, id: &Uuid) -> Result<(), CryptoStoreError> { - self.outgoing_to_device_requests.remove(id); - let info: Option = self.store.get_object(&id.to_string()).await?; + pub async fn mark_outgoing_request_as_sent(&self, id: Uuid) -> Result<(), CryptoStoreError> { + let info = self.store.get_outgoing_key_request(id).await?; if let Some(mut info) = info { trace!("Marking outgoing key request as sent {:#?}", info); info.sent_out = true; - self.save_outgoing_key_info(*id, info).await?; + self.save_outgoing_key_info(info).await?; } + self.outgoing_to_device_requests.remove(&id); + Ok(()) } /// Mark the given outgoing key info as done. /// /// This will queue up a request cancelation. - async fn mark_as_done(&self, key_info: OugoingKeyInfo) -> Result<(), CryptoStoreError> { + async fn mark_as_done(&self, key_info: OutgoingKeyRequest) -> Result<(), CryptoStoreError> { // TODO perhaps only remove the key info if the first known index is 0. trace!( "Successfully received a forwarded room key for {:#?}", @@ -847,7 +834,7 @@ mod test { let id = request.request_id; drop(request); - machine.mark_outgoing_request_as_sent(&id).await.unwrap(); + machine.mark_outgoing_request_as_sent(id).await.unwrap(); assert!(machine.outgoing_to_device_requests.is_empty()); } @@ -873,7 +860,7 @@ mod test { let id = request.request_id; drop(request); - machine.mark_outgoing_request_as_sent(&id).await.unwrap(); + machine.mark_outgoing_request_as_sent(id).await.unwrap(); let export = session.export_at_index(10).await; @@ -915,7 +902,7 @@ mod test { let request = machine.outgoing_to_device_requests.iter().next().unwrap(); let id = request.request_id; drop(request); - machine.mark_outgoing_request_as_sent(&id).await.unwrap(); + machine.mark_outgoing_request_as_sent(id).await.unwrap(); machine .create_outgoing_key_request( @@ -930,7 +917,7 @@ mod test { let id = request.request_id; drop(request); - machine.mark_outgoing_request_as_sent(&id).await.unwrap(); + machine.mark_outgoing_request_as_sent(id).await.unwrap(); let export = session.export_at_index(15).await; @@ -1148,7 +1135,7 @@ mod test { drop(request); alice_machine - .mark_outgoing_request_as_sent(&id) + .mark_outgoing_request_as_sent(id) .await .unwrap(); @@ -1186,10 +1173,7 @@ mod test { let content: EncryptedEventContent = serde_json::from_str(content.get()).unwrap(); drop(request); - bob_machine - .mark_outgoing_request_as_sent(&id) - .await - .unwrap(); + bob_machine.mark_outgoing_request_as_sent(id).await.unwrap(); let event = ToDeviceEvent { sender: bob_id(), @@ -1317,7 +1301,7 @@ mod test { drop(request); alice_machine - .mark_outgoing_request_as_sent(&id) + .mark_outgoing_request_as_sent(id) .await .unwrap(); @@ -1378,10 +1362,7 @@ mod test { let content: EncryptedEventContent = serde_json::from_str(content.get()).unwrap(); drop(request); - bob_machine - .mark_outgoing_request_as_sent(&id) - .await - .unwrap(); + bob_machine.mark_outgoing_request_as_sent(id).await.unwrap(); let event = ToDeviceEvent { sender: bob_id(), diff --git a/matrix_sdk_crypto/src/machine.rs b/matrix_sdk_crypto/src/machine.rs index 415b80f7..601be381 100644 --- a/matrix_sdk_crypto/src/machine.rs +++ b/matrix_sdk_crypto/src/machine.rs @@ -751,7 +751,7 @@ impl OlmMachine { async fn mark_to_device_request_as_sent(&self, request_id: &Uuid) -> StoreResult<()> { self.verification_machine.mark_request_as_sent(request_id); self.key_request_machine - .mark_outgoing_request_as_sent(request_id) + .mark_outgoing_request_as_sent(*request_id) .await?; self.group_session_manager .mark_request_as_sent(request_id) diff --git a/matrix_sdk_crypto/src/store/memorystore.rs b/matrix_sdk_crypto/src/store/memorystore.rs index 3d249c82..27993e25 100644 --- a/matrix_sdk_crypto/src/store/memorystore.rs +++ b/matrix_sdk_crypto/src/store/memorystore.rs @@ -20,8 +20,10 @@ use std::{ use dashmap::{DashMap, DashSet}; use matrix_sdk_common::{ async_trait, + events::room_key_request::RequestedKeyInfo, identifiers::{DeviceId, DeviceIdBox, RoomId, UserId}, locks::Mutex, + uuid::Uuid, }; use super::{ @@ -30,9 +32,17 @@ use super::{ }; use crate::{ identities::{ReadOnlyDevice, UserIdentities}, + key_request::OutgoingKeyRequest, olm::{OutboundGroupSession, PrivateCrossSigningIdentity}, }; +fn encode_key_info(info: &RequestedKeyInfo) -> String { + format!( + "{}{}{}{}", + &info.room_id, &info.sender_key, &info.algorithm, &info.session_id + ) +} + /// An in-memory only store that will forget all the E2EE key once it's dropped. #[derive(Debug, Clone)] pub struct MemoryStore { @@ -43,7 +53,8 @@ pub struct MemoryStore { olm_hashes: Arc>>, devices: DeviceStore, identities: Arc>, - values: Arc>, + outgoing_key_requests: Arc>, + key_requests_by_info: Arc>, } impl Default for MemoryStore { @@ -56,7 +67,8 @@ impl Default for MemoryStore { olm_hashes: Arc::new(DashMap::new()), devices: DeviceStore::new(), identities: Arc::new(DashMap::new()), - values: Arc::new(DashMap::new()), + outgoing_key_requests: Arc::new(DashMap::new()), + key_requests_by_info: Arc::new(DashMap::new()), } } } @@ -103,6 +115,10 @@ impl CryptoStore for MemoryStore { Ok(()) } + async fn load_identity(&self) -> Result> { + Ok(None) + } + async fn save_changes(&self, mut changes: Changes) -> Result<()> { self.save_sessions(changes.sessions).await; self.save_inbound_group_sessions(changes.inbound_group_sessions) @@ -130,6 +146,14 @@ impl CryptoStore for MemoryStore { .insert(hash.hash.clone()); } + for key_request in changes.key_requests { + let id = key_request.request_id; + let info_string = encode_key_info(&key_request.info); + + self.outgoing_key_requests.insert(id, key_request); + self.key_requests_by_info.insert(info_string, id); + } + Ok(()) } @@ -152,9 +176,11 @@ impl CryptoStore for MemoryStore { Ok(self.inbound_group_sessions.get_all()) } - fn users_for_key_query(&self) -> HashSet { - #[allow(clippy::map_clone)] - self.users_for_key_query.iter().map(|u| u.clone()).collect() + async fn get_outbound_group_sessions( + &self, + _: &RoomId, + ) -> Result> { + Ok(None) } fn is_user_tracked(&self, user_id: &UserId) -> bool { @@ -165,6 +191,11 @@ impl CryptoStore for MemoryStore { !self.users_for_key_query.is_empty() } + fn users_for_key_query(&self) -> HashSet { + #[allow(clippy::map_clone)] + self.users_for_key_query.iter().map(|u| u.clone()).collect() + } + async fn update_tracked_user(&self, user: &UserId, dirty: bool) -> Result { // TODO to prevent a race between the sync and a key query in flight we // need to have an additional state to mention that the user changed. @@ -207,24 +238,6 @@ impl CryptoStore for MemoryStore { Ok(self.identities.get(user_id).map(|i| i.clone())) } - async fn save_value(&self, key: String, value: String) -> Result<()> { - self.values.insert(key, value); - Ok(()) - } - - async fn remove_value(&self, key: &str) -> Result<()> { - self.values.remove(key); - Ok(()) - } - - async fn get_value(&self, key: &str) -> Result> { - Ok(self.values.get(key).map(|v| v.to_owned())) - } - - async fn load_identity(&self) -> Result> { - Ok(None) - } - async fn is_message_known(&self, message_hash: &crate::olm::OlmMessageHash) -> Result { Ok(self .olm_hashes @@ -233,11 +246,37 @@ impl CryptoStore for MemoryStore { .contains(&message_hash.hash)) } - async fn get_outbound_group_sessions( + async fn get_outgoing_key_request( &self, - _: &RoomId, - ) -> Result> { - Ok(None) + request_id: Uuid, + ) -> Result> { + Ok(self + .outgoing_key_requests + .get(&request_id) + .map(|r| r.clone())) + } + + async fn get_key_request_by_info( + &self, + key_info: &RequestedKeyInfo, + ) -> Result> { + let key_info_string = encode_key_info(key_info); + + Ok(self + .key_requests_by_info + .get(&key_info_string) + .and_then(|i| self.outgoing_key_requests.get(&i).map(|r| r.clone()))) + } + + async fn delete_outgoing_key_request(&self, request_id: Uuid) -> Result<()> { + self.outgoing_key_requests + .remove(&request_id) + .and_then(|(_, i)| { + let key_info_string = encode_key_info(&i.info); + self.key_requests_by_info.remove(&key_info_string) + }); + + Ok(()) } } diff --git a/matrix_sdk_crypto/src/store/mod.rs b/matrix_sdk_crypto/src/store/mod.rs index a837dfd0..387bd610 100644 --- a/matrix_sdk_crypto/src/store/mod.rs +++ b/matrix_sdk_crypto/src/store/mod.rs @@ -57,23 +57,25 @@ use std::{ }; use olm_rs::errors::{OlmAccountError, OlmGroupSessionError, OlmSessionError}; -use serde::{Deserialize, Serialize}; use serde_json::Error as SerdeError; use thiserror::Error; use matrix_sdk_common::{ async_trait, + events::room_key_request::RequestedKeyInfo, identifiers::{ DeviceId, DeviceIdBox, DeviceKeyAlgorithm, Error as IdentifierValidationError, RoomId, UserId, }, locks::Mutex, + uuid::Uuid, AsyncTraitDeps, }; use crate::{ error::SessionUnpicklingError, identities::{Device, ReadOnlyDevice, UserDevices, UserIdentities}, + key_request::OutgoingKeyRequest, olm::{ InboundGroupSession, OlmMessageHash, OutboundGroupSession, PrivateCrossSigningIdentity, ReadOnlyAccount, Session, @@ -108,6 +110,7 @@ pub struct Changes { pub inbound_group_sessions: Vec, pub outbound_group_sessions: Vec, pub identities: IdentityChanges, + pub key_requests: Vec, pub devices: DeviceChanges, } @@ -257,24 +260,6 @@ impl Store { device_owner_identity, })) } - - pub async fn get_object Deserialize<'b>>(&self, key: &str) -> Result> { - if let Some(value) = self.get_value(key).await? { - Ok(Some(serde_json::from_str(&value)?)) - } else { - Ok(None) - } - } - - pub async fn save_object(&self, key: &str, value: &impl Serialize) -> Result<()> { - let value = serde_json::to_string(value)?; - self.save_value(key.to_owned(), value).await - } - - pub async fn delete_object(&self, key: &str) -> Result<()> { - self.inner.remove_value(key).await?; - Ok(()) - } } impl Deref for Store { @@ -438,15 +423,38 @@ pub trait CryptoStore: AsyncTraitDeps { /// * `user_id` - The user for which we should get the identity. async fn get_user_identity(&self, user_id: &UserId) -> Result>; - /// Save a serializeable object in the store. - async fn save_value(&self, key: String, value: String) -> Result<()>; - - /// Remove a value from the store. - async fn remove_value(&self, key: &str) -> Result<()>; - - /// Load a serializeable object from the store. - async fn get_value(&self, key: &str) -> Result>; - /// Check if a hash for an Olm message stored in the database. async fn is_message_known(&self, message_hash: &OlmMessageHash) -> Result; + + /// Get an outoing key request that we created that matches the given + /// request id. + /// + /// # Arguments + /// + /// * `request_id` - The unique request id that identifies this outgoing key + /// request. + async fn get_outgoing_key_request( + &self, + request_id: Uuid, + ) -> Result>; + + /// Get an outoing key request that we created that matches the given + /// requested key info. + /// + /// # Arguments + /// + /// * `key_info` - The key info of an outgoing key request. + async fn get_key_request_by_info( + &self, + key_info: &RequestedKeyInfo, + ) -> Result>; + + /// Delete an outoing key request that we created that matches the given + /// request id. + /// + /// # Arguments + /// + /// * `request_id` - The unique request id that identifies this outgoing key + /// request. + async fn delete_outgoing_key_request(&self, request_id: Uuid) -> Result<()>; } diff --git a/matrix_sdk_crypto/src/store/sled.rs b/matrix_sdk_crypto/src/store/sled.rs index 00950d6f..f177f10d 100644 --- a/matrix_sdk_crypto/src/store/sled.rs +++ b/matrix_sdk_crypto/src/store/sled.rs @@ -29,9 +29,12 @@ use sled::{ use matrix_sdk_common::{ async_trait, + events::room_key_request::RequestedKeyInfo, identifiers::{DeviceId, DeviceIdBox, RoomId, UserId}, locks::Mutex, + uuid, }; +use uuid::Uuid; use super::{ caches::SessionStore, Changes, CryptoStore, CryptoStoreError, InboundGroupSession, PickleKey, @@ -39,6 +42,7 @@ use super::{ }; use crate::{ identities::{ReadOnlyDevice, UserIdentities}, + key_request::OutgoingKeyRequest, olm::{OutboundGroupSession, PickledInboundGroupSession, PrivateCrossSigningIdentity}, }; @@ -51,6 +55,28 @@ trait EncodeKey { fn encode(&self) -> Vec; } +impl EncodeKey for Uuid { + fn encode(&self) -> Vec { + self.as_u128().to_be_bytes().to_vec() + } +} + +impl EncodeKey for &RequestedKeyInfo { + fn encode(&self) -> Vec { + [ + self.room_id.as_bytes(), + &[Self::SEPARATOR], + self.sender_key.as_bytes(), + &[Self::SEPARATOR], + self.algorithm.as_ref().as_bytes(), + &[Self::SEPARATOR], + self.session_id.as_bytes(), + &[Self::SEPARATOR], + ] + .concat() + } +} + impl EncodeKey for &UserId { fn encode(&self) -> Vec { self.as_str().encode() @@ -122,12 +148,14 @@ pub struct SledStore { inbound_group_sessions: Tree, outbound_group_sessions: Tree, + outgoing_key_requests: Tree, + key_requests_by_info: Tree, + devices: Tree, identities: Tree, tracked_users: Tree, users_for_key_query: Tree, - values: Tree, } impl std::fmt::Debug for SledStore { @@ -178,13 +206,16 @@ impl SledStore { let sessions = db.open_tree("session")?; let inbound_group_sessions = db.open_tree("inbound_group_sessions")?; let outbound_group_sessions = db.open_tree("outbound_group_sessions")?; + let tracked_users = db.open_tree("tracked_users")?; let users_for_key_query = db.open_tree("users_for_key_query")?; let olm_hashes = db.open_tree("olm_hashes")?; let devices = db.open_tree("devices")?; let identities = db.open_tree("identities")?; - let values = db.open_tree("values")?; + + let outgoing_key_requests = db.open_tree("outgoing_key_requests")?; + let key_requests_by_info = db.open_tree("key_requests_by_info")?; let session_cache = SessionStore::new(); @@ -208,12 +239,13 @@ impl SledStore { users_for_key_query_cache: DashSet::new().into(), inbound_group_sessions, outbound_group_sessions, + outgoing_key_requests, + key_requests_by_info, devices, tracked_users, users_for_key_query, olm_hashes, identities, - values, }) } @@ -332,6 +364,7 @@ impl SledStore { let identity_changes = changes.identities; let olm_hashes = changes.message_hashes; + let key_requests = changes.key_requests; let ret: Result<(), TransactionError> = ( &self.account, @@ -342,6 +375,8 @@ impl SledStore { &self.inbound_group_sessions, &self.outbound_group_sessions, &self.olm_hashes, + &self.outgoing_key_requests, + &self.key_requests_by_info, ) .transaction( |( @@ -353,6 +388,8 @@ impl SledStore { inbound_sessions, outbound_sessions, hashes, + outgoing_key_requests, + key_requests_by_info, )| { if let Some(a) = &account_pickle { account.insert( @@ -420,6 +457,19 @@ impl SledStore { )?; } + for key_request in &key_requests { + key_requests_by_info.insert( + (&key_request.info).encode(), + key_request.request_id.encode(), + )?; + + outgoing_key_requests.insert( + key_request.request_id.encode(), + serde_json::to_vec(&key_request) + .map_err(ConflictableTransactionError::Abort)?, + )?; + } + Ok(()) }, ); @@ -472,6 +522,19 @@ impl CryptoStore for SledStore { self.save_changes(changes).await } + async fn load_identity(&self) -> Result> { + if let Some(i) = self.private_identity.get("identity".encode())? { + let pickle = serde_json::from_slice(&i)?; + Ok(Some( + PrivateCrossSigningIdentity::from_pickle(pickle, self.get_pickle_key()) + .await + .map_err(|_| CryptoStoreError::UnpicklingError)?, + )) + } else { + Ok(None) + } + } + async fn save_changes(&self, changes: Changes) -> Result<()> { self.save_changes(changes).await } @@ -539,12 +602,11 @@ impl CryptoStore for SledStore { .collect()) } - fn users_for_key_query(&self) -> HashSet { - #[allow(clippy::map_clone)] - self.users_for_key_query_cache - .iter() - .map(|u| u.clone()) - .collect() + async fn get_outbound_group_sessions( + &self, + room_id: &RoomId, + ) -> Result> { + self.load_outbound_group_session(room_id).await } fn is_user_tracked(&self, user_id: &UserId) -> bool { @@ -555,6 +617,14 @@ impl CryptoStore for SledStore { !self.users_for_key_query_cache.is_empty() } + fn users_for_key_query(&self) -> HashSet { + #[allow(clippy::map_clone)] + self.users_for_key_query_cache + .iter() + .map(|u| u.clone()) + .collect() + } + async fn update_tracked_user(&self, user: &UserId, dirty: bool) -> Result { let already_added = self.tracked_users_cache.insert(user.clone()); @@ -605,48 +675,62 @@ impl CryptoStore for SledStore { .transpose()?) } - async fn save_value(&self, key: String, value: String) -> Result<()> { - self.values.insert(key.as_str().encode(), value.as_str())?; - self.inner.flush_async().await?; - Ok(()) - } - - async fn remove_value(&self, key: &str) -> Result<()> { - self.values.remove(key.encode())?; - Ok(()) - } - - async fn get_value(&self, key: &str) -> Result> { - Ok(self - .values - .get(key.encode())? - .map(|v| String::from_utf8_lossy(&v).to_string())) - } - - async fn load_identity(&self) -> Result> { - if let Some(i) = self.private_identity.get("identity".encode())? { - let pickle = serde_json::from_slice(&i)?; - Ok(Some( - PrivateCrossSigningIdentity::from_pickle(pickle, self.get_pickle_key()) - .await - .map_err(|_| CryptoStoreError::UnpicklingError)?, - )) - } else { - Ok(None) - } - } - async fn is_message_known(&self, message_hash: &crate::olm::OlmMessageHash) -> Result { Ok(self .olm_hashes .contains_key(serde_json::to_vec(message_hash)?)?) } - async fn get_outbound_group_sessions( + async fn get_outgoing_key_request( &self, - room_id: &RoomId, - ) -> Result> { - self.load_outbound_group_session(room_id).await + request_id: Uuid, + ) -> Result> { + Ok(self + .outgoing_key_requests + .get(request_id.encode())? + .map(|r| serde_json::from_slice(&r)) + .transpose()?) + } + + async fn get_key_request_by_info( + &self, + key_info: &RequestedKeyInfo, + ) -> Result> { + let id = self.key_requests_by_info.get(key_info.encode())?; + + if let Some(id) = id { + Ok(self + .outgoing_key_requests + .get(id)? + .map(|r| serde_json::from_slice(&r)) + .transpose()?) + } else { + Ok(None) + } + } + + async fn delete_outgoing_key_request(&self, request_id: Uuid) -> Result<()> { + let ret: Result<(), TransactionError> = + (&self.outgoing_key_requests, &self.key_requests_by_info).transaction( + |(outgoing_key_requests, key_requests_by_info)| { + let request: Option = outgoing_key_requests + .remove(request_id.encode())? + .map(|r| serde_json::from_slice(&r)) + .transpose() + .map_err(ConflictableTransactionError::Abort)?; + + if let Some(request) = request { + key_requests_by_info.remove((&request.info).encode())?; + } + + Ok(()) + }, + ); + + ret?; + self.inner.flush_async().await?; + + Ok(()) } } @@ -665,14 +749,16 @@ mod test { }; use matrix_sdk_common::{ api::r0::keys::SignedKey, - identifiers::{room_id, user_id, DeviceId, UserId}, + events::room_key_request::RequestedKeyInfo, + identifiers::{room_id, user_id, DeviceId, EventEncryptionAlgorithm, UserId}, + uuid::Uuid, }; use matrix_sdk_test::async_test; use olm_rs::outbound_group_session::OlmOutboundGroupSession; use std::collections::BTreeMap; use tempfile::tempdir; - use super::{CryptoStore, SledStore}; + use super::{CryptoStore, OutgoingKeyRequest, SledStore}; fn alice_id() -> UserId { user_id!("@alice:example.org") @@ -1184,21 +1270,6 @@ mod test { assert_eq!(identity.user_id(), loaded_identity.user_id()); } - #[async_test] - async fn key_value_saving() { - let (_, store, _dir) = get_loaded_store().await; - let key = "test_key".to_string(); - let value = "secret value".to_string(); - - store.save_value(key.clone(), value.clone()).await.unwrap(); - let stored_value = store.get_value(&key).await.unwrap().unwrap(); - - assert_eq!(value, stored_value); - - store.remove_value(&key).await.unwrap(); - assert!(store.get_value(&key).await.unwrap().is_none()); - } - #[async_test] async fn olm_hash_saving() { let (_, store, _dir) = get_loaded_store().await; @@ -1215,4 +1286,45 @@ mod test { store.save_changes(changes).await.unwrap(); assert!(store.is_message_known(&hash).await.unwrap()); } + + #[async_test] + async fn key_request_saving() { + let (_, store, _dir) = get_loaded_store().await; + + let id = Uuid::new_v4(); + let info = RequestedKeyInfo { + algorithm: EventEncryptionAlgorithm::MegolmV1AesSha2, + room_id: room_id!("!test:localhost"), + sender_key: "test_sender_key".to_string(), + session_id: "test_session_id".to_string(), + }; + + let request = OutgoingKeyRequest { + request_id: id, + info: info.clone(), + sent_out: false, + }; + + assert!(store.get_outgoing_key_request(id).await.unwrap().is_none()); + + let mut changes = Changes::default(); + changes.key_requests.push(request.clone()); + store.save_changes(changes).await.unwrap(); + + let request = Some(request); + + let stored_request = store.get_outgoing_key_request(id).await.unwrap(); + assert_eq!(request, stored_request); + + let stored_request = store.get_key_request_by_info(&info).await.unwrap(); + assert_eq!(request, stored_request); + + store.delete_outgoing_key_request(id).await.unwrap(); + + let stored_request = store.get_outgoing_key_request(id).await.unwrap(); + assert_eq!(None, stored_request); + + let stored_request = store.get_key_request_by_info(&info).await.unwrap(); + assert_eq!(None, stored_request); + } }