crypto: Add support to request secrets

This commit is contained in:
Damir Jelić 2021-07-30 11:27:49 +02:00
parent a916288d03
commit 5192feb836
6 changed files with 279 additions and 126 deletions

View file

@ -29,7 +29,9 @@ use ruma::{
forwarded_room_key::ForwardedRoomKeyToDeviceEventContent, forwarded_room_key::ForwardedRoomKeyToDeviceEventContent,
room_key_request::{Action, RequestedKeyInfo, RoomKeyRequestToDeviceEventContent}, room_key_request::{Action, RequestedKeyInfo, RoomKeyRequestToDeviceEventContent},
secret::{ secret::{
request::{RequestAction, RequestToDeviceEventContent as SecretRequestEventContent}, request::{
RequestAction, RequestToDeviceEventContent as SecretRequestEventContent, SecretName,
},
send::SendToDeviceEventContent as SecretSendEventContent, send::SendToDeviceEventContent as SecretSendEventContent,
}, },
AnyToDeviceEvent, AnyToDeviceEventContent, ToDeviceEvent, AnyToDeviceEvent, AnyToDeviceEventContent, ToDeviceEvent,
@ -201,24 +203,56 @@ pub struct OutgoingKeyRequest {
/// The unique id of the key request. /// The unique id of the key request.
pub request_id: Uuid, pub request_id: Uuid,
/// The info of the requested key. /// The info of the requested key.
pub info: RequestedKeyInfo, pub info: SecretInfo,
/// Has the request been sent out. /// Has the request been sent out.
pub sent_out: bool, pub sent_out: bool,
} }
/// An enum over the various secret request types we can have.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum SecretInfo {
// Info for the `m.room_key_request` variant
KeyRequest(RequestedKeyInfo),
// Info for the `m.secret.request` variant
SecretRequest(SecretName),
}
impl From<RequestedKeyInfo> for SecretInfo {
fn from(i: RequestedKeyInfo) -> Self {
Self::KeyRequest(i)
}
}
impl From<SecretName> for SecretInfo {
fn from(i: SecretName) -> Self {
Self::SecretRequest(i)
}
}
impl OutgoingKeyRequest { impl OutgoingKeyRequest {
fn to_request(&self, own_device_id: &DeviceId) -> OutgoingRequest { fn to_request(&self, own_device_id: &DeviceId) -> OutgoingRequest {
let content = RoomKeyRequestToDeviceEventContent::new( let content = match &self.info {
Action::Request, SecretInfo::KeyRequest(r) => {
Some(self.info.clone()), AnyToDeviceEventContent::RoomKeyRequest(RoomKeyRequestToDeviceEventContent::new(
own_device_id.to_owned(), Action::Request,
self.request_id.to_string(), Some(r.clone()),
); own_device_id.to_owned(),
self.request_id.to_string(),
))
}
SecretInfo::SecretRequest(s) => {
AnyToDeviceEventContent::SecretRequest(SecretRequestEventContent::new(
RequestAction::Request(s.clone()),
own_device_id.to_owned(),
self.request_id.to_string(),
))
}
};
let request = ToDeviceRequest::new_with_id( let request = ToDeviceRequest::new_with_id(
&self.request_recipient, &self.request_recipient,
DeviceIdOrAllDevices::AllDevices, DeviceIdOrAllDevices::AllDevices,
AnyToDeviceEventContent::RoomKeyRequest(content), content,
self.request_id, self.request_id,
); );
@ -226,17 +260,28 @@ impl OutgoingKeyRequest {
} }
fn to_cancellation(&self, own_device_id: &DeviceId) -> OutgoingRequest { fn to_cancellation(&self, own_device_id: &DeviceId) -> OutgoingRequest {
let content = RoomKeyRequestToDeviceEventContent::new( let content = match self.info {
Action::CancelRequest, SecretInfo::KeyRequest(_) => {
None, AnyToDeviceEventContent::RoomKeyRequest(RoomKeyRequestToDeviceEventContent::new(
own_device_id.to_owned(), Action::CancelRequest,
self.request_id.to_string(), None,
); own_device_id.to_owned(),
self.request_id.to_string(),
))
}
SecretInfo::SecretRequest(_) => {
AnyToDeviceEventContent::SecretRequest(SecretRequestEventContent::new(
RequestAction::RequestCancellation,
own_device_id.to_owned(),
self.request_id.to_string(),
))
}
};
let request = ToDeviceRequest::new( let request = ToDeviceRequest::new(
&self.request_recipient, &self.request_recipient,
DeviceIdOrAllDevices::AllDevices, DeviceIdOrAllDevices::AllDevices,
AnyToDeviceEventContent::RoomKeyRequest(content), content,
); );
OutgoingRequest { request_id: request.txn_id, request: Arc::new(request.into()) } OutgoingRequest { request_id: request.txn_id, request: Arc::new(request.into()) }
@ -245,11 +290,20 @@ impl OutgoingKeyRequest {
impl PartialEq for OutgoingKeyRequest { impl PartialEq for OutgoingKeyRequest {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
self.request_id == other.request_id let is_info_equal = match (&self.info, &other.info) {
&& self.info.algorithm == other.info.algorithm (SecretInfo::KeyRequest(first), SecretInfo::KeyRequest(second)) => {
&& self.info.room_id == other.info.room_id first.algorithm == second.algorithm
&& self.info.session_id == other.info.session_id && first.room_id == second.room_id
&& self.info.sender_key == other.info.sender_key && first.session_id == second.session_id
}
(SecretInfo::SecretRequest(first), SecretInfo::SecretRequest(second)) => {
first == second
}
(SecretInfo::KeyRequest(_), SecretInfo::SecretRequest(_))
| (SecretInfo::SecretRequest(_), SecretInfo::KeyRequest(_)) => false,
};
self.request_id == other.request_id && is_info_equal
} }
} }
@ -277,7 +331,7 @@ impl KeyRequestMachine {
async fn load_outgoing_requests(&self) -> Result<Vec<OutgoingRequest>, CryptoStoreError> { async fn load_outgoing_requests(&self) -> Result<Vec<OutgoingRequest>, CryptoStoreError> {
Ok(self Ok(self
.store .store
.get_unsent_key_requests() .get_unsent_secret_requests()
.await? .await?
.into_iter() .into_iter()
.filter(|i| !i.sent_out) .filter(|i| !i.sent_out)
@ -675,11 +729,8 @@ impl KeyRequestMachine {
/// ///
/// * `key_info` - The info of our key request containing information about /// * `key_info` - The info of our key request containing information about
/// the key we wish to request. /// the key we wish to request.
async fn should_request_key( async fn should_request_key(&self, key_info: &SecretInfo) -> Result<bool, CryptoStoreError> {
&self, let request = self.store.get_secret_request_by_info(key_info).await?;
key_info: &RequestedKeyInfo,
) -> Result<bool, CryptoStoreError> {
let request = self.store.get_key_request_by_info(key_info).await?;
// Don't send out duplicate requests, users can re-request them if they // Don't send out duplicate requests, users can re-request them if they
// think a second request might succeed. // think a second request might succeed.
@ -728,9 +779,10 @@ impl KeyRequestMachine {
room_id.to_owned(), room_id.to_owned(),
sender_key.to_owned(), sender_key.to_owned(),
session_id.to_owned(), session_id.to_owned(),
); )
.into();
let request = self.store.get_key_request_by_info(&key_info).await?; let request = self.store.get_secret_request_by_info(&key_info).await?;
if let Some(request) = request { if let Some(request) = request {
let cancel = request.to_cancellation(self.device_id()); let cancel = request.to_cancellation(self.device_id());
@ -744,9 +796,39 @@ impl KeyRequestMachine {
} }
} }
#[allow(dead_code)]
pub async fn request_missing_secrets(&self) -> Result<Vec<OutgoingRequest>, CryptoStoreError> {
let secret_names = self.store.get_missing_secrets().await;
Ok(if secret_names.is_empty() {
info!(secret_names =? secret_names, "Creating new outgoing secret requests");
let requests: Vec<OutgoingKeyRequest> = secret_names
.into_iter()
.map(|n| OutgoingKeyRequest {
request_recipient: self.user_id().to_owned(),
request_id: Uuid::new_v4(),
info: n.into(),
sent_out: false,
})
.collect();
let outgoing_requests =
requests.iter().map(|r| r.to_request(self.device_id())).collect();
let changes = Changes { key_requests: requests, ..Default::default() };
self.store.save_changes(changes).await?;
outgoing_requests
} else {
trace!("No secrets are missing from our store, not requesting them");
vec![]
})
}
async fn request_key_helper( async fn request_key_helper(
&self, &self,
key_info: RequestedKeyInfo, key_info: SecretInfo,
) -> Result<OutgoingRequest, CryptoStoreError> { ) -> Result<OutgoingRequest, CryptoStoreError> {
info!("Creating new outgoing room key request {:#?}", key_info); info!("Creating new outgoing room key request {:#?}", key_info);
@ -788,7 +870,8 @@ impl KeyRequestMachine {
room_id.to_owned(), room_id.to_owned(),
sender_key.to_owned(), sender_key.to_owned(),
session_id.to_owned(), session_id.to_owned(),
); )
.into();
if self.should_request_key(&key_info).await? { if self.should_request_key(&key_info).await? {
self.request_key_helper(key_info).await?; self.request_key_helper(key_info).await?;
@ -819,19 +902,20 @@ impl KeyRequestMachine {
content.room_id.clone(), content.room_id.clone(),
content.sender_key.clone(), content.sender_key.clone(),
content.session_id.clone(), content.session_id.clone(),
); )
.into();
self.store.get_key_request_by_info(&info).await self.store.get_secret_request_by_info(&info).await
} }
/// Delete the given outgoing key info. /// Delete the given outgoing key info.
async fn delete_key_info(&self, info: &OutgoingKeyRequest) -> Result<(), CryptoStoreError> { async fn delete_key_info(&self, info: &OutgoingKeyRequest) -> Result<(), CryptoStoreError> {
self.store.delete_outgoing_key_request(info.request_id).await self.store.delete_outgoing_secret_requests(info.request_id).await
} }
/// Mark the outgoing request as sent. /// Mark the outgoing request as sent.
pub async fn mark_outgoing_request_as_sent(&self, id: Uuid) -> Result<(), CryptoStoreError> { pub async fn mark_outgoing_request_as_sent(&self, id: Uuid) -> Result<(), CryptoStoreError> {
let info = self.store.get_outgoing_key_request(id).await?; let info = self.store.get_outgoing_secret_requests(id).await?;
if let Some(mut info) = info { if let Some(mut info) = info {
trace!("Marking outgoing key request as sent {:#?}", info); trace!("Marking outgoing key request as sent {:#?}", info);

View file

@ -267,7 +267,7 @@ pub(crate) mod test {
}) })
.to_string(); .to_string();
let event: AnySyncRoomEvent = serde_json::from_str(&event).expect("WHAAAT?!?!?"); let event: AnySyncRoomEvent = serde_json::from_str(&event).unwrap();
let event = let event =
if let AnySyncRoomEvent::Message(AnySyncMessageEvent::RoomEncrypted(event)) = event { if let AnySyncRoomEvent::Message(AnySyncMessageEvent::RoomEncrypted(event)) = event {

View file

@ -97,6 +97,11 @@ impl PrivateCrossSigningIdentity {
self.self_signing_key.lock().await.is_some() self.self_signing_key.lock().await.is_some()
} }
/// Can we sign other users, i.e. do we have a user signing key.
pub async fn can_sign_users(&self) -> bool {
self.user_signing_key.lock().await.is_some()
}
/// Do we have the master key. /// Do we have the master key.
pub async fn has_master_key(&self) -> bool { pub async fn has_master_key(&self) -> bool {
self.master_key.lock().await.is_some() self.master_key.lock().await.is_some()
@ -130,6 +135,25 @@ impl PrivateCrossSigningIdentity {
} }
} }
/// Get the names of the secrets we are missing.
pub(crate) async fn get_missing_secrets(&self) -> Vec<SecretName> {
let mut missing = Vec::new();
if !self.has_master_key().await {
missing.push(SecretName::CrossSigningMasterKey);
}
if !self.can_sign_devices().await {
missing.push(SecretName::CrossSigningSelfSigningKey);
}
if !self.can_sign_users().await {
missing.push(SecretName::CrossSigningUserSigningKey);
}
missing
}
/// Create a new empty identity. /// Create a new empty identity.
pub(crate) fn empty(user_id: UserId) -> Self { pub(crate) fn empty(user_id: UserId) -> Self {
Self { Self {

View file

@ -19,7 +19,7 @@ use std::{
use dashmap::{DashMap, DashSet}; use dashmap::{DashMap, DashSet};
use matrix_sdk_common::{async_trait, locks::Mutex, uuid::Uuid}; use matrix_sdk_common::{async_trait, locks::Mutex, uuid::Uuid};
use ruma::{events::room_key_request::RequestedKeyInfo, DeviceId, DeviceIdBox, RoomId, UserId}; use ruma::{DeviceId, DeviceIdBox, RoomId, UserId};
use super::{ use super::{
caches::{DeviceStore, GroupSessionStore, SessionStore}, caches::{DeviceStore, GroupSessionStore, SessionStore},
@ -27,12 +27,21 @@ use super::{
}; };
use crate::{ use crate::{
identities::{ReadOnlyDevice, ReadOnlyUserIdentities}, identities::{ReadOnlyDevice, ReadOnlyUserIdentities},
key_request::OutgoingKeyRequest, key_request::{OutgoingKeyRequest, SecretInfo},
olm::{OutboundGroupSession, PrivateCrossSigningIdentity}, olm::{OutboundGroupSession, PrivateCrossSigningIdentity},
}; };
fn encode_key_info(info: &RequestedKeyInfo) -> String { fn encode_key_info(info: &SecretInfo) -> String {
format!("{}{}{}{}", info.room_id, info.sender_key, info.algorithm, info.session_id) match info {
SecretInfo::KeyRequest(info) => {
format!("{}{}{}{}", info.room_id, info.sender_key, info.algorithm, info.session_id)
}
SecretInfo::SecretRequest(i) => {
// TODO don't use serde here, use `as_ref()` when it becomes
// available
serde_json::to_string(i).expect("Can't serialize secret name")
}
}
} }
/// An in-memory only store that will forget all the E2EE key once it's dropped. /// An in-memory only store that will forget all the E2EE key once it's dropped.
@ -228,16 +237,16 @@ impl CryptoStore for MemoryStore {
.contains(&message_hash.hash)) .contains(&message_hash.hash))
} }
async fn get_outgoing_key_request( async fn get_outgoing_secret_requests(
&self, &self,
request_id: Uuid, request_id: Uuid,
) -> Result<Option<OutgoingKeyRequest>> { ) -> Result<Option<OutgoingKeyRequest>> {
Ok(self.outgoing_key_requests.get(&request_id).map(|r| r.clone())) Ok(self.outgoing_key_requests.get(&request_id).map(|r| r.clone()))
} }
async fn get_key_request_by_info( async fn get_secret_request_by_info(
&self, &self,
key_info: &RequestedKeyInfo, key_info: &SecretInfo,
) -> Result<Option<OutgoingKeyRequest>> { ) -> Result<Option<OutgoingKeyRequest>> {
let key_info_string = encode_key_info(key_info); let key_info_string = encode_key_info(key_info);
@ -247,7 +256,7 @@ impl CryptoStore for MemoryStore {
.and_then(|i| self.outgoing_key_requests.get(&i).map(|r| r.clone()))) .and_then(|i| self.outgoing_key_requests.get(&i).map(|r| r.clone())))
} }
async fn get_unsent_key_requests(&self) -> Result<Vec<OutgoingKeyRequest>> { async fn get_unsent_secret_requests(&self) -> Result<Vec<OutgoingKeyRequest>> {
Ok(self Ok(self
.outgoing_key_requests .outgoing_key_requests
.iter() .iter()
@ -256,7 +265,7 @@ impl CryptoStore for MemoryStore {
.collect()) .collect())
} }
async fn delete_outgoing_key_request(&self, request_id: Uuid) -> Result<()> { async fn delete_outgoing_secret_requests(&self, request_id: Uuid) -> Result<()> {
self.outgoing_key_requests.remove(&request_id).and_then(|(_, i)| { self.outgoing_key_requests.remove(&request_id).and_then(|(_, i)| {
let key_info_string = encode_key_info(&i.info); let key_info_string = encode_key_info(&i.info);
self.key_requests_by_info.remove(&key_info_string) self.key_requests_by_info.remove(&key_info_string)

View file

@ -56,9 +56,8 @@ pub use memorystore::MemoryStore;
use olm_rs::errors::{OlmAccountError, OlmGroupSessionError, OlmSessionError}; use olm_rs::errors::{OlmAccountError, OlmGroupSessionError, OlmSessionError};
pub use pickle_key::{EncryptedPickleKey, PickleKey}; pub use pickle_key::{EncryptedPickleKey, PickleKey};
use ruma::{ use ruma::{
events::{room_key_request::RequestedKeyInfo, secret::request::SecretName}, events::secret::request::SecretName, identifiers::Error as IdentifierValidationError, DeviceId,
identifiers::Error as IdentifierValidationError, DeviceIdBox, DeviceKeyAlgorithm, RoomId, UserId,
DeviceId, DeviceIdBox, DeviceKeyAlgorithm, RoomId, UserId,
}; };
use serde_json::Error as SerdeError; use serde_json::Error as SerdeError;
use thiserror::Error; use thiserror::Error;
@ -72,7 +71,7 @@ use crate::{
user::{OwnUserIdentity, UserIdentities, UserIdentity}, user::{OwnUserIdentity, UserIdentities, UserIdentity},
Device, ReadOnlyDevice, ReadOnlyUserIdentities, UserDevices, Device, ReadOnlyDevice, ReadOnlyUserIdentities, UserDevices,
}, },
key_request::OutgoingKeyRequest, key_request::{OutgoingKeyRequest, SecretInfo},
olm::{ olm::{
InboundGroupSession, OlmMessageHash, OutboundGroupSession, PrivateCrossSigningIdentity, InboundGroupSession, OlmMessageHash, OutboundGroupSession, PrivateCrossSigningIdentity,
ReadOnlyAccount, Session, ReadOnlyAccount, Session,
@ -273,6 +272,11 @@ impl Store {
} }
} }
} }
pub async fn get_missing_secrets(&self) -> Vec<SecretName> {
// TODO add the backup key to our missing secrets
self.identity.lock().await.get_missing_secrets().await
}
} }
impl Deref for Store { impl Deref for Store {
@ -440,14 +444,14 @@ pub trait CryptoStore: AsyncTraitDeps {
/// Check if a hash for an Olm message stored in the database. /// Check if a hash for an Olm message stored in the database.
async fn is_message_known(&self, message_hash: &OlmMessageHash) -> Result<bool>; async fn is_message_known(&self, message_hash: &OlmMessageHash) -> Result<bool>;
/// Get an outgoing key request that we created that matches the given /// Get an outgoing secret request that we created that matches the given
/// request id. /// request id.
/// ///
/// # Arguments /// # Arguments
/// ///
/// * `request_id` - The unique request id that identifies this outgoing key /// * `request_id` - The unique request id that identifies this outgoing
/// request. /// secret request.
async fn get_outgoing_key_request( async fn get_outgoing_secret_requests(
&self, &self,
request_id: Uuid, request_id: Uuid,
) -> Result<Option<OutgoingKeyRequest>>; ) -> Result<Option<OutgoingKeyRequest>>;
@ -457,14 +461,14 @@ pub trait CryptoStore: AsyncTraitDeps {
/// ///
/// # Arguments /// # Arguments
/// ///
/// * `key_info` - The key info of an outgoing key request. /// * `key_info` - The key info of an outgoing secret request.
async fn get_key_request_by_info( async fn get_secret_request_by_info(
&self, &self,
key_info: &RequestedKeyInfo, secret_info: &SecretInfo,
) -> Result<Option<OutgoingKeyRequest>>; ) -> Result<Option<OutgoingKeyRequest>>;
/// Get all outgoing key requests that we have in the store. /// Get all outgoing secret requests that we have in the store.
async fn get_unsent_key_requests(&self) -> Result<Vec<OutgoingKeyRequest>>; async fn get_unsent_secret_requests(&self) -> Result<Vec<OutgoingKeyRequest>>;
/// Delete an outgoing key request that we created that matches the given /// Delete an outgoing key request that we created that matches the given
/// request id. /// request id.
@ -473,5 +477,5 @@ pub trait CryptoStore: AsyncTraitDeps {
/// ///
/// * `request_id` - The unique request id that identifies this outgoing key /// * `request_id` - The unique request id that identifies this outgoing key
/// request. /// request.
async fn delete_outgoing_key_request(&self, request_id: Uuid) -> Result<()>; async fn delete_outgoing_secret_requests(&self, request_id: Uuid) -> Result<()>;
} }

View file

@ -22,7 +22,10 @@ use std::{
use dashmap::DashSet; use dashmap::DashSet;
use matrix_sdk_common::{async_trait, locks::Mutex, uuid}; use matrix_sdk_common::{async_trait, locks::Mutex, uuid};
use olm_rs::{account::IdentityKeys, PicklingMode}; use olm_rs::{account::IdentityKeys, PicklingMode};
use ruma::{events::room_key_request::RequestedKeyInfo, DeviceId, DeviceIdBox, RoomId, UserId}; use ruma::{
events::{room_key_request::RequestedKeyInfo, secret::request::SecretName},
DeviceId, DeviceIdBox, RoomId, UserId,
};
pub use sled::Error; pub use sled::Error;
use sled::{ use sled::{
transaction::{ConflictableTransactionError, TransactionError}, transaction::{ConflictableTransactionError, TransactionError},
@ -36,7 +39,7 @@ use super::{
}; };
use crate::{ use crate::{
identities::{ReadOnlyDevice, ReadOnlyUserIdentities}, identities::{ReadOnlyDevice, ReadOnlyUserIdentities},
key_request::OutgoingKeyRequest, key_request::{OutgoingKeyRequest, SecretInfo},
olm::{OutboundGroupSession, PickledInboundGroupSession, PrivateCrossSigningIdentity}, olm::{OutboundGroupSession, PickledInboundGroupSession, PrivateCrossSigningIdentity},
}; };
@ -55,6 +58,27 @@ impl EncodeKey for Uuid {
} }
} }
impl EncodeKey for SecretName {
fn encode(&self) -> Vec<u8> {
[
// TODO don't use serde here, use `as_ref()` when it becomes
// available
serde_json::to_string(self).expect("Can't serialize secret name").as_bytes(),
&[Self::SEPARATOR],
]
.concat()
}
}
impl EncodeKey for SecretInfo {
fn encode(&self) -> Vec<u8> {
match self {
SecretInfo::KeyRequest(k) => k.encode(),
SecretInfo::SecretRequest(s) => s.encode(),
}
}
}
impl EncodeKey for &RequestedKeyInfo { impl EncodeKey for &RequestedKeyInfo {
fn encode(&self) -> Vec<u8> { fn encode(&self) -> Vec<u8> {
[ [
@ -136,9 +160,9 @@ pub struct SledStore {
inbound_group_sessions: Tree, inbound_group_sessions: Tree,
outbound_group_sessions: Tree, outbound_group_sessions: Tree,
outgoing_key_requests: Tree, outgoing_secret_requests: Tree,
unsent_key_requests: Tree, unsent_secret_requests: Tree,
key_requests_by_info: Tree, secret_requests_by_info: Tree,
devices: Tree, devices: Tree,
identities: Tree, identities: Tree,
@ -201,9 +225,9 @@ impl SledStore {
let devices = db.open_tree("devices")?; let devices = db.open_tree("devices")?;
let identities = db.open_tree("identities")?; let identities = db.open_tree("identities")?;
let outgoing_key_requests = db.open_tree("outgoing_key_requests")?; let outgoing_secret_requests = db.open_tree("outgoing_secret_requests")?;
let unsent_key_requests = db.open_tree("unsent_key_requests")?; let unsent_secret_requests = db.open_tree("unsent_secret_requests")?;
let key_requests_by_info = db.open_tree("key_requests_by_info")?; let secret_requests_by_info = db.open_tree("secret_requests_by_info")?;
let session_cache = SessionStore::new(); let session_cache = SessionStore::new();
@ -227,9 +251,9 @@ impl SledStore {
users_for_key_query_cache: DashSet::new().into(), users_for_key_query_cache: DashSet::new().into(),
inbound_group_sessions, inbound_group_sessions,
outbound_group_sessions, outbound_group_sessions,
outgoing_key_requests, outgoing_secret_requests,
unsent_key_requests, unsent_secret_requests,
key_requests_by_info, secret_requests_by_info,
devices, devices,
tracked_users, tracked_users,
users_for_key_query, users_for_key_query,
@ -361,9 +385,9 @@ impl SledStore {
&self.inbound_group_sessions, &self.inbound_group_sessions,
&self.outbound_group_sessions, &self.outbound_group_sessions,
&self.olm_hashes, &self.olm_hashes,
&self.outgoing_key_requests, &self.outgoing_secret_requests,
&self.unsent_key_requests, &self.unsent_secret_requests,
&self.key_requests_by_info, &self.secret_requests_by_info,
) )
.transaction( .transaction(
|( |(
@ -375,9 +399,9 @@ impl SledStore {
inbound_sessions, inbound_sessions,
outbound_sessions, outbound_sessions,
hashes, hashes,
outgoing_key_requests, outgoing_secret_requests,
unsent_key_requests, unsent_secret_requests,
key_requests_by_info, secret_requests_by_info,
)| { )| {
if let Some(a) = &account_pickle { if let Some(a) = &account_pickle {
account.insert( account.insert(
@ -446,7 +470,7 @@ impl SledStore {
} }
for key_request in &key_requests { for key_request in &key_requests {
key_requests_by_info.insert( secret_requests_by_info.insert(
(&key_request.info).encode(), (&key_request.info).encode(),
key_request.request_id.encode(), key_request.request_id.encode(),
)?; )?;
@ -454,15 +478,15 @@ impl SledStore {
let key_request_id = key_request.request_id.encode(); let key_request_id = key_request.request_id.encode();
if key_request.sent_out { if key_request.sent_out {
unsent_key_requests.remove(key_request_id.clone())?; unsent_secret_requests.remove(key_request_id.clone())?;
outgoing_key_requests.insert( outgoing_secret_requests.insert(
key_request_id, key_request_id,
serde_json::to_vec(&key_request) serde_json::to_vec(&key_request)
.map_err(ConflictableTransactionError::Abort)?, .map_err(ConflictableTransactionError::Abort)?,
)?; )?;
} else { } else {
outgoing_key_requests.remove(key_request_id.clone())?; outgoing_secret_requests.remove(key_request_id.clone())?;
unsent_key_requests.insert( unsent_secret_requests.insert(
key_request_id, key_request_id,
serde_json::to_vec(&key_request) serde_json::to_vec(&key_request)
.map_err(ConflictableTransactionError::Abort)?, .map_err(ConflictableTransactionError::Abort)?,
@ -484,11 +508,14 @@ impl SledStore {
&self, &self,
id: &[u8], id: &[u8],
) -> Result<Option<OutgoingKeyRequest>> { ) -> Result<Option<OutgoingKeyRequest>> {
let request = let request = self
self.outgoing_key_requests.get(id)?.map(|r| serde_json::from_slice(&r)).transpose()?; .outgoing_secret_requests
.get(id)?
.map(|r| serde_json::from_slice(&r))
.transpose()?;
let request = if request.is_none() { let request = if request.is_none() {
self.unsent_key_requests.get(id)?.map(|r| serde_json::from_slice(&r)).transpose()? self.unsent_secret_requests.get(id)?.map(|r| serde_json::from_slice(&r)).transpose()?
} else { } else {
request request
}; };
@ -681,7 +708,7 @@ impl CryptoStore for SledStore {
Ok(self.olm_hashes.contains_key(serde_json::to_vec(message_hash)?)?) Ok(self.olm_hashes.contains_key(serde_json::to_vec(message_hash)?)?)
} }
async fn get_outgoing_key_request( async fn get_outgoing_secret_requests(
&self, &self,
request_id: Uuid, request_id: Uuid,
) -> Result<Option<OutgoingKeyRequest>> { ) -> Result<Option<OutgoingKeyRequest>> {
@ -690,11 +717,11 @@ impl CryptoStore for SledStore {
self.get_outgoing_key_request_helper(&request_id).await self.get_outgoing_key_request_helper(&request_id).await
} }
async fn get_key_request_by_info( async fn get_secret_request_by_info(
&self, &self,
key_info: &RequestedKeyInfo, key_info: &SecretInfo,
) -> Result<Option<OutgoingKeyRequest>> { ) -> Result<Option<OutgoingKeyRequest>> {
let id = self.key_requests_by_info.get(key_info.encode())?; let id = self.secret_requests_by_info.get(key_info.encode())?;
if let Some(id) = id { if let Some(id) = id {
self.get_outgoing_key_request_helper(&id).await self.get_outgoing_key_request_helper(&id).await
@ -703,9 +730,9 @@ impl CryptoStore for SledStore {
} }
} }
async fn get_unsent_key_requests(&self) -> Result<Vec<OutgoingKeyRequest>> { async fn get_unsent_secret_requests(&self) -> Result<Vec<OutgoingKeyRequest>> {
let requests: Result<Vec<OutgoingKeyRequest>> = self let requests: Result<Vec<OutgoingKeyRequest>> = self
.unsent_key_requests .unsent_secret_requests
.iter() .iter()
.map(|i| serde_json::from_slice(&i?.1).map_err(CryptoStoreError::from)) .map(|i| serde_json::from_slice(&i?.1).map_err(CryptoStoreError::from))
.collect(); .collect();
@ -713,34 +740,37 @@ impl CryptoStore for SledStore {
requests requests
} }
async fn delete_outgoing_key_request(&self, request_id: Uuid) -> Result<()> { async fn delete_outgoing_secret_requests(&self, request_id: Uuid) -> Result<()> {
let ret: Result<(), TransactionError<serde_json::Error>> = let ret: Result<(), TransactionError<serde_json::Error>> = (
(&self.outgoing_key_requests, &self.unsent_key_requests, &self.key_requests_by_info) &self.outgoing_secret_requests,
.transaction( &self.unsent_secret_requests,
|(outgoing_key_requests, unsent_key_requests, key_requests_by_info)| { &self.secret_requests_by_info,
let sent_request: Option<OutgoingKeyRequest> = outgoing_key_requests )
.remove(request_id.encode())? .transaction(
.map(|r| serde_json::from_slice(&r)) |(outgoing_key_requests, unsent_key_requests, key_requests_by_info)| {
.transpose() let sent_request: Option<OutgoingKeyRequest> = outgoing_key_requests
.map_err(ConflictableTransactionError::Abort)?; .remove(request_id.encode())?
.map(|r| serde_json::from_slice(&r))
.transpose()
.map_err(ConflictableTransactionError::Abort)?;
let unsent_request: Option<OutgoingKeyRequest> = unsent_key_requests let unsent_request: Option<OutgoingKeyRequest> = unsent_key_requests
.remove(request_id.encode())? .remove(request_id.encode())?
.map(|r| serde_json::from_slice(&r)) .map(|r| serde_json::from_slice(&r))
.transpose() .transpose()
.map_err(ConflictableTransactionError::Abort)?; .map_err(ConflictableTransactionError::Abort)?;
if let Some(request) = sent_request { if let Some(request) = sent_request {
key_requests_by_info.remove((&request.info).encode())?; key_requests_by_info.remove((&request.info).encode())?;
} }
if let Some(request) = unsent_request { if let Some(request) = unsent_request {
key_requests_by_info.remove((&request.info).encode())?; key_requests_by_info.remove((&request.info).encode())?;
} }
Ok(()) Ok(())
}, },
); );
ret?; ret?;
self.inner.flush_async().await?; self.inner.flush_async().await?;
@ -768,6 +798,7 @@ mod test {
device::test::get_device, device::test::get_device,
user::test::{get_other_identity, get_own_identity}, user::test::{get_other_identity, get_own_identity},
}, },
key_request::SecretInfo,
olm::{ olm::{
GroupSessionKey, InboundGroupSession, OlmMessageHash, PrivateCrossSigningIdentity, GroupSessionKey, InboundGroupSession, OlmMessageHash, PrivateCrossSigningIdentity,
ReadOnlyAccount, Session, ReadOnlyAccount, Session,
@ -1199,12 +1230,13 @@ mod test {
let (account, store, _dir) = get_loaded_store().await; let (account, store, _dir) = get_loaded_store().await;
let id = Uuid::new_v4(); let id = Uuid::new_v4();
let info = RequestedKeyInfo::new( let info: SecretInfo = RequestedKeyInfo::new(
EventEncryptionAlgorithm::MegolmV1AesSha2, EventEncryptionAlgorithm::MegolmV1AesSha2,
room_id!("!test:localhost"), room_id!("!test:localhost"),
"test_sender_key".to_string(), "test_sender_key".to_string(),
"test_session_id".to_string(), "test_session_id".to_string(),
); )
.into();
let request = OutgoingKeyRequest { let request = OutgoingKeyRequest {
request_recipient: account.user_id().to_owned(), request_recipient: account.user_id().to_owned(),
@ -1213,7 +1245,7 @@ mod test {
sent_out: false, sent_out: false,
}; };
assert!(store.get_outgoing_key_request(id).await.unwrap().is_none()); assert!(store.get_outgoing_secret_requests(id).await.unwrap().is_none());
let mut changes = Changes::default(); let mut changes = Changes::default();
changes.key_requests.push(request.clone()); changes.key_requests.push(request.clone());
@ -1221,12 +1253,12 @@ mod test {
let request = Some(request); let request = Some(request);
let stored_request = store.get_outgoing_key_request(id).await.unwrap(); let stored_request = store.get_outgoing_secret_requests(id).await.unwrap();
assert_eq!(request, stored_request); assert_eq!(request, stored_request);
let stored_request = store.get_key_request_by_info(&info).await.unwrap(); let stored_request = store.get_secret_request_by_info(&info).await.unwrap();
assert_eq!(request, stored_request); assert_eq!(request, stored_request);
assert!(!store.get_unsent_key_requests().await.unwrap().is_empty()); assert!(!store.get_unsent_secret_requests().await.unwrap().is_empty());
let request = OutgoingKeyRequest { let request = OutgoingKeyRequest {
request_recipient: account.user_id().to_owned(), request_recipient: account.user_id().to_owned(),
@ -1239,17 +1271,17 @@ mod test {
changes.key_requests.push(request.clone()); changes.key_requests.push(request.clone());
store.save_changes(changes).await.unwrap(); store.save_changes(changes).await.unwrap();
assert!(store.get_unsent_key_requests().await.unwrap().is_empty()); assert!(store.get_unsent_secret_requests().await.unwrap().is_empty());
let stored_request = store.get_outgoing_key_request(id).await.unwrap(); let stored_request = store.get_outgoing_secret_requests(id).await.unwrap();
assert_eq!(Some(request), stored_request); assert_eq!(Some(request), stored_request);
store.delete_outgoing_key_request(id).await.unwrap(); store.delete_outgoing_secret_requests(id).await.unwrap();
let stored_request = store.get_outgoing_key_request(id).await.unwrap(); let stored_request = store.get_outgoing_secret_requests(id).await.unwrap();
assert_eq!(None, stored_request); assert_eq!(None, stored_request);
let stored_request = store.get_key_request_by_info(&info).await.unwrap(); let stored_request = store.get_secret_request_by_info(&info).await.unwrap();
assert_eq!(None, stored_request); assert_eq!(None, stored_request);
assert!(store.get_unsent_key_requests().await.unwrap().is_empty()); assert!(store.get_unsent_secret_requests().await.unwrap().is_empty());
} }
} }