crypto: Go through the user device keys in parallel

master
Damir Jelić 2021-03-10 13:45:47 +01:00
parent 570bd2e358
commit daf313e358
2 changed files with 83 additions and 45 deletions

View File

@ -23,7 +23,7 @@ use tracing::{trace, warn};
use matrix_sdk_common::{ use matrix_sdk_common::{
api::r0::keys::get_keys::Response as KeysQueryResponse, api::r0::keys::get_keys::Response as KeysQueryResponse,
encryption::DeviceKeys, encryption::DeviceKeys,
identifiers::{DeviceId, DeviceIdBox, UserId}, identifiers::{DeviceIdBox, UserId},
}; };
use crate::{ use crate::{
@ -64,10 +64,6 @@ impl IdentityManager {
&self.user_id &self.user_id
} }
fn device_id(&self) -> &DeviceId {
&self.device_id
}
/// Receive a successful keys query response. /// Receive a successful keys query response.
/// ///
/// Returns a list of devices newly discovered devices and devices that /// Returns a list of devices newly discovered devices and devices that
@ -82,7 +78,7 @@ impl IdentityManager {
response: &KeysQueryResponse, response: &KeysQueryResponse,
) -> OlmResult<(DeviceChanges, IdentityChanges)> { ) -> OlmResult<(DeviceChanges, IdentityChanges)> {
let changed_devices = self let changed_devices = self
.handle_devices_from_key_query(&response.device_keys) .handle_devices_from_key_query(response.device_keys.clone())
.await?; .await?;
let changed_identities = self.handle_cross_singing_keys(response).await?; let changed_identities = self.handle_cross_singing_keys(response).await?;
@ -140,6 +136,63 @@ impl IdentityManager {
} }
} }
async fn update_user_devices(
store: Store,
own_user_id: Arc<UserId>,
own_device_id: Arc<DeviceIdBox>,
user_id: UserId,
device_map: BTreeMap<DeviceIdBox, DeviceKeys>,
) -> StoreResult<DeviceChanges> {
let mut changes = DeviceChanges::default();
let current_devices: HashSet<DeviceIdBox> = device_map.keys().cloned().collect();
let tasks = device_map
.into_iter()
.filter_map(|(device_id, device_keys)| {
// We don't need our own device in the device store.
if user_id == *own_user_id && device_id == *own_device_id {
None
} else if user_id != device_keys.user_id || device_id != device_keys.device_id {
warn!(
"Mismatch in device keys payload of device {}|{} from user {}|{}",
device_id, device_keys.device_id, user_id, device_keys.user_id
);
None
} else {
Some(tokio::spawn(Self::update_or_create_device(
store.clone(),
device_keys,
)))
}
});
let results = join_all(tasks).await;
for device in results {
match device.expect("Creating or updating a device panicked")? {
DeviceChange::New(d) => changes.new.push(d),
DeviceChange::Updated(d) => changes.changed.push(d),
DeviceChange::None => (),
}
}
let current_devices: HashSet<&DeviceIdBox> = current_devices.iter().collect();
let stored_devices = store.get_readonly_devices(&user_id).await?;
let stored_devices_set: HashSet<&DeviceIdBox> = stored_devices.keys().collect();
let deleted_devices_set = stored_devices_set.difference(&current_devices);
for device_id in deleted_devices_set {
if let Some(device) = stored_devices.get(*device_id) {
device.mark_as_deleted();
changes.deleted.push(device.clone());
}
}
Ok(changes)
}
/// Handle the device keys part of a key query response. /// Handle the device keys part of a key query response.
/// ///
/// # Arguments /// # Arguments
@ -151,51 +204,27 @@ impl IdentityManager {
/// they are new, one of their properties has changed or they got deleted. /// they are new, one of their properties has changed or they got deleted.
async fn handle_devices_from_key_query( async fn handle_devices_from_key_query(
&self, &self,
device_keys_map: &BTreeMap<UserId, BTreeMap<DeviceIdBox, DeviceKeys>>, device_keys_map: BTreeMap<UserId, BTreeMap<DeviceIdBox, DeviceKeys>>,
) -> StoreResult<DeviceChanges> { ) -> StoreResult<DeviceChanges> {
let mut changes = DeviceChanges::default(); let mut changes = DeviceChanges::default();
for (user_id, device_map) in device_keys_map { let tasks = device_keys_map
let tasks = device_map.iter().filter_map(|(device_id, device_keys)| { .into_iter()
// We don't need our own device in the device store. .map(|(user_id, device_keys_map)| {
if user_id == self.user_id() && &**device_id == self.device_id() { tokio::spawn(Self::update_user_devices(
None self.store.clone(),
} else if user_id != &device_keys.user_id || device_id != &device_keys.device_id { self.user_id.clone(),
warn!( self.device_id.clone(),
"Mismatch in device keys payload of device {}|{} from user {}|{}", user_id,
device_id, device_keys.device_id, user_id, device_keys.user_id device_keys_map,
); ))
None
} else {
Some(tokio::spawn(Self::update_or_create_device(
self.store.clone(),
device_keys.clone(),
)))
}
}); });
let results = join_all(tasks).await; let results = join_all(tasks).await;
for device in results { for result in results {
match device.expect("Creating or updating a device panicked")? { let change_fragment = result.expect("Panic while updating user devices")?;
DeviceChange::New(d) => changes.new.push(d), changes.extend(change_fragment);
DeviceChange::Updated(d) => changes.changed.push(d),
DeviceChange::None => (),
}
}
let current_devices: HashSet<&DeviceIdBox> = device_map.keys().collect();
let stored_devices = self.store.get_readonly_devices(&user_id).await?;
let stored_devices_set: HashSet<&DeviceIdBox> = stored_devices.keys().collect();
let deleted_devices_set = stored_devices_set.difference(&current_devices);
for device_id in deleted_devices_set {
if let Some(device) = stored_devices.get(*device_id) {
device.mark_as_deleted();
changes.deleted.push(device.clone());
}
}
} }
Ok(changes) Ok(changes)

View File

@ -126,6 +126,15 @@ pub struct DeviceChanges {
pub deleted: Vec<ReadOnlyDevice>, pub deleted: Vec<ReadOnlyDevice>,
} }
impl DeviceChanges {
/// Merge the given `DeviceChanges` into this instance of `DeviceChanges`.
pub fn extend(&mut self, other: DeviceChanges) {
self.new.extend(other.new);
self.changed.extend(other.changed);
self.deleted.extend(other.deleted);
}
}
impl Store { impl Store {
pub fn new( pub fn new(
user_id: Arc<UserId>, user_id: Arc<UserId>,