crypto: Refactor out the key share wait queue.

master
Damir Jelić 2020-10-07 12:42:39 +02:00
parent 6a8ac62a51
commit 06b9c71dbc
1 changed files with 85 additions and 36 deletions

View File

@ -20,7 +20,7 @@
// If we don't trust the device store an object that remembers the request and // If we don't trust the device store an object that remembers the request and
// let the users introspect that object. // let the users introspect that object.
use dashmap::{DashMap, DashSet}; use dashmap::{mapref::entry::Entry, DashMap, DashSet};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::value::to_raw_value; use serde_json::value::to_raw_value;
use std::{collections::BTreeMap, sync::Arc}; use std::{collections::BTreeMap, sync::Arc};
@ -63,6 +63,66 @@ pub enum KeyshareDecision {
UntrustedDevice, UntrustedDevice,
} }
/// A queue where we store room key requests that we want to serve but the
/// device that requested the key doesn't share an Olm session with us.
#[derive(Debug, Clone)]
struct WaitQueue {
requests_waiting_for_session:
Arc<DashMap<(UserId, DeviceIdBox, String), ToDeviceEvent<RoomKeyRequestEventContent>>>,
requests_ids_waiting: Arc<DashMap<(UserId, DeviceIdBox), DashSet<String>>>,
}
impl WaitQueue {
fn new() -> Self {
Self {
requests_waiting_for_session: Arc::new(DashMap::new()),
requests_ids_waiting: Arc::new(DashMap::new()),
}
}
#[cfg(test)]
fn is_empty(&self) -> bool {
self.requests_ids_waiting.is_empty() && self.requests_waiting_for_session.is_empty()
}
fn insert(&self, device: &Device, event: &ToDeviceEvent<RoomKeyRequestEventContent>) {
let key = (
device.user_id().to_owned(),
device.device_id().into(),
event.content.request_id.to_owned(),
);
self.requests_waiting_for_session.insert(key, event.clone());
let key = (device.user_id().to_owned(), device.device_id().into());
self.requests_ids_waiting
.entry(key)
.or_insert_with(DashSet::new)
.insert(event.content.request_id.clone());
}
fn remove(
&self,
user_id: &UserId,
device_id: &DeviceId,
) -> Vec<(
(UserId, DeviceIdBox, String),
ToDeviceEvent<RoomKeyRequestEventContent>,
)> {
self.requests_ids_waiting
.remove(&(user_id.to_owned(), device_id.into()))
.map(|(_, request_ids)| {
request_ids
.iter()
.filter_map(|id| {
let key = (user_id.to_owned(), device_id.into(), id.to_owned());
self.requests_waiting_for_session.remove(&key)
})
.collect()
})
.unwrap_or_default()
}
}
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub(crate) struct KeyRequestMachine { pub(crate) struct KeyRequestMachine {
user_id: Arc<UserId>, user_id: Arc<UserId>,
@ -72,10 +132,7 @@ pub(crate) struct KeyRequestMachine {
outgoing_to_device_requests: Arc<DashMap<Uuid, OutgoingRequest>>, outgoing_to_device_requests: Arc<DashMap<Uuid, OutgoingRequest>>,
incoming_key_requests: incoming_key_requests:
Arc<DashMap<(UserId, DeviceIdBox, String), ToDeviceEvent<RoomKeyRequestEventContent>>>, Arc<DashMap<(UserId, DeviceIdBox, String), ToDeviceEvent<RoomKeyRequestEventContent>>>,
// TODO group these hashmaps into a logical unit. wait_queue: WaitQueue,
requests_waiting_for_session:
Arc<DashMap<(UserId, DeviceIdBox, String), ToDeviceEvent<RoomKeyRequestEventContent>>>,
requests_ids_waiting: Arc<DashMap<(UserId, DeviceIdBox), DashSet<String>>>,
users_for_key_claim: Arc<DashMap<UserId, DashSet<DeviceIdBox>>>, users_for_key_claim: Arc<DashMap<UserId, DashSet<DeviceIdBox>>>,
} }
@ -147,9 +204,8 @@ impl KeyRequestMachine {
outbound_group_sessions, outbound_group_sessions,
outgoing_to_device_requests: Arc::new(DashMap::new()), outgoing_to_device_requests: Arc::new(DashMap::new()),
incoming_key_requests: Arc::new(DashMap::new()), incoming_key_requests: Arc::new(DashMap::new()),
requests_waiting_for_session: Arc::new(DashMap::new()), wait_queue: WaitQueue::new(),
users_for_key_claim: Arc::new(DashMap::new()), users_for_key_claim: Arc::new(DashMap::new()),
requests_ids_waiting: Arc::new(DashMap::new()),
} }
} }
@ -189,6 +245,9 @@ impl KeyRequestMachine {
Ok(()) Ok(())
} }
/// Store the key share request for later, once we get an Olm session with
/// the given device [`retry_keyshare`](#method.retry_keyshare) should be
/// called.
fn handle_key_share_without_session( fn handle_key_share_without_session(
&self, &self,
device: Device, device: Device,
@ -198,18 +257,7 @@ impl KeyRequestMachine {
.entry(device.user_id().to_owned()) .entry(device.user_id().to_owned())
.or_insert_with(DashSet::new) .or_insert_with(DashSet::new)
.insert(device.device_id().into()); .insert(device.device_id().into());
let key = ( self.wait_queue.insert(&device, event);
device.user_id().to_owned(),
device.device_id().into(),
event.content.request_id.to_owned(),
);
self.requests_waiting_for_session.insert(key, event.clone());
let key = (device.user_id().to_owned(), device.device_id().into());
self.requests_ids_waiting
.entry(key)
.or_insert_with(DashSet::new)
.insert(event.content.request_id.clone());
} }
/// Retry keyshares for a device that previously didn't have an Olm session /// Retry keyshares for a device that previously didn't have an Olm session
@ -225,19 +273,20 @@ impl KeyRequestMachine {
/// ///
/// * `device_id` - The device id of the device that got the Olm session. /// * `device_id` - The device id of the device that got the Olm session.
pub fn retry_keyshare(&self, user_id: &UserId, device_id: &DeviceId) { pub fn retry_keyshare(&self, user_id: &UserId, device_id: &DeviceId) {
if let Some((_, request_ids)) = self match self.users_for_key_claim.entry(user_id.to_owned()) {
.requests_ids_waiting Entry::Occupied(e) => {
.remove(&(user_id.to_owned(), device_id.into())) e.get().remove(device_id);
{
for id in request_ids {
let key = (user_id.to_owned(), device_id.into(), id.to_owned());
let content = self.requests_waiting_for_session.remove(&key);
if let Some((_, c)) = content { if e.get().is_empty() {
e.remove();
}
}
_ => (),
}
for (key, event) in self.wait_queue.remove(user_id, device_id) {
if !self.incoming_key_requests.contains_key(&key) { if !self.incoming_key_requests.contains_key(&key) {
self.incoming_key_requests.insert(key, c); self.incoming_key_requests.insert(key, event);
}
}
} }
} }
} }
@ -1188,16 +1237,16 @@ mod test {
// Bob doesn't have any outgoing requests. // Bob doesn't have any outgoing requests.
assert!(bob_machine.outgoing_to_device_requests.is_empty()); assert!(bob_machine.outgoing_to_device_requests.is_empty());
assert!(bob_machine.requests_ids_waiting.is_empty()); assert!(bob_machine.users_for_key_claim.is_empty());
assert!(bob_machine.requests_waiting_for_session.is_empty()); assert!(bob_machine.wait_queue.is_empty());
// Receive the room key request from alice. // Receive the room key request from alice.
bob_machine.receive_incoming_key_request(&event); bob_machine.receive_incoming_key_request(&event);
bob_machine.collect_incoming_key_requests().await.unwrap(); bob_machine.collect_incoming_key_requests().await.unwrap();
// Bob doens't have an outgoing requests since we're lacking a session. // Bob doens't have an outgoing requests since we're lacking a session.
assert!(bob_machine.outgoing_to_device_requests.is_empty()); assert!(bob_machine.outgoing_to_device_requests.is_empty());
assert!(!bob_machine.requests_ids_waiting.is_empty()); assert!(!bob_machine.users_for_key_claim.is_empty());
assert!(!bob_machine.requests_waiting_for_session.is_empty()); assert!(!bob_machine.wait_queue.is_empty());
// We create a session now. // We create a session now.
alice_machine alice_machine
@ -1212,11 +1261,11 @@ mod test {
.unwrap(); .unwrap();
bob_machine.retry_keyshare(&alice_id(), &alice_device_id()); bob_machine.retry_keyshare(&alice_id(), &alice_device_id());
assert!(bob_machine.users_for_key_claim.is_empty());
bob_machine.collect_incoming_key_requests().await.unwrap(); bob_machine.collect_incoming_key_requests().await.unwrap();
// Bob now has an outgoing requests. // Bob now has an outgoing requests.
assert!(!bob_machine.outgoing_to_device_requests.is_empty()); assert!(!bob_machine.outgoing_to_device_requests.is_empty());
assert!(bob_machine.requests_ids_waiting.is_empty()); assert!(bob_machine.wait_queue.is_empty());
assert!(bob_machine.requests_waiting_for_session.is_empty());
// Get the request and convert it to a encrypted to-device event. // Get the request and convert it to a encrypted to-device event.
let request = bob_machine let request = bob_machine