Merge branch 'key-share-improvements'
commit
324a0aafca
|
@ -1839,7 +1839,16 @@ impl Client {
|
||||||
warn!("Error while claiming one-time keys {:?}", e);
|
warn!("Error while claiming one-time keys {:?}", e);
|
||||||
}
|
}
|
||||||
|
|
||||||
for r in self.base_client.outgoing_requests().await {
|
// TODO we should probably abort if we get an cryptostore error here
|
||||||
|
let outgoing_requests = match self.base_client.outgoing_requests().await {
|
||||||
|
Ok(r) => r,
|
||||||
|
Err(e) => {
|
||||||
|
warn!("Could not fetch the outgoing requests {:?}", e);
|
||||||
|
vec![]
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
for r in outgoing_requests {
|
||||||
match r.request() {
|
match r.request() {
|
||||||
OutgoingRequests::KeysQuery(request) => {
|
OutgoingRequests::KeysQuery(request) => {
|
||||||
if let Err(e) = self
|
if let Err(e) = self
|
||||||
|
|
|
@ -1085,12 +1085,12 @@ impl BaseClient {
|
||||||
/// [`mark_request_as_sent`]: #method.mark_request_as_sent
|
/// [`mark_request_as_sent`]: #method.mark_request_as_sent
|
||||||
#[cfg(feature = "encryption")]
|
#[cfg(feature = "encryption")]
|
||||||
#[cfg_attr(feature = "docs", doc(cfg(encryption)))]
|
#[cfg_attr(feature = "docs", doc(cfg(encryption)))]
|
||||||
pub async fn outgoing_requests(&self) -> Vec<OutgoingRequest> {
|
pub async fn outgoing_requests(&self) -> Result<Vec<OutgoingRequest>, CryptoStoreError> {
|
||||||
let olm = self.olm.lock().await;
|
let olm = self.olm.lock().await;
|
||||||
|
|
||||||
match &*olm {
|
match &*olm {
|
||||||
Some(o) => o.outgoing_requests().await,
|
Some(o) => o.outgoing_requests().await,
|
||||||
None => vec![],
|
None => Ok(vec![]),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -215,7 +215,7 @@ pub fn room_key_sharing(c: &mut Criterion) {
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
assert!(requests.len() >= 8);
|
assert!(!requests.is_empty());
|
||||||
|
|
||||||
for request in requests {
|
for request in requests {
|
||||||
machine
|
machine
|
||||||
|
@ -251,7 +251,7 @@ pub fn room_key_sharing(c: &mut Criterion) {
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
assert!(requests.len() >= 8);
|
assert!(!requests.is_empty());
|
||||||
|
|
||||||
for request in requests {
|
for request in requests {
|
||||||
machine
|
machine
|
||||||
|
|
|
@ -258,6 +258,14 @@ impl UserDevices {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns true if there is at least one devices of this user that is
|
||||||
|
/// considered to be verified, false otherwise.
|
||||||
|
pub fn is_any_verified(&self) -> bool {
|
||||||
|
self.inner
|
||||||
|
.values()
|
||||||
|
.any(|d| d.trust_state(&self.own_identity, &self.device_owner_identity))
|
||||||
|
}
|
||||||
|
|
||||||
/// Iterator over all the device ids of the user devices.
|
/// Iterator over all the device ids of the user devices.
|
||||||
pub fn keys(&self) -> impl Iterator<Item = &DeviceIdBox> {
|
pub fn keys(&self) -> impl Iterator<Item = &DeviceIdBox> {
|
||||||
self.inner.keys()
|
self.inner.keys()
|
||||||
|
|
|
@ -40,9 +40,10 @@ use matrix_sdk_common::{
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
error::{OlmError, OlmResult},
|
error::{OlmError, OlmResult},
|
||||||
olm::{InboundGroupSession, OutboundGroupSession, Session, ShareState},
|
olm::{InboundGroupSession, Session, ShareState},
|
||||||
requests::{OutgoingRequest, ToDeviceRequest},
|
requests::{OutgoingRequest, ToDeviceRequest},
|
||||||
store::{CryptoStoreError, Store},
|
session_manager::GroupSessionCache,
|
||||||
|
store::{Changes, CryptoStoreError, Store},
|
||||||
Device,
|
Device,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -128,7 +129,7 @@ pub(crate) struct KeyRequestMachine {
|
||||||
user_id: Arc<UserId>,
|
user_id: Arc<UserId>,
|
||||||
device_id: Arc<DeviceIdBox>,
|
device_id: Arc<DeviceIdBox>,
|
||||||
store: Store,
|
store: Store,
|
||||||
outbound_group_sessions: Arc<DashMap<RoomId, OutboundGroupSession>>,
|
outbound_group_sessions: GroupSessionCache,
|
||||||
outgoing_to_device_requests: Arc<DashMap<Uuid, OutgoingRequest>>,
|
outgoing_to_device_requests: Arc<DashMap<Uuid, OutgoingRequest>>,
|
||||||
incoming_key_requests: Arc<
|
incoming_key_requests: Arc<
|
||||||
DashMap<(UserId, DeviceIdBox, String), ToDeviceEvent<RoomKeyRequestToDeviceEventContent>>,
|
DashMap<(UserId, DeviceIdBox, String), ToDeviceEvent<RoomKeyRequestToDeviceEventContent>>,
|
||||||
|
@ -137,32 +138,54 @@ pub(crate) struct KeyRequestMachine {
|
||||||
users_for_key_claim: Arc<DashMap<UserId, DashSet<DeviceIdBox>>>,
|
users_for_key_claim: Arc<DashMap<UserId, DashSet<DeviceIdBox>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
/// A struct describing an outgoing key request.
|
||||||
struct OugoingKeyInfo {
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
request_id: Uuid,
|
pub struct OutgoingKeyRequest {
|
||||||
info: RequestedKeyInfo,
|
/// The user we requested the key from
|
||||||
sent_out: bool,
|
pub request_recipient: UserId,
|
||||||
|
/// The unique id of the key request.
|
||||||
|
pub request_id: Uuid,
|
||||||
|
/// The info of the requested key.
|
||||||
|
pub info: RequestedKeyInfo,
|
||||||
|
/// Has the request been sent out.
|
||||||
|
pub sent_out: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
trait Encode {
|
impl OutgoingKeyRequest {
|
||||||
fn encode(&self) -> String;
|
fn to_request(&self, own_device_id: &DeviceId) -> Result<OutgoingRequest, serde_json::Error> {
|
||||||
}
|
let content = RoomKeyRequestToDeviceEventContent {
|
||||||
|
action: Action::Request,
|
||||||
|
request_id: self.request_id.to_string(),
|
||||||
|
requesting_device_id: own_device_id.to_owned(),
|
||||||
|
body: Some(self.info.clone()),
|
||||||
|
};
|
||||||
|
|
||||||
impl Encode for RequestedKeyInfo {
|
wrap_key_request_content(self.request_recipient.clone(), self.request_id, &content)
|
||||||
fn encode(&self) -> String {
|
}
|
||||||
format!(
|
|
||||||
"{}|{}|{}|{}",
|
fn to_cancelation(
|
||||||
self.sender_key, self.room_id, self.session_id, self.algorithm
|
&self,
|
||||||
)
|
own_device_id: &DeviceId,
|
||||||
|
) -> Result<OutgoingRequest, serde_json::Error> {
|
||||||
|
let content = RoomKeyRequestToDeviceEventContent {
|
||||||
|
action: Action::CancelRequest,
|
||||||
|
request_id: self.request_id.to_string(),
|
||||||
|
requesting_device_id: own_device_id.to_owned(),
|
||||||
|
body: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let id = Uuid::new_v4();
|
||||||
|
wrap_key_request_content(self.request_recipient.clone(), id, &content)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Encode for ForwardedRoomKeyToDeviceEventContent {
|
impl PartialEq for OutgoingKeyRequest {
|
||||||
fn encode(&self) -> String {
|
fn eq(&self, other: &Self) -> bool {
|
||||||
format!(
|
self.request_id == other.request_id
|
||||||
"{}|{}|{}|{}",
|
&& self.info.algorithm == other.info.algorithm
|
||||||
self.sender_key, self.room_id, self.session_id, self.algorithm
|
&& self.info.room_id == other.info.room_id
|
||||||
)
|
&& self.info.session_id == other.info.session_id
|
||||||
|
&& self.info.sender_key == other.info.sender_key
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -196,7 +219,7 @@ impl KeyRequestMachine {
|
||||||
user_id: Arc<UserId>,
|
user_id: Arc<UserId>,
|
||||||
device_id: Arc<DeviceIdBox>,
|
device_id: Arc<DeviceIdBox>,
|
||||||
store: Store,
|
store: Store,
|
||||||
outbound_group_sessions: Arc<DashMap<RoomId, OutboundGroupSession>>,
|
outbound_group_sessions: GroupSessionCache,
|
||||||
users_for_key_claim: Arc<DashMap<UserId, DashSet<DeviceIdBox>>>,
|
users_for_key_claim: Arc<DashMap<UserId, DashSet<DeviceIdBox>>>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
|
@ -204,13 +227,27 @@ impl KeyRequestMachine {
|
||||||
device_id,
|
device_id,
|
||||||
store,
|
store,
|
||||||
outbound_group_sessions,
|
outbound_group_sessions,
|
||||||
outgoing_to_device_requests: Arc::new(DashMap::new()),
|
outgoing_to_device_requests: DashMap::new().into(),
|
||||||
incoming_key_requests: Arc::new(DashMap::new()),
|
incoming_key_requests: DashMap::new().into(),
|
||||||
wait_queue: WaitQueue::new(),
|
wait_queue: WaitQueue::new(),
|
||||||
users_for_key_claim,
|
users_for_key_claim,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Load stored outgoing requests that were not yet sent out.
|
||||||
|
async fn load_outgoing_requests(&self) -> Result<Vec<OutgoingRequest>, CryptoStoreError> {
|
||||||
|
self.store
|
||||||
|
.get_unsent_key_requests()
|
||||||
|
.await?
|
||||||
|
.into_iter()
|
||||||
|
.filter(|i| !i.sent_out)
|
||||||
|
.map(|info| {
|
||||||
|
info.to_request(self.device_id())
|
||||||
|
.map_err(CryptoStoreError::from)
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
/// Our own user id.
|
/// Our own user id.
|
||||||
pub fn user_id(&self) -> &UserId {
|
pub fn user_id(&self) -> &UserId {
|
||||||
&self.user_id
|
&self.user_id
|
||||||
|
@ -221,12 +258,18 @@ impl KeyRequestMachine {
|
||||||
&self.device_id
|
&self.device_id
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn outgoing_to_device_requests(&self) -> Vec<OutgoingRequest> {
|
pub async fn outgoing_to_device_requests(
|
||||||
#[allow(clippy::map_clone)]
|
&self,
|
||||||
self.outgoing_to_device_requests
|
) -> Result<Vec<OutgoingRequest>, CryptoStoreError> {
|
||||||
|
let mut key_requests = self.load_outgoing_requests().await?;
|
||||||
|
let key_forwards: Vec<OutgoingRequest> = self
|
||||||
|
.outgoing_to_device_requests
|
||||||
.iter()
|
.iter()
|
||||||
.map(|r| (*r).clone())
|
.map(|i| i.value().clone())
|
||||||
.collect()
|
.collect();
|
||||||
|
key_requests.extend(key_forwards);
|
||||||
|
|
||||||
|
Ok(key_requests)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Receive a room key request event.
|
/// Receive a room key request event.
|
||||||
|
@ -246,6 +289,7 @@ impl KeyRequestMachine {
|
||||||
/// key request queue.
|
/// key request queue.
|
||||||
pub async fn collect_incoming_key_requests(&self) -> OlmResult<Vec<Session>> {
|
pub async fn collect_incoming_key_requests(&self) -> OlmResult<Vec<Session>> {
|
||||||
let mut changed_sessions = Vec::new();
|
let mut changed_sessions = Vec::new();
|
||||||
|
|
||||||
for item in self.incoming_key_requests.iter() {
|
for item in self.incoming_key_requests.iter() {
|
||||||
let event = item.value();
|
let event = item.value();
|
||||||
if let Some(s) = self.handle_key_request(event).await? {
|
if let Some(s) = self.handle_key_request(event).await? {
|
||||||
|
@ -363,12 +407,7 @@ impl KeyRequestMachine {
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
if let Some(device) = device {
|
if let Some(device) = device {
|
||||||
match self.should_share_session(
|
match self.should_share_key(&device, &session).await {
|
||||||
&device,
|
|
||||||
self.outbound_group_sessions
|
|
||||||
.get(&key_info.room_id)
|
|
||||||
.as_deref(),
|
|
||||||
) {
|
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
info!(
|
info!(
|
||||||
"Received a key request from {} {} that we won't serve: {}",
|
"Received a key request from {} {} that we won't serve: {}",
|
||||||
|
@ -469,33 +508,146 @@ impl KeyRequestMachine {
|
||||||
///
|
///
|
||||||
/// * `device` - The device that is requesting a session from us.
|
/// * `device` - The device that is requesting a session from us.
|
||||||
///
|
///
|
||||||
/// * `outbound_session` - If one still exists, the matching outbound
|
/// * `session` - The session that was requested to be shared.
|
||||||
/// session that was used to create the inbound session that is being
|
async fn should_share_key(
|
||||||
/// requested.
|
|
||||||
fn should_share_session(
|
|
||||||
&self,
|
&self,
|
||||||
device: &Device,
|
device: &Device,
|
||||||
outbound_session: Option<&OutboundGroupSession>,
|
session: &InboundGroupSession,
|
||||||
) -> Result<Option<u32>, KeyshareDecision> {
|
) -> Result<Option<u32>, KeyshareDecision> {
|
||||||
if device.user_id() == self.user_id() {
|
let outbound_session = self
|
||||||
|
.outbound_group_sessions
|
||||||
|
.get_with_id(session.room_id(), session.session_id())
|
||||||
|
.await
|
||||||
|
.ok()
|
||||||
|
.flatten();
|
||||||
|
|
||||||
|
let own_device_check = || {
|
||||||
if device.trust_state() {
|
if device.trust_state() {
|
||||||
Ok(None)
|
Ok(None)
|
||||||
} else {
|
} else {
|
||||||
Err(KeyshareDecision::UntrustedDevice)
|
Err(KeyshareDecision::UntrustedDevice)
|
||||||
}
|
}
|
||||||
} else if let Some(outbound) = outbound_session {
|
};
|
||||||
|
|
||||||
|
// If we have a matching outbound session we can check the list of
|
||||||
|
// users/devices that received the session, if it wasn't shared check if
|
||||||
|
// it's our own device and if it's trusted.
|
||||||
|
if let Some(outbound) = outbound_session {
|
||||||
if let ShareState::Shared(message_index) =
|
if let ShareState::Shared(message_index) =
|
||||||
outbound.is_shared_with(device.user_id(), device.device_id())
|
outbound.is_shared_with(device.user_id(), device.device_id())
|
||||||
{
|
{
|
||||||
Ok(Some(message_index))
|
Ok(Some(message_index))
|
||||||
|
} else if device.user_id() == self.user_id() {
|
||||||
|
own_device_check()
|
||||||
} else {
|
} else {
|
||||||
Err(KeyshareDecision::OutboundSessionNotShared)
|
Err(KeyshareDecision::OutboundSessionNotShared)
|
||||||
}
|
}
|
||||||
|
// Else just check if it's one of our own devices that requested the key and
|
||||||
|
// check if the device is trusted.
|
||||||
|
} else if device.user_id() == self.user_id() {
|
||||||
|
own_device_check()
|
||||||
|
// Otherwise, there's not enough info to decide if we can safely share
|
||||||
|
// the session.
|
||||||
} else {
|
} else {
|
||||||
Err(KeyshareDecision::MissingOutboundSession)
|
Err(KeyshareDecision::MissingOutboundSession)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Check if it's ok, or rather if it makes sense to automatically request
|
||||||
|
/// a key from our other devices.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `key_info` - The info of our key request containing information about
|
||||||
|
/// the key we wish to request.
|
||||||
|
async fn should_request_key(
|
||||||
|
&self,
|
||||||
|
key_info: &RequestedKeyInfo,
|
||||||
|
) -> Result<bool, CryptoStoreError> {
|
||||||
|
let request = self.store.get_key_request_by_info(&key_info).await?;
|
||||||
|
|
||||||
|
// Don't send out duplicate requests, users can re-request them if they
|
||||||
|
// think a second request might succeed.
|
||||||
|
if request.is_none() {
|
||||||
|
let devices = self.store.get_user_devices(self.user_id()).await?;
|
||||||
|
|
||||||
|
// Devices will only respond to key requests if the devices are
|
||||||
|
// verified, if the device isn't verified by us it's unlikely that
|
||||||
|
// we're verified by them either. Don't request keys if there isn't
|
||||||
|
// at least one verified device.
|
||||||
|
if devices.is_any_verified() {
|
||||||
|
Ok(true)
|
||||||
|
} else {
|
||||||
|
Ok(false)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Ok(false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a new outgoing key request for the key with the given session id.
|
||||||
|
///
|
||||||
|
/// This will queue up a new to-device request and store the key info so
|
||||||
|
/// once we receive a forwarded room key we can check that it matches the
|
||||||
|
/// key we requested.
|
||||||
|
///
|
||||||
|
/// This method will return a cancel request and a new key request if the
|
||||||
|
/// key was already requested, otherwise it will return just the key
|
||||||
|
/// request.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `room_id` - The id of the room where the key is used in.
|
||||||
|
///
|
||||||
|
/// * `sender_key` - The curve25519 key of the sender that owns the key.
|
||||||
|
///
|
||||||
|
/// * `session_id` - The id that uniquely identifies the session.
|
||||||
|
pub async fn request_key(
|
||||||
|
&self,
|
||||||
|
room_id: &RoomId,
|
||||||
|
sender_key: &str,
|
||||||
|
session_id: &str,
|
||||||
|
) -> Result<(Option<OutgoingRequest>, OutgoingRequest), CryptoStoreError> {
|
||||||
|
let key_info = RequestedKeyInfo {
|
||||||
|
algorithm: EventEncryptionAlgorithm::MegolmV1AesSha2,
|
||||||
|
room_id: room_id.to_owned(),
|
||||||
|
sender_key: sender_key.to_owned(),
|
||||||
|
session_id: session_id.to_owned(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let request = self.store.get_key_request_by_info(&key_info).await?;
|
||||||
|
|
||||||
|
if let Some(request) = request {
|
||||||
|
let cancel = request.to_cancelation(self.device_id())?;
|
||||||
|
let request = request.to_request(self.device_id())?;
|
||||||
|
|
||||||
|
Ok((Some(cancel), request))
|
||||||
|
} else {
|
||||||
|
let request = self.request_key_helper(key_info).await?;
|
||||||
|
|
||||||
|
Ok((None, request))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn request_key_helper(
|
||||||
|
&self,
|
||||||
|
key_info: RequestedKeyInfo,
|
||||||
|
) -> Result<OutgoingRequest, CryptoStoreError> {
|
||||||
|
info!("Creating new outgoing room key request {:#?}", key_info);
|
||||||
|
|
||||||
|
let request = OutgoingKeyRequest {
|
||||||
|
request_recipient: self.user_id().to_owned(),
|
||||||
|
request_id: Uuid::new_v4(),
|
||||||
|
info: key_info,
|
||||||
|
sent_out: false,
|
||||||
|
};
|
||||||
|
|
||||||
|
let outgoing_request = request.to_request(self.device_id())?;
|
||||||
|
self.save_outgoing_key_info(request).await?;
|
||||||
|
|
||||||
|
Ok(outgoing_request)
|
||||||
|
}
|
||||||
|
|
||||||
/// Create a new outgoing key request for the key with the given session id.
|
/// Create a new outgoing key request for the key with the given session id.
|
||||||
///
|
///
|
||||||
/// This will queue up a new to-device request and store the key info so
|
/// This will queue up a new to-device request and store the key info so
|
||||||
|
@ -523,51 +675,21 @@ impl KeyRequestMachine {
|
||||||
session_id: session_id.to_owned(),
|
session_id: session_id.to_owned(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let id: Option<String> = self.store.get_object(&key_info.encode()).await?;
|
if self.should_request_key(&key_info).await? {
|
||||||
|
self.request_key_helper(key_info).await?;
|
||||||
if id.is_some() {
|
|
||||||
// We already sent out a request for this key, nothing to do.
|
|
||||||
return Ok(());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
info!("Creating new outgoing room key request {:#?}", key_info);
|
|
||||||
|
|
||||||
let id = Uuid::new_v4();
|
|
||||||
|
|
||||||
let content = RoomKeyRequestToDeviceEventContent {
|
|
||||||
action: Action::Request,
|
|
||||||
request_id: id.to_string(),
|
|
||||||
requesting_device_id: (&*self.device_id).clone(),
|
|
||||||
body: Some(key_info),
|
|
||||||
};
|
|
||||||
|
|
||||||
let request = wrap_key_request_content(self.user_id().clone(), id, &content)?;
|
|
||||||
|
|
||||||
let info = OugoingKeyInfo {
|
|
||||||
request_id: id,
|
|
||||||
info: content.body.unwrap(),
|
|
||||||
sent_out: false,
|
|
||||||
};
|
|
||||||
|
|
||||||
self.save_outgoing_key_info(id, info).await?;
|
|
||||||
self.outgoing_to_device_requests.insert(id, request);
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Save an outgoing key info.
|
/// Save an outgoing key info.
|
||||||
async fn save_outgoing_key_info(
|
async fn save_outgoing_key_info(
|
||||||
&self,
|
&self,
|
||||||
id: Uuid,
|
info: OutgoingKeyRequest,
|
||||||
info: OugoingKeyInfo,
|
|
||||||
) -> Result<(), CryptoStoreError> {
|
) -> Result<(), CryptoStoreError> {
|
||||||
// TODO we'll want to use a transaction to store those atomically.
|
let mut changes = Changes::default();
|
||||||
// To allow this we'll need to rework our cryptostore trait to return
|
changes.key_requests.push(info);
|
||||||
// a transaction trait and the transaction trait will have the save_X
|
self.store.save_changes(changes).await?;
|
||||||
// methods.
|
|
||||||
let id_string = id.to_string();
|
|
||||||
self.store.save_object(&id_string, &info).await?;
|
|
||||||
self.store.save_object(&info.info.encode(), &id).await?;
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -576,44 +698,43 @@ impl KeyRequestMachine {
|
||||||
async fn get_key_info(
|
async fn get_key_info(
|
||||||
&self,
|
&self,
|
||||||
content: &ForwardedRoomKeyToDeviceEventContent,
|
content: &ForwardedRoomKeyToDeviceEventContent,
|
||||||
) -> Result<Option<OugoingKeyInfo>, CryptoStoreError> {
|
) -> Result<Option<OutgoingKeyRequest>, CryptoStoreError> {
|
||||||
let id: Option<Uuid> = self.store.get_object(&content.encode()).await?;
|
let info = RequestedKeyInfo {
|
||||||
|
algorithm: content.algorithm.clone(),
|
||||||
|
room_id: content.room_id.clone(),
|
||||||
|
sender_key: content.sender_key.clone(),
|
||||||
|
session_id: content.session_id.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
if let Some(id) = id {
|
self.store.get_key_request_by_info(&info).await
|
||||||
self.store.get_object(&id.to_string()).await
|
|
||||||
} else {
|
|
||||||
Ok(None)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Delete the given outgoing key info.
|
/// Delete the given outgoing key info.
|
||||||
async fn delete_key_info(&self, info: &OugoingKeyInfo) -> Result<(), CryptoStoreError> {
|
async fn delete_key_info(&self, info: &OutgoingKeyRequest) -> Result<(), CryptoStoreError> {
|
||||||
self.store
|
self.store
|
||||||
.delete_object(&info.request_id.to_string())
|
.delete_outgoing_key_request(info.request_id)
|
||||||
.await?;
|
.await
|
||||||
self.store.delete_object(&info.info.encode()).await?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Mark the outgoing request as sent.
|
/// Mark the outgoing request as sent.
|
||||||
pub async fn mark_outgoing_request_as_sent(&self, id: &Uuid) -> Result<(), CryptoStoreError> {
|
pub async fn mark_outgoing_request_as_sent(&self, id: Uuid) -> Result<(), CryptoStoreError> {
|
||||||
self.outgoing_to_device_requests.remove(id);
|
let info = self.store.get_outgoing_key_request(id).await?;
|
||||||
let info: Option<OugoingKeyInfo> = self.store.get_object(&id.to_string()).await?;
|
|
||||||
|
|
||||||
if let Some(mut info) = info {
|
if let Some(mut info) = info {
|
||||||
trace!("Marking outgoing key request as sent {:#?}", info);
|
trace!("Marking outgoing key request as sent {:#?}", info);
|
||||||
info.sent_out = true;
|
info.sent_out = true;
|
||||||
self.save_outgoing_key_info(*id, info).await?;
|
self.save_outgoing_key_info(info).await?;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
self.outgoing_to_device_requests.remove(&id);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Mark the given outgoing key info as done.
|
/// Mark the given outgoing key info as done.
|
||||||
///
|
///
|
||||||
/// This will queue up a request cancelation.
|
/// This will queue up a request cancelation.
|
||||||
async fn mark_as_done(&self, key_info: OugoingKeyInfo) -> Result<(), CryptoStoreError> {
|
async fn mark_as_done(&self, key_info: OutgoingKeyRequest) -> Result<(), CryptoStoreError> {
|
||||||
// TODO perhaps only remove the key info if the first known index is 0.
|
// TODO perhaps only remove the key info if the first known index is 0.
|
||||||
trace!(
|
trace!(
|
||||||
"Successfully received a forwarded room key for {:#?}",
|
"Successfully received a forwarded room key for {:#?}",
|
||||||
|
@ -626,18 +747,9 @@ impl KeyRequestMachine {
|
||||||
// can delete it in one transaction.
|
// can delete it in one transaction.
|
||||||
self.delete_key_info(&key_info).await?;
|
self.delete_key_info(&key_info).await?;
|
||||||
|
|
||||||
let content = RoomKeyRequestToDeviceEventContent {
|
let request = key_info.to_cancelation(self.device_id())?;
|
||||||
action: Action::CancelRequest,
|
self.outgoing_to_device_requests
|
||||||
request_id: key_info.request_id.to_string(),
|
.insert(request.request_id, request);
|
||||||
requesting_device_id: (&*self.device_id).clone(),
|
|
||||||
body: None,
|
|
||||||
};
|
|
||||||
|
|
||||||
let id = Uuid::new_v4();
|
|
||||||
|
|
||||||
let request = wrap_key_request_content(self.user_id().clone(), id, &content)?;
|
|
||||||
|
|
||||||
self.outgoing_to_device_requests.insert(id, request);
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -722,7 +834,8 @@ mod test {
|
||||||
use crate::{
|
use crate::{
|
||||||
identities::{LocalTrust, ReadOnlyDevice},
|
identities::{LocalTrust, ReadOnlyDevice},
|
||||||
olm::{Account, PrivateCrossSigningIdentity, ReadOnlyAccount},
|
olm::{Account, PrivateCrossSigningIdentity, ReadOnlyAccount},
|
||||||
store::{CryptoStore, MemoryStore, Store},
|
session_manager::GroupSessionCache,
|
||||||
|
store::{Changes, CryptoStore, MemoryStore, Store},
|
||||||
verification::VerificationMachine,
|
verification::VerificationMachine,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -744,6 +857,10 @@ mod test {
|
||||||
"ILMLKASTES".into()
|
"ILMLKASTES".into()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn alice2_device_id() -> DeviceIdBox {
|
||||||
|
"ILMLKASTES".into()
|
||||||
|
}
|
||||||
|
|
||||||
fn room_id() -> RoomId {
|
fn room_id() -> RoomId {
|
||||||
room_id!("!test:example.org")
|
room_id!("!test:example.org")
|
||||||
}
|
}
|
||||||
|
@ -756,6 +873,10 @@ mod test {
|
||||||
ReadOnlyAccount::new(&bob_id(), &bob_device_id())
|
ReadOnlyAccount::new(&bob_id(), &bob_device_id())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn alice_2_account() -> ReadOnlyAccount {
|
||||||
|
ReadOnlyAccount::new(&alice_id(), &alice2_device_id())
|
||||||
|
}
|
||||||
|
|
||||||
fn bob_machine() -> KeyRequestMachine {
|
fn bob_machine() -> KeyRequestMachine {
|
||||||
let user_id = Arc::new(bob_id());
|
let user_id = Arc::new(bob_id());
|
||||||
let account = ReadOnlyAccount::new(&user_id, &alice_device_id());
|
let account = ReadOnlyAccount::new(&user_id, &alice_device_id());
|
||||||
|
@ -763,12 +884,13 @@ mod test {
|
||||||
let identity = Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty(bob_id())));
|
let identity = Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty(bob_id())));
|
||||||
let verification = VerificationMachine::new(account, identity.clone(), store.clone());
|
let verification = VerificationMachine::new(account, identity.clone(), store.clone());
|
||||||
let store = Store::new(user_id.clone(), identity, store, verification);
|
let store = Store::new(user_id.clone(), identity, store, verification);
|
||||||
|
let session_cache = GroupSessionCache::new(store.clone());
|
||||||
|
|
||||||
KeyRequestMachine::new(
|
KeyRequestMachine::new(
|
||||||
user_id,
|
user_id,
|
||||||
Arc::new(bob_device_id()),
|
Arc::new(bob_device_id()),
|
||||||
store,
|
store,
|
||||||
Arc::new(DashMap::new()),
|
session_cache,
|
||||||
Arc::new(DashMap::new()),
|
Arc::new(DashMap::new()),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -782,12 +904,13 @@ mod test {
|
||||||
let verification = VerificationMachine::new(account, identity.clone(), store.clone());
|
let verification = VerificationMachine::new(account, identity.clone(), store.clone());
|
||||||
let store = Store::new(user_id.clone(), identity, store, verification);
|
let store = Store::new(user_id.clone(), identity, store, verification);
|
||||||
store.save_devices(&[device]).await.unwrap();
|
store.save_devices(&[device]).await.unwrap();
|
||||||
|
let session_cache = GroupSessionCache::new(store.clone());
|
||||||
|
|
||||||
KeyRequestMachine::new(
|
KeyRequestMachine::new(
|
||||||
user_id,
|
user_id,
|
||||||
Arc::new(alice_device_id()),
|
Arc::new(alice_device_id()),
|
||||||
store,
|
store,
|
||||||
Arc::new(DashMap::new()),
|
session_cache,
|
||||||
Arc::new(DashMap::new()),
|
Arc::new(DashMap::new()),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -796,11 +919,15 @@ mod test {
|
||||||
async fn create_machine() {
|
async fn create_machine() {
|
||||||
let machine = get_machine().await;
|
let machine = get_machine().await;
|
||||||
|
|
||||||
assert!(machine.outgoing_to_device_requests().is_empty());
|
assert!(machine
|
||||||
|
.outgoing_to_device_requests()
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.is_empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_test]
|
#[async_test]
|
||||||
async fn create_key_request() {
|
async fn re_request_keys() {
|
||||||
let machine = get_machine().await;
|
let machine = get_machine().await;
|
||||||
let account = account();
|
let account = account();
|
||||||
|
|
||||||
|
@ -809,7 +936,52 @@ mod test {
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
assert!(machine.outgoing_to_device_requests().is_empty());
|
assert!(machine
|
||||||
|
.outgoing_to_device_requests()
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.is_empty());
|
||||||
|
let (cancel, request) = machine
|
||||||
|
.request_key(session.room_id(), &session.sender_key, session.session_id())
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert!(cancel.is_none());
|
||||||
|
|
||||||
|
machine
|
||||||
|
.mark_outgoing_request_as_sent(request.request_id)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let (cancel, _) = machine
|
||||||
|
.request_key(session.room_id(), &session.sender_key, session.session_id())
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert!(cancel.is_some());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_test]
|
||||||
|
async fn create_key_request() {
|
||||||
|
let machine = get_machine().await;
|
||||||
|
let account = account();
|
||||||
|
let second_account = alice_2_account();
|
||||||
|
let alice_device = ReadOnlyDevice::from_account(&second_account).await;
|
||||||
|
|
||||||
|
// We need a trusted device, otherwise we won't request keys
|
||||||
|
alice_device.set_trust_state(LocalTrust::Verified);
|
||||||
|
machine.store.save_devices(&[alice_device]).await.unwrap();
|
||||||
|
|
||||||
|
let (_, session) = account
|
||||||
|
.create_group_session_pair_with_defaults(&room_id())
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert!(machine
|
||||||
|
.outgoing_to_device_requests()
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.is_empty());
|
||||||
machine
|
machine
|
||||||
.create_outgoing_key_request(
|
.create_outgoing_key_request(
|
||||||
session.room_id(),
|
session.room_id(),
|
||||||
|
@ -818,8 +990,15 @@ mod test {
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert!(!machine.outgoing_to_device_requests().is_empty());
|
assert!(!machine
|
||||||
assert_eq!(machine.outgoing_to_device_requests().len(), 1);
|
.outgoing_to_device_requests()
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.is_empty());
|
||||||
|
assert_eq!(
|
||||||
|
machine.outgoing_to_device_requests().await.unwrap().len(),
|
||||||
|
1
|
||||||
|
);
|
||||||
|
|
||||||
machine
|
machine
|
||||||
.create_outgoing_key_request(
|
.create_outgoing_key_request(
|
||||||
|
@ -829,15 +1008,21 @@ mod test {
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert_eq!(machine.outgoing_to_device_requests.len(), 1);
|
|
||||||
|
|
||||||
let request = machine.outgoing_to_device_requests.iter().next().unwrap();
|
let requests = machine.outgoing_to_device_requests().await.unwrap();
|
||||||
|
assert_eq!(requests.len(), 1);
|
||||||
|
|
||||||
let id = request.request_id;
|
let request = requests.get(0).unwrap();
|
||||||
drop(request);
|
|
||||||
|
|
||||||
machine.mark_outgoing_request_as_sent(&id).await.unwrap();
|
machine
|
||||||
assert!(machine.outgoing_to_device_requests.is_empty());
|
.mark_outgoing_request_as_sent(request.request_id)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert!(machine
|
||||||
|
.outgoing_to_device_requests()
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.is_empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_test]
|
#[async_test]
|
||||||
|
@ -845,6 +1030,13 @@ mod test {
|
||||||
let machine = get_machine().await;
|
let machine = get_machine().await;
|
||||||
let account = account();
|
let account = account();
|
||||||
|
|
||||||
|
let second_account = alice_2_account();
|
||||||
|
let alice_device = ReadOnlyDevice::from_account(&second_account).await;
|
||||||
|
|
||||||
|
// We need a trusted device, otherwise we won't request keys
|
||||||
|
alice_device.set_trust_state(LocalTrust::Verified);
|
||||||
|
machine.store.save_devices(&[alice_device]).await.unwrap();
|
||||||
|
|
||||||
let (_, session) = account
|
let (_, session) = account
|
||||||
.create_group_session_pair_with_defaults(&room_id())
|
.create_group_session_pair_with_defaults(&room_id())
|
||||||
.await
|
.await
|
||||||
|
@ -858,11 +1050,11 @@ mod test {
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let request = machine.outgoing_to_device_requests.iter().next().unwrap();
|
let requests = machine.outgoing_to_device_requests().await.unwrap();
|
||||||
|
let request = requests.get(0).unwrap();
|
||||||
let id = request.request_id;
|
let id = request.request_id;
|
||||||
drop(request);
|
|
||||||
|
|
||||||
machine.mark_outgoing_request_as_sent(&id).await.unwrap();
|
machine.mark_outgoing_request_as_sent(id).await.unwrap();
|
||||||
|
|
||||||
let export = session.export_at_index(10).await;
|
let export = session.export_at_index(10).await;
|
||||||
|
|
||||||
|
@ -904,7 +1096,7 @@ mod test {
|
||||||
let request = machine.outgoing_to_device_requests.iter().next().unwrap();
|
let request = machine.outgoing_to_device_requests.iter().next().unwrap();
|
||||||
let id = request.request_id;
|
let id = request.request_id;
|
||||||
drop(request);
|
drop(request);
|
||||||
machine.mark_outgoing_request_as_sent(&id).await.unwrap();
|
machine.mark_outgoing_request_as_sent(id).await.unwrap();
|
||||||
|
|
||||||
machine
|
machine
|
||||||
.create_outgoing_key_request(
|
.create_outgoing_key_request(
|
||||||
|
@ -915,11 +1107,13 @@ mod test {
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let request = machine.outgoing_to_device_requests.iter().next().unwrap();
|
let requests = machine.outgoing_to_device_requests().await.unwrap();
|
||||||
let id = request.request_id;
|
let request = &requests[0];
|
||||||
drop(request);
|
|
||||||
|
|
||||||
machine.mark_outgoing_request_as_sent(&id).await.unwrap();
|
machine
|
||||||
|
.mark_outgoing_request_as_sent(request.request_id)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let export = session.export_at_index(15).await;
|
let export = session.export_at_index(15).await;
|
||||||
|
|
||||||
|
@ -966,16 +1160,25 @@ mod test {
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
let (outbound, inbound) = account
|
||||||
|
.create_group_session_pair_with_defaults(&room_id())
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
// We don't share keys with untrusted devices.
|
// We don't share keys with untrusted devices.
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
machine
|
machine
|
||||||
.should_share_session(&own_device, None)
|
.should_share_key(&own_device, &inbound)
|
||||||
|
.await
|
||||||
.expect_err("Should not share with untrusted"),
|
.expect_err("Should not share with untrusted"),
|
||||||
KeyshareDecision::UntrustedDevice
|
KeyshareDecision::UntrustedDevice
|
||||||
);
|
);
|
||||||
own_device.set_trust_state(LocalTrust::Verified);
|
own_device.set_trust_state(LocalTrust::Verified);
|
||||||
// Now we do want to share the keys.
|
// Now we do want to share the keys.
|
||||||
assert!(machine.should_share_session(&own_device, None).is_ok());
|
assert!(machine
|
||||||
|
.should_share_key(&own_device, &inbound)
|
||||||
|
.await
|
||||||
|
.is_ok());
|
||||||
|
|
||||||
let bob_device = ReadOnlyDevice::from_account(&bob_account()).await;
|
let bob_device = ReadOnlyDevice::from_account(&bob_account()).await;
|
||||||
machine.store.save_devices(&[bob_device]).await.unwrap();
|
machine.store.save_devices(&[bob_device]).await.unwrap();
|
||||||
|
@ -991,21 +1194,25 @@ mod test {
|
||||||
// session was provided.
|
// session was provided.
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
machine
|
machine
|
||||||
.should_share_session(&bob_device, None)
|
.should_share_key(&bob_device, &inbound)
|
||||||
|
.await
|
||||||
.expect_err("Should not share with other."),
|
.expect_err("Should not share with other."),
|
||||||
KeyshareDecision::MissingOutboundSession
|
KeyshareDecision::MissingOutboundSession
|
||||||
);
|
);
|
||||||
|
|
||||||
let (session, _) = account
|
let mut changes = Changes::default();
|
||||||
.create_group_session_pair_with_defaults(&room_id())
|
|
||||||
.await
|
changes.outbound_group_sessions.push(outbound.clone());
|
||||||
.unwrap();
|
changes.inbound_group_sessions.push(inbound.clone());
|
||||||
|
machine.store.save_changes(changes).await.unwrap();
|
||||||
|
machine.outbound_group_sessions.insert(outbound.clone());
|
||||||
|
|
||||||
// We don't share sessions with other user's devices if the session
|
// We don't share sessions with other user's devices if the session
|
||||||
// wasn't shared in the first place.
|
// wasn't shared in the first place.
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
machine
|
machine
|
||||||
.should_share_session(&bob_device, Some(&session))
|
.should_share_key(&bob_device, &inbound)
|
||||||
|
.await
|
||||||
.expect_err("Should not share with other unless shared."),
|
.expect_err("Should not share with other unless shared."),
|
||||||
KeyshareDecision::OutboundSessionNotShared
|
KeyshareDecision::OutboundSessionNotShared
|
||||||
);
|
);
|
||||||
|
@ -1016,15 +1223,33 @@ mod test {
|
||||||
// wasn't shared in the first place even if the device is trusted.
|
// wasn't shared in the first place even if the device is trusted.
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
machine
|
machine
|
||||||
.should_share_session(&bob_device, Some(&session))
|
.should_share_key(&bob_device, &inbound)
|
||||||
|
.await
|
||||||
.expect_err("Should not share with other unless shared."),
|
.expect_err("Should not share with other unless shared."),
|
||||||
KeyshareDecision::OutboundSessionNotShared
|
KeyshareDecision::OutboundSessionNotShared
|
||||||
);
|
);
|
||||||
|
|
||||||
session.mark_shared_with(bob_device.user_id(), bob_device.device_id());
|
// We now share the session, since it was shared before.
|
||||||
|
outbound.mark_shared_with(bob_device.user_id(), bob_device.device_id());
|
||||||
assert!(machine
|
assert!(machine
|
||||||
.should_share_session(&bob_device, Some(&session))
|
.should_share_key(&bob_device, &inbound)
|
||||||
|
.await
|
||||||
.is_ok());
|
.is_ok());
|
||||||
|
|
||||||
|
// But we don't share some other session that doesn't match our outbound
|
||||||
|
// session
|
||||||
|
let (_, other_inbound) = account
|
||||||
|
.create_group_session_pair_with_defaults(&room_id())
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
machine
|
||||||
|
.should_share_key(&bob_device, &other_inbound)
|
||||||
|
.await
|
||||||
|
.expect_err("Should not share with other unless shared."),
|
||||||
|
KeyshareDecision::MissingOutboundSession
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_test]
|
#[async_test]
|
||||||
|
@ -1038,6 +1263,17 @@ mod test {
|
||||||
let bob_machine = bob_machine();
|
let bob_machine = bob_machine();
|
||||||
let bob_account = bob_account();
|
let bob_account = bob_account();
|
||||||
|
|
||||||
|
let second_account = alice_2_account();
|
||||||
|
let alice_device = ReadOnlyDevice::from_account(&second_account).await;
|
||||||
|
|
||||||
|
// We need a trusted device, otherwise we won't request keys
|
||||||
|
alice_device.set_trust_state(LocalTrust::Verified);
|
||||||
|
alice_machine
|
||||||
|
.store
|
||||||
|
.save_devices(&[alice_device])
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
// Create Olm sessions for our two accounts.
|
// Create Olm sessions for our two accounts.
|
||||||
let (alice_session, bob_session) = alice_account.create_session_for(&bob_account).await;
|
let (alice_session, bob_session) = alice_account.create_session_for(&bob_account).await;
|
||||||
|
|
||||||
|
@ -1092,14 +1328,11 @@ mod test {
|
||||||
// Put the outbound session into bobs store.
|
// Put the outbound session into bobs store.
|
||||||
bob_machine
|
bob_machine
|
||||||
.outbound_group_sessions
|
.outbound_group_sessions
|
||||||
.insert(room_id(), group_session.clone());
|
.insert(group_session.clone());
|
||||||
|
|
||||||
// Get the request and convert it into a event.
|
// Get the request and convert it into a event.
|
||||||
let request = alice_machine
|
let requests = alice_machine.outgoing_to_device_requests().await.unwrap();
|
||||||
.outgoing_to_device_requests
|
let request = &requests[0];
|
||||||
.iter()
|
|
||||||
.next()
|
|
||||||
.unwrap();
|
|
||||||
let id = request.request_id;
|
let id = request.request_id;
|
||||||
let content = request
|
let content = request
|
||||||
.request
|
.request
|
||||||
|
@ -1113,9 +1346,8 @@ mod test {
|
||||||
let content: RoomKeyRequestToDeviceEventContent =
|
let content: RoomKeyRequestToDeviceEventContent =
|
||||||
serde_json::from_str(content.get()).unwrap();
|
serde_json::from_str(content.get()).unwrap();
|
||||||
|
|
||||||
drop(request);
|
|
||||||
alice_machine
|
alice_machine
|
||||||
.mark_outgoing_request_as_sent(&id)
|
.mark_outgoing_request_as_sent(id)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
@ -1134,11 +1366,8 @@ mod test {
|
||||||
assert!(!bob_machine.outgoing_to_device_requests.is_empty());
|
assert!(!bob_machine.outgoing_to_device_requests.is_empty());
|
||||||
|
|
||||||
// Get the request and convert it to a encrypted to-device event.
|
// Get the request and convert it to a encrypted to-device event.
|
||||||
let request = bob_machine
|
let requests = bob_machine.outgoing_to_device_requests().await.unwrap();
|
||||||
.outgoing_to_device_requests
|
let request = &requests[0];
|
||||||
.iter()
|
|
||||||
.next()
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let id = request.request_id;
|
let id = request.request_id;
|
||||||
let content = request
|
let content = request
|
||||||
|
@ -1152,11 +1381,7 @@ mod test {
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let content: EncryptedEventContent = serde_json::from_str(content.get()).unwrap();
|
let content: EncryptedEventContent = serde_json::from_str(content.get()).unwrap();
|
||||||
|
|
||||||
drop(request);
|
bob_machine.mark_outgoing_request_as_sent(id).await.unwrap();
|
||||||
bob_machine
|
|
||||||
.mark_outgoing_request_as_sent(&id)
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let event = ToDeviceEvent {
|
let event = ToDeviceEvent {
|
||||||
sender: bob_id(),
|
sender: bob_id(),
|
||||||
|
@ -1217,6 +1442,17 @@ mod test {
|
||||||
let bob_machine = bob_machine();
|
let bob_machine = bob_machine();
|
||||||
let bob_account = bob_account();
|
let bob_account = bob_account();
|
||||||
|
|
||||||
|
let second_account = alice_2_account();
|
||||||
|
let alice_device = ReadOnlyDevice::from_account(&second_account).await;
|
||||||
|
|
||||||
|
// We need a trusted device, otherwise we won't request keys
|
||||||
|
alice_device.set_trust_state(LocalTrust::Verified);
|
||||||
|
alice_machine
|
||||||
|
.store
|
||||||
|
.save_devices(&[alice_device])
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
// Create Olm sessions for our two accounts.
|
// Create Olm sessions for our two accounts.
|
||||||
let (alice_session, bob_session) = alice_account.create_session_for(&bob_account).await;
|
let (alice_session, bob_session) = alice_account.create_session_for(&bob_account).await;
|
||||||
|
|
||||||
|
@ -1261,14 +1497,11 @@ mod test {
|
||||||
// Put the outbound session into bobs store.
|
// Put the outbound session into bobs store.
|
||||||
bob_machine
|
bob_machine
|
||||||
.outbound_group_sessions
|
.outbound_group_sessions
|
||||||
.insert(room_id(), group_session.clone());
|
.insert(group_session.clone());
|
||||||
|
|
||||||
// Get the request and convert it into a event.
|
// Get the request and convert it into a event.
|
||||||
let request = alice_machine
|
let requests = alice_machine.outgoing_to_device_requests().await.unwrap();
|
||||||
.outgoing_to_device_requests
|
let request = &requests[0];
|
||||||
.iter()
|
|
||||||
.next()
|
|
||||||
.unwrap();
|
|
||||||
let id = request.request_id;
|
let id = request.request_id;
|
||||||
let content = request
|
let content = request
|
||||||
.request
|
.request
|
||||||
|
@ -1282,9 +1515,8 @@ mod test {
|
||||||
let content: RoomKeyRequestToDeviceEventContent =
|
let content: RoomKeyRequestToDeviceEventContent =
|
||||||
serde_json::from_str(content.get()).unwrap();
|
serde_json::from_str(content.get()).unwrap();
|
||||||
|
|
||||||
drop(request);
|
|
||||||
alice_machine
|
alice_machine
|
||||||
.mark_outgoing_request_as_sent(&id)
|
.mark_outgoing_request_as_sent(id)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
@ -1294,7 +1526,11 @@ mod test {
|
||||||
};
|
};
|
||||||
|
|
||||||
// Bob doesn't have any outgoing requests.
|
// Bob doesn't have any outgoing requests.
|
||||||
assert!(bob_machine.outgoing_to_device_requests.is_empty());
|
assert!(bob_machine
|
||||||
|
.outgoing_to_device_requests()
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.is_empty());
|
||||||
assert!(bob_machine.users_for_key_claim.is_empty());
|
assert!(bob_machine.users_for_key_claim.is_empty());
|
||||||
assert!(bob_machine.wait_queue.is_empty());
|
assert!(bob_machine.wait_queue.is_empty());
|
||||||
|
|
||||||
|
@ -1302,7 +1538,11 @@ mod test {
|
||||||
bob_machine.receive_incoming_key_request(&event);
|
bob_machine.receive_incoming_key_request(&event);
|
||||||
bob_machine.collect_incoming_key_requests().await.unwrap();
|
bob_machine.collect_incoming_key_requests().await.unwrap();
|
||||||
// Bob doens't have an outgoing requests since we're lacking a session.
|
// Bob doens't have an outgoing requests since we're lacking a session.
|
||||||
assert!(bob_machine.outgoing_to_device_requests.is_empty());
|
assert!(bob_machine
|
||||||
|
.outgoing_to_device_requests()
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.is_empty());
|
||||||
assert!(!bob_machine.users_for_key_claim.is_empty());
|
assert!(!bob_machine.users_for_key_claim.is_empty());
|
||||||
assert!(!bob_machine.wait_queue.is_empty());
|
assert!(!bob_machine.wait_queue.is_empty());
|
||||||
|
|
||||||
|
@ -1322,15 +1562,17 @@ mod test {
|
||||||
assert!(bob_machine.users_for_key_claim.is_empty());
|
assert!(bob_machine.users_for_key_claim.is_empty());
|
||||||
bob_machine.collect_incoming_key_requests().await.unwrap();
|
bob_machine.collect_incoming_key_requests().await.unwrap();
|
||||||
// Bob now has an outgoing requests.
|
// Bob now has an outgoing requests.
|
||||||
assert!(!bob_machine.outgoing_to_device_requests.is_empty());
|
assert!(!bob_machine
|
||||||
|
.outgoing_to_device_requests()
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.is_empty());
|
||||||
assert!(bob_machine.wait_queue.is_empty());
|
assert!(bob_machine.wait_queue.is_empty());
|
||||||
|
|
||||||
// Get the request and convert it to a encrypted to-device event.
|
// Get the request and convert it to a encrypted to-device event.
|
||||||
let request = bob_machine
|
let requests = bob_machine.outgoing_to_device_requests().await.unwrap();
|
||||||
.outgoing_to_device_requests
|
|
||||||
.iter()
|
let request = &requests[0];
|
||||||
.next()
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let id = request.request_id;
|
let id = request.request_id;
|
||||||
let content = request
|
let content = request
|
||||||
|
@ -1344,11 +1586,7 @@ mod test {
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let content: EncryptedEventContent = serde_json::from_str(content.get()).unwrap();
|
let content: EncryptedEventContent = serde_json::from_str(content.get()).unwrap();
|
||||||
|
|
||||||
drop(request);
|
bob_machine.mark_outgoing_request_as_sent(id).await.unwrap();
|
||||||
bob_machine
|
|
||||||
.mark_outgoing_request_as_sent(&id)
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let event = ToDeviceEvent {
|
let event = ToDeviceEvent {
|
||||||
sender: bob_id(),
|
sender: bob_id(),
|
||||||
|
|
|
@ -156,29 +156,29 @@ impl OlmMachine {
|
||||||
verification_machine.clone(),
|
verification_machine.clone(),
|
||||||
);
|
);
|
||||||
let device_id: Arc<DeviceIdBox> = Arc::new(device_id);
|
let device_id: Arc<DeviceIdBox> = Arc::new(device_id);
|
||||||
let outbound_group_sessions = Arc::new(DashMap::new());
|
|
||||||
let users_for_key_claim = Arc::new(DashMap::new());
|
let users_for_key_claim = Arc::new(DashMap::new());
|
||||||
|
|
||||||
let key_request_machine = KeyRequestMachine::new(
|
|
||||||
user_id.clone(),
|
|
||||||
device_id.clone(),
|
|
||||||
store.clone(),
|
|
||||||
outbound_group_sessions,
|
|
||||||
users_for_key_claim.clone(),
|
|
||||||
);
|
|
||||||
|
|
||||||
let account = Account {
|
let account = Account {
|
||||||
inner: account,
|
inner: account,
|
||||||
store: store.clone(),
|
store: store.clone(),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let group_session_manager = GroupSessionManager::new(account.clone(), store.clone());
|
||||||
|
|
||||||
|
let key_request_machine = KeyRequestMachine::new(
|
||||||
|
user_id.clone(),
|
||||||
|
device_id.clone(),
|
||||||
|
store.clone(),
|
||||||
|
group_session_manager.session_cache(),
|
||||||
|
users_for_key_claim.clone(),
|
||||||
|
);
|
||||||
|
|
||||||
let session_manager = SessionManager::new(
|
let session_manager = SessionManager::new(
|
||||||
account.clone(),
|
account.clone(),
|
||||||
users_for_key_claim,
|
users_for_key_claim,
|
||||||
key_request_machine.clone(),
|
key_request_machine.clone(),
|
||||||
store.clone(),
|
store.clone(),
|
||||||
);
|
);
|
||||||
let group_session_manager = GroupSessionManager::new(account.clone(), store.clone());
|
|
||||||
let identity_manager =
|
let identity_manager =
|
||||||
IdentityManager::new(user_id.clone(), device_id.clone(), store.clone());
|
IdentityManager::new(user_id.clone(), device_id.clone(), store.clone());
|
||||||
|
|
||||||
|
@ -294,7 +294,7 @@ impl OlmMachine {
|
||||||
/// machine using [`mark_request_as_sent`].
|
/// machine using [`mark_request_as_sent`].
|
||||||
///
|
///
|
||||||
/// [`mark_request_as_sent`]: #method.mark_request_as_sent
|
/// [`mark_request_as_sent`]: #method.mark_request_as_sent
|
||||||
pub async fn outgoing_requests(&self) -> Vec<OutgoingRequest> {
|
pub async fn outgoing_requests(&self) -> StoreResult<Vec<OutgoingRequest>> {
|
||||||
let mut requests = Vec::new();
|
let mut requests = Vec::new();
|
||||||
|
|
||||||
if let Some(r) = self.keys_for_upload().await.map(|r| OutgoingRequest {
|
if let Some(r) = self.keys_for_upload().await.map(|r| OutgoingRequest {
|
||||||
|
@ -319,9 +319,14 @@ impl OlmMachine {
|
||||||
|
|
||||||
requests.append(&mut self.outgoing_to_device_requests());
|
requests.append(&mut self.outgoing_to_device_requests());
|
||||||
requests.append(&mut self.verification_machine.outgoing_room_message_requests());
|
requests.append(&mut self.verification_machine.outgoing_room_message_requests());
|
||||||
requests.append(&mut self.key_request_machine.outgoing_to_device_requests());
|
requests.append(
|
||||||
|
&mut self
|
||||||
|
.key_request_machine
|
||||||
|
.outgoing_to_device_requests()
|
||||||
|
.await?,
|
||||||
|
);
|
||||||
|
|
||||||
requests
|
Ok(requests)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Mark the request with the given request id as sent.
|
/// Mark the request with the given request id as sent.
|
||||||
|
@ -751,7 +756,7 @@ impl OlmMachine {
|
||||||
async fn mark_to_device_request_as_sent(&self, request_id: &Uuid) -> StoreResult<()> {
|
async fn mark_to_device_request_as_sent(&self, request_id: &Uuid) -> StoreResult<()> {
|
||||||
self.verification_machine.mark_request_as_sent(request_id);
|
self.verification_machine.mark_request_as_sent(request_id);
|
||||||
self.key_request_machine
|
self.key_request_machine
|
||||||
.mark_outgoing_request_as_sent(request_id)
|
.mark_outgoing_request_as_sent(*request_id)
|
||||||
.await?;
|
.await?;
|
||||||
self.group_session_manager
|
self.group_session_manager
|
||||||
.mark_request_as_sent(request_id)
|
.mark_request_as_sent(request_id)
|
||||||
|
@ -913,6 +918,38 @@ impl OlmMachine {
|
||||||
Ok(ToDevice { events })
|
Ok(ToDevice { events })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Request a room key from our devices.
|
||||||
|
///
|
||||||
|
/// This method will return a request cancelation and a new key request if
|
||||||
|
/// the key was already requested, otherwise it will return just the key
|
||||||
|
/// request.
|
||||||
|
///
|
||||||
|
/// The request cancelation *must* be sent out before the request is sent
|
||||||
|
/// out, otherwise devices will ignore the key request.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `room_id` - The id of the room where the key is used in.
|
||||||
|
///
|
||||||
|
/// * `sender_key` - The curve25519 key of the sender that owns the key.
|
||||||
|
///
|
||||||
|
/// * `session_id` - The id that uniquely identifies the session.
|
||||||
|
pub async fn request_room_key(
|
||||||
|
&self,
|
||||||
|
event: &SyncMessageEvent<EncryptedEventContent>,
|
||||||
|
room_id: &RoomId,
|
||||||
|
) -> MegolmResult<(Option<OutgoingRequest>, OutgoingRequest)> {
|
||||||
|
let content = match &event.content {
|
||||||
|
EncryptedEventContent::MegolmV1AesSha2(c) => c,
|
||||||
|
_ => return Err(EventError::UnsupportedAlgorithm.into()),
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(self
|
||||||
|
.key_request_machine
|
||||||
|
.request_key(room_id, &content.sender_key, &content.session_id)
|
||||||
|
.await?)
|
||||||
|
}
|
||||||
|
|
||||||
/// Decrypt an event from a room timeline.
|
/// Decrypt an event from a room timeline.
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
|
|
|
@ -40,6 +40,78 @@ use crate::{
|
||||||
Device, EncryptionSettings, OlmError, ToDeviceRequest,
|
Device, EncryptionSettings, OlmError, ToDeviceRequest,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub(crate) struct GroupSessionCache {
|
||||||
|
store: Store,
|
||||||
|
sessions: Arc<DashMap<RoomId, OutboundGroupSession>>,
|
||||||
|
/// A map from the request id to the group session that the request belongs
|
||||||
|
/// to. Used to mark requests belonging to the session as shared.
|
||||||
|
sessions_being_shared: Arc<DashMap<Uuid, OutboundGroupSession>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl GroupSessionCache {
|
||||||
|
pub(crate) fn new(store: Store) -> Self {
|
||||||
|
Self {
|
||||||
|
store,
|
||||||
|
sessions: DashMap::new().into(),
|
||||||
|
sessions_being_shared: Arc::new(DashMap::new()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn insert(&self, session: OutboundGroupSession) {
|
||||||
|
self.sessions.insert(session.room_id().to_owned(), session);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Either get a session for the given room from the cache or load it from
|
||||||
|
/// the store.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `room_id` - The id of the room this session is used for.
|
||||||
|
pub async fn get_or_load(&self, room_id: &RoomId) -> StoreResult<Option<OutboundGroupSession>> {
|
||||||
|
// Get the cached session, if there isn't one load one from the store
|
||||||
|
// and put it in the cache.
|
||||||
|
if let Some(s) = self.sessions.get(room_id) {
|
||||||
|
Ok(Some(s.clone()))
|
||||||
|
} else if let Some(s) = self.store.get_outbound_group_sessions(room_id).await? {
|
||||||
|
for request_id in s.pending_request_ids() {
|
||||||
|
self.sessions_being_shared.insert(request_id, s.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
self.sessions.insert(room_id.clone(), s.clone());
|
||||||
|
|
||||||
|
Ok(Some(s))
|
||||||
|
} else {
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get an outbound group session for a room, if one exists.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `room_id` - The id of the room for which we should get the outbound
|
||||||
|
/// group session.
|
||||||
|
fn get(&self, room_id: &RoomId) -> Option<OutboundGroupSession> {
|
||||||
|
self.sessions.get(room_id).map(|s| s.clone())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get or load the session for the given room with the given session id.
|
||||||
|
///
|
||||||
|
/// This is the same as [get_or_load()](#method.get_or_load) but it will
|
||||||
|
/// filter out the session if it doesn't match the given session id.
|
||||||
|
pub async fn get_with_id(
|
||||||
|
&self,
|
||||||
|
room_id: &RoomId,
|
||||||
|
session_id: &str,
|
||||||
|
) -> StoreResult<Option<OutboundGroupSession>> {
|
||||||
|
Ok(self
|
||||||
|
.get_or_load(room_id)
|
||||||
|
.await?
|
||||||
|
.filter(|o| session_id == o.session_id()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct GroupSessionManager {
|
pub struct GroupSessionManager {
|
||||||
account: Account,
|
account: Account,
|
||||||
|
@ -48,10 +120,7 @@ pub struct GroupSessionManager {
|
||||||
/// without the need to create new keys.
|
/// without the need to create new keys.
|
||||||
store: Store,
|
store: Store,
|
||||||
/// The currently active outbound group sessions.
|
/// The currently active outbound group sessions.
|
||||||
outbound_group_sessions: Arc<DashMap<RoomId, OutboundGroupSession>>,
|
sessions: GroupSessionCache,
|
||||||
/// A map from the request id to the group session that the request belongs
|
|
||||||
/// to. Used to mark requests belonging to the session as shared.
|
|
||||||
outbound_sessions_being_shared: Arc<DashMap<Uuid, OutboundGroupSession>>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl GroupSessionManager {
|
impl GroupSessionManager {
|
||||||
|
@ -60,14 +129,13 @@ impl GroupSessionManager {
|
||||||
pub(crate) fn new(account: Account, store: Store) -> Self {
|
pub(crate) fn new(account: Account, store: Store) -> Self {
|
||||||
Self {
|
Self {
|
||||||
account,
|
account,
|
||||||
store,
|
store: store.clone(),
|
||||||
outbound_group_sessions: Arc::new(DashMap::new()),
|
sessions: GroupSessionCache::new(store),
|
||||||
outbound_sessions_being_shared: Arc::new(DashMap::new()),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn invalidate_group_session(&self, room_id: &RoomId) -> StoreResult<bool> {
|
pub async fn invalidate_group_session(&self, room_id: &RoomId) -> StoreResult<bool> {
|
||||||
if let Some(s) = self.outbound_group_sessions.get(room_id) {
|
if let Some(s) = self.sessions.get(room_id) {
|
||||||
s.invalidate_session();
|
s.invalidate_session();
|
||||||
|
|
||||||
let mut changes = Changes::default();
|
let mut changes = Changes::default();
|
||||||
|
@ -81,7 +149,7 @@ impl GroupSessionManager {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn mark_request_as_sent(&self, request_id: &Uuid) -> StoreResult<()> {
|
pub async fn mark_request_as_sent(&self, request_id: &Uuid) -> StoreResult<()> {
|
||||||
if let Some((_, s)) = self.outbound_sessions_being_shared.remove(request_id) {
|
if let Some((_, s)) = self.sessions.sessions_being_shared.remove(request_id) {
|
||||||
s.mark_request_as_sent(request_id);
|
s.mark_request_as_sent(request_id);
|
||||||
|
|
||||||
let mut changes = Changes::default();
|
let mut changes = Changes::default();
|
||||||
|
@ -97,15 +165,9 @@ impl GroupSessionManager {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get an outbound group session for a room, if one exists.
|
#[cfg(test)]
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `room_id` - The id of the room for which we should get the outbound
|
|
||||||
/// group session.
|
|
||||||
pub fn get_outbound_group_session(&self, room_id: &RoomId) -> Option<OutboundGroupSession> {
|
pub fn get_outbound_group_session(&self, room_id: &RoomId) -> Option<OutboundGroupSession> {
|
||||||
#[allow(clippy::map_clone)]
|
self.sessions.get(room_id)
|
||||||
self.outbound_group_sessions.get(room_id).map(|s| s.clone())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn encrypt(
|
pub async fn encrypt(
|
||||||
|
@ -113,7 +175,7 @@ impl GroupSessionManager {
|
||||||
room_id: &RoomId,
|
room_id: &RoomId,
|
||||||
content: AnyMessageEventContent,
|
content: AnyMessageEventContent,
|
||||||
) -> MegolmResult<EncryptedEventContent> {
|
) -> MegolmResult<EncryptedEventContent> {
|
||||||
let session = if let Some(s) = self.get_outbound_group_session(room_id) {
|
let session = if let Some(s) = self.sessions.get(room_id) {
|
||||||
s
|
s
|
||||||
} else {
|
} else {
|
||||||
panic!("Session wasn't created nor shared");
|
panic!("Session wasn't created nor shared");
|
||||||
|
@ -147,9 +209,7 @@ impl GroupSessionManager {
|
||||||
.await
|
.await
|
||||||
.map_err(|_| EventError::UnsupportedAlgorithm)?;
|
.map_err(|_| EventError::UnsupportedAlgorithm)?;
|
||||||
|
|
||||||
let _ = self
|
self.sessions.insert(outbound.clone());
|
||||||
.outbound_group_sessions
|
|
||||||
.insert(room_id.to_owned(), outbound.clone());
|
|
||||||
Ok((outbound, inbound))
|
Ok((outbound, inbound))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -158,23 +218,7 @@ impl GroupSessionManager {
|
||||||
room_id: &RoomId,
|
room_id: &RoomId,
|
||||||
settings: EncryptionSettings,
|
settings: EncryptionSettings,
|
||||||
) -> OlmResult<(OutboundGroupSession, Option<InboundGroupSession>)> {
|
) -> OlmResult<(OutboundGroupSession, Option<InboundGroupSession>)> {
|
||||||
// Get the cached session, if there isn't one load one from the store
|
let outbound_session = self.sessions.get_or_load(&room_id).await?;
|
||||||
// and put it in the cache.
|
|
||||||
let outbound_session = if let Some(s) = self.outbound_group_sessions.get(room_id) {
|
|
||||||
Some(s.clone())
|
|
||||||
} else if let Some(s) = self.store.get_outbound_group_sessions(room_id).await? {
|
|
||||||
for request_id in s.pending_request_ids() {
|
|
||||||
self.outbound_sessions_being_shared
|
|
||||||
.insert(request_id, s.clone());
|
|
||||||
}
|
|
||||||
|
|
||||||
self.outbound_group_sessions
|
|
||||||
.insert(room_id.clone(), s.clone());
|
|
||||||
|
|
||||||
Some(s)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
// If there is no session or the session has expired or is invalid,
|
// If there is no session or the session has expired or is invalid,
|
||||||
// create a new one.
|
// create a new one.
|
||||||
|
@ -388,6 +432,10 @@ impl GroupSessionManager {
|
||||||
Ok(used_sessions)
|
Ok(used_sessions)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn session_cache(&self) -> GroupSessionCache {
|
||||||
|
self.sessions.clone()
|
||||||
|
}
|
||||||
|
|
||||||
/// Get to-device requests to share a group session with users in a room.
|
/// Get to-device requests to share a group session with users in a room.
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
|
@ -489,7 +537,7 @@ impl GroupSessionManager {
|
||||||
key_content.clone(),
|
key_content.clone(),
|
||||||
outbound.clone(),
|
outbound.clone(),
|
||||||
message_index,
|
message_index,
|
||||||
self.outbound_sessions_being_shared.clone(),
|
self.sessions.sessions_being_shared.clone(),
|
||||||
))
|
))
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
|
@ -15,5 +15,5 @@
|
||||||
mod group_sessions;
|
mod group_sessions;
|
||||||
mod sessions;
|
mod sessions;
|
||||||
|
|
||||||
pub(crate) use group_sessions::GroupSessionManager;
|
pub(crate) use group_sessions::{GroupSessionCache, GroupSessionManager};
|
||||||
pub(crate) use sessions::SessionManager;
|
pub(crate) use sessions::SessionManager;
|
||||||
|
|
|
@ -322,6 +322,7 @@ mod test {
|
||||||
identities::ReadOnlyDevice,
|
identities::ReadOnlyDevice,
|
||||||
key_request::KeyRequestMachine,
|
key_request::KeyRequestMachine,
|
||||||
olm::{Account, PrivateCrossSigningIdentity, ReadOnlyAccount},
|
olm::{Account, PrivateCrossSigningIdentity, ReadOnlyAccount},
|
||||||
|
session_manager::GroupSessionCache,
|
||||||
store::{CryptoStore, MemoryStore, Store},
|
store::{CryptoStore, MemoryStore, Store},
|
||||||
verification::VerificationMachine,
|
verification::VerificationMachine,
|
||||||
};
|
};
|
||||||
|
@ -342,7 +343,6 @@ mod test {
|
||||||
let user_id = user_id();
|
let user_id = user_id();
|
||||||
let device_id = device_id();
|
let device_id = device_id();
|
||||||
|
|
||||||
let outbound_sessions = Arc::new(DashMap::new());
|
|
||||||
let users_for_key_claim = Arc::new(DashMap::new());
|
let users_for_key_claim = Arc::new(DashMap::new());
|
||||||
let account = ReadOnlyAccount::new(&user_id, &device_id);
|
let account = ReadOnlyAccount::new(&user_id, &device_id);
|
||||||
let store: Arc<Box<dyn CryptoStore>> = Arc::new(Box::new(MemoryStore::new()));
|
let store: Arc<Box<dyn CryptoStore>> = Arc::new(Box::new(MemoryStore::new()));
|
||||||
|
@ -363,11 +363,13 @@ mod test {
|
||||||
store: store.clone(),
|
store: store.clone(),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let session_cache = GroupSessionCache::new(store.clone());
|
||||||
|
|
||||||
let key_request = KeyRequestMachine::new(
|
let key_request = KeyRequestMachine::new(
|
||||||
user_id,
|
user_id,
|
||||||
device_id,
|
device_id,
|
||||||
store.clone(),
|
store.clone(),
|
||||||
outbound_sessions,
|
session_cache,
|
||||||
users_for_key_claim.clone(),
|
users_for_key_claim.clone(),
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
|
@ -20,8 +20,10 @@ use std::{
|
||||||
use dashmap::{DashMap, DashSet};
|
use dashmap::{DashMap, DashSet};
|
||||||
use matrix_sdk_common::{
|
use matrix_sdk_common::{
|
||||||
async_trait,
|
async_trait,
|
||||||
|
events::room_key_request::RequestedKeyInfo,
|
||||||
identifiers::{DeviceId, DeviceIdBox, RoomId, UserId},
|
identifiers::{DeviceId, DeviceIdBox, RoomId, UserId},
|
||||||
locks::Mutex,
|
locks::Mutex,
|
||||||
|
uuid::Uuid,
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::{
|
use super::{
|
||||||
|
@ -30,9 +32,17 @@ use super::{
|
||||||
};
|
};
|
||||||
use crate::{
|
use crate::{
|
||||||
identities::{ReadOnlyDevice, UserIdentities},
|
identities::{ReadOnlyDevice, UserIdentities},
|
||||||
|
key_request::OutgoingKeyRequest,
|
||||||
olm::{OutboundGroupSession, PrivateCrossSigningIdentity},
|
olm::{OutboundGroupSession, PrivateCrossSigningIdentity},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
fn encode_key_info(info: &RequestedKeyInfo) -> String {
|
||||||
|
format!(
|
||||||
|
"{}{}{}{}",
|
||||||
|
info.room_id, info.sender_key, info.algorithm, info.session_id
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
/// An in-memory only store that will forget all the E2EE key once it's dropped.
|
/// An in-memory only store that will forget all the E2EE key once it's dropped.
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct MemoryStore {
|
pub struct MemoryStore {
|
||||||
|
@ -43,7 +53,8 @@ pub struct MemoryStore {
|
||||||
olm_hashes: Arc<DashMap<String, DashSet<String>>>,
|
olm_hashes: Arc<DashMap<String, DashSet<String>>>,
|
||||||
devices: DeviceStore,
|
devices: DeviceStore,
|
||||||
identities: Arc<DashMap<UserId, UserIdentities>>,
|
identities: Arc<DashMap<UserId, UserIdentities>>,
|
||||||
values: Arc<DashMap<String, String>>,
|
outgoing_key_requests: Arc<DashMap<Uuid, OutgoingKeyRequest>>,
|
||||||
|
key_requests_by_info: Arc<DashMap<String, Uuid>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for MemoryStore {
|
impl Default for MemoryStore {
|
||||||
|
@ -56,7 +67,8 @@ impl Default for MemoryStore {
|
||||||
olm_hashes: Arc::new(DashMap::new()),
|
olm_hashes: Arc::new(DashMap::new()),
|
||||||
devices: DeviceStore::new(),
|
devices: DeviceStore::new(),
|
||||||
identities: Arc::new(DashMap::new()),
|
identities: Arc::new(DashMap::new()),
|
||||||
values: Arc::new(DashMap::new()),
|
outgoing_key_requests: Arc::new(DashMap::new()),
|
||||||
|
key_requests_by_info: Arc::new(DashMap::new()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -103,6 +115,10 @@ impl CryptoStore for MemoryStore {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn load_identity(&self) -> Result<Option<PrivateCrossSigningIdentity>> {
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
|
||||||
async fn save_changes(&self, mut changes: Changes) -> Result<()> {
|
async fn save_changes(&self, mut changes: Changes) -> Result<()> {
|
||||||
self.save_sessions(changes.sessions).await;
|
self.save_sessions(changes.sessions).await;
|
||||||
self.save_inbound_group_sessions(changes.inbound_group_sessions)
|
self.save_inbound_group_sessions(changes.inbound_group_sessions)
|
||||||
|
@ -130,6 +146,14 @@ impl CryptoStore for MemoryStore {
|
||||||
.insert(hash.hash.clone());
|
.insert(hash.hash.clone());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for key_request in changes.key_requests {
|
||||||
|
let id = key_request.request_id;
|
||||||
|
let info_string = encode_key_info(&key_request.info);
|
||||||
|
|
||||||
|
self.outgoing_key_requests.insert(id, key_request);
|
||||||
|
self.key_requests_by_info.insert(info_string, id);
|
||||||
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -152,9 +176,11 @@ impl CryptoStore for MemoryStore {
|
||||||
Ok(self.inbound_group_sessions.get_all())
|
Ok(self.inbound_group_sessions.get_all())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn users_for_key_query(&self) -> HashSet<UserId> {
|
async fn get_outbound_group_sessions(
|
||||||
#[allow(clippy::map_clone)]
|
&self,
|
||||||
self.users_for_key_query.iter().map(|u| u.clone()).collect()
|
_: &RoomId,
|
||||||
|
) -> Result<Option<OutboundGroupSession>> {
|
||||||
|
Ok(None)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn is_user_tracked(&self, user_id: &UserId) -> bool {
|
fn is_user_tracked(&self, user_id: &UserId) -> bool {
|
||||||
|
@ -165,6 +191,11 @@ impl CryptoStore for MemoryStore {
|
||||||
!self.users_for_key_query.is_empty()
|
!self.users_for_key_query.is_empty()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn users_for_key_query(&self) -> HashSet<UserId> {
|
||||||
|
#[allow(clippy::map_clone)]
|
||||||
|
self.users_for_key_query.iter().map(|u| u.clone()).collect()
|
||||||
|
}
|
||||||
|
|
||||||
async fn update_tracked_user(&self, user: &UserId, dirty: bool) -> Result<bool> {
|
async fn update_tracked_user(&self, user: &UserId, dirty: bool) -> Result<bool> {
|
||||||
// TODO to prevent a race between the sync and a key query in flight we
|
// TODO to prevent a race between the sync and a key query in flight we
|
||||||
// need to have an additional state to mention that the user changed.
|
// need to have an additional state to mention that the user changed.
|
||||||
|
@ -207,24 +238,6 @@ impl CryptoStore for MemoryStore {
|
||||||
Ok(self.identities.get(user_id).map(|i| i.clone()))
|
Ok(self.identities.get(user_id).map(|i| i.clone()))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn save_value(&self, key: String, value: String) -> Result<()> {
|
|
||||||
self.values.insert(key, value);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn remove_value(&self, key: &str) -> Result<()> {
|
|
||||||
self.values.remove(key);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn get_value(&self, key: &str) -> Result<Option<String>> {
|
|
||||||
Ok(self.values.get(key).map(|v| v.to_owned()))
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn load_identity(&self) -> Result<Option<PrivateCrossSigningIdentity>> {
|
|
||||||
Ok(None)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn is_message_known(&self, message_hash: &crate::olm::OlmMessageHash) -> Result<bool> {
|
async fn is_message_known(&self, message_hash: &crate::olm::OlmMessageHash) -> Result<bool> {
|
||||||
Ok(self
|
Ok(self
|
||||||
.olm_hashes
|
.olm_hashes
|
||||||
|
@ -233,11 +246,46 @@ impl CryptoStore for MemoryStore {
|
||||||
.contains(&message_hash.hash))
|
.contains(&message_hash.hash))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_outbound_group_sessions(
|
async fn get_outgoing_key_request(
|
||||||
&self,
|
&self,
|
||||||
_: &RoomId,
|
request_id: Uuid,
|
||||||
) -> Result<Option<OutboundGroupSession>> {
|
) -> Result<Option<OutgoingKeyRequest>> {
|
||||||
Ok(None)
|
Ok(self
|
||||||
|
.outgoing_key_requests
|
||||||
|
.get(&request_id)
|
||||||
|
.map(|r| r.clone()))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_key_request_by_info(
|
||||||
|
&self,
|
||||||
|
key_info: &RequestedKeyInfo,
|
||||||
|
) -> Result<Option<OutgoingKeyRequest>> {
|
||||||
|
let key_info_string = encode_key_info(key_info);
|
||||||
|
|
||||||
|
Ok(self
|
||||||
|
.key_requests_by_info
|
||||||
|
.get(&key_info_string)
|
||||||
|
.and_then(|i| self.outgoing_key_requests.get(&i).map(|r| r.clone())))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_unsent_key_requests(&self) -> Result<Vec<OutgoingKeyRequest>> {
|
||||||
|
Ok(self
|
||||||
|
.outgoing_key_requests
|
||||||
|
.iter()
|
||||||
|
.filter(|i| !i.value().sent_out)
|
||||||
|
.map(|i| i.value().clone())
|
||||||
|
.collect())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn delete_outgoing_key_request(&self, request_id: Uuid) -> Result<()> {
|
||||||
|
self.outgoing_key_requests
|
||||||
|
.remove(&request_id)
|
||||||
|
.and_then(|(_, i)| {
|
||||||
|
let key_info_string = encode_key_info(&i.info);
|
||||||
|
self.key_requests_by_info.remove(&key_info_string)
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -57,23 +57,25 @@ use std::{
|
||||||
};
|
};
|
||||||
|
|
||||||
use olm_rs::errors::{OlmAccountError, OlmGroupSessionError, OlmSessionError};
|
use olm_rs::errors::{OlmAccountError, OlmGroupSessionError, OlmSessionError};
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use serde_json::Error as SerdeError;
|
use serde_json::Error as SerdeError;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
|
||||||
use matrix_sdk_common::{
|
use matrix_sdk_common::{
|
||||||
async_trait,
|
async_trait,
|
||||||
|
events::room_key_request::RequestedKeyInfo,
|
||||||
identifiers::{
|
identifiers::{
|
||||||
DeviceId, DeviceIdBox, DeviceKeyAlgorithm, Error as IdentifierValidationError, RoomId,
|
DeviceId, DeviceIdBox, DeviceKeyAlgorithm, Error as IdentifierValidationError, RoomId,
|
||||||
UserId,
|
UserId,
|
||||||
},
|
},
|
||||||
locks::Mutex,
|
locks::Mutex,
|
||||||
|
uuid::Uuid,
|
||||||
AsyncTraitDeps,
|
AsyncTraitDeps,
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
error::SessionUnpicklingError,
|
error::SessionUnpicklingError,
|
||||||
identities::{Device, ReadOnlyDevice, UserDevices, UserIdentities},
|
identities::{Device, ReadOnlyDevice, UserDevices, UserIdentities},
|
||||||
|
key_request::OutgoingKeyRequest,
|
||||||
olm::{
|
olm::{
|
||||||
InboundGroupSession, OlmMessageHash, OutboundGroupSession, PrivateCrossSigningIdentity,
|
InboundGroupSession, OlmMessageHash, OutboundGroupSession, PrivateCrossSigningIdentity,
|
||||||
ReadOnlyAccount, Session,
|
ReadOnlyAccount, Session,
|
||||||
|
@ -108,6 +110,7 @@ pub struct Changes {
|
||||||
pub inbound_group_sessions: Vec<InboundGroupSession>,
|
pub inbound_group_sessions: Vec<InboundGroupSession>,
|
||||||
pub outbound_group_sessions: Vec<OutboundGroupSession>,
|
pub outbound_group_sessions: Vec<OutboundGroupSession>,
|
||||||
pub identities: IdentityChanges,
|
pub identities: IdentityChanges,
|
||||||
|
pub key_requests: Vec<OutgoingKeyRequest>,
|
||||||
pub devices: DeviceChanges,
|
pub devices: DeviceChanges,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -257,24 +260,6 @@ impl Store {
|
||||||
device_owner_identity,
|
device_owner_identity,
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn get_object<V: for<'b> Deserialize<'b>>(&self, key: &str) -> Result<Option<V>> {
|
|
||||||
if let Some(value) = self.get_value(key).await? {
|
|
||||||
Ok(Some(serde_json::from_str(&value)?))
|
|
||||||
} else {
|
|
||||||
Ok(None)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn save_object(&self, key: &str, value: &impl Serialize) -> Result<()> {
|
|
||||||
let value = serde_json::to_string(value)?;
|
|
||||||
self.save_value(key.to_owned(), value).await
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn delete_object(&self, key: &str) -> Result<()> {
|
|
||||||
self.inner.remove_value(key).await?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Deref for Store {
|
impl Deref for Store {
|
||||||
|
@ -438,15 +423,41 @@ pub trait CryptoStore: AsyncTraitDeps {
|
||||||
/// * `user_id` - The user for which we should get the identity.
|
/// * `user_id` - The user for which we should get the identity.
|
||||||
async fn get_user_identity(&self, user_id: &UserId) -> Result<Option<UserIdentities>>;
|
async fn get_user_identity(&self, user_id: &UserId) -> Result<Option<UserIdentities>>;
|
||||||
|
|
||||||
/// Save a serializeable object in the store.
|
|
||||||
async fn save_value(&self, key: String, value: String) -> Result<()>;
|
|
||||||
|
|
||||||
/// Remove a value from the store.
|
|
||||||
async fn remove_value(&self, key: &str) -> Result<()>;
|
|
||||||
|
|
||||||
/// Load a serializeable object from the store.
|
|
||||||
async fn get_value(&self, key: &str) -> Result<Option<String>>;
|
|
||||||
|
|
||||||
/// Check if a hash for an Olm message stored in the database.
|
/// Check if a hash for an Olm message stored in the database.
|
||||||
async fn is_message_known(&self, message_hash: &OlmMessageHash) -> Result<bool>;
|
async fn is_message_known(&self, message_hash: &OlmMessageHash) -> Result<bool>;
|
||||||
|
|
||||||
|
/// Get an outoing key request that we created that matches the given
|
||||||
|
/// request id.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `request_id` - The unique request id that identifies this outgoing key
|
||||||
|
/// request.
|
||||||
|
async fn get_outgoing_key_request(
|
||||||
|
&self,
|
||||||
|
request_id: Uuid,
|
||||||
|
) -> Result<Option<OutgoingKeyRequest>>;
|
||||||
|
|
||||||
|
/// Get an outoing key request that we created that matches the given
|
||||||
|
/// requested key info.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `key_info` - The key info of an outgoing key request.
|
||||||
|
async fn get_key_request_by_info(
|
||||||
|
&self,
|
||||||
|
key_info: &RequestedKeyInfo,
|
||||||
|
) -> Result<Option<OutgoingKeyRequest>>;
|
||||||
|
|
||||||
|
/// Get all outgoing key requests that we have in the store.
|
||||||
|
async fn get_unsent_key_requests(&self) -> Result<Vec<OutgoingKeyRequest>>;
|
||||||
|
|
||||||
|
/// Delete an outoing key request that we created that matches the given
|
||||||
|
/// request id.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `request_id` - The unique request id that identifies this outgoing key
|
||||||
|
/// request.
|
||||||
|
async fn delete_outgoing_key_request(&self, request_id: Uuid) -> Result<()>;
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,9 +29,12 @@ use sled::{
|
||||||
|
|
||||||
use matrix_sdk_common::{
|
use matrix_sdk_common::{
|
||||||
async_trait,
|
async_trait,
|
||||||
|
events::room_key_request::RequestedKeyInfo,
|
||||||
identifiers::{DeviceId, DeviceIdBox, RoomId, UserId},
|
identifiers::{DeviceId, DeviceIdBox, RoomId, UserId},
|
||||||
locks::Mutex,
|
locks::Mutex,
|
||||||
|
uuid,
|
||||||
};
|
};
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
use super::{
|
use super::{
|
||||||
caches::SessionStore, Changes, CryptoStore, CryptoStoreError, InboundGroupSession, PickleKey,
|
caches::SessionStore, Changes, CryptoStore, CryptoStoreError, InboundGroupSession, PickleKey,
|
||||||
|
@ -39,6 +42,7 @@ use super::{
|
||||||
};
|
};
|
||||||
use crate::{
|
use crate::{
|
||||||
identities::{ReadOnlyDevice, UserIdentities},
|
identities::{ReadOnlyDevice, UserIdentities},
|
||||||
|
key_request::OutgoingKeyRequest,
|
||||||
olm::{OutboundGroupSession, PickledInboundGroupSession, PrivateCrossSigningIdentity},
|
olm::{OutboundGroupSession, PickledInboundGroupSession, PrivateCrossSigningIdentity},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -51,6 +55,28 @@ trait EncodeKey {
|
||||||
fn encode(&self) -> Vec<u8>;
|
fn encode(&self) -> Vec<u8>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl EncodeKey for Uuid {
|
||||||
|
fn encode(&self) -> Vec<u8> {
|
||||||
|
self.as_u128().to_be_bytes().to_vec()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EncodeKey for &RequestedKeyInfo {
|
||||||
|
fn encode(&self) -> Vec<u8> {
|
||||||
|
[
|
||||||
|
self.room_id.as_bytes(),
|
||||||
|
&[Self::SEPARATOR],
|
||||||
|
self.sender_key.as_bytes(),
|
||||||
|
&[Self::SEPARATOR],
|
||||||
|
self.algorithm.as_ref().as_bytes(),
|
||||||
|
&[Self::SEPARATOR],
|
||||||
|
self.session_id.as_bytes(),
|
||||||
|
&[Self::SEPARATOR],
|
||||||
|
]
|
||||||
|
.concat()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl EncodeKey for &UserId {
|
impl EncodeKey for &UserId {
|
||||||
fn encode(&self) -> Vec<u8> {
|
fn encode(&self) -> Vec<u8> {
|
||||||
self.as_str().encode()
|
self.as_str().encode()
|
||||||
|
@ -122,12 +148,15 @@ pub struct SledStore {
|
||||||
inbound_group_sessions: Tree,
|
inbound_group_sessions: Tree,
|
||||||
outbound_group_sessions: Tree,
|
outbound_group_sessions: Tree,
|
||||||
|
|
||||||
|
outgoing_key_requests: Tree,
|
||||||
|
unsent_key_requests: Tree,
|
||||||
|
key_requests_by_info: Tree,
|
||||||
|
|
||||||
devices: Tree,
|
devices: Tree,
|
||||||
identities: Tree,
|
identities: Tree,
|
||||||
|
|
||||||
tracked_users: Tree,
|
tracked_users: Tree,
|
||||||
users_for_key_query: Tree,
|
users_for_key_query: Tree,
|
||||||
values: Tree,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::fmt::Debug for SledStore {
|
impl std::fmt::Debug for SledStore {
|
||||||
|
@ -178,13 +207,17 @@ impl SledStore {
|
||||||
let sessions = db.open_tree("session")?;
|
let sessions = db.open_tree("session")?;
|
||||||
let inbound_group_sessions = db.open_tree("inbound_group_sessions")?;
|
let inbound_group_sessions = db.open_tree("inbound_group_sessions")?;
|
||||||
let outbound_group_sessions = db.open_tree("outbound_group_sessions")?;
|
let outbound_group_sessions = db.open_tree("outbound_group_sessions")?;
|
||||||
|
|
||||||
let tracked_users = db.open_tree("tracked_users")?;
|
let tracked_users = db.open_tree("tracked_users")?;
|
||||||
let users_for_key_query = db.open_tree("users_for_key_query")?;
|
let users_for_key_query = db.open_tree("users_for_key_query")?;
|
||||||
let olm_hashes = db.open_tree("olm_hashes")?;
|
let olm_hashes = db.open_tree("olm_hashes")?;
|
||||||
|
|
||||||
let devices = db.open_tree("devices")?;
|
let devices = db.open_tree("devices")?;
|
||||||
let identities = db.open_tree("identities")?;
|
let identities = db.open_tree("identities")?;
|
||||||
let values = db.open_tree("values")?;
|
|
||||||
|
let outgoing_key_requests = db.open_tree("outgoing_key_requests")?;
|
||||||
|
let unsent_key_requests = db.open_tree("unsent_key_requests")?;
|
||||||
|
let key_requests_by_info = db.open_tree("key_requests_by_info")?;
|
||||||
|
|
||||||
let session_cache = SessionStore::new();
|
let session_cache = SessionStore::new();
|
||||||
|
|
||||||
|
@ -208,12 +241,14 @@ impl SledStore {
|
||||||
users_for_key_query_cache: DashSet::new().into(),
|
users_for_key_query_cache: DashSet::new().into(),
|
||||||
inbound_group_sessions,
|
inbound_group_sessions,
|
||||||
outbound_group_sessions,
|
outbound_group_sessions,
|
||||||
|
outgoing_key_requests,
|
||||||
|
unsent_key_requests,
|
||||||
|
key_requests_by_info,
|
||||||
devices,
|
devices,
|
||||||
tracked_users,
|
tracked_users,
|
||||||
users_for_key_query,
|
users_for_key_query,
|
||||||
olm_hashes,
|
olm_hashes,
|
||||||
identities,
|
identities,
|
||||||
values,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -332,6 +367,7 @@ impl SledStore {
|
||||||
|
|
||||||
let identity_changes = changes.identities;
|
let identity_changes = changes.identities;
|
||||||
let olm_hashes = changes.message_hashes;
|
let olm_hashes = changes.message_hashes;
|
||||||
|
let key_requests = changes.key_requests;
|
||||||
|
|
||||||
let ret: Result<(), TransactionError<serde_json::Error>> = (
|
let ret: Result<(), TransactionError<serde_json::Error>> = (
|
||||||
&self.account,
|
&self.account,
|
||||||
|
@ -342,6 +378,9 @@ impl SledStore {
|
||||||
&self.inbound_group_sessions,
|
&self.inbound_group_sessions,
|
||||||
&self.outbound_group_sessions,
|
&self.outbound_group_sessions,
|
||||||
&self.olm_hashes,
|
&self.olm_hashes,
|
||||||
|
&self.outgoing_key_requests,
|
||||||
|
&self.unsent_key_requests,
|
||||||
|
&self.key_requests_by_info,
|
||||||
)
|
)
|
||||||
.transaction(
|
.transaction(
|
||||||
|(
|
|(
|
||||||
|
@ -353,6 +392,9 @@ impl SledStore {
|
||||||
inbound_sessions,
|
inbound_sessions,
|
||||||
outbound_sessions,
|
outbound_sessions,
|
||||||
hashes,
|
hashes,
|
||||||
|
outgoing_key_requests,
|
||||||
|
unsent_key_requests,
|
||||||
|
key_requests_by_info,
|
||||||
)| {
|
)| {
|
||||||
if let Some(a) = &account_pickle {
|
if let Some(a) = &account_pickle {
|
||||||
account.insert(
|
account.insert(
|
||||||
|
@ -420,6 +462,31 @@ impl SledStore {
|
||||||
)?;
|
)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for key_request in &key_requests {
|
||||||
|
key_requests_by_info.insert(
|
||||||
|
(&key_request.info).encode(),
|
||||||
|
key_request.request_id.encode(),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let key_request_id = key_request.request_id.encode();
|
||||||
|
|
||||||
|
if key_request.sent_out {
|
||||||
|
unsent_key_requests.remove(key_request_id.clone())?;
|
||||||
|
outgoing_key_requests.insert(
|
||||||
|
key_request_id,
|
||||||
|
serde_json::to_vec(&key_request)
|
||||||
|
.map_err(ConflictableTransactionError::Abort)?,
|
||||||
|
)?;
|
||||||
|
} else {
|
||||||
|
outgoing_key_requests.remove(key_request_id.clone())?;
|
||||||
|
unsent_key_requests.insert(
|
||||||
|
key_request_id,
|
||||||
|
serde_json::to_vec(&key_request)
|
||||||
|
.map_err(ConflictableTransactionError::Abort)?,
|
||||||
|
)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
@ -429,6 +496,28 @@ impl SledStore {
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn get_outgoing_key_request_helper(
|
||||||
|
&self,
|
||||||
|
id: &[u8],
|
||||||
|
) -> Result<Option<OutgoingKeyRequest>> {
|
||||||
|
let request = self
|
||||||
|
.outgoing_key_requests
|
||||||
|
.get(id)?
|
||||||
|
.map(|r| serde_json::from_slice(&r))
|
||||||
|
.transpose()?;
|
||||||
|
|
||||||
|
let request = if request.is_none() {
|
||||||
|
self.unsent_key_requests
|
||||||
|
.get(id)?
|
||||||
|
.map(|r| serde_json::from_slice(&r))
|
||||||
|
.transpose()?
|
||||||
|
} else {
|
||||||
|
request
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(request)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
|
@ -472,6 +561,19 @@ impl CryptoStore for SledStore {
|
||||||
self.save_changes(changes).await
|
self.save_changes(changes).await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn load_identity(&self) -> Result<Option<PrivateCrossSigningIdentity>> {
|
||||||
|
if let Some(i) = self.private_identity.get("identity".encode())? {
|
||||||
|
let pickle = serde_json::from_slice(&i)?;
|
||||||
|
Ok(Some(
|
||||||
|
PrivateCrossSigningIdentity::from_pickle(pickle, self.get_pickle_key())
|
||||||
|
.await
|
||||||
|
.map_err(|_| CryptoStoreError::UnpicklingError)?,
|
||||||
|
))
|
||||||
|
} else {
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async fn save_changes(&self, changes: Changes) -> Result<()> {
|
async fn save_changes(&self, changes: Changes) -> Result<()> {
|
||||||
self.save_changes(changes).await
|
self.save_changes(changes).await
|
||||||
}
|
}
|
||||||
|
@ -539,12 +641,11 @@ impl CryptoStore for SledStore {
|
||||||
.collect())
|
.collect())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn users_for_key_query(&self) -> HashSet<UserId> {
|
async fn get_outbound_group_sessions(
|
||||||
#[allow(clippy::map_clone)]
|
&self,
|
||||||
self.users_for_key_query_cache
|
room_id: &RoomId,
|
||||||
.iter()
|
) -> Result<Option<OutboundGroupSession>> {
|
||||||
.map(|u| u.clone())
|
self.load_outbound_group_session(room_id).await
|
||||||
.collect()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn is_user_tracked(&self, user_id: &UserId) -> bool {
|
fn is_user_tracked(&self, user_id: &UserId) -> bool {
|
||||||
|
@ -555,6 +656,14 @@ impl CryptoStore for SledStore {
|
||||||
!self.users_for_key_query_cache.is_empty()
|
!self.users_for_key_query_cache.is_empty()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn users_for_key_query(&self) -> HashSet<UserId> {
|
||||||
|
#[allow(clippy::map_clone)]
|
||||||
|
self.users_for_key_query_cache
|
||||||
|
.iter()
|
||||||
|
.map(|u| u.clone())
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
async fn update_tracked_user(&self, user: &UserId, dirty: bool) -> Result<bool> {
|
async fn update_tracked_user(&self, user: &UserId, dirty: bool) -> Result<bool> {
|
||||||
let already_added = self.tracked_users_cache.insert(user.clone());
|
let already_added = self.tracked_users_cache.insert(user.clone());
|
||||||
|
|
||||||
|
@ -605,48 +714,80 @@ impl CryptoStore for SledStore {
|
||||||
.transpose()?)
|
.transpose()?)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn save_value(&self, key: String, value: String) -> Result<()> {
|
|
||||||
self.values.insert(key.as_str().encode(), value.as_str())?;
|
|
||||||
self.inner.flush_async().await?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn remove_value(&self, key: &str) -> Result<()> {
|
|
||||||
self.values.remove(key.encode())?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn get_value(&self, key: &str) -> Result<Option<String>> {
|
|
||||||
Ok(self
|
|
||||||
.values
|
|
||||||
.get(key.encode())?
|
|
||||||
.map(|v| String::from_utf8_lossy(&v).to_string()))
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn load_identity(&self) -> Result<Option<PrivateCrossSigningIdentity>> {
|
|
||||||
if let Some(i) = self.private_identity.get("identity".encode())? {
|
|
||||||
let pickle = serde_json::from_slice(&i)?;
|
|
||||||
Ok(Some(
|
|
||||||
PrivateCrossSigningIdentity::from_pickle(pickle, self.get_pickle_key())
|
|
||||||
.await
|
|
||||||
.map_err(|_| CryptoStoreError::UnpicklingError)?,
|
|
||||||
))
|
|
||||||
} else {
|
|
||||||
Ok(None)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn is_message_known(&self, message_hash: &crate::olm::OlmMessageHash) -> Result<bool> {
|
async fn is_message_known(&self, message_hash: &crate::olm::OlmMessageHash) -> Result<bool> {
|
||||||
Ok(self
|
Ok(self
|
||||||
.olm_hashes
|
.olm_hashes
|
||||||
.contains_key(serde_json::to_vec(message_hash)?)?)
|
.contains_key(serde_json::to_vec(message_hash)?)?)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_outbound_group_sessions(
|
async fn get_outgoing_key_request(
|
||||||
&self,
|
&self,
|
||||||
room_id: &RoomId,
|
request_id: Uuid,
|
||||||
) -> Result<Option<OutboundGroupSession>> {
|
) -> Result<Option<OutgoingKeyRequest>> {
|
||||||
self.load_outbound_group_session(room_id).await
|
let request_id = request_id.encode();
|
||||||
|
|
||||||
|
self.get_outgoing_key_request_helper(&request_id).await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_key_request_by_info(
|
||||||
|
&self,
|
||||||
|
key_info: &RequestedKeyInfo,
|
||||||
|
) -> Result<Option<OutgoingKeyRequest>> {
|
||||||
|
let id = self.key_requests_by_info.get(key_info.encode())?;
|
||||||
|
|
||||||
|
if let Some(id) = id {
|
||||||
|
self.get_outgoing_key_request_helper(&id).await
|
||||||
|
} else {
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_unsent_key_requests(&self) -> Result<Vec<OutgoingKeyRequest>> {
|
||||||
|
let requests: Result<Vec<OutgoingKeyRequest>> = self
|
||||||
|
.unsent_key_requests
|
||||||
|
.iter()
|
||||||
|
.map(|i| serde_json::from_slice(&i?.1).map_err(CryptoStoreError::from))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
requests
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn delete_outgoing_key_request(&self, request_id: Uuid) -> Result<()> {
|
||||||
|
let ret: Result<(), TransactionError<serde_json::Error>> = (
|
||||||
|
&self.outgoing_key_requests,
|
||||||
|
&self.unsent_key_requests,
|
||||||
|
&self.key_requests_by_info,
|
||||||
|
)
|
||||||
|
.transaction(
|
||||||
|
|(outgoing_key_requests, unsent_key_requests, key_requests_by_info)| {
|
||||||
|
let sent_request: Option<OutgoingKeyRequest> = outgoing_key_requests
|
||||||
|
.remove(request_id.encode())?
|
||||||
|
.map(|r| serde_json::from_slice(&r))
|
||||||
|
.transpose()
|
||||||
|
.map_err(ConflictableTransactionError::Abort)?;
|
||||||
|
|
||||||
|
let unsent_request: Option<OutgoingKeyRequest> = unsent_key_requests
|
||||||
|
.remove(request_id.encode())?
|
||||||
|
.map(|r| serde_json::from_slice(&r))
|
||||||
|
.transpose()
|
||||||
|
.map_err(ConflictableTransactionError::Abort)?;
|
||||||
|
|
||||||
|
if let Some(request) = sent_request {
|
||||||
|
key_requests_by_info.remove((&request.info).encode())?;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(request) = unsent_request {
|
||||||
|
key_requests_by_info.remove((&request.info).encode())?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
ret?;
|
||||||
|
self.inner.flush_async().await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -665,14 +806,16 @@ mod test {
|
||||||
};
|
};
|
||||||
use matrix_sdk_common::{
|
use matrix_sdk_common::{
|
||||||
api::r0::keys::SignedKey,
|
api::r0::keys::SignedKey,
|
||||||
identifiers::{room_id, user_id, DeviceId, UserId},
|
events::room_key_request::RequestedKeyInfo,
|
||||||
|
identifiers::{room_id, user_id, DeviceId, EventEncryptionAlgorithm, UserId},
|
||||||
|
uuid::Uuid,
|
||||||
};
|
};
|
||||||
use matrix_sdk_test::async_test;
|
use matrix_sdk_test::async_test;
|
||||||
use olm_rs::outbound_group_session::OlmOutboundGroupSession;
|
use olm_rs::outbound_group_session::OlmOutboundGroupSession;
|
||||||
use std::collections::BTreeMap;
|
use std::collections::BTreeMap;
|
||||||
use tempfile::tempdir;
|
use tempfile::tempdir;
|
||||||
|
|
||||||
use super::{CryptoStore, SledStore};
|
use super::{CryptoStore, OutgoingKeyRequest, SledStore};
|
||||||
|
|
||||||
fn alice_id() -> UserId {
|
fn alice_id() -> UserId {
|
||||||
user_id!("@alice:example.org")
|
user_id!("@alice:example.org")
|
||||||
|
@ -1184,21 +1327,6 @@ mod test {
|
||||||
assert_eq!(identity.user_id(), loaded_identity.user_id());
|
assert_eq!(identity.user_id(), loaded_identity.user_id());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_test]
|
|
||||||
async fn key_value_saving() {
|
|
||||||
let (_, store, _dir) = get_loaded_store().await;
|
|
||||||
let key = "test_key".to_string();
|
|
||||||
let value = "secret value".to_string();
|
|
||||||
|
|
||||||
store.save_value(key.clone(), value.clone()).await.unwrap();
|
|
||||||
let stored_value = store.get_value(&key).await.unwrap().unwrap();
|
|
||||||
|
|
||||||
assert_eq!(value, stored_value);
|
|
||||||
|
|
||||||
store.remove_value(&key).await.unwrap();
|
|
||||||
assert!(store.get_value(&key).await.unwrap().is_none());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_test]
|
#[async_test]
|
||||||
async fn olm_hash_saving() {
|
async fn olm_hash_saving() {
|
||||||
let (_, store, _dir) = get_loaded_store().await;
|
let (_, store, _dir) = get_loaded_store().await;
|
||||||
|
@ -1215,4 +1343,63 @@ mod test {
|
||||||
store.save_changes(changes).await.unwrap();
|
store.save_changes(changes).await.unwrap();
|
||||||
assert!(store.is_message_known(&hash).await.unwrap());
|
assert!(store.is_message_known(&hash).await.unwrap());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[async_test]
|
||||||
|
async fn key_request_saving() {
|
||||||
|
let (account, store, _dir) = get_loaded_store().await;
|
||||||
|
|
||||||
|
let id = Uuid::new_v4();
|
||||||
|
let info = RequestedKeyInfo {
|
||||||
|
algorithm: EventEncryptionAlgorithm::MegolmV1AesSha2,
|
||||||
|
room_id: room_id!("!test:localhost"),
|
||||||
|
sender_key: "test_sender_key".to_string(),
|
||||||
|
session_id: "test_session_id".to_string(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let request = OutgoingKeyRequest {
|
||||||
|
request_recipient: account.user_id().to_owned(),
|
||||||
|
request_id: id,
|
||||||
|
info: info.clone(),
|
||||||
|
sent_out: false,
|
||||||
|
};
|
||||||
|
|
||||||
|
assert!(store.get_outgoing_key_request(id).await.unwrap().is_none());
|
||||||
|
|
||||||
|
let mut changes = Changes::default();
|
||||||
|
changes.key_requests.push(request.clone());
|
||||||
|
store.save_changes(changes).await.unwrap();
|
||||||
|
|
||||||
|
let request = Some(request);
|
||||||
|
|
||||||
|
let stored_request = store.get_outgoing_key_request(id).await.unwrap();
|
||||||
|
assert_eq!(request, stored_request);
|
||||||
|
|
||||||
|
let stored_request = store.get_key_request_by_info(&info).await.unwrap();
|
||||||
|
assert_eq!(request, stored_request);
|
||||||
|
assert!(!store.get_unsent_key_requests().await.unwrap().is_empty());
|
||||||
|
|
||||||
|
let request = OutgoingKeyRequest {
|
||||||
|
request_recipient: account.user_id().to_owned(),
|
||||||
|
request_id: id,
|
||||||
|
info: info.clone(),
|
||||||
|
sent_out: true,
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut changes = Changes::default();
|
||||||
|
changes.key_requests.push(request.clone());
|
||||||
|
store.save_changes(changes).await.unwrap();
|
||||||
|
|
||||||
|
assert!(store.get_unsent_key_requests().await.unwrap().is_empty());
|
||||||
|
let stored_request = store.get_outgoing_key_request(id).await.unwrap();
|
||||||
|
assert_eq!(Some(request), stored_request);
|
||||||
|
|
||||||
|
store.delete_outgoing_key_request(id).await.unwrap();
|
||||||
|
|
||||||
|
let stored_request = store.get_outgoing_key_request(id).await.unwrap();
|
||||||
|
assert_eq!(None, stored_request);
|
||||||
|
|
||||||
|
let stored_request = store.get_key_request_by_info(&info).await.unwrap();
|
||||||
|
assert_eq!(None, stored_request);
|
||||||
|
assert!(store.get_unsent_key_requests().await.unwrap().is_empty());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue