From 5192feb836150462ccd69044cef63f2389bc65e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Damir=20Jeli=C4=87?= Date: Fri, 30 Jul 2021 11:27:49 +0200 Subject: [PATCH] crypto: Add support to request secrets --- matrix_sdk_crypto/src/key_request.rs | 154 ++++++++++++++----- matrix_sdk_crypto/src/olm/mod.rs | 2 +- matrix_sdk_crypto/src/olm/signing/mod.rs | 24 +++ matrix_sdk_crypto/src/store/memorystore.rs | 27 ++-- matrix_sdk_crypto/src/store/mod.rs | 32 ++-- matrix_sdk_crypto/src/store/sled.rs | 166 ++++++++++++--------- 6 files changed, 279 insertions(+), 126 deletions(-) diff --git a/matrix_sdk_crypto/src/key_request.rs b/matrix_sdk_crypto/src/key_request.rs index 53065dbc..827a8ffc 100644 --- a/matrix_sdk_crypto/src/key_request.rs +++ b/matrix_sdk_crypto/src/key_request.rs @@ -29,7 +29,9 @@ use ruma::{ forwarded_room_key::ForwardedRoomKeyToDeviceEventContent, room_key_request::{Action, RequestedKeyInfo, RoomKeyRequestToDeviceEventContent}, secret::{ - request::{RequestAction, RequestToDeviceEventContent as SecretRequestEventContent}, + request::{ + RequestAction, RequestToDeviceEventContent as SecretRequestEventContent, SecretName, + }, send::SendToDeviceEventContent as SecretSendEventContent, }, AnyToDeviceEvent, AnyToDeviceEventContent, ToDeviceEvent, @@ -201,24 +203,56 @@ pub struct OutgoingKeyRequest { /// The unique id of the key request. pub request_id: Uuid, /// The info of the requested key. - pub info: RequestedKeyInfo, + pub info: SecretInfo, /// Has the request been sent out. pub sent_out: bool, } +/// An enum over the various secret request types we can have. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum SecretInfo { + // Info for the `m.room_key_request` variant + KeyRequest(RequestedKeyInfo), + // Info for the `m.secret.request` variant + SecretRequest(SecretName), +} + +impl From for SecretInfo { + fn from(i: RequestedKeyInfo) -> Self { + Self::KeyRequest(i) + } +} + +impl From for SecretInfo { + fn from(i: SecretName) -> Self { + Self::SecretRequest(i) + } +} + impl OutgoingKeyRequest { fn to_request(&self, own_device_id: &DeviceId) -> OutgoingRequest { - let content = RoomKeyRequestToDeviceEventContent::new( - Action::Request, - Some(self.info.clone()), - own_device_id.to_owned(), - self.request_id.to_string(), - ); + let content = match &self.info { + SecretInfo::KeyRequest(r) => { + AnyToDeviceEventContent::RoomKeyRequest(RoomKeyRequestToDeviceEventContent::new( + Action::Request, + Some(r.clone()), + own_device_id.to_owned(), + self.request_id.to_string(), + )) + } + SecretInfo::SecretRequest(s) => { + AnyToDeviceEventContent::SecretRequest(SecretRequestEventContent::new( + RequestAction::Request(s.clone()), + own_device_id.to_owned(), + self.request_id.to_string(), + )) + } + }; let request = ToDeviceRequest::new_with_id( &self.request_recipient, DeviceIdOrAllDevices::AllDevices, - AnyToDeviceEventContent::RoomKeyRequest(content), + content, self.request_id, ); @@ -226,17 +260,28 @@ impl OutgoingKeyRequest { } fn to_cancellation(&self, own_device_id: &DeviceId) -> OutgoingRequest { - let content = RoomKeyRequestToDeviceEventContent::new( - Action::CancelRequest, - None, - own_device_id.to_owned(), - self.request_id.to_string(), - ); + let content = match self.info { + SecretInfo::KeyRequest(_) => { + AnyToDeviceEventContent::RoomKeyRequest(RoomKeyRequestToDeviceEventContent::new( + Action::CancelRequest, + None, + own_device_id.to_owned(), + self.request_id.to_string(), + )) + } + SecretInfo::SecretRequest(_) => { + AnyToDeviceEventContent::SecretRequest(SecretRequestEventContent::new( + RequestAction::RequestCancellation, + own_device_id.to_owned(), + self.request_id.to_string(), + )) + } + }; let request = ToDeviceRequest::new( &self.request_recipient, DeviceIdOrAllDevices::AllDevices, - AnyToDeviceEventContent::RoomKeyRequest(content), + content, ); OutgoingRequest { request_id: request.txn_id, request: Arc::new(request.into()) } @@ -245,11 +290,20 @@ impl OutgoingKeyRequest { impl PartialEq for OutgoingKeyRequest { fn eq(&self, other: &Self) -> bool { - self.request_id == other.request_id - && self.info.algorithm == other.info.algorithm - && self.info.room_id == other.info.room_id - && self.info.session_id == other.info.session_id - && self.info.sender_key == other.info.sender_key + let is_info_equal = match (&self.info, &other.info) { + (SecretInfo::KeyRequest(first), SecretInfo::KeyRequest(second)) => { + first.algorithm == second.algorithm + && first.room_id == second.room_id + && first.session_id == second.session_id + } + (SecretInfo::SecretRequest(first), SecretInfo::SecretRequest(second)) => { + first == second + } + (SecretInfo::KeyRequest(_), SecretInfo::SecretRequest(_)) + | (SecretInfo::SecretRequest(_), SecretInfo::KeyRequest(_)) => false, + }; + + self.request_id == other.request_id && is_info_equal } } @@ -277,7 +331,7 @@ impl KeyRequestMachine { async fn load_outgoing_requests(&self) -> Result, CryptoStoreError> { Ok(self .store - .get_unsent_key_requests() + .get_unsent_secret_requests() .await? .into_iter() .filter(|i| !i.sent_out) @@ -675,11 +729,8 @@ impl KeyRequestMachine { /// /// * `key_info` - The info of our key request containing information about /// the key we wish to request. - async fn should_request_key( - &self, - key_info: &RequestedKeyInfo, - ) -> Result { - let request = self.store.get_key_request_by_info(key_info).await?; + async fn should_request_key(&self, key_info: &SecretInfo) -> Result { + let request = self.store.get_secret_request_by_info(key_info).await?; // Don't send out duplicate requests, users can re-request them if they // think a second request might succeed. @@ -728,9 +779,10 @@ impl KeyRequestMachine { room_id.to_owned(), sender_key.to_owned(), session_id.to_owned(), - ); + ) + .into(); - let request = self.store.get_key_request_by_info(&key_info).await?; + let request = self.store.get_secret_request_by_info(&key_info).await?; if let Some(request) = request { let cancel = request.to_cancellation(self.device_id()); @@ -744,9 +796,39 @@ impl KeyRequestMachine { } } + #[allow(dead_code)] + pub async fn request_missing_secrets(&self) -> Result, CryptoStoreError> { + let secret_names = self.store.get_missing_secrets().await; + + Ok(if secret_names.is_empty() { + info!(secret_names =? secret_names, "Creating new outgoing secret requests"); + + let requests: Vec = secret_names + .into_iter() + .map(|n| OutgoingKeyRequest { + request_recipient: self.user_id().to_owned(), + request_id: Uuid::new_v4(), + info: n.into(), + sent_out: false, + }) + .collect(); + + let outgoing_requests = + requests.iter().map(|r| r.to_request(self.device_id())).collect(); + + let changes = Changes { key_requests: requests, ..Default::default() }; + self.store.save_changes(changes).await?; + + outgoing_requests + } else { + trace!("No secrets are missing from our store, not requesting them"); + vec![] + }) + } + async fn request_key_helper( &self, - key_info: RequestedKeyInfo, + key_info: SecretInfo, ) -> Result { info!("Creating new outgoing room key request {:#?}", key_info); @@ -788,7 +870,8 @@ impl KeyRequestMachine { room_id.to_owned(), sender_key.to_owned(), session_id.to_owned(), - ); + ) + .into(); if self.should_request_key(&key_info).await? { self.request_key_helper(key_info).await?; @@ -819,19 +902,20 @@ impl KeyRequestMachine { content.room_id.clone(), content.sender_key.clone(), content.session_id.clone(), - ); + ) + .into(); - self.store.get_key_request_by_info(&info).await + self.store.get_secret_request_by_info(&info).await } /// Delete the given outgoing key info. async fn delete_key_info(&self, info: &OutgoingKeyRequest) -> Result<(), CryptoStoreError> { - self.store.delete_outgoing_key_request(info.request_id).await + self.store.delete_outgoing_secret_requests(info.request_id).await } /// Mark the outgoing request as sent. pub async fn mark_outgoing_request_as_sent(&self, id: Uuid) -> Result<(), CryptoStoreError> { - let info = self.store.get_outgoing_key_request(id).await?; + let info = self.store.get_outgoing_secret_requests(id).await?; if let Some(mut info) = info { trace!("Marking outgoing key request as sent {:#?}", info); diff --git a/matrix_sdk_crypto/src/olm/mod.rs b/matrix_sdk_crypto/src/olm/mod.rs index 9abe2620..799f58ff 100644 --- a/matrix_sdk_crypto/src/olm/mod.rs +++ b/matrix_sdk_crypto/src/olm/mod.rs @@ -267,7 +267,7 @@ pub(crate) mod test { }) .to_string(); - let event: AnySyncRoomEvent = serde_json::from_str(&event).expect("WHAAAT?!?!?"); + let event: AnySyncRoomEvent = serde_json::from_str(&event).unwrap(); let event = if let AnySyncRoomEvent::Message(AnySyncMessageEvent::RoomEncrypted(event)) = event { diff --git a/matrix_sdk_crypto/src/olm/signing/mod.rs b/matrix_sdk_crypto/src/olm/signing/mod.rs index 80823faf..579f0a73 100644 --- a/matrix_sdk_crypto/src/olm/signing/mod.rs +++ b/matrix_sdk_crypto/src/olm/signing/mod.rs @@ -97,6 +97,11 @@ impl PrivateCrossSigningIdentity { self.self_signing_key.lock().await.is_some() } + /// Can we sign other users, i.e. do we have a user signing key. + pub async fn can_sign_users(&self) -> bool { + self.user_signing_key.lock().await.is_some() + } + /// Do we have the master key. pub async fn has_master_key(&self) -> bool { self.master_key.lock().await.is_some() @@ -130,6 +135,25 @@ impl PrivateCrossSigningIdentity { } } + /// Get the names of the secrets we are missing. + pub(crate) async fn get_missing_secrets(&self) -> Vec { + let mut missing = Vec::new(); + + if !self.has_master_key().await { + missing.push(SecretName::CrossSigningMasterKey); + } + + if !self.can_sign_devices().await { + missing.push(SecretName::CrossSigningSelfSigningKey); + } + + if !self.can_sign_users().await { + missing.push(SecretName::CrossSigningUserSigningKey); + } + + missing + } + /// Create a new empty identity. pub(crate) fn empty(user_id: UserId) -> Self { Self { diff --git a/matrix_sdk_crypto/src/store/memorystore.rs b/matrix_sdk_crypto/src/store/memorystore.rs index ef52fd44..549f5f1c 100644 --- a/matrix_sdk_crypto/src/store/memorystore.rs +++ b/matrix_sdk_crypto/src/store/memorystore.rs @@ -19,7 +19,7 @@ use std::{ use dashmap::{DashMap, DashSet}; use matrix_sdk_common::{async_trait, locks::Mutex, uuid::Uuid}; -use ruma::{events::room_key_request::RequestedKeyInfo, DeviceId, DeviceIdBox, RoomId, UserId}; +use ruma::{DeviceId, DeviceIdBox, RoomId, UserId}; use super::{ caches::{DeviceStore, GroupSessionStore, SessionStore}, @@ -27,12 +27,21 @@ use super::{ }; use crate::{ identities::{ReadOnlyDevice, ReadOnlyUserIdentities}, - key_request::OutgoingKeyRequest, + key_request::{OutgoingKeyRequest, SecretInfo}, olm::{OutboundGroupSession, PrivateCrossSigningIdentity}, }; -fn encode_key_info(info: &RequestedKeyInfo) -> String { - format!("{}{}{}{}", info.room_id, info.sender_key, info.algorithm, info.session_id) +fn encode_key_info(info: &SecretInfo) -> String { + match info { + SecretInfo::KeyRequest(info) => { + format!("{}{}{}{}", info.room_id, info.sender_key, info.algorithm, info.session_id) + } + SecretInfo::SecretRequest(i) => { + // TODO don't use serde here, use `as_ref()` when it becomes + // available + serde_json::to_string(i).expect("Can't serialize secret name") + } + } } /// An in-memory only store that will forget all the E2EE key once it's dropped. @@ -228,16 +237,16 @@ impl CryptoStore for MemoryStore { .contains(&message_hash.hash)) } - async fn get_outgoing_key_request( + async fn get_outgoing_secret_requests( &self, request_id: Uuid, ) -> Result> { Ok(self.outgoing_key_requests.get(&request_id).map(|r| r.clone())) } - async fn get_key_request_by_info( + async fn get_secret_request_by_info( &self, - key_info: &RequestedKeyInfo, + key_info: &SecretInfo, ) -> Result> { let key_info_string = encode_key_info(key_info); @@ -247,7 +256,7 @@ impl CryptoStore for MemoryStore { .and_then(|i| self.outgoing_key_requests.get(&i).map(|r| r.clone()))) } - async fn get_unsent_key_requests(&self) -> Result> { + async fn get_unsent_secret_requests(&self) -> Result> { Ok(self .outgoing_key_requests .iter() @@ -256,7 +265,7 @@ impl CryptoStore for MemoryStore { .collect()) } - async fn delete_outgoing_key_request(&self, request_id: Uuid) -> Result<()> { + async fn delete_outgoing_secret_requests(&self, request_id: Uuid) -> Result<()> { self.outgoing_key_requests.remove(&request_id).and_then(|(_, i)| { let key_info_string = encode_key_info(&i.info); self.key_requests_by_info.remove(&key_info_string) diff --git a/matrix_sdk_crypto/src/store/mod.rs b/matrix_sdk_crypto/src/store/mod.rs index 0df3d983..56175f6e 100644 --- a/matrix_sdk_crypto/src/store/mod.rs +++ b/matrix_sdk_crypto/src/store/mod.rs @@ -56,9 +56,8 @@ pub use memorystore::MemoryStore; use olm_rs::errors::{OlmAccountError, OlmGroupSessionError, OlmSessionError}; pub use pickle_key::{EncryptedPickleKey, PickleKey}; use ruma::{ - events::{room_key_request::RequestedKeyInfo, secret::request::SecretName}, - identifiers::Error as IdentifierValidationError, - DeviceId, DeviceIdBox, DeviceKeyAlgorithm, RoomId, UserId, + events::secret::request::SecretName, identifiers::Error as IdentifierValidationError, DeviceId, + DeviceIdBox, DeviceKeyAlgorithm, RoomId, UserId, }; use serde_json::Error as SerdeError; use thiserror::Error; @@ -72,7 +71,7 @@ use crate::{ user::{OwnUserIdentity, UserIdentities, UserIdentity}, Device, ReadOnlyDevice, ReadOnlyUserIdentities, UserDevices, }, - key_request::OutgoingKeyRequest, + key_request::{OutgoingKeyRequest, SecretInfo}, olm::{ InboundGroupSession, OlmMessageHash, OutboundGroupSession, PrivateCrossSigningIdentity, ReadOnlyAccount, Session, @@ -273,6 +272,11 @@ impl Store { } } } + + pub async fn get_missing_secrets(&self) -> Vec { + // TODO add the backup key to our missing secrets + self.identity.lock().await.get_missing_secrets().await + } } impl Deref for Store { @@ -440,14 +444,14 @@ pub trait CryptoStore: AsyncTraitDeps { /// Check if a hash for an Olm message stored in the database. async fn is_message_known(&self, message_hash: &OlmMessageHash) -> Result; - /// Get an outgoing key request that we created that matches the given + /// Get an outgoing secret request that we created that matches the given /// request id. /// /// # Arguments /// - /// * `request_id` - The unique request id that identifies this outgoing key - /// request. - async fn get_outgoing_key_request( + /// * `request_id` - The unique request id that identifies this outgoing + /// secret request. + async fn get_outgoing_secret_requests( &self, request_id: Uuid, ) -> Result>; @@ -457,14 +461,14 @@ pub trait CryptoStore: AsyncTraitDeps { /// /// # Arguments /// - /// * `key_info` - The key info of an outgoing key request. - async fn get_key_request_by_info( + /// * `key_info` - The key info of an outgoing secret request. + async fn get_secret_request_by_info( &self, - key_info: &RequestedKeyInfo, + secret_info: &SecretInfo, ) -> Result>; - /// Get all outgoing key requests that we have in the store. - async fn get_unsent_key_requests(&self) -> Result>; + /// Get all outgoing secret requests that we have in the store. + async fn get_unsent_secret_requests(&self) -> Result>; /// Delete an outgoing key request that we created that matches the given /// request id. @@ -473,5 +477,5 @@ pub trait CryptoStore: AsyncTraitDeps { /// /// * `request_id` - The unique request id that identifies this outgoing key /// request. - async fn delete_outgoing_key_request(&self, request_id: Uuid) -> Result<()>; + async fn delete_outgoing_secret_requests(&self, request_id: Uuid) -> Result<()>; } diff --git a/matrix_sdk_crypto/src/store/sled.rs b/matrix_sdk_crypto/src/store/sled.rs index 9ed76e07..b8e82b4a 100644 --- a/matrix_sdk_crypto/src/store/sled.rs +++ b/matrix_sdk_crypto/src/store/sled.rs @@ -22,7 +22,10 @@ use std::{ use dashmap::DashSet; use matrix_sdk_common::{async_trait, locks::Mutex, uuid}; use olm_rs::{account::IdentityKeys, PicklingMode}; -use ruma::{events::room_key_request::RequestedKeyInfo, DeviceId, DeviceIdBox, RoomId, UserId}; +use ruma::{ + events::{room_key_request::RequestedKeyInfo, secret::request::SecretName}, + DeviceId, DeviceIdBox, RoomId, UserId, +}; pub use sled::Error; use sled::{ transaction::{ConflictableTransactionError, TransactionError}, @@ -36,7 +39,7 @@ use super::{ }; use crate::{ identities::{ReadOnlyDevice, ReadOnlyUserIdentities}, - key_request::OutgoingKeyRequest, + key_request::{OutgoingKeyRequest, SecretInfo}, olm::{OutboundGroupSession, PickledInboundGroupSession, PrivateCrossSigningIdentity}, }; @@ -55,6 +58,27 @@ impl EncodeKey for Uuid { } } +impl EncodeKey for SecretName { + fn encode(&self) -> Vec { + [ + // TODO don't use serde here, use `as_ref()` when it becomes + // available + serde_json::to_string(self).expect("Can't serialize secret name").as_bytes(), + &[Self::SEPARATOR], + ] + .concat() + } +} + +impl EncodeKey for SecretInfo { + fn encode(&self) -> Vec { + match self { + SecretInfo::KeyRequest(k) => k.encode(), + SecretInfo::SecretRequest(s) => s.encode(), + } + } +} + impl EncodeKey for &RequestedKeyInfo { fn encode(&self) -> Vec { [ @@ -136,9 +160,9 @@ pub struct SledStore { inbound_group_sessions: Tree, outbound_group_sessions: Tree, - outgoing_key_requests: Tree, - unsent_key_requests: Tree, - key_requests_by_info: Tree, + outgoing_secret_requests: Tree, + unsent_secret_requests: Tree, + secret_requests_by_info: Tree, devices: Tree, identities: Tree, @@ -201,9 +225,9 @@ impl SledStore { let devices = db.open_tree("devices")?; let identities = db.open_tree("identities")?; - let outgoing_key_requests = db.open_tree("outgoing_key_requests")?; - let unsent_key_requests = db.open_tree("unsent_key_requests")?; - let key_requests_by_info = db.open_tree("key_requests_by_info")?; + let outgoing_secret_requests = db.open_tree("outgoing_secret_requests")?; + let unsent_secret_requests = db.open_tree("unsent_secret_requests")?; + let secret_requests_by_info = db.open_tree("secret_requests_by_info")?; let session_cache = SessionStore::new(); @@ -227,9 +251,9 @@ impl SledStore { users_for_key_query_cache: DashSet::new().into(), inbound_group_sessions, outbound_group_sessions, - outgoing_key_requests, - unsent_key_requests, - key_requests_by_info, + outgoing_secret_requests, + unsent_secret_requests, + secret_requests_by_info, devices, tracked_users, users_for_key_query, @@ -361,9 +385,9 @@ impl SledStore { &self.inbound_group_sessions, &self.outbound_group_sessions, &self.olm_hashes, - &self.outgoing_key_requests, - &self.unsent_key_requests, - &self.key_requests_by_info, + &self.outgoing_secret_requests, + &self.unsent_secret_requests, + &self.secret_requests_by_info, ) .transaction( |( @@ -375,9 +399,9 @@ impl SledStore { inbound_sessions, outbound_sessions, hashes, - outgoing_key_requests, - unsent_key_requests, - key_requests_by_info, + outgoing_secret_requests, + unsent_secret_requests, + secret_requests_by_info, )| { if let Some(a) = &account_pickle { account.insert( @@ -446,7 +470,7 @@ impl SledStore { } for key_request in &key_requests { - key_requests_by_info.insert( + secret_requests_by_info.insert( (&key_request.info).encode(), key_request.request_id.encode(), )?; @@ -454,15 +478,15 @@ impl SledStore { let key_request_id = key_request.request_id.encode(); if key_request.sent_out { - unsent_key_requests.remove(key_request_id.clone())?; - outgoing_key_requests.insert( + unsent_secret_requests.remove(key_request_id.clone())?; + outgoing_secret_requests.insert( key_request_id, serde_json::to_vec(&key_request) .map_err(ConflictableTransactionError::Abort)?, )?; } else { - outgoing_key_requests.remove(key_request_id.clone())?; - unsent_key_requests.insert( + outgoing_secret_requests.remove(key_request_id.clone())?; + unsent_secret_requests.insert( key_request_id, serde_json::to_vec(&key_request) .map_err(ConflictableTransactionError::Abort)?, @@ -484,11 +508,14 @@ impl SledStore { &self, id: &[u8], ) -> Result> { - let request = - self.outgoing_key_requests.get(id)?.map(|r| serde_json::from_slice(&r)).transpose()?; + let request = self + .outgoing_secret_requests + .get(id)? + .map(|r| serde_json::from_slice(&r)) + .transpose()?; let request = if request.is_none() { - self.unsent_key_requests.get(id)?.map(|r| serde_json::from_slice(&r)).transpose()? + self.unsent_secret_requests.get(id)?.map(|r| serde_json::from_slice(&r)).transpose()? } else { request }; @@ -681,7 +708,7 @@ impl CryptoStore for SledStore { Ok(self.olm_hashes.contains_key(serde_json::to_vec(message_hash)?)?) } - async fn get_outgoing_key_request( + async fn get_outgoing_secret_requests( &self, request_id: Uuid, ) -> Result> { @@ -690,11 +717,11 @@ impl CryptoStore for SledStore { self.get_outgoing_key_request_helper(&request_id).await } - async fn get_key_request_by_info( + async fn get_secret_request_by_info( &self, - key_info: &RequestedKeyInfo, + key_info: &SecretInfo, ) -> Result> { - let id = self.key_requests_by_info.get(key_info.encode())?; + let id = self.secret_requests_by_info.get(key_info.encode())?; if let Some(id) = id { self.get_outgoing_key_request_helper(&id).await @@ -703,9 +730,9 @@ impl CryptoStore for SledStore { } } - async fn get_unsent_key_requests(&self) -> Result> { + async fn get_unsent_secret_requests(&self) -> Result> { let requests: Result> = self - .unsent_key_requests + .unsent_secret_requests .iter() .map(|i| serde_json::from_slice(&i?.1).map_err(CryptoStoreError::from)) .collect(); @@ -713,34 +740,37 @@ impl CryptoStore for SledStore { requests } - async fn delete_outgoing_key_request(&self, request_id: Uuid) -> Result<()> { - let ret: Result<(), TransactionError> = - (&self.outgoing_key_requests, &self.unsent_key_requests, &self.key_requests_by_info) - .transaction( - |(outgoing_key_requests, unsent_key_requests, key_requests_by_info)| { - let sent_request: Option = outgoing_key_requests - .remove(request_id.encode())? - .map(|r| serde_json::from_slice(&r)) - .transpose() - .map_err(ConflictableTransactionError::Abort)?; + async fn delete_outgoing_secret_requests(&self, request_id: Uuid) -> Result<()> { + let ret: Result<(), TransactionError> = ( + &self.outgoing_secret_requests, + &self.unsent_secret_requests, + &self.secret_requests_by_info, + ) + .transaction( + |(outgoing_key_requests, unsent_key_requests, key_requests_by_info)| { + let sent_request: Option = outgoing_key_requests + .remove(request_id.encode())? + .map(|r| serde_json::from_slice(&r)) + .transpose() + .map_err(ConflictableTransactionError::Abort)?; - let unsent_request: Option = unsent_key_requests - .remove(request_id.encode())? - .map(|r| serde_json::from_slice(&r)) - .transpose() - .map_err(ConflictableTransactionError::Abort)?; + let unsent_request: Option = unsent_key_requests + .remove(request_id.encode())? + .map(|r| serde_json::from_slice(&r)) + .transpose() + .map_err(ConflictableTransactionError::Abort)?; - if let Some(request) = sent_request { - key_requests_by_info.remove((&request.info).encode())?; - } + if let Some(request) = sent_request { + key_requests_by_info.remove((&request.info).encode())?; + } - if let Some(request) = unsent_request { - key_requests_by_info.remove((&request.info).encode())?; - } + if let Some(request) = unsent_request { + key_requests_by_info.remove((&request.info).encode())?; + } - Ok(()) - }, - ); + Ok(()) + }, + ); ret?; self.inner.flush_async().await?; @@ -768,6 +798,7 @@ mod test { device::test::get_device, user::test::{get_other_identity, get_own_identity}, }, + key_request::SecretInfo, olm::{ GroupSessionKey, InboundGroupSession, OlmMessageHash, PrivateCrossSigningIdentity, ReadOnlyAccount, Session, @@ -1199,12 +1230,13 @@ mod test { let (account, store, _dir) = get_loaded_store().await; let id = Uuid::new_v4(); - let info = RequestedKeyInfo::new( + let info: SecretInfo = RequestedKeyInfo::new( EventEncryptionAlgorithm::MegolmV1AesSha2, room_id!("!test:localhost"), "test_sender_key".to_string(), "test_session_id".to_string(), - ); + ) + .into(); let request = OutgoingKeyRequest { request_recipient: account.user_id().to_owned(), @@ -1213,7 +1245,7 @@ mod test { sent_out: false, }; - assert!(store.get_outgoing_key_request(id).await.unwrap().is_none()); + assert!(store.get_outgoing_secret_requests(id).await.unwrap().is_none()); let mut changes = Changes::default(); changes.key_requests.push(request.clone()); @@ -1221,12 +1253,12 @@ mod test { let request = Some(request); - let stored_request = store.get_outgoing_key_request(id).await.unwrap(); + let stored_request = store.get_outgoing_secret_requests(id).await.unwrap(); assert_eq!(request, stored_request); - let stored_request = store.get_key_request_by_info(&info).await.unwrap(); + let stored_request = store.get_secret_request_by_info(&info).await.unwrap(); assert_eq!(request, stored_request); - assert!(!store.get_unsent_key_requests().await.unwrap().is_empty()); + assert!(!store.get_unsent_secret_requests().await.unwrap().is_empty()); let request = OutgoingKeyRequest { request_recipient: account.user_id().to_owned(), @@ -1239,17 +1271,17 @@ mod test { changes.key_requests.push(request.clone()); store.save_changes(changes).await.unwrap(); - assert!(store.get_unsent_key_requests().await.unwrap().is_empty()); - let stored_request = store.get_outgoing_key_request(id).await.unwrap(); + assert!(store.get_unsent_secret_requests().await.unwrap().is_empty()); + let stored_request = store.get_outgoing_secret_requests(id).await.unwrap(); assert_eq!(Some(request), stored_request); - store.delete_outgoing_key_request(id).await.unwrap(); + store.delete_outgoing_secret_requests(id).await.unwrap(); - let stored_request = store.get_outgoing_key_request(id).await.unwrap(); + let stored_request = store.get_outgoing_secret_requests(id).await.unwrap(); assert_eq!(None, stored_request); - let stored_request = store.get_key_request_by_info(&info).await.unwrap(); + let stored_request = store.get_secret_request_by_info(&info).await.unwrap(); assert_eq!(None, stored_request); - assert!(store.get_unsent_key_requests().await.unwrap().is_empty()); + assert!(store.get_unsent_secret_requests().await.unwrap().is_empty()); } }