crypto: Add specialized methods to store outgoing key requests
This commit is contained in:
parent
5637ca3080
commit
02331fa325
5 changed files with 327 additions and 187 deletions
|
@ -42,7 +42,7 @@ use crate::{
|
|||
error::{OlmError, OlmResult},
|
||||
olm::{InboundGroupSession, OutboundGroupSession, Session, ShareState},
|
||||
requests::{OutgoingRequest, ToDeviceRequest},
|
||||
store::{CryptoStoreError, Store},
|
||||
store::{Changes, CryptoStoreError, Store},
|
||||
Device,
|
||||
};
|
||||
|
||||
|
@ -137,32 +137,24 @@ pub(crate) struct KeyRequestMachine {
|
|||
users_for_key_claim: Arc<DashMap<UserId, DashSet<DeviceIdBox>>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct OugoingKeyInfo {
|
||||
request_id: Uuid,
|
||||
info: RequestedKeyInfo,
|
||||
sent_out: bool,
|
||||
/// A struct describing an outgoing key request.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OutgoingKeyRequest {
|
||||
/// The unique id of the key request.
|
||||
pub request_id: Uuid,
|
||||
/// The info of the requested key.
|
||||
pub info: RequestedKeyInfo,
|
||||
/// Has the request been sent out.
|
||||
pub sent_out: bool,
|
||||
}
|
||||
|
||||
trait Encode {
|
||||
fn encode(&self) -> String;
|
||||
}
|
||||
|
||||
impl Encode for RequestedKeyInfo {
|
||||
fn encode(&self) -> String {
|
||||
format!(
|
||||
"{}|{}|{}|{}",
|
||||
self.sender_key, self.room_id, self.session_id, self.algorithm
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl Encode for ForwardedRoomKeyToDeviceEventContent {
|
||||
fn encode(&self) -> String {
|
||||
format!(
|
||||
"{}|{}|{}|{}",
|
||||
self.sender_key, self.room_id, self.session_id, self.algorithm
|
||||
)
|
||||
impl PartialEq for OutgoingKeyRequest {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.request_id == other.request_id
|
||||
&& self.info.algorithm == other.info.algorithm
|
||||
&& self.info.room_id == other.info.room_id
|
||||
&& self.info.session_id == other.info.session_id
|
||||
&& self.info.sender_key == other.info.sender_key
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -246,6 +238,7 @@ impl KeyRequestMachine {
|
|||
/// key request queue.
|
||||
pub async fn collect_incoming_key_requests(&self) -> OlmResult<Vec<Session>> {
|
||||
let mut changed_sessions = Vec::new();
|
||||
|
||||
for item in self.incoming_key_requests.iter() {
|
||||
let event = item.value();
|
||||
if let Some(s) = self.handle_key_request(event).await? {
|
||||
|
@ -534,9 +527,9 @@ impl KeyRequestMachine {
|
|||
session_id: session_id.to_owned(),
|
||||
};
|
||||
|
||||
let id: Option<String> = self.store.get_object(&key_info.encode()).await?;
|
||||
let request = self.store.get_key_request_by_info(&key_info).await?;
|
||||
|
||||
if id.is_some() {
|
||||
if request.is_some() {
|
||||
// We already sent out a request for this key, nothing to do.
|
||||
return Ok(());
|
||||
}
|
||||
|
@ -554,13 +547,13 @@ impl KeyRequestMachine {
|
|||
|
||||
let request = wrap_key_request_content(self.user_id().clone(), id, &content)?;
|
||||
|
||||
let info = OugoingKeyInfo {
|
||||
let info = OutgoingKeyRequest {
|
||||
request_id: id,
|
||||
info: content.body.unwrap(),
|
||||
sent_out: false,
|
||||
};
|
||||
|
||||
self.save_outgoing_key_info(id, info).await?;
|
||||
self.save_outgoing_key_info(info).await?;
|
||||
self.outgoing_to_device_requests.insert(id, request);
|
||||
|
||||
Ok(())
|
||||
|
@ -569,16 +562,11 @@ impl KeyRequestMachine {
|
|||
/// Save an outgoing key info.
|
||||
async fn save_outgoing_key_info(
|
||||
&self,
|
||||
id: Uuid,
|
||||
info: OugoingKeyInfo,
|
||||
info: OutgoingKeyRequest,
|
||||
) -> Result<(), CryptoStoreError> {
|
||||
// TODO we'll want to use a transaction to store those atomically.
|
||||
// To allow this we'll need to rework our cryptostore trait to return
|
||||
// a transaction trait and the transaction trait will have the save_X
|
||||
// methods.
|
||||
let id_string = id.to_string();
|
||||
self.store.save_object(&id_string, &info).await?;
|
||||
self.store.save_object(&info.info.encode(), &id).await?;
|
||||
let mut changes = Changes::default();
|
||||
changes.key_requests.push(info);
|
||||
self.store.save_changes(changes).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -587,44 +575,43 @@ impl KeyRequestMachine {
|
|||
async fn get_key_info(
|
||||
&self,
|
||||
content: &ForwardedRoomKeyToDeviceEventContent,
|
||||
) -> Result<Option<OugoingKeyInfo>, CryptoStoreError> {
|
||||
let id: Option<Uuid> = self.store.get_object(&content.encode()).await?;
|
||||
) -> Result<Option<OutgoingKeyRequest>, CryptoStoreError> {
|
||||
let info = RequestedKeyInfo {
|
||||
algorithm: content.algorithm.clone(),
|
||||
room_id: content.room_id.clone(),
|
||||
sender_key: content.sender_key.clone(),
|
||||
session_id: content.session_id.clone(),
|
||||
};
|
||||
|
||||
if let Some(id) = id {
|
||||
self.store.get_object(&id.to_string()).await
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
self.store.get_key_request_by_info(&info).await
|
||||
}
|
||||
|
||||
/// Delete the given outgoing key info.
|
||||
async fn delete_key_info(&self, info: &OugoingKeyInfo) -> Result<(), CryptoStoreError> {
|
||||
async fn delete_key_info(&self, info: &OutgoingKeyRequest) -> Result<(), CryptoStoreError> {
|
||||
self.store
|
||||
.delete_object(&info.request_id.to_string())
|
||||
.await?;
|
||||
self.store.delete_object(&info.info.encode()).await?;
|
||||
|
||||
Ok(())
|
||||
.delete_outgoing_key_request(info.request_id)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Mark the outgoing request as sent.
|
||||
pub async fn mark_outgoing_request_as_sent(&self, id: &Uuid) -> Result<(), CryptoStoreError> {
|
||||
self.outgoing_to_device_requests.remove(id);
|
||||
let info: Option<OugoingKeyInfo> = self.store.get_object(&id.to_string()).await?;
|
||||
pub async fn mark_outgoing_request_as_sent(&self, id: Uuid) -> Result<(), CryptoStoreError> {
|
||||
let info = self.store.get_outgoing_key_request(id).await?;
|
||||
|
||||
if let Some(mut info) = info {
|
||||
trace!("Marking outgoing key request as sent {:#?}", info);
|
||||
info.sent_out = true;
|
||||
self.save_outgoing_key_info(*id, info).await?;
|
||||
self.save_outgoing_key_info(info).await?;
|
||||
}
|
||||
|
||||
self.outgoing_to_device_requests.remove(&id);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Mark the given outgoing key info as done.
|
||||
///
|
||||
/// This will queue up a request cancelation.
|
||||
async fn mark_as_done(&self, key_info: OugoingKeyInfo) -> Result<(), CryptoStoreError> {
|
||||
async fn mark_as_done(&self, key_info: OutgoingKeyRequest) -> Result<(), CryptoStoreError> {
|
||||
// TODO perhaps only remove the key info if the first known index is 0.
|
||||
trace!(
|
||||
"Successfully received a forwarded room key for {:#?}",
|
||||
|
@ -847,7 +834,7 @@ mod test {
|
|||
let id = request.request_id;
|
||||
drop(request);
|
||||
|
||||
machine.mark_outgoing_request_as_sent(&id).await.unwrap();
|
||||
machine.mark_outgoing_request_as_sent(id).await.unwrap();
|
||||
assert!(machine.outgoing_to_device_requests.is_empty());
|
||||
}
|
||||
|
||||
|
@ -873,7 +860,7 @@ mod test {
|
|||
let id = request.request_id;
|
||||
drop(request);
|
||||
|
||||
machine.mark_outgoing_request_as_sent(&id).await.unwrap();
|
||||
machine.mark_outgoing_request_as_sent(id).await.unwrap();
|
||||
|
||||
let export = session.export_at_index(10).await;
|
||||
|
||||
|
@ -915,7 +902,7 @@ mod test {
|
|||
let request = machine.outgoing_to_device_requests.iter().next().unwrap();
|
||||
let id = request.request_id;
|
||||
drop(request);
|
||||
machine.mark_outgoing_request_as_sent(&id).await.unwrap();
|
||||
machine.mark_outgoing_request_as_sent(id).await.unwrap();
|
||||
|
||||
machine
|
||||
.create_outgoing_key_request(
|
||||
|
@ -930,7 +917,7 @@ mod test {
|
|||
let id = request.request_id;
|
||||
drop(request);
|
||||
|
||||
machine.mark_outgoing_request_as_sent(&id).await.unwrap();
|
||||
machine.mark_outgoing_request_as_sent(id).await.unwrap();
|
||||
|
||||
let export = session.export_at_index(15).await;
|
||||
|
||||
|
@ -1148,7 +1135,7 @@ mod test {
|
|||
|
||||
drop(request);
|
||||
alice_machine
|
||||
.mark_outgoing_request_as_sent(&id)
|
||||
.mark_outgoing_request_as_sent(id)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
@ -1186,10 +1173,7 @@ mod test {
|
|||
let content: EncryptedEventContent = serde_json::from_str(content.get()).unwrap();
|
||||
|
||||
drop(request);
|
||||
bob_machine
|
||||
.mark_outgoing_request_as_sent(&id)
|
||||
.await
|
||||
.unwrap();
|
||||
bob_machine.mark_outgoing_request_as_sent(id).await.unwrap();
|
||||
|
||||
let event = ToDeviceEvent {
|
||||
sender: bob_id(),
|
||||
|
@ -1317,7 +1301,7 @@ mod test {
|
|||
|
||||
drop(request);
|
||||
alice_machine
|
||||
.mark_outgoing_request_as_sent(&id)
|
||||
.mark_outgoing_request_as_sent(id)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
@ -1378,10 +1362,7 @@ mod test {
|
|||
let content: EncryptedEventContent = serde_json::from_str(content.get()).unwrap();
|
||||
|
||||
drop(request);
|
||||
bob_machine
|
||||
.mark_outgoing_request_as_sent(&id)
|
||||
.await
|
||||
.unwrap();
|
||||
bob_machine.mark_outgoing_request_as_sent(id).await.unwrap();
|
||||
|
||||
let event = ToDeviceEvent {
|
||||
sender: bob_id(),
|
||||
|
|
|
@ -751,7 +751,7 @@ impl OlmMachine {
|
|||
async fn mark_to_device_request_as_sent(&self, request_id: &Uuid) -> StoreResult<()> {
|
||||
self.verification_machine.mark_request_as_sent(request_id);
|
||||
self.key_request_machine
|
||||
.mark_outgoing_request_as_sent(request_id)
|
||||
.mark_outgoing_request_as_sent(*request_id)
|
||||
.await?;
|
||||
self.group_session_manager
|
||||
.mark_request_as_sent(request_id)
|
||||
|
|
|
@ -20,8 +20,10 @@ use std::{
|
|||
use dashmap::{DashMap, DashSet};
|
||||
use matrix_sdk_common::{
|
||||
async_trait,
|
||||
events::room_key_request::RequestedKeyInfo,
|
||||
identifiers::{DeviceId, DeviceIdBox, RoomId, UserId},
|
||||
locks::Mutex,
|
||||
uuid::Uuid,
|
||||
};
|
||||
|
||||
use super::{
|
||||
|
@ -30,9 +32,17 @@ use super::{
|
|||
};
|
||||
use crate::{
|
||||
identities::{ReadOnlyDevice, UserIdentities},
|
||||
key_request::OutgoingKeyRequest,
|
||||
olm::{OutboundGroupSession, PrivateCrossSigningIdentity},
|
||||
};
|
||||
|
||||
fn encode_key_info(info: &RequestedKeyInfo) -> String {
|
||||
format!(
|
||||
"{}{}{}{}",
|
||||
&info.room_id, &info.sender_key, &info.algorithm, &info.session_id
|
||||
)
|
||||
}
|
||||
|
||||
/// An in-memory only store that will forget all the E2EE key once it's dropped.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MemoryStore {
|
||||
|
@ -43,7 +53,8 @@ pub struct MemoryStore {
|
|||
olm_hashes: Arc<DashMap<String, DashSet<String>>>,
|
||||
devices: DeviceStore,
|
||||
identities: Arc<DashMap<UserId, UserIdentities>>,
|
||||
values: Arc<DashMap<String, String>>,
|
||||
outgoing_key_requests: Arc<DashMap<Uuid, OutgoingKeyRequest>>,
|
||||
key_requests_by_info: Arc<DashMap<String, Uuid>>,
|
||||
}
|
||||
|
||||
impl Default for MemoryStore {
|
||||
|
@ -56,7 +67,8 @@ impl Default for MemoryStore {
|
|||
olm_hashes: Arc::new(DashMap::new()),
|
||||
devices: DeviceStore::new(),
|
||||
identities: Arc::new(DashMap::new()),
|
||||
values: Arc::new(DashMap::new()),
|
||||
outgoing_key_requests: Arc::new(DashMap::new()),
|
||||
key_requests_by_info: Arc::new(DashMap::new()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -103,6 +115,10 @@ impl CryptoStore for MemoryStore {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
async fn load_identity(&self) -> Result<Option<PrivateCrossSigningIdentity>> {
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
async fn save_changes(&self, mut changes: Changes) -> Result<()> {
|
||||
self.save_sessions(changes.sessions).await;
|
||||
self.save_inbound_group_sessions(changes.inbound_group_sessions)
|
||||
|
@ -130,6 +146,14 @@ impl CryptoStore for MemoryStore {
|
|||
.insert(hash.hash.clone());
|
||||
}
|
||||
|
||||
for key_request in changes.key_requests {
|
||||
let id = key_request.request_id;
|
||||
let info_string = encode_key_info(&key_request.info);
|
||||
|
||||
self.outgoing_key_requests.insert(id, key_request);
|
||||
self.key_requests_by_info.insert(info_string, id);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -152,9 +176,11 @@ impl CryptoStore for MemoryStore {
|
|||
Ok(self.inbound_group_sessions.get_all())
|
||||
}
|
||||
|
||||
fn users_for_key_query(&self) -> HashSet<UserId> {
|
||||
#[allow(clippy::map_clone)]
|
||||
self.users_for_key_query.iter().map(|u| u.clone()).collect()
|
||||
async fn get_outbound_group_sessions(
|
||||
&self,
|
||||
_: &RoomId,
|
||||
) -> Result<Option<OutboundGroupSession>> {
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
fn is_user_tracked(&self, user_id: &UserId) -> bool {
|
||||
|
@ -165,6 +191,11 @@ impl CryptoStore for MemoryStore {
|
|||
!self.users_for_key_query.is_empty()
|
||||
}
|
||||
|
||||
fn users_for_key_query(&self) -> HashSet<UserId> {
|
||||
#[allow(clippy::map_clone)]
|
||||
self.users_for_key_query.iter().map(|u| u.clone()).collect()
|
||||
}
|
||||
|
||||
async fn update_tracked_user(&self, user: &UserId, dirty: bool) -> Result<bool> {
|
||||
// TODO to prevent a race between the sync and a key query in flight we
|
||||
// need to have an additional state to mention that the user changed.
|
||||
|
@ -207,24 +238,6 @@ impl CryptoStore for MemoryStore {
|
|||
Ok(self.identities.get(user_id).map(|i| i.clone()))
|
||||
}
|
||||
|
||||
async fn save_value(&self, key: String, value: String) -> Result<()> {
|
||||
self.values.insert(key, value);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn remove_value(&self, key: &str) -> Result<()> {
|
||||
self.values.remove(key);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_value(&self, key: &str) -> Result<Option<String>> {
|
||||
Ok(self.values.get(key).map(|v| v.to_owned()))
|
||||
}
|
||||
|
||||
async fn load_identity(&self) -> Result<Option<PrivateCrossSigningIdentity>> {
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
async fn is_message_known(&self, message_hash: &crate::olm::OlmMessageHash) -> Result<bool> {
|
||||
Ok(self
|
||||
.olm_hashes
|
||||
|
@ -233,11 +246,37 @@ impl CryptoStore for MemoryStore {
|
|||
.contains(&message_hash.hash))
|
||||
}
|
||||
|
||||
async fn get_outbound_group_sessions(
|
||||
async fn get_outgoing_key_request(
|
||||
&self,
|
||||
_: &RoomId,
|
||||
) -> Result<Option<OutboundGroupSession>> {
|
||||
Ok(None)
|
||||
request_id: Uuid,
|
||||
) -> Result<Option<OutgoingKeyRequest>> {
|
||||
Ok(self
|
||||
.outgoing_key_requests
|
||||
.get(&request_id)
|
||||
.map(|r| r.clone()))
|
||||
}
|
||||
|
||||
async fn get_key_request_by_info(
|
||||
&self,
|
||||
key_info: &RequestedKeyInfo,
|
||||
) -> Result<Option<OutgoingKeyRequest>> {
|
||||
let key_info_string = encode_key_info(key_info);
|
||||
|
||||
Ok(self
|
||||
.key_requests_by_info
|
||||
.get(&key_info_string)
|
||||
.and_then(|i| self.outgoing_key_requests.get(&i).map(|r| r.clone())))
|
||||
}
|
||||
|
||||
async fn delete_outgoing_key_request(&self, request_id: Uuid) -> Result<()> {
|
||||
self.outgoing_key_requests
|
||||
.remove(&request_id)
|
||||
.and_then(|(_, i)| {
|
||||
let key_info_string = encode_key_info(&i.info);
|
||||
self.key_requests_by_info.remove(&key_info_string)
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -57,23 +57,25 @@ use std::{
|
|||
};
|
||||
|
||||
use olm_rs::errors::{OlmAccountError, OlmGroupSessionError, OlmSessionError};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Error as SerdeError;
|
||||
use thiserror::Error;
|
||||
|
||||
use matrix_sdk_common::{
|
||||
async_trait,
|
||||
events::room_key_request::RequestedKeyInfo,
|
||||
identifiers::{
|
||||
DeviceId, DeviceIdBox, DeviceKeyAlgorithm, Error as IdentifierValidationError, RoomId,
|
||||
UserId,
|
||||
},
|
||||
locks::Mutex,
|
||||
uuid::Uuid,
|
||||
AsyncTraitDeps,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
error::SessionUnpicklingError,
|
||||
identities::{Device, ReadOnlyDevice, UserDevices, UserIdentities},
|
||||
key_request::OutgoingKeyRequest,
|
||||
olm::{
|
||||
InboundGroupSession, OlmMessageHash, OutboundGroupSession, PrivateCrossSigningIdentity,
|
||||
ReadOnlyAccount, Session,
|
||||
|
@ -108,6 +110,7 @@ pub struct Changes {
|
|||
pub inbound_group_sessions: Vec<InboundGroupSession>,
|
||||
pub outbound_group_sessions: Vec<OutboundGroupSession>,
|
||||
pub identities: IdentityChanges,
|
||||
pub key_requests: Vec<OutgoingKeyRequest>,
|
||||
pub devices: DeviceChanges,
|
||||
}
|
||||
|
||||
|
@ -257,24 +260,6 @@ impl Store {
|
|||
device_owner_identity,
|
||||
}))
|
||||
}
|
||||
|
||||
pub async fn get_object<V: for<'b> Deserialize<'b>>(&self, key: &str) -> Result<Option<V>> {
|
||||
if let Some(value) = self.get_value(key).await? {
|
||||
Ok(Some(serde_json::from_str(&value)?))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn save_object(&self, key: &str, value: &impl Serialize) -> Result<()> {
|
||||
let value = serde_json::to_string(value)?;
|
||||
self.save_value(key.to_owned(), value).await
|
||||
}
|
||||
|
||||
pub async fn delete_object(&self, key: &str) -> Result<()> {
|
||||
self.inner.remove_value(key).await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for Store {
|
||||
|
@ -438,15 +423,38 @@ pub trait CryptoStore: AsyncTraitDeps {
|
|||
/// * `user_id` - The user for which we should get the identity.
|
||||
async fn get_user_identity(&self, user_id: &UserId) -> Result<Option<UserIdentities>>;
|
||||
|
||||
/// Save a serializeable object in the store.
|
||||
async fn save_value(&self, key: String, value: String) -> Result<()>;
|
||||
|
||||
/// Remove a value from the store.
|
||||
async fn remove_value(&self, key: &str) -> Result<()>;
|
||||
|
||||
/// Load a serializeable object from the store.
|
||||
async fn get_value(&self, key: &str) -> Result<Option<String>>;
|
||||
|
||||
/// Check if a hash for an Olm message stored in the database.
|
||||
async fn is_message_known(&self, message_hash: &OlmMessageHash) -> Result<bool>;
|
||||
|
||||
/// Get an outoing key request that we created that matches the given
|
||||
/// request id.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `request_id` - The unique request id that identifies this outgoing key
|
||||
/// request.
|
||||
async fn get_outgoing_key_request(
|
||||
&self,
|
||||
request_id: Uuid,
|
||||
) -> Result<Option<OutgoingKeyRequest>>;
|
||||
|
||||
/// Get an outoing key request that we created that matches the given
|
||||
/// requested key info.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `key_info` - The key info of an outgoing key request.
|
||||
async fn get_key_request_by_info(
|
||||
&self,
|
||||
key_info: &RequestedKeyInfo,
|
||||
) -> Result<Option<OutgoingKeyRequest>>;
|
||||
|
||||
/// Delete an outoing key request that we created that matches the given
|
||||
/// request id.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `request_id` - The unique request id that identifies this outgoing key
|
||||
/// request.
|
||||
async fn delete_outgoing_key_request(&self, request_id: Uuid) -> Result<()>;
|
||||
}
|
||||
|
|
|
@ -29,9 +29,12 @@ use sled::{
|
|||
|
||||
use matrix_sdk_common::{
|
||||
async_trait,
|
||||
events::room_key_request::RequestedKeyInfo,
|
||||
identifiers::{DeviceId, DeviceIdBox, RoomId, UserId},
|
||||
locks::Mutex,
|
||||
uuid,
|
||||
};
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::{
|
||||
caches::SessionStore, Changes, CryptoStore, CryptoStoreError, InboundGroupSession, PickleKey,
|
||||
|
@ -39,6 +42,7 @@ use super::{
|
|||
};
|
||||
use crate::{
|
||||
identities::{ReadOnlyDevice, UserIdentities},
|
||||
key_request::OutgoingKeyRequest,
|
||||
olm::{OutboundGroupSession, PickledInboundGroupSession, PrivateCrossSigningIdentity},
|
||||
};
|
||||
|
||||
|
@ -51,6 +55,28 @@ trait EncodeKey {
|
|||
fn encode(&self) -> Vec<u8>;
|
||||
}
|
||||
|
||||
impl EncodeKey for Uuid {
|
||||
fn encode(&self) -> Vec<u8> {
|
||||
self.as_u128().to_be_bytes().to_vec()
|
||||
}
|
||||
}
|
||||
|
||||
impl EncodeKey for &RequestedKeyInfo {
|
||||
fn encode(&self) -> Vec<u8> {
|
||||
[
|
||||
self.room_id.as_bytes(),
|
||||
&[Self::SEPARATOR],
|
||||
self.sender_key.as_bytes(),
|
||||
&[Self::SEPARATOR],
|
||||
self.algorithm.as_ref().as_bytes(),
|
||||
&[Self::SEPARATOR],
|
||||
self.session_id.as_bytes(),
|
||||
&[Self::SEPARATOR],
|
||||
]
|
||||
.concat()
|
||||
}
|
||||
}
|
||||
|
||||
impl EncodeKey for &UserId {
|
||||
fn encode(&self) -> Vec<u8> {
|
||||
self.as_str().encode()
|
||||
|
@ -122,12 +148,14 @@ pub struct SledStore {
|
|||
inbound_group_sessions: Tree,
|
||||
outbound_group_sessions: Tree,
|
||||
|
||||
outgoing_key_requests: Tree,
|
||||
key_requests_by_info: Tree,
|
||||
|
||||
devices: Tree,
|
||||
identities: Tree,
|
||||
|
||||
tracked_users: Tree,
|
||||
users_for_key_query: Tree,
|
||||
values: Tree,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for SledStore {
|
||||
|
@ -178,13 +206,16 @@ impl SledStore {
|
|||
let sessions = db.open_tree("session")?;
|
||||
let inbound_group_sessions = db.open_tree("inbound_group_sessions")?;
|
||||
let outbound_group_sessions = db.open_tree("outbound_group_sessions")?;
|
||||
|
||||
let tracked_users = db.open_tree("tracked_users")?;
|
||||
let users_for_key_query = db.open_tree("users_for_key_query")?;
|
||||
let olm_hashes = db.open_tree("olm_hashes")?;
|
||||
|
||||
let devices = db.open_tree("devices")?;
|
||||
let identities = db.open_tree("identities")?;
|
||||
let values = db.open_tree("values")?;
|
||||
|
||||
let outgoing_key_requests = db.open_tree("outgoing_key_requests")?;
|
||||
let key_requests_by_info = db.open_tree("key_requests_by_info")?;
|
||||
|
||||
let session_cache = SessionStore::new();
|
||||
|
||||
|
@ -208,12 +239,13 @@ impl SledStore {
|
|||
users_for_key_query_cache: DashSet::new().into(),
|
||||
inbound_group_sessions,
|
||||
outbound_group_sessions,
|
||||
outgoing_key_requests,
|
||||
key_requests_by_info,
|
||||
devices,
|
||||
tracked_users,
|
||||
users_for_key_query,
|
||||
olm_hashes,
|
||||
identities,
|
||||
values,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -332,6 +364,7 @@ impl SledStore {
|
|||
|
||||
let identity_changes = changes.identities;
|
||||
let olm_hashes = changes.message_hashes;
|
||||
let key_requests = changes.key_requests;
|
||||
|
||||
let ret: Result<(), TransactionError<serde_json::Error>> = (
|
||||
&self.account,
|
||||
|
@ -342,6 +375,8 @@ impl SledStore {
|
|||
&self.inbound_group_sessions,
|
||||
&self.outbound_group_sessions,
|
||||
&self.olm_hashes,
|
||||
&self.outgoing_key_requests,
|
||||
&self.key_requests_by_info,
|
||||
)
|
||||
.transaction(
|
||||
|(
|
||||
|
@ -353,6 +388,8 @@ impl SledStore {
|
|||
inbound_sessions,
|
||||
outbound_sessions,
|
||||
hashes,
|
||||
outgoing_key_requests,
|
||||
key_requests_by_info,
|
||||
)| {
|
||||
if let Some(a) = &account_pickle {
|
||||
account.insert(
|
||||
|
@ -420,6 +457,19 @@ impl SledStore {
|
|||
)?;
|
||||
}
|
||||
|
||||
for key_request in &key_requests {
|
||||
key_requests_by_info.insert(
|
||||
(&key_request.info).encode(),
|
||||
key_request.request_id.encode(),
|
||||
)?;
|
||||
|
||||
outgoing_key_requests.insert(
|
||||
key_request.request_id.encode(),
|
||||
serde_json::to_vec(&key_request)
|
||||
.map_err(ConflictableTransactionError::Abort)?,
|
||||
)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
},
|
||||
);
|
||||
|
@ -472,6 +522,19 @@ impl CryptoStore for SledStore {
|
|||
self.save_changes(changes).await
|
||||
}
|
||||
|
||||
async fn load_identity(&self) -> Result<Option<PrivateCrossSigningIdentity>> {
|
||||
if let Some(i) = self.private_identity.get("identity".encode())? {
|
||||
let pickle = serde_json::from_slice(&i)?;
|
||||
Ok(Some(
|
||||
PrivateCrossSigningIdentity::from_pickle(pickle, self.get_pickle_key())
|
||||
.await
|
||||
.map_err(|_| CryptoStoreError::UnpicklingError)?,
|
||||
))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
async fn save_changes(&self, changes: Changes) -> Result<()> {
|
||||
self.save_changes(changes).await
|
||||
}
|
||||
|
@ -539,12 +602,11 @@ impl CryptoStore for SledStore {
|
|||
.collect())
|
||||
}
|
||||
|
||||
fn users_for_key_query(&self) -> HashSet<UserId> {
|
||||
#[allow(clippy::map_clone)]
|
||||
self.users_for_key_query_cache
|
||||
.iter()
|
||||
.map(|u| u.clone())
|
||||
.collect()
|
||||
async fn get_outbound_group_sessions(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
) -> Result<Option<OutboundGroupSession>> {
|
||||
self.load_outbound_group_session(room_id).await
|
||||
}
|
||||
|
||||
fn is_user_tracked(&self, user_id: &UserId) -> bool {
|
||||
|
@ -555,6 +617,14 @@ impl CryptoStore for SledStore {
|
|||
!self.users_for_key_query_cache.is_empty()
|
||||
}
|
||||
|
||||
fn users_for_key_query(&self) -> HashSet<UserId> {
|
||||
#[allow(clippy::map_clone)]
|
||||
self.users_for_key_query_cache
|
||||
.iter()
|
||||
.map(|u| u.clone())
|
||||
.collect()
|
||||
}
|
||||
|
||||
async fn update_tracked_user(&self, user: &UserId, dirty: bool) -> Result<bool> {
|
||||
let already_added = self.tracked_users_cache.insert(user.clone());
|
||||
|
||||
|
@ -605,48 +675,62 @@ impl CryptoStore for SledStore {
|
|||
.transpose()?)
|
||||
}
|
||||
|
||||
async fn save_value(&self, key: String, value: String) -> Result<()> {
|
||||
self.values.insert(key.as_str().encode(), value.as_str())?;
|
||||
self.inner.flush_async().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn remove_value(&self, key: &str) -> Result<()> {
|
||||
self.values.remove(key.encode())?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_value(&self, key: &str) -> Result<Option<String>> {
|
||||
Ok(self
|
||||
.values
|
||||
.get(key.encode())?
|
||||
.map(|v| String::from_utf8_lossy(&v).to_string()))
|
||||
}
|
||||
|
||||
async fn load_identity(&self) -> Result<Option<PrivateCrossSigningIdentity>> {
|
||||
if let Some(i) = self.private_identity.get("identity".encode())? {
|
||||
let pickle = serde_json::from_slice(&i)?;
|
||||
Ok(Some(
|
||||
PrivateCrossSigningIdentity::from_pickle(pickle, self.get_pickle_key())
|
||||
.await
|
||||
.map_err(|_| CryptoStoreError::UnpicklingError)?,
|
||||
))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
async fn is_message_known(&self, message_hash: &crate::olm::OlmMessageHash) -> Result<bool> {
|
||||
Ok(self
|
||||
.olm_hashes
|
||||
.contains_key(serde_json::to_vec(message_hash)?)?)
|
||||
}
|
||||
|
||||
async fn get_outbound_group_sessions(
|
||||
async fn get_outgoing_key_request(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
) -> Result<Option<OutboundGroupSession>> {
|
||||
self.load_outbound_group_session(room_id).await
|
||||
request_id: Uuid,
|
||||
) -> Result<Option<OutgoingKeyRequest>> {
|
||||
Ok(self
|
||||
.outgoing_key_requests
|
||||
.get(request_id.encode())?
|
||||
.map(|r| serde_json::from_slice(&r))
|
||||
.transpose()?)
|
||||
}
|
||||
|
||||
async fn get_key_request_by_info(
|
||||
&self,
|
||||
key_info: &RequestedKeyInfo,
|
||||
) -> Result<Option<OutgoingKeyRequest>> {
|
||||
let id = self.key_requests_by_info.get(key_info.encode())?;
|
||||
|
||||
if let Some(id) = id {
|
||||
Ok(self
|
||||
.outgoing_key_requests
|
||||
.get(id)?
|
||||
.map(|r| serde_json::from_slice(&r))
|
||||
.transpose()?)
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
async fn delete_outgoing_key_request(&self, request_id: Uuid) -> Result<()> {
|
||||
let ret: Result<(), TransactionError<serde_json::Error>> =
|
||||
(&self.outgoing_key_requests, &self.key_requests_by_info).transaction(
|
||||
|(outgoing_key_requests, key_requests_by_info)| {
|
||||
let request: Option<OutgoingKeyRequest> = outgoing_key_requests
|
||||
.remove(request_id.encode())?
|
||||
.map(|r| serde_json::from_slice(&r))
|
||||
.transpose()
|
||||
.map_err(ConflictableTransactionError::Abort)?;
|
||||
|
||||
if let Some(request) = request {
|
||||
key_requests_by_info.remove((&request.info).encode())?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
},
|
||||
);
|
||||
|
||||
ret?;
|
||||
self.inner.flush_async().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -665,14 +749,16 @@ mod test {
|
|||
};
|
||||
use matrix_sdk_common::{
|
||||
api::r0::keys::SignedKey,
|
||||
identifiers::{room_id, user_id, DeviceId, UserId},
|
||||
events::room_key_request::RequestedKeyInfo,
|
||||
identifiers::{room_id, user_id, DeviceId, EventEncryptionAlgorithm, UserId},
|
||||
uuid::Uuid,
|
||||
};
|
||||
use matrix_sdk_test::async_test;
|
||||
use olm_rs::outbound_group_session::OlmOutboundGroupSession;
|
||||
use std::collections::BTreeMap;
|
||||
use tempfile::tempdir;
|
||||
|
||||
use super::{CryptoStore, SledStore};
|
||||
use super::{CryptoStore, OutgoingKeyRequest, SledStore};
|
||||
|
||||
fn alice_id() -> UserId {
|
||||
user_id!("@alice:example.org")
|
||||
|
@ -1184,21 +1270,6 @@ mod test {
|
|||
assert_eq!(identity.user_id(), loaded_identity.user_id());
|
||||
}
|
||||
|
||||
#[async_test]
|
||||
async fn key_value_saving() {
|
||||
let (_, store, _dir) = get_loaded_store().await;
|
||||
let key = "test_key".to_string();
|
||||
let value = "secret value".to_string();
|
||||
|
||||
store.save_value(key.clone(), value.clone()).await.unwrap();
|
||||
let stored_value = store.get_value(&key).await.unwrap().unwrap();
|
||||
|
||||
assert_eq!(value, stored_value);
|
||||
|
||||
store.remove_value(&key).await.unwrap();
|
||||
assert!(store.get_value(&key).await.unwrap().is_none());
|
||||
}
|
||||
|
||||
#[async_test]
|
||||
async fn olm_hash_saving() {
|
||||
let (_, store, _dir) = get_loaded_store().await;
|
||||
|
@ -1215,4 +1286,45 @@ mod test {
|
|||
store.save_changes(changes).await.unwrap();
|
||||
assert!(store.is_message_known(&hash).await.unwrap());
|
||||
}
|
||||
|
||||
#[async_test]
|
||||
async fn key_request_saving() {
|
||||
let (_, store, _dir) = get_loaded_store().await;
|
||||
|
||||
let id = Uuid::new_v4();
|
||||
let info = RequestedKeyInfo {
|
||||
algorithm: EventEncryptionAlgorithm::MegolmV1AesSha2,
|
||||
room_id: room_id!("!test:localhost"),
|
||||
sender_key: "test_sender_key".to_string(),
|
||||
session_id: "test_session_id".to_string(),
|
||||
};
|
||||
|
||||
let request = OutgoingKeyRequest {
|
||||
request_id: id,
|
||||
info: info.clone(),
|
||||
sent_out: false,
|
||||
};
|
||||
|
||||
assert!(store.get_outgoing_key_request(id).await.unwrap().is_none());
|
||||
|
||||
let mut changes = Changes::default();
|
||||
changes.key_requests.push(request.clone());
|
||||
store.save_changes(changes).await.unwrap();
|
||||
|
||||
let request = Some(request);
|
||||
|
||||
let stored_request = store.get_outgoing_key_request(id).await.unwrap();
|
||||
assert_eq!(request, stored_request);
|
||||
|
||||
let stored_request = store.get_key_request_by_info(&info).await.unwrap();
|
||||
assert_eq!(request, stored_request);
|
||||
|
||||
store.delete_outgoing_key_request(id).await.unwrap();
|
||||
|
||||
let stored_request = store.get_outgoing_key_request(id).await.unwrap();
|
||||
assert_eq!(None, stored_request);
|
||||
|
||||
let stored_request = store.get_key_request_by_info(&info).await.unwrap();
|
||||
assert_eq!(None, stored_request);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue