Merge branch 'crypto-improvements' into master
commit
7c3e751d6e
|
@ -300,7 +300,7 @@ mod test {
|
||||||
let room_id = room_id!("!test:localhost");
|
let room_id = room_id!("!test:localhost");
|
||||||
|
|
||||||
machine
|
machine
|
||||||
.create_outnbound_group_session_with_defaults(&room_id)
|
.create_outbound_group_session_with_defaults(&room_id)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let export = machine
|
let export = machine
|
||||||
|
|
|
@ -0,0 +1,349 @@
|
||||||
|
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
use std::{
|
||||||
|
collections::{BTreeMap, HashSet},
|
||||||
|
convert::TryFrom,
|
||||||
|
sync::Arc,
|
||||||
|
};
|
||||||
|
use tracing::{info, trace, warn};
|
||||||
|
|
||||||
|
use matrix_sdk_common::{
|
||||||
|
api::r0::keys::get_keys::Response as KeysQueryResponse,
|
||||||
|
encryption::DeviceKeys,
|
||||||
|
identifiers::{DeviceId, DeviceIdBox, UserId},
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
error::OlmResult,
|
||||||
|
identities::{
|
||||||
|
MasterPubkey, OwnUserIdentity, ReadOnlyDevice, SelfSigningPubkey, UserIdentities,
|
||||||
|
UserIdentity, UserSigningPubkey,
|
||||||
|
},
|
||||||
|
requests::KeysQueryRequest,
|
||||||
|
store::{Result as StoreResult, Store},
|
||||||
|
};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub(crate) struct IdentityManager {
|
||||||
|
user_id: Arc<UserId>,
|
||||||
|
device_id: Arc<DeviceIdBox>,
|
||||||
|
store: Store,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IdentityManager {
|
||||||
|
pub fn new(user_id: Arc<UserId>, device_id: Arc<DeviceIdBox>, store: Store) -> Self {
|
||||||
|
IdentityManager {
|
||||||
|
user_id,
|
||||||
|
device_id,
|
||||||
|
store,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn user_id(&self) -> &UserId {
|
||||||
|
&self.user_id
|
||||||
|
}
|
||||||
|
|
||||||
|
fn device_id(&self) -> &DeviceId {
|
||||||
|
&self.device_id
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Receive a successful keys query response.
|
||||||
|
///
|
||||||
|
/// Returns a list of devices newly discovered devices and devices that
|
||||||
|
/// changed.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `response` - The keys query response of the request that the client
|
||||||
|
/// performed.
|
||||||
|
pub async fn receive_keys_query_response(
|
||||||
|
&self,
|
||||||
|
response: &KeysQueryResponse,
|
||||||
|
) -> OlmResult<(Vec<ReadOnlyDevice>, Vec<UserIdentities>)> {
|
||||||
|
// TODO create a enum that tells us how the device/identity changed,
|
||||||
|
// e.g. new/deleted/display name change.
|
||||||
|
//
|
||||||
|
// TODO create a struct that will hold the device/identity and the
|
||||||
|
// change enum and return the struct.
|
||||||
|
//
|
||||||
|
// TODO once outbound group sessions hold on to the set of users that
|
||||||
|
// received the session, invalidate the session if a user device
|
||||||
|
// got added/deleted.
|
||||||
|
let changed_devices = self
|
||||||
|
.handle_devices_from_key_query(&response.device_keys)
|
||||||
|
.await?;
|
||||||
|
self.store.save_devices(&changed_devices).await?;
|
||||||
|
let changed_identities = self.handle_cross_singing_keys(response).await?;
|
||||||
|
self.store.save_user_identities(&changed_identities).await?;
|
||||||
|
|
||||||
|
Ok((changed_devices, changed_identities))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Handle the device keys part of a key query response.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `device_keys_map` - A map holding the device keys of the users for
|
||||||
|
/// which the key query was done.
|
||||||
|
///
|
||||||
|
/// Returns a list of devices that changed. Changed here means either
|
||||||
|
/// they are new, one of their properties has changed or they got deleted.
|
||||||
|
async fn handle_devices_from_key_query(
|
||||||
|
&self,
|
||||||
|
device_keys_map: &BTreeMap<UserId, BTreeMap<DeviceIdBox, DeviceKeys>>,
|
||||||
|
) -> StoreResult<Vec<ReadOnlyDevice>> {
|
||||||
|
let mut changed_devices = Vec::new();
|
||||||
|
|
||||||
|
for (user_id, device_map) in device_keys_map {
|
||||||
|
// TODO move this out into the handle keys query response method
|
||||||
|
// since we might fail handle the new device at any point here or
|
||||||
|
// when updating the user identities.
|
||||||
|
self.store.update_tracked_user(user_id, false).await?;
|
||||||
|
|
||||||
|
for (device_id, device_keys) in device_map.iter() {
|
||||||
|
// We don't need our own device in the device store.
|
||||||
|
if user_id == self.user_id() && &**device_id == self.device_id() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if user_id != &device_keys.user_id || device_id != &device_keys.device_id {
|
||||||
|
warn!(
|
||||||
|
"Mismatch in device keys payload of device {} from user {}",
|
||||||
|
device_keys.device_id, device_keys.user_id
|
||||||
|
);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let device = self.store.get_device(&user_id, device_id).await?;
|
||||||
|
|
||||||
|
let device = if let Some(mut device) = device {
|
||||||
|
if let Err(e) = device.update_device(device_keys) {
|
||||||
|
warn!(
|
||||||
|
"Failed to update the device keys for {} {}: {:?}",
|
||||||
|
user_id, device_id, e
|
||||||
|
);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
device
|
||||||
|
} else {
|
||||||
|
let device = match ReadOnlyDevice::try_from(device_keys) {
|
||||||
|
Ok(d) => d,
|
||||||
|
Err(e) => {
|
||||||
|
warn!(
|
||||||
|
"Failed to create a new device for {} {}: {:?}",
|
||||||
|
user_id, device_id, e
|
||||||
|
);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
info!("Adding a new device to the device store {:?}", device);
|
||||||
|
device
|
||||||
|
};
|
||||||
|
|
||||||
|
changed_devices.push(device);
|
||||||
|
}
|
||||||
|
|
||||||
|
let current_devices: HashSet<&DeviceId> =
|
||||||
|
device_map.keys().map(|id| id.as_ref()).collect();
|
||||||
|
let stored_devices = self.store.get_user_devices(&user_id).await.unwrap();
|
||||||
|
let stored_devices_set: HashSet<&DeviceId> = stored_devices.keys().collect();
|
||||||
|
|
||||||
|
let deleted_devices = stored_devices_set.difference(¤t_devices);
|
||||||
|
|
||||||
|
for device_id in deleted_devices {
|
||||||
|
if let Some(device) = stored_devices.get(device_id) {
|
||||||
|
device.mark_as_deleted();
|
||||||
|
self.store.delete_device(device).await?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(changed_devices)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Handle the device keys part of a key query response.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `response` - The keys query response.
|
||||||
|
///
|
||||||
|
/// Returns a list of identities that changed. Changed here means either
|
||||||
|
/// they are new, one of their properties has changed or they got deleted.
|
||||||
|
async fn handle_cross_singing_keys(
|
||||||
|
&self,
|
||||||
|
response: &KeysQueryResponse,
|
||||||
|
) -> StoreResult<Vec<UserIdentities>> {
|
||||||
|
let mut changed = Vec::new();
|
||||||
|
|
||||||
|
for (user_id, master_key) in &response.master_keys {
|
||||||
|
let master_key = MasterPubkey::from(master_key);
|
||||||
|
|
||||||
|
let self_signing = if let Some(s) = response.self_signing_keys.get(user_id) {
|
||||||
|
SelfSigningPubkey::from(s)
|
||||||
|
} else {
|
||||||
|
warn!(
|
||||||
|
"User identity for user {} didn't contain a self signing pubkey",
|
||||||
|
user_id
|
||||||
|
);
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
|
||||||
|
let identity = if let Some(mut i) = self.store.get_user_identity(user_id).await? {
|
||||||
|
match &mut i {
|
||||||
|
UserIdentities::Own(ref mut identity) => {
|
||||||
|
let user_signing = if let Some(s) = response.user_signing_keys.get(user_id)
|
||||||
|
{
|
||||||
|
UserSigningPubkey::from(s)
|
||||||
|
} else {
|
||||||
|
warn!(
|
||||||
|
"User identity for our own user {} didn't \
|
||||||
|
contain a user signing pubkey",
|
||||||
|
user_id
|
||||||
|
);
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
|
||||||
|
identity
|
||||||
|
.update(master_key, self_signing, user_signing)
|
||||||
|
.map(|_| i)
|
||||||
|
}
|
||||||
|
UserIdentities::Other(ref mut identity) => {
|
||||||
|
identity.update(master_key, self_signing).map(|_| i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if user_id == self.user_id() {
|
||||||
|
if let Some(s) = response.user_signing_keys.get(user_id) {
|
||||||
|
let user_signing = UserSigningPubkey::from(s);
|
||||||
|
|
||||||
|
if master_key.user_id() != user_id
|
||||||
|
|| self_signing.user_id() != user_id
|
||||||
|
|| user_signing.user_id() != user_id
|
||||||
|
{
|
||||||
|
warn!(
|
||||||
|
"User id mismatch in one of the cross signing keys for user {}",
|
||||||
|
user_id
|
||||||
|
);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
OwnUserIdentity::new(master_key, self_signing, user_signing)
|
||||||
|
.map(UserIdentities::Own)
|
||||||
|
} else {
|
||||||
|
warn!(
|
||||||
|
"User identity for our own user {} didn't contain a \
|
||||||
|
user signing pubkey",
|
||||||
|
user_id
|
||||||
|
);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
} else if master_key.user_id() != user_id || self_signing.user_id() != user_id {
|
||||||
|
warn!(
|
||||||
|
"User id mismatch in one of the cross signing keys for user {}",
|
||||||
|
user_id
|
||||||
|
);
|
||||||
|
continue;
|
||||||
|
} else {
|
||||||
|
UserIdentity::new(master_key, self_signing).map(UserIdentities::Other)
|
||||||
|
};
|
||||||
|
|
||||||
|
match identity {
|
||||||
|
Ok(i) => {
|
||||||
|
trace!(
|
||||||
|
"Updated or created new user identity for {}: {:?}",
|
||||||
|
user_id,
|
||||||
|
i
|
||||||
|
);
|
||||||
|
changed.push(i);
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
warn!(
|
||||||
|
"Couldn't update or create new user identity for {}: {:?}",
|
||||||
|
user_id, e
|
||||||
|
);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(changed)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get a key query request if one is needed.
|
||||||
|
///
|
||||||
|
/// Returns a key query reqeust if the client should query E2E keys,
|
||||||
|
/// otherwise None.
|
||||||
|
///
|
||||||
|
/// The response of a successful key query requests needs to be passed to
|
||||||
|
/// the [`OlmMachine`] with the [`receive_keys_query_response`].
|
||||||
|
///
|
||||||
|
/// [`OlmMachine`]: struct.OlmMachine.html
|
||||||
|
/// [`receive_keys_query_response`]: #method.receive_keys_query_response
|
||||||
|
pub async fn users_for_key_query(&self) -> Option<KeysQueryRequest> {
|
||||||
|
let mut users = self.store.users_for_key_query();
|
||||||
|
|
||||||
|
if users.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
let mut device_keys: BTreeMap<UserId, Vec<Box<DeviceId>>> = BTreeMap::new();
|
||||||
|
|
||||||
|
for user in users.drain() {
|
||||||
|
device_keys.insert(user, Vec::new());
|
||||||
|
}
|
||||||
|
|
||||||
|
Some(KeysQueryRequest::new(device_keys))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Mark that the given user has changed his devices.
|
||||||
|
///
|
||||||
|
/// This will queue up the given user for a key query.
|
||||||
|
///
|
||||||
|
/// Note: The user already needs to be tracked for it to be queued up for a
|
||||||
|
/// key query.
|
||||||
|
///
|
||||||
|
/// Returns true if the user was queued up for a key query, false otherwise.
|
||||||
|
pub async fn mark_user_as_changed(&self, user_id: &UserId) -> StoreResult<bool> {
|
||||||
|
if self.store.is_user_tracked(user_id) {
|
||||||
|
self.store.update_tracked_user(user_id, true).await?;
|
||||||
|
Ok(true)
|
||||||
|
} else {
|
||||||
|
Ok(false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Update the tracked users.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `users` - An iterator over user ids that should be marked for
|
||||||
|
/// tracking.
|
||||||
|
///
|
||||||
|
/// This will mark users that weren't seen before for a key query and
|
||||||
|
/// tracking.
|
||||||
|
///
|
||||||
|
/// If the user is already known to the Olm machine it will not be
|
||||||
|
/// considered for a key query.
|
||||||
|
pub async fn update_tracked_users(&self, users: impl IntoIterator<Item = &UserId>) {
|
||||||
|
for user in users {
|
||||||
|
if self.store.is_user_tracked(user) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Err(e) = self.store.update_tracked_user(user, true).await {
|
||||||
|
warn!("Error storing users for tracking {}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -41,9 +41,11 @@
|
||||||
//! Both identity sets need to reqularly fetched from the server using the
|
//! Both identity sets need to reqularly fetched from the server using the
|
||||||
//! `/keys/query` API call.
|
//! `/keys/query` API call.
|
||||||
pub(crate) mod device;
|
pub(crate) mod device;
|
||||||
|
mod manager;
|
||||||
pub(crate) mod user;
|
pub(crate) mod user;
|
||||||
|
|
||||||
pub use device::{Device, LocalTrust, ReadOnlyDevice, UserDevices};
|
pub use device::{Device, LocalTrust, ReadOnlyDevice, UserDevices};
|
||||||
|
pub(crate) use manager::IdentityManager;
|
||||||
pub use user::{
|
pub use user::{
|
||||||
MasterPubkey, OwnUserIdentity, SelfSigningPubkey, UserIdentities, UserIdentity,
|
MasterPubkey, OwnUserIdentity, SelfSigningPubkey, UserIdentities, UserIdentity,
|
||||||
UserSigningPubkey,
|
UserSigningPubkey,
|
||||||
|
|
|
@ -374,7 +374,7 @@ mod test {
|
||||||
let account = account();
|
let account = account();
|
||||||
|
|
||||||
let (_, session) = account
|
let (_, session) = account
|
||||||
.create_group_session_pair(&room_id(), Default::default())
|
.create_group_session_pair_with_defaults(&room_id())
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
@ -415,7 +415,7 @@ mod test {
|
||||||
let account = account();
|
let account = account();
|
||||||
|
|
||||||
let (_, session) = account
|
let (_, session) = account
|
||||||
.create_group_session_pair(&room_id(), Default::default())
|
.create_group_session_pair_with_defaults(&room_id())
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
machine
|
machine
|
||||||
|
|
|
@ -14,13 +14,7 @@
|
||||||
|
|
||||||
#[cfg(feature = "sqlite_cryptostore")]
|
#[cfg(feature = "sqlite_cryptostore")]
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use std::{
|
use std::{collections::BTreeMap, convert::TryInto, mem, sync::Arc, time::Duration};
|
||||||
collections::{BTreeMap, HashSet},
|
|
||||||
convert::{TryFrom, TryInto},
|
|
||||||
mem,
|
|
||||||
sync::Arc,
|
|
||||||
time::Duration,
|
|
||||||
};
|
|
||||||
|
|
||||||
use dashmap::DashMap;
|
use dashmap::DashMap;
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
@ -37,7 +31,6 @@ use matrix_sdk_common::{
|
||||||
to_device::DeviceIdOrAllDevices,
|
to_device::DeviceIdOrAllDevices,
|
||||||
},
|
},
|
||||||
assign,
|
assign,
|
||||||
encryption::DeviceKeys,
|
|
||||||
events::{
|
events::{
|
||||||
room::encrypted::EncryptedEventContent, room_key::RoomKeyEventContent,
|
room::encrypted::EncryptedEventContent, room_key::RoomKeyEventContent,
|
||||||
room_key_request::RoomKeyRequestEventContent, AnyMessageEventContent, AnySyncRoomEvent,
|
room_key_request::RoomKeyRequestEventContent, AnyMessageEventContent, AnySyncRoomEvent,
|
||||||
|
@ -54,16 +47,13 @@ use matrix_sdk_common::{
|
||||||
use super::store::sqlite::SqliteStore;
|
use super::store::sqlite::SqliteStore;
|
||||||
use super::{
|
use super::{
|
||||||
error::{EventError, MegolmError, MegolmResult, OlmError, OlmResult},
|
error::{EventError, MegolmError, MegolmResult, OlmError, OlmResult},
|
||||||
identities::{
|
identities::{Device, IdentityManager, ReadOnlyDevice, UserDevices, UserIdentities},
|
||||||
Device, MasterPubkey, OwnUserIdentity, ReadOnlyDevice, SelfSigningPubkey, UserDevices,
|
|
||||||
UserIdentities, UserIdentity, UserSigningPubkey,
|
|
||||||
},
|
|
||||||
key_request::KeyRequestMachine,
|
key_request::KeyRequestMachine,
|
||||||
olm::{
|
olm::{
|
||||||
Account, EncryptionSettings, ExportedRoomKey, GroupSessionKey, IdentityKeys,
|
Account, EncryptionSettings, ExportedRoomKey, GroupSessionKey, IdentityKeys,
|
||||||
InboundGroupSession, OlmMessage, OutboundGroupSession,
|
InboundGroupSession, OlmMessage, OutboundGroupSession,
|
||||||
},
|
},
|
||||||
requests::{IncomingResponse, KeysQueryRequest, OutgoingRequest, ToDeviceRequest},
|
requests::{IncomingResponse, OutgoingRequest, ToDeviceRequest},
|
||||||
store::{CryptoStore, MemoryStore, Result as StoreResult, Store},
|
store::{CryptoStore, MemoryStore, Result as StoreResult, Store},
|
||||||
verification::{Sas, VerificationMachine},
|
verification::{Sas, VerificationMachine},
|
||||||
};
|
};
|
||||||
|
@ -90,6 +80,9 @@ pub struct OlmMachine {
|
||||||
/// The state machine that is responsible to handle outgoing and incoming
|
/// The state machine that is responsible to handle outgoing and incoming
|
||||||
/// key requests.
|
/// key requests.
|
||||||
key_request_machine: KeyRequestMachine,
|
key_request_machine: KeyRequestMachine,
|
||||||
|
/// State machine handling public user identities and devices, keeping track
|
||||||
|
/// of when a key query needs to be done and handling one.
|
||||||
|
identity_manager: IdentityManager,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(not(tarpaulin_include))]
|
#[cfg(not(tarpaulin_include))]
|
||||||
|
@ -118,19 +111,36 @@ impl OlmMachine {
|
||||||
/// * `device_id` - The unique id of the device that owns this machine.
|
/// * `device_id` - The unique id of the device that owns this machine.
|
||||||
pub fn new(user_id: &UserId, device_id: &DeviceId) -> Self {
|
pub fn new(user_id: &UserId, device_id: &DeviceId) -> Self {
|
||||||
let store: Box<dyn CryptoStore> = Box::new(MemoryStore::new());
|
let store: Box<dyn CryptoStore> = Box::new(MemoryStore::new());
|
||||||
|
let device_id: DeviceIdBox = device_id.into();
|
||||||
|
let account = Account::new(&user_id, &device_id);
|
||||||
|
|
||||||
|
OlmMachine::new_helper(user_id, device_id, store, account)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn new_helper(
|
||||||
|
user_id: &UserId,
|
||||||
|
device_id: DeviceIdBox,
|
||||||
|
store: Box<dyn CryptoStore>,
|
||||||
|
account: Account,
|
||||||
|
) -> Self {
|
||||||
let store = Store::new(store);
|
let store = Store::new(store);
|
||||||
let account = Account::new(user_id, device_id);
|
let verification_machine = VerificationMachine::new(account.clone(), store.clone());
|
||||||
let user_id = Arc::new(user_id.clone());
|
let user_id = Arc::new(user_id.clone());
|
||||||
let device_id: Arc<DeviceIdBox> = Arc::new(device_id.into());
|
let device_id: Arc<DeviceIdBox> = Arc::new(device_id);
|
||||||
|
let key_request_machine =
|
||||||
|
KeyRequestMachine::new(user_id.clone(), device_id.clone(), store.clone());
|
||||||
|
let identity_manager =
|
||||||
|
IdentityManager::new(user_id.clone(), device_id.clone(), store.clone());
|
||||||
|
|
||||||
OlmMachine {
|
OlmMachine {
|
||||||
user_id: user_id.clone(),
|
user_id,
|
||||||
device_id: device_id.clone(),
|
device_id,
|
||||||
account: account.clone(),
|
account,
|
||||||
store: store.clone(),
|
store,
|
||||||
outbound_group_sessions: Arc::new(DashMap::new()),
|
outbound_group_sessions: Arc::new(DashMap::new()),
|
||||||
verification_machine: VerificationMachine::new(account, store.clone()),
|
verification_machine,
|
||||||
key_request_machine: KeyRequestMachine::new(user_id, device_id, store),
|
key_request_machine,
|
||||||
|
identity_manager,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -155,7 +165,7 @@ impl OlmMachine {
|
||||||
/// [`Cryptostore`]: trait.CryptoStore.html
|
/// [`Cryptostore`]: trait.CryptoStore.html
|
||||||
pub async fn new_with_store(
|
pub async fn new_with_store(
|
||||||
user_id: UserId,
|
user_id: UserId,
|
||||||
device_id: Box<DeviceId>,
|
device_id: DeviceIdBox,
|
||||||
store: Box<dyn CryptoStore>,
|
store: Box<dyn CryptoStore>,
|
||||||
) -> StoreResult<Self> {
|
) -> StoreResult<Self> {
|
||||||
let account = match store.load_account().await? {
|
let account = match store.load_account().await? {
|
||||||
|
@ -169,22 +179,7 @@ impl OlmMachine {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let store = Store::new(store);
|
Ok(OlmMachine::new_helper(&user_id, device_id, store, account))
|
||||||
let verification_machine = VerificationMachine::new(account.clone(), store.clone());
|
|
||||||
let user_id = Arc::new(user_id.clone());
|
|
||||||
let device_id: Arc<DeviceIdBox> = Arc::new(device_id);
|
|
||||||
let key_request_machine =
|
|
||||||
KeyRequestMachine::new(user_id.clone(), device_id.clone(), store.clone());
|
|
||||||
|
|
||||||
Ok(OlmMachine {
|
|
||||||
user_id,
|
|
||||||
device_id,
|
|
||||||
account,
|
|
||||||
store,
|
|
||||||
outbound_group_sessions: Arc::new(DashMap::new()),
|
|
||||||
verification_machine,
|
|
||||||
key_request_machine,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create a new machine with the default crypto store.
|
/// Create a new machine with the default crypto store.
|
||||||
|
@ -243,10 +238,15 @@ impl OlmMachine {
|
||||||
requests.push(r);
|
requests.push(r);
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(r) = self.users_for_key_query().await.map(|r| OutgoingRequest {
|
if let Some(r) =
|
||||||
request_id: Uuid::new_v4(),
|
self.identity_manager
|
||||||
request: Arc::new(r.into()),
|
.users_for_key_query()
|
||||||
}) {
|
.await
|
||||||
|
.map(|r| OutgoingRequest {
|
||||||
|
request_id: Uuid::new_v4(),
|
||||||
|
request: Arc::new(r.into()),
|
||||||
|
})
|
||||||
|
{
|
||||||
requests.push(r);
|
requests.push(r);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -491,195 +491,6 @@ impl OlmMachine {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Handle the device keys part of a key query response.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `device_keys_map` - A map holding the device keys of the users for
|
|
||||||
/// which the key query was done.
|
|
||||||
///
|
|
||||||
/// Returns a list of devices that changed. Changed here means either
|
|
||||||
/// they are new, one of their properties has changed or they got deleted.
|
|
||||||
async fn handle_devices_from_key_query(
|
|
||||||
&self,
|
|
||||||
device_keys_map: &BTreeMap<UserId, BTreeMap<Box<DeviceId>, DeviceKeys>>,
|
|
||||||
) -> StoreResult<Vec<ReadOnlyDevice>> {
|
|
||||||
let mut changed_devices = Vec::new();
|
|
||||||
|
|
||||||
for (user_id, device_map) in device_keys_map {
|
|
||||||
// TODO move this out into the handle keys query response method
|
|
||||||
// since we might fail handle the new device at any point here or
|
|
||||||
// when updating the user identities.
|
|
||||||
self.store.update_tracked_user(user_id, false).await?;
|
|
||||||
|
|
||||||
for (device_id, device_keys) in device_map.iter() {
|
|
||||||
// We don't need our own device in the device store.
|
|
||||||
if user_id == self.user_id() && &**device_id == self.device_id() {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if user_id != &device_keys.user_id || device_id != &device_keys.device_id {
|
|
||||||
warn!(
|
|
||||||
"Mismatch in device keys payload of device {} from user {}",
|
|
||||||
device_keys.device_id, device_keys.user_id
|
|
||||||
);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
let device = self.store.get_device(&user_id, device_id).await?;
|
|
||||||
|
|
||||||
let device = if let Some(mut device) = device {
|
|
||||||
if let Err(e) = device.update_device(device_keys) {
|
|
||||||
warn!(
|
|
||||||
"Failed to update the device keys for {} {}: {:?}",
|
|
||||||
user_id, device_id, e
|
|
||||||
);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
device
|
|
||||||
} else {
|
|
||||||
let device = match ReadOnlyDevice::try_from(device_keys) {
|
|
||||||
Ok(d) => d,
|
|
||||||
Err(e) => {
|
|
||||||
warn!(
|
|
||||||
"Failed to create a new device for {} {}: {:?}",
|
|
||||||
user_id, device_id, e
|
|
||||||
);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
info!("Adding a new device to the device store {:?}", device);
|
|
||||||
device
|
|
||||||
};
|
|
||||||
|
|
||||||
changed_devices.push(device);
|
|
||||||
}
|
|
||||||
|
|
||||||
let current_devices: HashSet<&DeviceId> =
|
|
||||||
device_map.keys().map(|id| id.as_ref()).collect();
|
|
||||||
let stored_devices = self.store.get_user_devices(&user_id).await.unwrap();
|
|
||||||
let stored_devices_set: HashSet<&DeviceId> = stored_devices.keys().collect();
|
|
||||||
|
|
||||||
let deleted_devices = stored_devices_set.difference(¤t_devices);
|
|
||||||
|
|
||||||
for device_id in deleted_devices {
|
|
||||||
if let Some(device) = stored_devices.get(device_id) {
|
|
||||||
device.mark_as_deleted();
|
|
||||||
self.store.delete_device(device).await?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(changed_devices)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Handle the device keys part of a key query response.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `response` - The keys query response.
|
|
||||||
///
|
|
||||||
/// Returns a list of identities that changed. Changed here means either
|
|
||||||
/// they are new, one of their properties has changed or they got deleted.
|
|
||||||
async fn handle_cross_singing_keys(
|
|
||||||
&self,
|
|
||||||
response: &KeysQueryResponse,
|
|
||||||
) -> StoreResult<Vec<UserIdentities>> {
|
|
||||||
let mut changed = Vec::new();
|
|
||||||
|
|
||||||
for (user_id, master_key) in &response.master_keys {
|
|
||||||
let master_key = MasterPubkey::from(master_key);
|
|
||||||
|
|
||||||
let self_signing = if let Some(s) = response.self_signing_keys.get(user_id) {
|
|
||||||
SelfSigningPubkey::from(s)
|
|
||||||
} else {
|
|
||||||
warn!(
|
|
||||||
"User identity for user {} didn't contain a self signing pubkey",
|
|
||||||
user_id
|
|
||||||
);
|
|
||||||
continue;
|
|
||||||
};
|
|
||||||
|
|
||||||
let identity = if let Some(mut i) = self.store.get_user_identity(user_id).await? {
|
|
||||||
match &mut i {
|
|
||||||
UserIdentities::Own(ref mut identity) => {
|
|
||||||
let user_signing = if let Some(s) = response.user_signing_keys.get(user_id)
|
|
||||||
{
|
|
||||||
UserSigningPubkey::from(s)
|
|
||||||
} else {
|
|
||||||
warn!(
|
|
||||||
"User identity for our own user {} didn't \
|
|
||||||
contain a user signing pubkey",
|
|
||||||
user_id
|
|
||||||
);
|
|
||||||
continue;
|
|
||||||
};
|
|
||||||
|
|
||||||
identity
|
|
||||||
.update(master_key, self_signing, user_signing)
|
|
||||||
.map(|_| i)
|
|
||||||
}
|
|
||||||
UserIdentities::Other(ref mut identity) => {
|
|
||||||
identity.update(master_key, self_signing).map(|_| i)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if user_id == self.user_id() {
|
|
||||||
if let Some(s) = response.user_signing_keys.get(user_id) {
|
|
||||||
let user_signing = UserSigningPubkey::from(s);
|
|
||||||
|
|
||||||
if master_key.user_id() != user_id
|
|
||||||
|| self_signing.user_id() != user_id
|
|
||||||
|| user_signing.user_id() != user_id
|
|
||||||
{
|
|
||||||
warn!(
|
|
||||||
"User id mismatch in one of the cross signing keys for user {}",
|
|
||||||
user_id
|
|
||||||
);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
OwnUserIdentity::new(master_key, self_signing, user_signing)
|
|
||||||
.map(UserIdentities::Own)
|
|
||||||
} else {
|
|
||||||
warn!(
|
|
||||||
"User identity for our own user {} didn't contain a \
|
|
||||||
user signing pubkey",
|
|
||||||
user_id
|
|
||||||
);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
} else if master_key.user_id() != user_id || self_signing.user_id() != user_id {
|
|
||||||
warn!(
|
|
||||||
"User id mismatch in one of the cross signing keys for user {}",
|
|
||||||
user_id
|
|
||||||
);
|
|
||||||
continue;
|
|
||||||
} else {
|
|
||||||
UserIdentity::new(master_key, self_signing).map(UserIdentities::Other)
|
|
||||||
};
|
|
||||||
|
|
||||||
match identity {
|
|
||||||
Ok(i) => {
|
|
||||||
trace!(
|
|
||||||
"Updated or created new user identity for {}: {:?}",
|
|
||||||
user_id,
|
|
||||||
i
|
|
||||||
);
|
|
||||||
changed.push(i);
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
warn!(
|
|
||||||
"Couldn't update or create new user identity for {}: {:?}",
|
|
||||||
user_id, e
|
|
||||||
);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(changed)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Receive a successful keys query response.
|
/// Receive a successful keys query response.
|
||||||
///
|
///
|
||||||
/// Returns a list of devices newly discovered devices and devices that
|
/// Returns a list of devices newly discovered devices and devices that
|
||||||
|
@ -693,23 +504,9 @@ impl OlmMachine {
|
||||||
&self,
|
&self,
|
||||||
response: &KeysQueryResponse,
|
response: &KeysQueryResponse,
|
||||||
) -> OlmResult<(Vec<ReadOnlyDevice>, Vec<UserIdentities>)> {
|
) -> OlmResult<(Vec<ReadOnlyDevice>, Vec<UserIdentities>)> {
|
||||||
// TODO create a enum that tells us how the device/identity changed,
|
self.identity_manager
|
||||||
// e.g. new/deleted/display name change.
|
.receive_keys_query_response(response)
|
||||||
//
|
.await
|
||||||
// TODO create a struct that will hold the device/identity and the
|
|
||||||
// change enum and return the struct.
|
|
||||||
//
|
|
||||||
// TODO once outbound group sessions hold on to the set of users that
|
|
||||||
// received the session, invalidate the session if a user device
|
|
||||||
// got added/deleted.
|
|
||||||
let changed_devices = self
|
|
||||||
.handle_devices_from_key_query(&response.device_keys)
|
|
||||||
.await?;
|
|
||||||
self.store.save_devices(&changed_devices).await?;
|
|
||||||
let changed_identities = self.handle_cross_singing_keys(response).await?;
|
|
||||||
self.store.save_user_identities(&changed_identities).await?;
|
|
||||||
|
|
||||||
Ok((changed_devices, changed_identities))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get a request to upload E2EE keys to the server.
|
/// Get a request to upload E2EE keys to the server.
|
||||||
|
@ -1026,10 +823,11 @@ impl OlmMachine {
|
||||||
&self,
|
&self,
|
||||||
room_id: &RoomId,
|
room_id: &RoomId,
|
||||||
settings: EncryptionSettings,
|
settings: EncryptionSettings,
|
||||||
|
users_to_share_with: impl Iterator<Item = &UserId>,
|
||||||
) -> OlmResult<()> {
|
) -> OlmResult<()> {
|
||||||
let (outbound, inbound) = self
|
let (outbound, inbound) = self
|
||||||
.account
|
.account
|
||||||
.create_group_session_pair(room_id, settings)
|
.create_group_session_pair(room_id, settings, users_to_share_with)
|
||||||
.await
|
.await
|
||||||
.map_err(|_| EventError::UnsupportedAlgorithm)?;
|
.map_err(|_| EventError::UnsupportedAlgorithm)?;
|
||||||
|
|
||||||
|
@ -1042,11 +840,11 @@ impl OlmMachine {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
pub(crate) async fn create_outnbound_group_session_with_defaults(
|
pub(crate) async fn create_outbound_group_session_with_defaults(
|
||||||
&self,
|
&self,
|
||||||
room_id: &RoomId,
|
room_id: &RoomId,
|
||||||
) -> OlmResult<()> {
|
) -> OlmResult<()> {
|
||||||
self.create_outbound_group_session(room_id, EncryptionSettings::default())
|
self.create_outbound_group_session(room_id, EncryptionSettings::default(), [].iter())
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1143,7 +941,7 @@ impl OlmMachine {
|
||||||
users: impl Iterator<Item = &UserId>,
|
users: impl Iterator<Item = &UserId>,
|
||||||
encryption_settings: impl Into<EncryptionSettings>,
|
encryption_settings: impl Into<EncryptionSettings>,
|
||||||
) -> OlmResult<Vec<ToDeviceRequest>> {
|
) -> OlmResult<Vec<ToDeviceRequest>> {
|
||||||
self.create_outbound_group_session(room_id, encryption_settings.into())
|
self.create_outbound_group_session(room_id, encryption_settings.into(), users)
|
||||||
.await?;
|
.await?;
|
||||||
let session = self.outbound_group_sessions.get(room_id).unwrap();
|
let session = self.outbound_group_sessions.get(room_id).unwrap();
|
||||||
|
|
||||||
|
@ -1159,8 +957,8 @@ impl OlmMachine {
|
||||||
|
|
||||||
let mut devices = Vec::new();
|
let mut devices = Vec::new();
|
||||||
|
|
||||||
for user_id in users {
|
for user_id in session.users_to_share_with() {
|
||||||
for device in self.get_user_devices(user_id).await?.devices() {
|
for device in self.get_user_devices(&user_id).await?.devices() {
|
||||||
if !device.is_blacklisted() {
|
if !device.is_blacklisted() {
|
||||||
devices.push(device.clone());
|
devices.push(device.clone());
|
||||||
}
|
}
|
||||||
|
@ -1299,8 +1097,12 @@ impl OlmMachine {
|
||||||
let count: u64 = one_time_key_count.map_or(0, |c| (*c).into());
|
let count: u64 = one_time_key_count.map_or(0, |c| (*c).into());
|
||||||
self.update_key_count(count);
|
self.update_key_count(count);
|
||||||
|
|
||||||
|
if let Err(e) = self.store.save_account(self.account.clone()).await {
|
||||||
|
error!("Error updating the one-time key count {:?}", e);
|
||||||
|
}
|
||||||
|
|
||||||
for user_id in &response.device_lists.changed {
|
for user_id in &response.device_lists.changed {
|
||||||
if let Err(e) = self.mark_user_as_changed(&user_id).await {
|
if let Err(e) = self.identity_manager.mark_user_as_changed(&user_id).await {
|
||||||
error!("Error marking a tracked user as changed {:?}", e);
|
error!("Error marking a tracked user as changed {:?}", e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1389,23 +1191,6 @@ impl OlmMachine {
|
||||||
Ok(decrypted_event)
|
Ok(decrypted_event)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Mark that the given user has changed his devices.
|
|
||||||
///
|
|
||||||
/// This will queue up the given user for a key query.
|
|
||||||
///
|
|
||||||
/// Note: The user already needs to be tracked for it to be queued up for a
|
|
||||||
/// key query.
|
|
||||||
///
|
|
||||||
/// Returns true if the user was queued up for a key query, false otherwise.
|
|
||||||
async fn mark_user_as_changed(&self, user_id: &UserId) -> StoreResult<bool> {
|
|
||||||
if self.store.is_user_tracked(user_id) {
|
|
||||||
self.store.update_tracked_user(user_id, true).await?;
|
|
||||||
Ok(true)
|
|
||||||
} else {
|
|
||||||
Ok(false)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Update the tracked users.
|
/// Update the tracked users.
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
|
@ -1419,41 +1204,7 @@ impl OlmMachine {
|
||||||
/// If the user is already known to the Olm machine it will not be
|
/// If the user is already known to the Olm machine it will not be
|
||||||
/// considered for a key query.
|
/// considered for a key query.
|
||||||
pub async fn update_tracked_users(&self, users: impl IntoIterator<Item = &UserId>) {
|
pub async fn update_tracked_users(&self, users: impl IntoIterator<Item = &UserId>) {
|
||||||
for user in users {
|
self.identity_manager.update_tracked_users(users).await
|
||||||
if self.store.is_user_tracked(user) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Err(e) = self.store.update_tracked_user(user, true).await {
|
|
||||||
warn!("Error storing users for tracking {}", e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get a key query request if one is needed.
|
|
||||||
///
|
|
||||||
/// Returns a key query reqeust if the client should query E2E keys,
|
|
||||||
/// otherwise None.
|
|
||||||
///
|
|
||||||
/// The response of a successful key query requests needs to be passed to
|
|
||||||
/// the [`OlmMachine`] with the [`receive_keys_query_response`].
|
|
||||||
///
|
|
||||||
/// [`OlmMachine`]: struct.OlmMachine.html
|
|
||||||
/// [`receive_keys_query_response`]: #method.receive_keys_query_response
|
|
||||||
async fn users_for_key_query(&self) -> Option<KeysQueryRequest> {
|
|
||||||
let mut users = self.store.users_for_key_query();
|
|
||||||
|
|
||||||
if users.is_empty() {
|
|
||||||
None
|
|
||||||
} else {
|
|
||||||
let mut device_keys: BTreeMap<UserId, Vec<Box<DeviceId>>> = BTreeMap::new();
|
|
||||||
|
|
||||||
for user in users.drain() {
|
|
||||||
device_keys.insert(user, Vec::new());
|
|
||||||
}
|
|
||||||
|
|
||||||
Some(KeysQueryRequest::new(device_keys))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get a specific device of a user.
|
/// Get a specific device of a user.
|
||||||
|
@ -1928,7 +1679,7 @@ pub(crate) mod test {
|
||||||
let room_id = room_id!("!test:example.org");
|
let room_id = room_id!("!test:example.org");
|
||||||
|
|
||||||
machine
|
machine
|
||||||
.create_outnbound_group_session_with_defaults(&room_id)
|
.create_outbound_group_session_with_defaults(&room_id)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert!(machine.outbound_group_sessions.get(&room_id).is_some());
|
assert!(machine.outbound_group_sessions.get(&room_id).is_some());
|
||||||
|
|
|
@ -580,6 +580,7 @@ impl Account {
|
||||||
&self,
|
&self,
|
||||||
room_id: &RoomId,
|
room_id: &RoomId,
|
||||||
settings: EncryptionSettings,
|
settings: EncryptionSettings,
|
||||||
|
users_to_share_with: impl Iterator<Item = &UserId>,
|
||||||
) -> Result<(OutboundGroupSession, InboundGroupSession), ()> {
|
) -> Result<(OutboundGroupSession, InboundGroupSession), ()> {
|
||||||
if settings.algorithm != EventEncryptionAlgorithm::MegolmV1AesSha2 {
|
if settings.algorithm != EventEncryptionAlgorithm::MegolmV1AesSha2 {
|
||||||
return Err(());
|
return Err(());
|
||||||
|
@ -590,6 +591,7 @@ impl Account {
|
||||||
self.identity_keys.clone(),
|
self.identity_keys.clone(),
|
||||||
room_id,
|
room_id,
|
||||||
settings,
|
settings,
|
||||||
|
users_to_share_with,
|
||||||
);
|
);
|
||||||
let identity_keys = self.identity_keys();
|
let identity_keys = self.identity_keys();
|
||||||
|
|
||||||
|
@ -606,6 +608,15 @@ impl Account {
|
||||||
|
|
||||||
Ok((outbound, inbound))
|
Ok((outbound, inbound))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
pub(crate) async fn create_group_session_pair_with_defaults(
|
||||||
|
&self,
|
||||||
|
room_id: &RoomId,
|
||||||
|
) -> Result<(OutboundGroupSession, InboundGroupSession), ()> {
|
||||||
|
self.create_group_session_pair(room_id, EncryptionSettings::default(), [].iter())
|
||||||
|
.await
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl PartialEq for Account {
|
impl PartialEq for Account {
|
||||||
|
|
|
@ -147,7 +147,7 @@ mod test {
|
||||||
|
|
||||||
let account = Account::new(&user_id!("@alice:example.org"), "DEVICEID".into());
|
let account = Account::new(&user_id!("@alice:example.org"), "DEVICEID".into());
|
||||||
let (session, _) = account
|
let (session, _) = account
|
||||||
.create_group_session_pair(&room_id!("!test_room:example.org"), settings)
|
.create_group_session_pair(&room_id!("!test_room:example.org"), settings, [].iter())
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
@ -165,7 +165,7 @@ mod test {
|
||||||
};
|
};
|
||||||
|
|
||||||
let (mut session, _) = account
|
let (mut session, _) = account
|
||||||
.create_group_session_pair(&room_id!("!test_room:example.org"), settings)
|
.create_group_session_pair(&room_id!("!test_room:example.org"), settings, [].iter())
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
use dashmap::{setref::multiple::RefMulti, DashSet};
|
||||||
use std::{
|
use std::{
|
||||||
cmp::min,
|
cmp::min,
|
||||||
fmt,
|
fmt,
|
||||||
|
@ -27,7 +28,7 @@ use matrix_sdk_common::{
|
||||||
room::{encrypted::EncryptedEventContent, encryption::EncryptionEventContent},
|
room::{encrypted::EncryptedEventContent, encryption::EncryptionEventContent},
|
||||||
AnyMessageEventContent, EventContent,
|
AnyMessageEventContent, EventContent,
|
||||||
},
|
},
|
||||||
identifiers::{DeviceId, EventEncryptionAlgorithm, RoomId},
|
identifiers::{DeviceIdBox, EventEncryptionAlgorithm, RoomId, UserId},
|
||||||
instant::Instant,
|
instant::Instant,
|
||||||
locks::Mutex,
|
locks::Mutex,
|
||||||
};
|
};
|
||||||
|
@ -92,7 +93,7 @@ impl From<&EncryptionEventContent> for EncryptionSettings {
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct OutboundGroupSession {
|
pub struct OutboundGroupSession {
|
||||||
inner: Arc<Mutex<OlmOutboundGroupSession>>,
|
inner: Arc<Mutex<OlmOutboundGroupSession>>,
|
||||||
device_id: Arc<Box<DeviceId>>,
|
device_id: Arc<DeviceIdBox>,
|
||||||
account_identity_keys: Arc<IdentityKeys>,
|
account_identity_keys: Arc<IdentityKeys>,
|
||||||
session_id: Arc<String>,
|
session_id: Arc<String>,
|
||||||
room_id: Arc<RoomId>,
|
room_id: Arc<RoomId>,
|
||||||
|
@ -100,6 +101,8 @@ pub struct OutboundGroupSession {
|
||||||
message_count: Arc<AtomicU64>,
|
message_count: Arc<AtomicU64>,
|
||||||
shared: Arc<AtomicBool>,
|
shared: Arc<AtomicBool>,
|
||||||
settings: Arc<EncryptionSettings>,
|
settings: Arc<EncryptionSettings>,
|
||||||
|
shared_with_set: Arc<DashSet<UserId>>,
|
||||||
|
to_share_with_set: Arc<DashSet<UserId>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl OutboundGroupSession {
|
impl OutboundGroupSession {
|
||||||
|
@ -118,15 +121,18 @@ impl OutboundGroupSession {
|
||||||
///
|
///
|
||||||
/// * `settings` - Settings determining the algorithm and rotation period of
|
/// * `settings` - Settings determining the algorithm and rotation period of
|
||||||
/// the outbound group session.
|
/// the outbound group session.
|
||||||
pub fn new(
|
pub fn new<'a>(
|
||||||
device_id: Arc<Box<DeviceId>>,
|
device_id: Arc<DeviceIdBox>,
|
||||||
identity_keys: Arc<IdentityKeys>,
|
identity_keys: Arc<IdentityKeys>,
|
||||||
room_id: &RoomId,
|
room_id: &RoomId,
|
||||||
settings: EncryptionSettings,
|
settings: EncryptionSettings,
|
||||||
|
users_to_share_with: impl Iterator<Item = &'a UserId>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let session = OlmOutboundGroupSession::new();
|
let session = OlmOutboundGroupSession::new();
|
||||||
let session_id = session.session_id();
|
let session_id = session.session_id();
|
||||||
|
|
||||||
|
let users_to_share_with = users_to_share_with.cloned().collect();
|
||||||
|
|
||||||
OutboundGroupSession {
|
OutboundGroupSession {
|
||||||
inner: Arc::new(Mutex::new(session)),
|
inner: Arc::new(Mutex::new(session)),
|
||||||
room_id: Arc::new(room_id.to_owned()),
|
room_id: Arc::new(room_id.to_owned()),
|
||||||
|
@ -137,9 +143,15 @@ impl OutboundGroupSession {
|
||||||
message_count: Arc::new(AtomicU64::new(0)),
|
message_count: Arc::new(AtomicU64::new(0)),
|
||||||
shared: Arc::new(AtomicBool::new(false)),
|
shared: Arc::new(AtomicBool::new(false)),
|
||||||
settings: Arc::new(settings),
|
settings: Arc::new(settings),
|
||||||
|
shared_with_set: Arc::new(DashSet::new()),
|
||||||
|
to_share_with_set: Arc::new(users_to_share_with),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn users_to_share_with(&self) -> impl Iterator<Item = RefMulti<'_, UserId>> + '_ {
|
||||||
|
self.to_share_with_set.iter()
|
||||||
|
}
|
||||||
|
|
||||||
/// Encrypt the given plaintext using this session.
|
/// Encrypt the given plaintext using this session.
|
||||||
///
|
///
|
||||||
/// Returns the encrypted ciphertext.
|
/// Returns the encrypted ciphertext.
|
||||||
|
|
|
@ -194,7 +194,7 @@ pub(crate) mod test {
|
||||||
let room_id = room_id!("!test:localhost");
|
let room_id = room_id!("!test:localhost");
|
||||||
|
|
||||||
let (outbound, _) = alice
|
let (outbound, _) = alice
|
||||||
.create_group_session_pair(&room_id, Default::default())
|
.create_group_session_pair_with_defaults(&room_id)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
@ -230,7 +230,7 @@ pub(crate) mod test {
|
||||||
let room_id = room_id!("!test:localhost");
|
let room_id = room_id!("!test:localhost");
|
||||||
|
|
||||||
let (_, inbound) = alice
|
let (_, inbound) = alice
|
||||||
.create_group_session_pair(&room_id, Default::default())
|
.create_group_session_pair_with_defaults(&room_id)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
|
|
@ -265,7 +265,7 @@ mod test {
|
||||||
let room_id = room_id!("!test:localhost");
|
let room_id = room_id!("!test:localhost");
|
||||||
|
|
||||||
let (outbound, _) = account
|
let (outbound, _) = account
|
||||||
.create_group_session_pair(&room_id, Default::default())
|
.create_group_session_pair_with_defaults(&room_id)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
|
|
@ -219,7 +219,7 @@ mod test {
|
||||||
let room_id = room_id!("!test:localhost");
|
let room_id = room_id!("!test:localhost");
|
||||||
|
|
||||||
let (outbound, _) = account
|
let (outbound, _) = account
|
||||||
.create_group_session_pair(&room_id, Default::default())
|
.create_group_session_pair_with_defaults(&room_id)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let inbound = InboundGroupSession::new(
|
let inbound = InboundGroupSession::new(
|
||||||
|
|
Loading…
Reference in New Issue