crypto: Add support to request secrets
This commit is contained in:
parent
a916288d03
commit
5192feb836
6 changed files with 279 additions and 126 deletions
|
@ -29,7 +29,9 @@ use ruma::{
|
|||
forwarded_room_key::ForwardedRoomKeyToDeviceEventContent,
|
||||
room_key_request::{Action, RequestedKeyInfo, RoomKeyRequestToDeviceEventContent},
|
||||
secret::{
|
||||
request::{RequestAction, RequestToDeviceEventContent as SecretRequestEventContent},
|
||||
request::{
|
||||
RequestAction, RequestToDeviceEventContent as SecretRequestEventContent, SecretName,
|
||||
},
|
||||
send::SendToDeviceEventContent as SecretSendEventContent,
|
||||
},
|
||||
AnyToDeviceEvent, AnyToDeviceEventContent, ToDeviceEvent,
|
||||
|
@ -201,24 +203,56 @@ pub struct OutgoingKeyRequest {
|
|||
/// The unique id of the key request.
|
||||
pub request_id: Uuid,
|
||||
/// The info of the requested key.
|
||||
pub info: RequestedKeyInfo,
|
||||
pub info: SecretInfo,
|
||||
/// Has the request been sent out.
|
||||
pub sent_out: bool,
|
||||
}
|
||||
|
||||
/// An enum over the various secret request types we can have.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum SecretInfo {
|
||||
// Info for the `m.room_key_request` variant
|
||||
KeyRequest(RequestedKeyInfo),
|
||||
// Info for the `m.secret.request` variant
|
||||
SecretRequest(SecretName),
|
||||
}
|
||||
|
||||
impl From<RequestedKeyInfo> for SecretInfo {
|
||||
fn from(i: RequestedKeyInfo) -> Self {
|
||||
Self::KeyRequest(i)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<SecretName> for SecretInfo {
|
||||
fn from(i: SecretName) -> Self {
|
||||
Self::SecretRequest(i)
|
||||
}
|
||||
}
|
||||
|
||||
impl OutgoingKeyRequest {
|
||||
fn to_request(&self, own_device_id: &DeviceId) -> OutgoingRequest {
|
||||
let content = RoomKeyRequestToDeviceEventContent::new(
|
||||
Action::Request,
|
||||
Some(self.info.clone()),
|
||||
own_device_id.to_owned(),
|
||||
self.request_id.to_string(),
|
||||
);
|
||||
let content = match &self.info {
|
||||
SecretInfo::KeyRequest(r) => {
|
||||
AnyToDeviceEventContent::RoomKeyRequest(RoomKeyRequestToDeviceEventContent::new(
|
||||
Action::Request,
|
||||
Some(r.clone()),
|
||||
own_device_id.to_owned(),
|
||||
self.request_id.to_string(),
|
||||
))
|
||||
}
|
||||
SecretInfo::SecretRequest(s) => {
|
||||
AnyToDeviceEventContent::SecretRequest(SecretRequestEventContent::new(
|
||||
RequestAction::Request(s.clone()),
|
||||
own_device_id.to_owned(),
|
||||
self.request_id.to_string(),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let request = ToDeviceRequest::new_with_id(
|
||||
&self.request_recipient,
|
||||
DeviceIdOrAllDevices::AllDevices,
|
||||
AnyToDeviceEventContent::RoomKeyRequest(content),
|
||||
content,
|
||||
self.request_id,
|
||||
);
|
||||
|
||||
|
@ -226,17 +260,28 @@ impl OutgoingKeyRequest {
|
|||
}
|
||||
|
||||
fn to_cancellation(&self, own_device_id: &DeviceId) -> OutgoingRequest {
|
||||
let content = RoomKeyRequestToDeviceEventContent::new(
|
||||
Action::CancelRequest,
|
||||
None,
|
||||
own_device_id.to_owned(),
|
||||
self.request_id.to_string(),
|
||||
);
|
||||
let content = match self.info {
|
||||
SecretInfo::KeyRequest(_) => {
|
||||
AnyToDeviceEventContent::RoomKeyRequest(RoomKeyRequestToDeviceEventContent::new(
|
||||
Action::CancelRequest,
|
||||
None,
|
||||
own_device_id.to_owned(),
|
||||
self.request_id.to_string(),
|
||||
))
|
||||
}
|
||||
SecretInfo::SecretRequest(_) => {
|
||||
AnyToDeviceEventContent::SecretRequest(SecretRequestEventContent::new(
|
||||
RequestAction::RequestCancellation,
|
||||
own_device_id.to_owned(),
|
||||
self.request_id.to_string(),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let request = ToDeviceRequest::new(
|
||||
&self.request_recipient,
|
||||
DeviceIdOrAllDevices::AllDevices,
|
||||
AnyToDeviceEventContent::RoomKeyRequest(content),
|
||||
content,
|
||||
);
|
||||
|
||||
OutgoingRequest { request_id: request.txn_id, request: Arc::new(request.into()) }
|
||||
|
@ -245,11 +290,20 @@ impl OutgoingKeyRequest {
|
|||
|
||||
impl PartialEq for OutgoingKeyRequest {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.request_id == other.request_id
|
||||
&& self.info.algorithm == other.info.algorithm
|
||||
&& self.info.room_id == other.info.room_id
|
||||
&& self.info.session_id == other.info.session_id
|
||||
&& self.info.sender_key == other.info.sender_key
|
||||
let is_info_equal = match (&self.info, &other.info) {
|
||||
(SecretInfo::KeyRequest(first), SecretInfo::KeyRequest(second)) => {
|
||||
first.algorithm == second.algorithm
|
||||
&& first.room_id == second.room_id
|
||||
&& first.session_id == second.session_id
|
||||
}
|
||||
(SecretInfo::SecretRequest(first), SecretInfo::SecretRequest(second)) => {
|
||||
first == second
|
||||
}
|
||||
(SecretInfo::KeyRequest(_), SecretInfo::SecretRequest(_))
|
||||
| (SecretInfo::SecretRequest(_), SecretInfo::KeyRequest(_)) => false,
|
||||
};
|
||||
|
||||
self.request_id == other.request_id && is_info_equal
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -277,7 +331,7 @@ impl KeyRequestMachine {
|
|||
async fn load_outgoing_requests(&self) -> Result<Vec<OutgoingRequest>, CryptoStoreError> {
|
||||
Ok(self
|
||||
.store
|
||||
.get_unsent_key_requests()
|
||||
.get_unsent_secret_requests()
|
||||
.await?
|
||||
.into_iter()
|
||||
.filter(|i| !i.sent_out)
|
||||
|
@ -675,11 +729,8 @@ impl KeyRequestMachine {
|
|||
///
|
||||
/// * `key_info` - The info of our key request containing information about
|
||||
/// the key we wish to request.
|
||||
async fn should_request_key(
|
||||
&self,
|
||||
key_info: &RequestedKeyInfo,
|
||||
) -> Result<bool, CryptoStoreError> {
|
||||
let request = self.store.get_key_request_by_info(key_info).await?;
|
||||
async fn should_request_key(&self, key_info: &SecretInfo) -> Result<bool, CryptoStoreError> {
|
||||
let request = self.store.get_secret_request_by_info(key_info).await?;
|
||||
|
||||
// Don't send out duplicate requests, users can re-request them if they
|
||||
// think a second request might succeed.
|
||||
|
@ -728,9 +779,10 @@ impl KeyRequestMachine {
|
|||
room_id.to_owned(),
|
||||
sender_key.to_owned(),
|
||||
session_id.to_owned(),
|
||||
);
|
||||
)
|
||||
.into();
|
||||
|
||||
let request = self.store.get_key_request_by_info(&key_info).await?;
|
||||
let request = self.store.get_secret_request_by_info(&key_info).await?;
|
||||
|
||||
if let Some(request) = request {
|
||||
let cancel = request.to_cancellation(self.device_id());
|
||||
|
@ -744,9 +796,39 @@ impl KeyRequestMachine {
|
|||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub async fn request_missing_secrets(&self) -> Result<Vec<OutgoingRequest>, CryptoStoreError> {
|
||||
let secret_names = self.store.get_missing_secrets().await;
|
||||
|
||||
Ok(if secret_names.is_empty() {
|
||||
info!(secret_names =? secret_names, "Creating new outgoing secret requests");
|
||||
|
||||
let requests: Vec<OutgoingKeyRequest> = secret_names
|
||||
.into_iter()
|
||||
.map(|n| OutgoingKeyRequest {
|
||||
request_recipient: self.user_id().to_owned(),
|
||||
request_id: Uuid::new_v4(),
|
||||
info: n.into(),
|
||||
sent_out: false,
|
||||
})
|
||||
.collect();
|
||||
|
||||
let outgoing_requests =
|
||||
requests.iter().map(|r| r.to_request(self.device_id())).collect();
|
||||
|
||||
let changes = Changes { key_requests: requests, ..Default::default() };
|
||||
self.store.save_changes(changes).await?;
|
||||
|
||||
outgoing_requests
|
||||
} else {
|
||||
trace!("No secrets are missing from our store, not requesting them");
|
||||
vec![]
|
||||
})
|
||||
}
|
||||
|
||||
async fn request_key_helper(
|
||||
&self,
|
||||
key_info: RequestedKeyInfo,
|
||||
key_info: SecretInfo,
|
||||
) -> Result<OutgoingRequest, CryptoStoreError> {
|
||||
info!("Creating new outgoing room key request {:#?}", key_info);
|
||||
|
||||
|
@ -788,7 +870,8 @@ impl KeyRequestMachine {
|
|||
room_id.to_owned(),
|
||||
sender_key.to_owned(),
|
||||
session_id.to_owned(),
|
||||
);
|
||||
)
|
||||
.into();
|
||||
|
||||
if self.should_request_key(&key_info).await? {
|
||||
self.request_key_helper(key_info).await?;
|
||||
|
@ -819,19 +902,20 @@ impl KeyRequestMachine {
|
|||
content.room_id.clone(),
|
||||
content.sender_key.clone(),
|
||||
content.session_id.clone(),
|
||||
);
|
||||
)
|
||||
.into();
|
||||
|
||||
self.store.get_key_request_by_info(&info).await
|
||||
self.store.get_secret_request_by_info(&info).await
|
||||
}
|
||||
|
||||
/// Delete the given outgoing key info.
|
||||
async fn delete_key_info(&self, info: &OutgoingKeyRequest) -> Result<(), CryptoStoreError> {
|
||||
self.store.delete_outgoing_key_request(info.request_id).await
|
||||
self.store.delete_outgoing_secret_requests(info.request_id).await
|
||||
}
|
||||
|
||||
/// Mark the outgoing request as sent.
|
||||
pub async fn mark_outgoing_request_as_sent(&self, id: Uuid) -> Result<(), CryptoStoreError> {
|
||||
let info = self.store.get_outgoing_key_request(id).await?;
|
||||
let info = self.store.get_outgoing_secret_requests(id).await?;
|
||||
|
||||
if let Some(mut info) = info {
|
||||
trace!("Marking outgoing key request as sent {:#?}", info);
|
||||
|
|
|
@ -267,7 +267,7 @@ pub(crate) mod test {
|
|||
})
|
||||
.to_string();
|
||||
|
||||
let event: AnySyncRoomEvent = serde_json::from_str(&event).expect("WHAAAT?!?!?");
|
||||
let event: AnySyncRoomEvent = serde_json::from_str(&event).unwrap();
|
||||
|
||||
let event =
|
||||
if let AnySyncRoomEvent::Message(AnySyncMessageEvent::RoomEncrypted(event)) = event {
|
||||
|
|
|
@ -97,6 +97,11 @@ impl PrivateCrossSigningIdentity {
|
|||
self.self_signing_key.lock().await.is_some()
|
||||
}
|
||||
|
||||
/// Can we sign other users, i.e. do we have a user signing key.
|
||||
pub async fn can_sign_users(&self) -> bool {
|
||||
self.user_signing_key.lock().await.is_some()
|
||||
}
|
||||
|
||||
/// Do we have the master key.
|
||||
pub async fn has_master_key(&self) -> bool {
|
||||
self.master_key.lock().await.is_some()
|
||||
|
@ -130,6 +135,25 @@ impl PrivateCrossSigningIdentity {
|
|||
}
|
||||
}
|
||||
|
||||
/// Get the names of the secrets we are missing.
|
||||
pub(crate) async fn get_missing_secrets(&self) -> Vec<SecretName> {
|
||||
let mut missing = Vec::new();
|
||||
|
||||
if !self.has_master_key().await {
|
||||
missing.push(SecretName::CrossSigningMasterKey);
|
||||
}
|
||||
|
||||
if !self.can_sign_devices().await {
|
||||
missing.push(SecretName::CrossSigningSelfSigningKey);
|
||||
}
|
||||
|
||||
if !self.can_sign_users().await {
|
||||
missing.push(SecretName::CrossSigningUserSigningKey);
|
||||
}
|
||||
|
||||
missing
|
||||
}
|
||||
|
||||
/// Create a new empty identity.
|
||||
pub(crate) fn empty(user_id: UserId) -> Self {
|
||||
Self {
|
||||
|
|
|
@ -19,7 +19,7 @@ use std::{
|
|||
|
||||
use dashmap::{DashMap, DashSet};
|
||||
use matrix_sdk_common::{async_trait, locks::Mutex, uuid::Uuid};
|
||||
use ruma::{events::room_key_request::RequestedKeyInfo, DeviceId, DeviceIdBox, RoomId, UserId};
|
||||
use ruma::{DeviceId, DeviceIdBox, RoomId, UserId};
|
||||
|
||||
use super::{
|
||||
caches::{DeviceStore, GroupSessionStore, SessionStore},
|
||||
|
@ -27,12 +27,21 @@ use super::{
|
|||
};
|
||||
use crate::{
|
||||
identities::{ReadOnlyDevice, ReadOnlyUserIdentities},
|
||||
key_request::OutgoingKeyRequest,
|
||||
key_request::{OutgoingKeyRequest, SecretInfo},
|
||||
olm::{OutboundGroupSession, PrivateCrossSigningIdentity},
|
||||
};
|
||||
|
||||
fn encode_key_info(info: &RequestedKeyInfo) -> String {
|
||||
format!("{}{}{}{}", info.room_id, info.sender_key, info.algorithm, info.session_id)
|
||||
fn encode_key_info(info: &SecretInfo) -> String {
|
||||
match info {
|
||||
SecretInfo::KeyRequest(info) => {
|
||||
format!("{}{}{}{}", info.room_id, info.sender_key, info.algorithm, info.session_id)
|
||||
}
|
||||
SecretInfo::SecretRequest(i) => {
|
||||
// TODO don't use serde here, use `as_ref()` when it becomes
|
||||
// available
|
||||
serde_json::to_string(i).expect("Can't serialize secret name")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// An in-memory only store that will forget all the E2EE key once it's dropped.
|
||||
|
@ -228,16 +237,16 @@ impl CryptoStore for MemoryStore {
|
|||
.contains(&message_hash.hash))
|
||||
}
|
||||
|
||||
async fn get_outgoing_key_request(
|
||||
async fn get_outgoing_secret_requests(
|
||||
&self,
|
||||
request_id: Uuid,
|
||||
) -> Result<Option<OutgoingKeyRequest>> {
|
||||
Ok(self.outgoing_key_requests.get(&request_id).map(|r| r.clone()))
|
||||
}
|
||||
|
||||
async fn get_key_request_by_info(
|
||||
async fn get_secret_request_by_info(
|
||||
&self,
|
||||
key_info: &RequestedKeyInfo,
|
||||
key_info: &SecretInfo,
|
||||
) -> Result<Option<OutgoingKeyRequest>> {
|
||||
let key_info_string = encode_key_info(key_info);
|
||||
|
||||
|
@ -247,7 +256,7 @@ impl CryptoStore for MemoryStore {
|
|||
.and_then(|i| self.outgoing_key_requests.get(&i).map(|r| r.clone())))
|
||||
}
|
||||
|
||||
async fn get_unsent_key_requests(&self) -> Result<Vec<OutgoingKeyRequest>> {
|
||||
async fn get_unsent_secret_requests(&self) -> Result<Vec<OutgoingKeyRequest>> {
|
||||
Ok(self
|
||||
.outgoing_key_requests
|
||||
.iter()
|
||||
|
@ -256,7 +265,7 @@ impl CryptoStore for MemoryStore {
|
|||
.collect())
|
||||
}
|
||||
|
||||
async fn delete_outgoing_key_request(&self, request_id: Uuid) -> Result<()> {
|
||||
async fn delete_outgoing_secret_requests(&self, request_id: Uuid) -> Result<()> {
|
||||
self.outgoing_key_requests.remove(&request_id).and_then(|(_, i)| {
|
||||
let key_info_string = encode_key_info(&i.info);
|
||||
self.key_requests_by_info.remove(&key_info_string)
|
||||
|
|
|
@ -56,9 +56,8 @@ pub use memorystore::MemoryStore;
|
|||
use olm_rs::errors::{OlmAccountError, OlmGroupSessionError, OlmSessionError};
|
||||
pub use pickle_key::{EncryptedPickleKey, PickleKey};
|
||||
use ruma::{
|
||||
events::{room_key_request::RequestedKeyInfo, secret::request::SecretName},
|
||||
identifiers::Error as IdentifierValidationError,
|
||||
DeviceId, DeviceIdBox, DeviceKeyAlgorithm, RoomId, UserId,
|
||||
events::secret::request::SecretName, identifiers::Error as IdentifierValidationError, DeviceId,
|
||||
DeviceIdBox, DeviceKeyAlgorithm, RoomId, UserId,
|
||||
};
|
||||
use serde_json::Error as SerdeError;
|
||||
use thiserror::Error;
|
||||
|
@ -72,7 +71,7 @@ use crate::{
|
|||
user::{OwnUserIdentity, UserIdentities, UserIdentity},
|
||||
Device, ReadOnlyDevice, ReadOnlyUserIdentities, UserDevices,
|
||||
},
|
||||
key_request::OutgoingKeyRequest,
|
||||
key_request::{OutgoingKeyRequest, SecretInfo},
|
||||
olm::{
|
||||
InboundGroupSession, OlmMessageHash, OutboundGroupSession, PrivateCrossSigningIdentity,
|
||||
ReadOnlyAccount, Session,
|
||||
|
@ -273,6 +272,11 @@ impl Store {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_missing_secrets(&self) -> Vec<SecretName> {
|
||||
// TODO add the backup key to our missing secrets
|
||||
self.identity.lock().await.get_missing_secrets().await
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for Store {
|
||||
|
@ -440,14 +444,14 @@ pub trait CryptoStore: AsyncTraitDeps {
|
|||
/// Check if a hash for an Olm message stored in the database.
|
||||
async fn is_message_known(&self, message_hash: &OlmMessageHash) -> Result<bool>;
|
||||
|
||||
/// Get an outgoing key request that we created that matches the given
|
||||
/// Get an outgoing secret request that we created that matches the given
|
||||
/// request id.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `request_id` - The unique request id that identifies this outgoing key
|
||||
/// request.
|
||||
async fn get_outgoing_key_request(
|
||||
/// * `request_id` - The unique request id that identifies this outgoing
|
||||
/// secret request.
|
||||
async fn get_outgoing_secret_requests(
|
||||
&self,
|
||||
request_id: Uuid,
|
||||
) -> Result<Option<OutgoingKeyRequest>>;
|
||||
|
@ -457,14 +461,14 @@ pub trait CryptoStore: AsyncTraitDeps {
|
|||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `key_info` - The key info of an outgoing key request.
|
||||
async fn get_key_request_by_info(
|
||||
/// * `key_info` - The key info of an outgoing secret request.
|
||||
async fn get_secret_request_by_info(
|
||||
&self,
|
||||
key_info: &RequestedKeyInfo,
|
||||
secret_info: &SecretInfo,
|
||||
) -> Result<Option<OutgoingKeyRequest>>;
|
||||
|
||||
/// Get all outgoing key requests that we have in the store.
|
||||
async fn get_unsent_key_requests(&self) -> Result<Vec<OutgoingKeyRequest>>;
|
||||
/// Get all outgoing secret requests that we have in the store.
|
||||
async fn get_unsent_secret_requests(&self) -> Result<Vec<OutgoingKeyRequest>>;
|
||||
|
||||
/// Delete an outgoing key request that we created that matches the given
|
||||
/// request id.
|
||||
|
@ -473,5 +477,5 @@ pub trait CryptoStore: AsyncTraitDeps {
|
|||
///
|
||||
/// * `request_id` - The unique request id that identifies this outgoing key
|
||||
/// request.
|
||||
async fn delete_outgoing_key_request(&self, request_id: Uuid) -> Result<()>;
|
||||
async fn delete_outgoing_secret_requests(&self, request_id: Uuid) -> Result<()>;
|
||||
}
|
||||
|
|
|
@ -22,7 +22,10 @@ use std::{
|
|||
use dashmap::DashSet;
|
||||
use matrix_sdk_common::{async_trait, locks::Mutex, uuid};
|
||||
use olm_rs::{account::IdentityKeys, PicklingMode};
|
||||
use ruma::{events::room_key_request::RequestedKeyInfo, DeviceId, DeviceIdBox, RoomId, UserId};
|
||||
use ruma::{
|
||||
events::{room_key_request::RequestedKeyInfo, secret::request::SecretName},
|
||||
DeviceId, DeviceIdBox, RoomId, UserId,
|
||||
};
|
||||
pub use sled::Error;
|
||||
use sled::{
|
||||
transaction::{ConflictableTransactionError, TransactionError},
|
||||
|
@ -36,7 +39,7 @@ use super::{
|
|||
};
|
||||
use crate::{
|
||||
identities::{ReadOnlyDevice, ReadOnlyUserIdentities},
|
||||
key_request::OutgoingKeyRequest,
|
||||
key_request::{OutgoingKeyRequest, SecretInfo},
|
||||
olm::{OutboundGroupSession, PickledInboundGroupSession, PrivateCrossSigningIdentity},
|
||||
};
|
||||
|
||||
|
@ -55,6 +58,27 @@ impl EncodeKey for Uuid {
|
|||
}
|
||||
}
|
||||
|
||||
impl EncodeKey for SecretName {
|
||||
fn encode(&self) -> Vec<u8> {
|
||||
[
|
||||
// TODO don't use serde here, use `as_ref()` when it becomes
|
||||
// available
|
||||
serde_json::to_string(self).expect("Can't serialize secret name").as_bytes(),
|
||||
&[Self::SEPARATOR],
|
||||
]
|
||||
.concat()
|
||||
}
|
||||
}
|
||||
|
||||
impl EncodeKey for SecretInfo {
|
||||
fn encode(&self) -> Vec<u8> {
|
||||
match self {
|
||||
SecretInfo::KeyRequest(k) => k.encode(),
|
||||
SecretInfo::SecretRequest(s) => s.encode(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl EncodeKey for &RequestedKeyInfo {
|
||||
fn encode(&self) -> Vec<u8> {
|
||||
[
|
||||
|
@ -136,9 +160,9 @@ pub struct SledStore {
|
|||
inbound_group_sessions: Tree,
|
||||
outbound_group_sessions: Tree,
|
||||
|
||||
outgoing_key_requests: Tree,
|
||||
unsent_key_requests: Tree,
|
||||
key_requests_by_info: Tree,
|
||||
outgoing_secret_requests: Tree,
|
||||
unsent_secret_requests: Tree,
|
||||
secret_requests_by_info: Tree,
|
||||
|
||||
devices: Tree,
|
||||
identities: Tree,
|
||||
|
@ -201,9 +225,9 @@ impl SledStore {
|
|||
let devices = db.open_tree("devices")?;
|
||||
let identities = db.open_tree("identities")?;
|
||||
|
||||
let outgoing_key_requests = db.open_tree("outgoing_key_requests")?;
|
||||
let unsent_key_requests = db.open_tree("unsent_key_requests")?;
|
||||
let key_requests_by_info = db.open_tree("key_requests_by_info")?;
|
||||
let outgoing_secret_requests = db.open_tree("outgoing_secret_requests")?;
|
||||
let unsent_secret_requests = db.open_tree("unsent_secret_requests")?;
|
||||
let secret_requests_by_info = db.open_tree("secret_requests_by_info")?;
|
||||
|
||||
let session_cache = SessionStore::new();
|
||||
|
||||
|
@ -227,9 +251,9 @@ impl SledStore {
|
|||
users_for_key_query_cache: DashSet::new().into(),
|
||||
inbound_group_sessions,
|
||||
outbound_group_sessions,
|
||||
outgoing_key_requests,
|
||||
unsent_key_requests,
|
||||
key_requests_by_info,
|
||||
outgoing_secret_requests,
|
||||
unsent_secret_requests,
|
||||
secret_requests_by_info,
|
||||
devices,
|
||||
tracked_users,
|
||||
users_for_key_query,
|
||||
|
@ -361,9 +385,9 @@ impl SledStore {
|
|||
&self.inbound_group_sessions,
|
||||
&self.outbound_group_sessions,
|
||||
&self.olm_hashes,
|
||||
&self.outgoing_key_requests,
|
||||
&self.unsent_key_requests,
|
||||
&self.key_requests_by_info,
|
||||
&self.outgoing_secret_requests,
|
||||
&self.unsent_secret_requests,
|
||||
&self.secret_requests_by_info,
|
||||
)
|
||||
.transaction(
|
||||
|(
|
||||
|
@ -375,9 +399,9 @@ impl SledStore {
|
|||
inbound_sessions,
|
||||
outbound_sessions,
|
||||
hashes,
|
||||
outgoing_key_requests,
|
||||
unsent_key_requests,
|
||||
key_requests_by_info,
|
||||
outgoing_secret_requests,
|
||||
unsent_secret_requests,
|
||||
secret_requests_by_info,
|
||||
)| {
|
||||
if let Some(a) = &account_pickle {
|
||||
account.insert(
|
||||
|
@ -446,7 +470,7 @@ impl SledStore {
|
|||
}
|
||||
|
||||
for key_request in &key_requests {
|
||||
key_requests_by_info.insert(
|
||||
secret_requests_by_info.insert(
|
||||
(&key_request.info).encode(),
|
||||
key_request.request_id.encode(),
|
||||
)?;
|
||||
|
@ -454,15 +478,15 @@ impl SledStore {
|
|||
let key_request_id = key_request.request_id.encode();
|
||||
|
||||
if key_request.sent_out {
|
||||
unsent_key_requests.remove(key_request_id.clone())?;
|
||||
outgoing_key_requests.insert(
|
||||
unsent_secret_requests.remove(key_request_id.clone())?;
|
||||
outgoing_secret_requests.insert(
|
||||
key_request_id,
|
||||
serde_json::to_vec(&key_request)
|
||||
.map_err(ConflictableTransactionError::Abort)?,
|
||||
)?;
|
||||
} else {
|
||||
outgoing_key_requests.remove(key_request_id.clone())?;
|
||||
unsent_key_requests.insert(
|
||||
outgoing_secret_requests.remove(key_request_id.clone())?;
|
||||
unsent_secret_requests.insert(
|
||||
key_request_id,
|
||||
serde_json::to_vec(&key_request)
|
||||
.map_err(ConflictableTransactionError::Abort)?,
|
||||
|
@ -484,11 +508,14 @@ impl SledStore {
|
|||
&self,
|
||||
id: &[u8],
|
||||
) -> Result<Option<OutgoingKeyRequest>> {
|
||||
let request =
|
||||
self.outgoing_key_requests.get(id)?.map(|r| serde_json::from_slice(&r)).transpose()?;
|
||||
let request = self
|
||||
.outgoing_secret_requests
|
||||
.get(id)?
|
||||
.map(|r| serde_json::from_slice(&r))
|
||||
.transpose()?;
|
||||
|
||||
let request = if request.is_none() {
|
||||
self.unsent_key_requests.get(id)?.map(|r| serde_json::from_slice(&r)).transpose()?
|
||||
self.unsent_secret_requests.get(id)?.map(|r| serde_json::from_slice(&r)).transpose()?
|
||||
} else {
|
||||
request
|
||||
};
|
||||
|
@ -681,7 +708,7 @@ impl CryptoStore for SledStore {
|
|||
Ok(self.olm_hashes.contains_key(serde_json::to_vec(message_hash)?)?)
|
||||
}
|
||||
|
||||
async fn get_outgoing_key_request(
|
||||
async fn get_outgoing_secret_requests(
|
||||
&self,
|
||||
request_id: Uuid,
|
||||
) -> Result<Option<OutgoingKeyRequest>> {
|
||||
|
@ -690,11 +717,11 @@ impl CryptoStore for SledStore {
|
|||
self.get_outgoing_key_request_helper(&request_id).await
|
||||
}
|
||||
|
||||
async fn get_key_request_by_info(
|
||||
async fn get_secret_request_by_info(
|
||||
&self,
|
||||
key_info: &RequestedKeyInfo,
|
||||
key_info: &SecretInfo,
|
||||
) -> Result<Option<OutgoingKeyRequest>> {
|
||||
let id = self.key_requests_by_info.get(key_info.encode())?;
|
||||
let id = self.secret_requests_by_info.get(key_info.encode())?;
|
||||
|
||||
if let Some(id) = id {
|
||||
self.get_outgoing_key_request_helper(&id).await
|
||||
|
@ -703,9 +730,9 @@ impl CryptoStore for SledStore {
|
|||
}
|
||||
}
|
||||
|
||||
async fn get_unsent_key_requests(&self) -> Result<Vec<OutgoingKeyRequest>> {
|
||||
async fn get_unsent_secret_requests(&self) -> Result<Vec<OutgoingKeyRequest>> {
|
||||
let requests: Result<Vec<OutgoingKeyRequest>> = self
|
||||
.unsent_key_requests
|
||||
.unsent_secret_requests
|
||||
.iter()
|
||||
.map(|i| serde_json::from_slice(&i?.1).map_err(CryptoStoreError::from))
|
||||
.collect();
|
||||
|
@ -713,34 +740,37 @@ impl CryptoStore for SledStore {
|
|||
requests
|
||||
}
|
||||
|
||||
async fn delete_outgoing_key_request(&self, request_id: Uuid) -> Result<()> {
|
||||
let ret: Result<(), TransactionError<serde_json::Error>> =
|
||||
(&self.outgoing_key_requests, &self.unsent_key_requests, &self.key_requests_by_info)
|
||||
.transaction(
|
||||
|(outgoing_key_requests, unsent_key_requests, key_requests_by_info)| {
|
||||
let sent_request: Option<OutgoingKeyRequest> = outgoing_key_requests
|
||||
.remove(request_id.encode())?
|
||||
.map(|r| serde_json::from_slice(&r))
|
||||
.transpose()
|
||||
.map_err(ConflictableTransactionError::Abort)?;
|
||||
async fn delete_outgoing_secret_requests(&self, request_id: Uuid) -> Result<()> {
|
||||
let ret: Result<(), TransactionError<serde_json::Error>> = (
|
||||
&self.outgoing_secret_requests,
|
||||
&self.unsent_secret_requests,
|
||||
&self.secret_requests_by_info,
|
||||
)
|
||||
.transaction(
|
||||
|(outgoing_key_requests, unsent_key_requests, key_requests_by_info)| {
|
||||
let sent_request: Option<OutgoingKeyRequest> = outgoing_key_requests
|
||||
.remove(request_id.encode())?
|
||||
.map(|r| serde_json::from_slice(&r))
|
||||
.transpose()
|
||||
.map_err(ConflictableTransactionError::Abort)?;
|
||||
|
||||
let unsent_request: Option<OutgoingKeyRequest> = unsent_key_requests
|
||||
.remove(request_id.encode())?
|
||||
.map(|r| serde_json::from_slice(&r))
|
||||
.transpose()
|
||||
.map_err(ConflictableTransactionError::Abort)?;
|
||||
let unsent_request: Option<OutgoingKeyRequest> = unsent_key_requests
|
||||
.remove(request_id.encode())?
|
||||
.map(|r| serde_json::from_slice(&r))
|
||||
.transpose()
|
||||
.map_err(ConflictableTransactionError::Abort)?;
|
||||
|
||||
if let Some(request) = sent_request {
|
||||
key_requests_by_info.remove((&request.info).encode())?;
|
||||
}
|
||||
if let Some(request) = sent_request {
|
||||
key_requests_by_info.remove((&request.info).encode())?;
|
||||
}
|
||||
|
||||
if let Some(request) = unsent_request {
|
||||
key_requests_by_info.remove((&request.info).encode())?;
|
||||
}
|
||||
if let Some(request) = unsent_request {
|
||||
key_requests_by_info.remove((&request.info).encode())?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
},
|
||||
);
|
||||
Ok(())
|
||||
},
|
||||
);
|
||||
|
||||
ret?;
|
||||
self.inner.flush_async().await?;
|
||||
|
@ -768,6 +798,7 @@ mod test {
|
|||
device::test::get_device,
|
||||
user::test::{get_other_identity, get_own_identity},
|
||||
},
|
||||
key_request::SecretInfo,
|
||||
olm::{
|
||||
GroupSessionKey, InboundGroupSession, OlmMessageHash, PrivateCrossSigningIdentity,
|
||||
ReadOnlyAccount, Session,
|
||||
|
@ -1199,12 +1230,13 @@ mod test {
|
|||
let (account, store, _dir) = get_loaded_store().await;
|
||||
|
||||
let id = Uuid::new_v4();
|
||||
let info = RequestedKeyInfo::new(
|
||||
let info: SecretInfo = RequestedKeyInfo::new(
|
||||
EventEncryptionAlgorithm::MegolmV1AesSha2,
|
||||
room_id!("!test:localhost"),
|
||||
"test_sender_key".to_string(),
|
||||
"test_session_id".to_string(),
|
||||
);
|
||||
)
|
||||
.into();
|
||||
|
||||
let request = OutgoingKeyRequest {
|
||||
request_recipient: account.user_id().to_owned(),
|
||||
|
@ -1213,7 +1245,7 @@ mod test {
|
|||
sent_out: false,
|
||||
};
|
||||
|
||||
assert!(store.get_outgoing_key_request(id).await.unwrap().is_none());
|
||||
assert!(store.get_outgoing_secret_requests(id).await.unwrap().is_none());
|
||||
|
||||
let mut changes = Changes::default();
|
||||
changes.key_requests.push(request.clone());
|
||||
|
@ -1221,12 +1253,12 @@ mod test {
|
|||
|
||||
let request = Some(request);
|
||||
|
||||
let stored_request = store.get_outgoing_key_request(id).await.unwrap();
|
||||
let stored_request = store.get_outgoing_secret_requests(id).await.unwrap();
|
||||
assert_eq!(request, stored_request);
|
||||
|
||||
let stored_request = store.get_key_request_by_info(&info).await.unwrap();
|
||||
let stored_request = store.get_secret_request_by_info(&info).await.unwrap();
|
||||
assert_eq!(request, stored_request);
|
||||
assert!(!store.get_unsent_key_requests().await.unwrap().is_empty());
|
||||
assert!(!store.get_unsent_secret_requests().await.unwrap().is_empty());
|
||||
|
||||
let request = OutgoingKeyRequest {
|
||||
request_recipient: account.user_id().to_owned(),
|
||||
|
@ -1239,17 +1271,17 @@ mod test {
|
|||
changes.key_requests.push(request.clone());
|
||||
store.save_changes(changes).await.unwrap();
|
||||
|
||||
assert!(store.get_unsent_key_requests().await.unwrap().is_empty());
|
||||
let stored_request = store.get_outgoing_key_request(id).await.unwrap();
|
||||
assert!(store.get_unsent_secret_requests().await.unwrap().is_empty());
|
||||
let stored_request = store.get_outgoing_secret_requests(id).await.unwrap();
|
||||
assert_eq!(Some(request), stored_request);
|
||||
|
||||
store.delete_outgoing_key_request(id).await.unwrap();
|
||||
store.delete_outgoing_secret_requests(id).await.unwrap();
|
||||
|
||||
let stored_request = store.get_outgoing_key_request(id).await.unwrap();
|
||||
let stored_request = store.get_outgoing_secret_requests(id).await.unwrap();
|
||||
assert_eq!(None, stored_request);
|
||||
|
||||
let stored_request = store.get_key_request_by_info(&info).await.unwrap();
|
||||
let stored_request = store.get_secret_request_by_info(&info).await.unwrap();
|
||||
assert_eq!(None, stored_request);
|
||||
assert!(store.get_unsent_key_requests().await.unwrap().is_empty());
|
||||
assert!(store.get_unsent_secret_requests().await.unwrap().is_empty());
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue