crypto: Refactor out the key share wait queue.
parent
6a8ac62a51
commit
06b9c71dbc
|
@ -20,7 +20,7 @@
|
|||
// If we don't trust the device store an object that remembers the request and
|
||||
// let the users introspect that object.
|
||||
|
||||
use dashmap::{DashMap, DashSet};
|
||||
use dashmap::{mapref::entry::Entry, DashMap, DashSet};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::value::to_raw_value;
|
||||
use std::{collections::BTreeMap, sync::Arc};
|
||||
|
@ -63,6 +63,66 @@ pub enum KeyshareDecision {
|
|||
UntrustedDevice,
|
||||
}
|
||||
|
||||
/// A queue where we store room key requests that we want to serve but the
|
||||
/// device that requested the key doesn't share an Olm session with us.
|
||||
#[derive(Debug, Clone)]
|
||||
struct WaitQueue {
|
||||
requests_waiting_for_session:
|
||||
Arc<DashMap<(UserId, DeviceIdBox, String), ToDeviceEvent<RoomKeyRequestEventContent>>>,
|
||||
requests_ids_waiting: Arc<DashMap<(UserId, DeviceIdBox), DashSet<String>>>,
|
||||
}
|
||||
|
||||
impl WaitQueue {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
requests_waiting_for_session: Arc::new(DashMap::new()),
|
||||
requests_ids_waiting: Arc::new(DashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn is_empty(&self) -> bool {
|
||||
self.requests_ids_waiting.is_empty() && self.requests_waiting_for_session.is_empty()
|
||||
}
|
||||
|
||||
fn insert(&self, device: &Device, event: &ToDeviceEvent<RoomKeyRequestEventContent>) {
|
||||
let key = (
|
||||
device.user_id().to_owned(),
|
||||
device.device_id().into(),
|
||||
event.content.request_id.to_owned(),
|
||||
);
|
||||
self.requests_waiting_for_session.insert(key, event.clone());
|
||||
|
||||
let key = (device.user_id().to_owned(), device.device_id().into());
|
||||
self.requests_ids_waiting
|
||||
.entry(key)
|
||||
.or_insert_with(DashSet::new)
|
||||
.insert(event.content.request_id.clone());
|
||||
}
|
||||
|
||||
fn remove(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
) -> Vec<(
|
||||
(UserId, DeviceIdBox, String),
|
||||
ToDeviceEvent<RoomKeyRequestEventContent>,
|
||||
)> {
|
||||
self.requests_ids_waiting
|
||||
.remove(&(user_id.to_owned(), device_id.into()))
|
||||
.map(|(_, request_ids)| {
|
||||
request_ids
|
||||
.iter()
|
||||
.filter_map(|id| {
|
||||
let key = (user_id.to_owned(), device_id.into(), id.to_owned());
|
||||
self.requests_waiting_for_session.remove(&key)
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct KeyRequestMachine {
|
||||
user_id: Arc<UserId>,
|
||||
|
@ -72,10 +132,7 @@ pub(crate) struct KeyRequestMachine {
|
|||
outgoing_to_device_requests: Arc<DashMap<Uuid, OutgoingRequest>>,
|
||||
incoming_key_requests:
|
||||
Arc<DashMap<(UserId, DeviceIdBox, String), ToDeviceEvent<RoomKeyRequestEventContent>>>,
|
||||
// TODO group these hashmaps into a logical unit.
|
||||
requests_waiting_for_session:
|
||||
Arc<DashMap<(UserId, DeviceIdBox, String), ToDeviceEvent<RoomKeyRequestEventContent>>>,
|
||||
requests_ids_waiting: Arc<DashMap<(UserId, DeviceIdBox), DashSet<String>>>,
|
||||
wait_queue: WaitQueue,
|
||||
users_for_key_claim: Arc<DashMap<UserId, DashSet<DeviceIdBox>>>,
|
||||
}
|
||||
|
||||
|
@ -147,9 +204,8 @@ impl KeyRequestMachine {
|
|||
outbound_group_sessions,
|
||||
outgoing_to_device_requests: Arc::new(DashMap::new()),
|
||||
incoming_key_requests: Arc::new(DashMap::new()),
|
||||
requests_waiting_for_session: Arc::new(DashMap::new()),
|
||||
wait_queue: WaitQueue::new(),
|
||||
users_for_key_claim: Arc::new(DashMap::new()),
|
||||
requests_ids_waiting: Arc::new(DashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -189,6 +245,9 @@ impl KeyRequestMachine {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
/// Store the key share request for later, once we get an Olm session with
|
||||
/// the given device [`retry_keyshare`](#method.retry_keyshare) should be
|
||||
/// called.
|
||||
fn handle_key_share_without_session(
|
||||
&self,
|
||||
device: Device,
|
||||
|
@ -198,18 +257,7 @@ impl KeyRequestMachine {
|
|||
.entry(device.user_id().to_owned())
|
||||
.or_insert_with(DashSet::new)
|
||||
.insert(device.device_id().into());
|
||||
let key = (
|
||||
device.user_id().to_owned(),
|
||||
device.device_id().into(),
|
||||
event.content.request_id.to_owned(),
|
||||
);
|
||||
self.requests_waiting_for_session.insert(key, event.clone());
|
||||
|
||||
let key = (device.user_id().to_owned(), device.device_id().into());
|
||||
self.requests_ids_waiting
|
||||
.entry(key)
|
||||
.or_insert_with(DashSet::new)
|
||||
.insert(event.content.request_id.clone());
|
||||
self.wait_queue.insert(&device, event);
|
||||
}
|
||||
|
||||
/// Retry keyshares for a device that previously didn't have an Olm session
|
||||
|
@ -225,20 +273,21 @@ impl KeyRequestMachine {
|
|||
///
|
||||
/// * `device_id` - The device id of the device that got the Olm session.
|
||||
pub fn retry_keyshare(&self, user_id: &UserId, device_id: &DeviceId) {
|
||||
if let Some((_, request_ids)) = self
|
||||
.requests_ids_waiting
|
||||
.remove(&(user_id.to_owned(), device_id.into()))
|
||||
{
|
||||
for id in request_ids {
|
||||
let key = (user_id.to_owned(), device_id.into(), id.to_owned());
|
||||
let content = self.requests_waiting_for_session.remove(&key);
|
||||
match self.users_for_key_claim.entry(user_id.to_owned()) {
|
||||
Entry::Occupied(e) => {
|
||||
e.get().remove(device_id);
|
||||
|
||||
if let Some((_, c)) = content {
|
||||
if !self.incoming_key_requests.contains_key(&key) {
|
||||
self.incoming_key_requests.insert(key, c);
|
||||
}
|
||||
if e.get().is_empty() {
|
||||
e.remove();
|
||||
}
|
||||
}
|
||||
_ => (),
|
||||
}
|
||||
|
||||
for (key, event) in self.wait_queue.remove(user_id, device_id) {
|
||||
if !self.incoming_key_requests.contains_key(&key) {
|
||||
self.incoming_key_requests.insert(key, event);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1188,16 +1237,16 @@ mod test {
|
|||
|
||||
// Bob doesn't have any outgoing requests.
|
||||
assert!(bob_machine.outgoing_to_device_requests.is_empty());
|
||||
assert!(bob_machine.requests_ids_waiting.is_empty());
|
||||
assert!(bob_machine.requests_waiting_for_session.is_empty());
|
||||
assert!(bob_machine.users_for_key_claim.is_empty());
|
||||
assert!(bob_machine.wait_queue.is_empty());
|
||||
|
||||
// Receive the room key request from alice.
|
||||
bob_machine.receive_incoming_key_request(&event);
|
||||
bob_machine.collect_incoming_key_requests().await.unwrap();
|
||||
// Bob doens't have an outgoing requests since we're lacking a session.
|
||||
assert!(bob_machine.outgoing_to_device_requests.is_empty());
|
||||
assert!(!bob_machine.requests_ids_waiting.is_empty());
|
||||
assert!(!bob_machine.requests_waiting_for_session.is_empty());
|
||||
assert!(!bob_machine.users_for_key_claim.is_empty());
|
||||
assert!(!bob_machine.wait_queue.is_empty());
|
||||
|
||||
// We create a session now.
|
||||
alice_machine
|
||||
|
@ -1212,11 +1261,11 @@ mod test {
|
|||
.unwrap();
|
||||
|
||||
bob_machine.retry_keyshare(&alice_id(), &alice_device_id());
|
||||
assert!(bob_machine.users_for_key_claim.is_empty());
|
||||
bob_machine.collect_incoming_key_requests().await.unwrap();
|
||||
// Bob now has an outgoing requests.
|
||||
assert!(!bob_machine.outgoing_to_device_requests.is_empty());
|
||||
assert!(bob_machine.requests_ids_waiting.is_empty());
|
||||
assert!(bob_machine.requests_waiting_for_session.is_empty());
|
||||
assert!(bob_machine.wait_queue.is_empty());
|
||||
|
||||
// Get the request and convert it to a encrypted to-device event.
|
||||
let request = bob_machine
|
||||
|
|
Loading…
Reference in New Issue