Merge branch 'multithreaded-crypto'

master
Damir Jelić 2021-03-23 11:34:07 +01:00
commit 15d5b234ed
16 changed files with 40635 additions and 186 deletions

View File

@ -10,15 +10,15 @@ edition = "2018"
crate-type = ["cdylib"]
[dependencies]
url = "2.1.1"
wasm-bindgen = { version = "0.2.62", features = ["serde-serialize"] }
wasm-bindgen-futures = "0.4.12"
console_error_panic_hook = "*"
web-sys = { version = "0.3.39", features = ["console"] }
url = "2.2.1"
wasm-bindgen = { version = "0.2.72", features = ["serde-serialize"] }
wasm-bindgen-futures = "0.4.22"
console_error_panic_hook = "0.1.6"
web-sys = { version = "0.3.49", features = ["console"] }
[dependencies.matrix-sdk]
path = "../.."
default-features = false
features = ["native-tls"]
features = ["native-tls", "encryption"]
[workspace]

View File

@ -7,6 +7,4 @@ You can build the example locally with:
and then visiting http://localhost:8080 in a browser should run the example!
Note: Encryption isn't supported yet
This example is loosely based off of [this example](https://github.com/seanmonstar/reqwest/tree/master/examples/wasm_github_fetch), an example usage of `fetch` from `wasm-bindgen`.

View File

@ -5,8 +5,8 @@
},
"devDependencies": {
"@wasm-tool/wasm-pack-plugin": "1.0.1",
"text-encoding": "^0.7.0",
"html-webpack-plugin": "^3.2.0",
"text-encoding": "^0.7.0",
"webpack": "^4.29.4",
"webpack-cli": "^3.1.1",
"webpack-dev-server": "^3.1.0"

View File

@ -1,7 +1,7 @@
use matrix_sdk::{
deserialized_responses::SyncResponse,
events::{
room::message::{MessageEventContent, TextMessageEventContent},
room::message::{MessageEventContent, MessageType, TextMessageEventContent},
AnyMessageEventContent, AnySyncMessageEvent, AnySyncRoomEvent, SyncMessageEvent,
},
identifiers::RoomId,
@ -17,35 +17,49 @@ impl WasmBot {
async fn on_room_message(
&self,
room_id: &RoomId,
event: SyncMessageEvent<MessageEventContent>,
event: &SyncMessageEvent<MessageEventContent>,
) {
let msg_body = if let SyncMessageEvent {
content: MessageEventContent::Text(TextMessageEventContent { body: msg_body, .. }),
content:
MessageEventContent {
msgtype: MessageType::Text(TextMessageEventContent { body: msg_body, .. }),
..
},
..
} = event
{
msg_body.clone()
msg_body
} else {
return;
};
console::log_1(&format!("Received message event {:?}", &msg_body).into());
if msg_body.starts_with("!party") {
let content = AnyMessageEventContent::RoomMessage(MessageEventContent::Text(
TextMessageEventContent::plain("🎉🎊🥳 let's PARTY with wasm!! 🥳🎊🎉".to_string()),
if msg_body.contains("!party") {
let content = AnyMessageEventContent::RoomMessage(MessageEventContent::text_plain(
"🎉🎊🥳 let's PARTY!! 🥳🎊🎉",
));
self.0.room_send(&room_id, content, None).await.unwrap();
println!("sending");
self.0
// send our message to the room we found the "!party" command in
// the last parameter is an optional Uuid which we don't care about.
.room_send(room_id, content, None)
.await
.unwrap();
println!("message sent");
}
}
async fn on_sync_response(&self, response: SyncResponse) -> LoopCtrl {
console::log_1(&"Synced".to_string().into());
for (room_id, room) in response.rooms.join {
for event in room.timeline.events {
if let AnySyncRoomEvent::Message(AnySyncMessageEvent::RoomMessage(ev)) = event {
self.on_room_message(&room_id, ev).await
self.on_room_message(&room_id, &ev).await
}
}
}

View File

@ -34,5 +34,7 @@ default-features = false
features = ["sync"]
[target.'cfg(target_arch = "wasm32")'.dependencies]
futures = "0.3.12"
futures-locks = { version = "0.6.0", default-features = false }
wasm-bindgen-futures = "0.4"
uuid = { version = "0.8.2", default-features = false, features = ["v4", "wasm-bindgen"] }

View File

@ -0,0 +1,42 @@
//! Abstraction over an executor so we can spawn tasks under WASM the same way
//! we do usually.
#[cfg(target_arch = "wasm32")]
use std::{
pin::Pin,
task::{Context, Poll},
};
#[cfg(not(target_arch = "wasm32"))]
pub use tokio::spawn;
#[cfg(target_arch = "wasm32")]
use wasm_bindgen_futures::spawn_local;
#[cfg(target_arch = "wasm32")]
use futures::{future::RemoteHandle, Future, FutureExt};
#[cfg(target_arch = "wasm32")]
pub fn spawn<F, T>(future: F) -> JoinHandle<T>
where
F: Future<Output = T> + 'static,
{
let fut = future.unit_error();
let (fut, handle) = fut.remote_handle();
spawn_local(fut);
JoinHandle { handle }
}
#[cfg(target_arch = "wasm32")]
pub struct JoinHandle<T> {
handle: RemoteHandle<Result<T, ()>>,
}
#[cfg(target_arch = "wasm32")]
impl<T: 'static> Future for JoinHandle<T> {
type Output = Result<T, ()>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.handle).poll(cx)
}
}

View File

@ -14,6 +14,7 @@ pub use ruma::{
pub use uuid;
pub mod deserialized_responses;
pub mod executor;
pub mod locks;
/// Super trait that is used for our store traits, this trait will differ if

View File

@ -29,6 +29,7 @@ serde_json = "1.0.61"
zeroize = { version = "1.2.0", features = ["zeroize_derive"] }
# Misc dependencies
futures = "0.3.12"
sled = { version = "0.34.6", optional = true }
thiserror = "1.0.23"
tracing = "0.1.22"
@ -44,14 +45,13 @@ byteorder = "1.4.2"
[dev-dependencies]
tokio = { version = "1.1.0", default-features = false, features = ["rt-multi-thread", "macros"] }
futures = "0.3.12"
proptest = "0.10.1"
serde_json = "1.0.61"
tempfile = "3.2.0"
http = "0.2.3"
matrix-sdk-test = { version = "0.2.0", path = "../matrix_sdk_test" }
indoc = "1.0.3"
criterion = { version = "0.3.4", features = ["async", "async_futures", "html_reports"] }
criterion = { version = "0.3.4", features = ["async", "async_tokio", "html_reports"] }
[target.'cfg(target_os = "linux")'.dev-dependencies]
pprof = { version = "0.4.2", features = ["flamegraph"] }

View File

@ -1,11 +1,10 @@
#[cfg(target_os = "linux")]
mod perf;
use std::convert::TryFrom;
use std::{convert::TryFrom, sync::Arc};
use criterion::{async_executor::FuturesExecutor, *};
use criterion::*;
use futures::executor::block_on;
use matrix_sdk_common::{
api::r0::{
keys::{claim_keys, get_keys},
@ -17,6 +16,7 @@ use matrix_sdk_common::{
use matrix_sdk_crypto::{EncryptionSettings, OlmMachine};
use matrix_sdk_test::response_from_file;
use serde_json::Value;
use tokio::runtime::Builder;
fn alice_id() -> UserId {
user_id!("@alice:example.org")
@ -40,7 +40,17 @@ fn keys_claim_response() -> claim_keys::Response {
claim_keys::Response::try_from(data).expect("Can't parse the keys upload response")
}
fn huge_keys_query_resopnse() -> get_keys::Response {
let data = include_bytes!("./keys_query_2000_members.json");
let data: Value = serde_json::from_slice(data).unwrap();
let data = response_from_file(&data);
get_keys::Response::try_from(data).expect("Can't parse the keys query response")
}
pub fn keys_query(c: &mut Criterion) {
let runtime = Builder::new_multi_thread()
.build()
.expect("Can't create runtime");
let machine = OlmMachine::new(&alice_id(), &alice_device_id());
let response = keys_query_response();
let uuid = Uuid::new_v4();
@ -62,25 +72,26 @@ pub fn keys_query(c: &mut Criterion) {
BenchmarkId::new("memory store", &name),
&response,
|b, response| {
b.to_async(FuturesExecutor)
b.to_async(&runtime)
.iter(|| async { machine.mark_request_as_sent(&uuid, response).await.unwrap() })
},
);
let dir = tempfile::tempdir().unwrap();
let machine = block_on(OlmMachine::new_with_default_store(
&alice_id(),
&alice_device_id(),
dir.path(),
None,
))
.unwrap();
let machine = runtime
.block_on(OlmMachine::new_with_default_store(
&alice_id(),
&alice_device_id(),
dir.path(),
None,
))
.unwrap();
group.bench_with_input(
BenchmarkId::new("sled store", &name),
&response,
|b, response| {
b.to_async(FuturesExecutor)
b.to_async(&runtime)
.iter(|| async { machine.mark_request_as_sent(&uuid, response).await.unwrap() })
},
);
@ -89,6 +100,12 @@ pub fn keys_query(c: &mut Criterion) {
}
pub fn keys_claiming(c: &mut Criterion) {
let runtime = Arc::new(
Builder::new_multi_thread()
.build()
.expect("Can't create runtime"),
);
let keys_query_response = keys_query_response();
let uuid = Uuid::new_v4();
@ -111,10 +128,16 @@ pub fn keys_claiming(c: &mut Criterion) {
b.iter_batched(
|| {
let machine = OlmMachine::new(&alice_id(), &alice_device_id());
block_on(machine.mark_request_as_sent(&uuid, &keys_query_response)).unwrap();
machine
runtime
.block_on(machine.mark_request_as_sent(&uuid, &keys_query_response))
.unwrap();
(machine, runtime.clone())
},
move |(machine, runtime)| {
runtime
.block_on(machine.mark_request_as_sent(&uuid, response))
.unwrap()
},
move |machine| block_on(machine.mark_request_as_sent(&uuid, response)).unwrap(),
BatchSize::SmallInput,
)
},
@ -127,17 +150,24 @@ pub fn keys_claiming(c: &mut Criterion) {
b.iter_batched(
|| {
let dir = tempfile::tempdir().unwrap();
let machine = block_on(OlmMachine::new_with_default_store(
&alice_id(),
&alice_device_id(),
dir.path(),
None,
))
.unwrap();
block_on(machine.mark_request_as_sent(&uuid, &keys_query_response)).unwrap();
machine
let machine = runtime
.block_on(OlmMachine::new_with_default_store(
&alice_id(),
&alice_device_id(),
dir.path(),
None,
))
.unwrap();
runtime
.block_on(machine.mark_request_as_sent(&uuid, &keys_query_response))
.unwrap();
(machine, runtime.clone())
},
move |(machine, runtime)| {
runtime
.block_on(machine.mark_request_as_sent(&uuid, response))
.unwrap()
},
move |machine| block_on(machine.mark_request_as_sent(&uuid, response)).unwrap(),
BatchSize::SmallInput,
)
},
@ -147,6 +177,10 @@ pub fn keys_claiming(c: &mut Criterion) {
}
pub fn room_key_sharing(c: &mut Criterion) {
let runtime = Builder::new_multi_thread()
.build()
.expect("Can't create runtime");
let keys_query_response = keys_query_response();
let uuid = Uuid::new_v4();
let response = keys_claim_response();
@ -161,15 +195,19 @@ pub fn room_key_sharing(c: &mut Criterion) {
.fold(0, |acc, d| acc + d.len());
let machine = OlmMachine::new(&alice_id(), &alice_device_id());
block_on(machine.mark_request_as_sent(&uuid, &keys_query_response)).unwrap();
block_on(machine.mark_request_as_sent(&uuid, &response)).unwrap();
runtime
.block_on(machine.mark_request_as_sent(&uuid, &keys_query_response))
.unwrap();
runtime
.block_on(machine.mark_request_as_sent(&uuid, &response))
.unwrap();
let mut group = c.benchmark_group("Room key sharing");
group.throughput(Throughput::Elements(count as u64));
let name = format!("{} devices", count);
group.bench_function(BenchmarkId::new("memory store", &name), |b| {
b.to_async(FuturesExecutor).iter(|| async {
b.to_async(&runtime).iter(|| async {
let requests = machine
.share_group_session(&room_id, users.iter(), EncryptionSettings::default())
.await
@ -189,18 +227,23 @@ pub fn room_key_sharing(c: &mut Criterion) {
});
let dir = tempfile::tempdir().unwrap();
let machine = block_on(OlmMachine::new_with_default_store(
&alice_id(),
&alice_device_id(),
dir.path(),
None,
))
.unwrap();
block_on(machine.mark_request_as_sent(&uuid, &keys_query_response)).unwrap();
block_on(machine.mark_request_as_sent(&uuid, &response)).unwrap();
let machine = runtime
.block_on(OlmMachine::new_with_default_store(
&alice_id(),
&alice_device_id(),
dir.path(),
None,
))
.unwrap();
runtime
.block_on(machine.mark_request_as_sent(&uuid, &keys_query_response))
.unwrap();
runtime
.block_on(machine.mark_request_as_sent(&uuid, &response))
.unwrap();
group.bench_function(BenchmarkId::new("sled store", &name), |b| {
b.to_async(FuturesExecutor).iter(|| async {
b.to_async(&runtime).iter(|| async {
let requests = machine
.share_group_session(&room_id, users.iter(), EncryptionSettings::default())
.await
@ -222,6 +265,58 @@ pub fn room_key_sharing(c: &mut Criterion) {
group.finish()
}
pub fn devices_missing_sessions_collecting(c: &mut Criterion) {
let runtime = Builder::new_multi_thread()
.build()
.expect("Can't create runtime");
let machine = OlmMachine::new(&alice_id(), &alice_device_id());
let response = huge_keys_query_resopnse();
let uuid = Uuid::new_v4();
let users: Vec<UserId> = response.device_keys.keys().cloned().collect();
let count = response
.device_keys
.values()
.fold(0, |acc, d| acc + d.len());
let mut group = c.benchmark_group("Devices missing sessions collecting");
group.throughput(Throughput::Elements(count as u64));
let name = format!("{} devices", count);
runtime
.block_on(machine.mark_request_as_sent(&uuid, &response))
.unwrap();
group.bench_function(BenchmarkId::new("memory store", &name), |b| {
b.to_async(&runtime).iter_with_large_drop(|| async {
machine.get_missing_sessions(users.iter()).await.unwrap()
})
});
let dir = tempfile::tempdir().unwrap();
let machine = runtime
.block_on(OlmMachine::new_with_default_store(
&alice_id(),
&alice_device_id(),
dir.path(),
None,
))
.unwrap();
runtime
.block_on(machine.mark_request_as_sent(&uuid, &response))
.unwrap();
group.bench_function(BenchmarkId::new("sled store", &name), |b| {
b.to_async(&runtime)
.iter(|| async { machine.get_missing_sessions(users.iter()).await.unwrap() })
});
group.finish()
}
fn criterion() -> Criterion {
#[cfg(target_os = "linux")]
let criterion = Criterion::default().with_profiler(perf::FlamegraphProfiler::new(100));
@ -234,6 +329,6 @@ fn criterion() -> Criterion {
criterion_group! {
name = benches;
config = criterion();
targets = keys_query, keys_claiming, room_key_sharing
targets = keys_query, keys_claiming, room_key_sharing, devices_missing_sessions_collecting,
}
criterion_main!(benches);

File diff suppressed because it is too large Load Diff

View File

@ -12,17 +12,19 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use futures::future::join_all;
use std::{
collections::{BTreeMap, HashSet},
convert::TryFrom,
sync::Arc,
};
use tracing::{info, trace, warn};
use tracing::{trace, warn};
use matrix_sdk_common::{
api::r0::keys::get_keys::Response as KeysQueryResponse,
encryption::DeviceKeys,
identifiers::{DeviceId, DeviceIdBox, UserId},
executor::spawn,
identifiers::{DeviceIdBox, UserId},
};
use crate::{
@ -35,6 +37,12 @@ use crate::{
store::{Changes, DeviceChanges, IdentityChanges, Result as StoreResult, Store},
};
enum DeviceChange {
New(ReadOnlyDevice),
Updated(ReadOnlyDevice),
None,
}
#[derive(Debug, Clone)]
pub(crate) struct IdentityManager {
user_id: Arc<UserId>,
@ -57,10 +65,6 @@ impl IdentityManager {
&self.user_id
}
fn device_id(&self) -> &DeviceId {
&self.device_id
}
/// Receive a successful keys query response.
///
/// Returns a list of devices newly discovered devices and devices that
@ -74,17 +78,8 @@ impl IdentityManager {
&self,
response: &KeysQueryResponse,
) -> OlmResult<(DeviceChanges, IdentityChanges)> {
// TODO create a enum that tells us how the device/identity changed,
// e.g. new/deleted/display name change.
//
// TODO create a struct that will hold the device/identity and the
// change enum and return the struct.
//
// TODO once outbound group sessions hold on to the set of users that
// received the session, invalidate the session if a user device
// got added/deleted.
let changed_devices = self
.handle_devices_from_key_query(&response.device_keys)
.handle_devices_from_key_query(response.device_keys.clone())
.await?;
let changed_identities = self.handle_cross_singing_keys(response).await?;
@ -94,11 +89,113 @@ impl IdentityManager {
..Default::default()
};
// TODO turn this into a single transaction.
self.store.save_changes(changes).await?;
let updated_users: Vec<&UserId> = response.device_keys.keys().collect();
for user_id in updated_users {
self.store.update_tracked_user(user_id, false).await?;
}
Ok((changed_devices, changed_identities))
}
async fn update_or_create_device(
store: Store,
device_keys: DeviceKeys,
) -> StoreResult<DeviceChange> {
let old_device = store
.get_readonly_device(&device_keys.user_id, &device_keys.device_id)
.await?;
if let Some(mut device) = old_device {
if let Err(e) = device.update_device(&device_keys) {
warn!(
"Failed to update the device keys for {} {}: {:?}",
device.user_id(),
device.device_id(),
e
);
Ok(DeviceChange::None)
} else {
Ok(DeviceChange::Updated(device))
}
} else {
match ReadOnlyDevice::try_from(&device_keys) {
Ok(d) => {
trace!("Adding a new device to the device store {:?}", d);
Ok(DeviceChange::New(d))
}
Err(e) => {
warn!(
"Failed to create a new device for {} {}: {:?}",
device_keys.user_id, device_keys.device_id, e
);
Ok(DeviceChange::None)
}
}
}
}
async fn update_user_devices(
store: Store,
own_user_id: Arc<UserId>,
own_device_id: Arc<DeviceIdBox>,
user_id: UserId,
device_map: BTreeMap<DeviceIdBox, DeviceKeys>,
) -> StoreResult<DeviceChanges> {
let mut changes = DeviceChanges::default();
let current_devices: HashSet<DeviceIdBox> = device_map.keys().cloned().collect();
let tasks = device_map
.into_iter()
.filter_map(|(device_id, device_keys)| {
// We don't need our own device in the device store.
if user_id == *own_user_id && device_id == *own_device_id {
None
} else if user_id != device_keys.user_id || device_id != device_keys.device_id {
warn!(
"Mismatch in device keys payload of device {}|{} from user {}|{}",
device_id, device_keys.device_id, user_id, device_keys.user_id
);
None
} else {
Some(spawn(Self::update_or_create_device(
store.clone(),
device_keys,
)))
}
});
let results = join_all(tasks).await;
for device in results {
let device = device.expect("Creating or updating a device panicked")?;
match device {
DeviceChange::New(d) => changes.new.push(d),
DeviceChange::Updated(d) => changes.changed.push(d),
DeviceChange::None => (),
}
}
let current_devices: HashSet<&DeviceIdBox> = current_devices.iter().collect();
let stored_devices = store.get_readonly_devices(&user_id).await?;
let stored_devices_set: HashSet<&DeviceIdBox> = stored_devices.keys().collect();
let deleted_devices_set = stored_devices_set.difference(&current_devices);
for device_id in deleted_devices_set {
if let Some(device) = stored_devices.get(*device_id) {
device.mark_as_deleted();
changes.deleted.push(device.clone());
}
}
Ok(changes)
}
/// Handle the device keys part of a key query response.
///
/// # Arguments
@ -110,69 +207,28 @@ impl IdentityManager {
/// they are new, one of their properties has changed or they got deleted.
async fn handle_devices_from_key_query(
&self,
device_keys_map: &BTreeMap<UserId, BTreeMap<DeviceIdBox, DeviceKeys>>,
device_keys_map: BTreeMap<UserId, BTreeMap<DeviceIdBox, DeviceKeys>>,
) -> StoreResult<DeviceChanges> {
let mut changes = DeviceChanges::default();
for (user_id, device_map) in device_keys_map {
// TODO move this out into the handle keys query response method
// since we might fail handle the new device at any point here or
// when updating the user identities.
self.store.update_tracked_user(user_id, false).await?;
let tasks = device_keys_map
.into_iter()
.map(|(user_id, device_keys_map)| {
spawn(Self::update_user_devices(
self.store.clone(),
self.user_id.clone(),
self.device_id.clone(),
user_id,
device_keys_map,
))
});
for (device_id, device_keys) in device_map.iter() {
// We don't need our own device in the device store.
if user_id == self.user_id() && &**device_id == self.device_id() {
continue;
}
let results = join_all(tasks).await;
if user_id != &device_keys.user_id || device_id != &device_keys.device_id {
warn!(
"Mismatch in device keys payload of device {}|{} from user {}|{}",
device_id, device_keys.device_id, user_id, device_keys.user_id
);
continue;
}
for result in results {
let change_fragment = result.expect("Panic while updating user devices")?;
let device = self.store.get_readonly_device(&user_id, device_id).await?;
if let Some(mut device) = device {
if let Err(e) = device.update_device(device_keys) {
warn!(
"Failed to update the device keys for {} {}: {:?}",
user_id, device_id, e
);
continue;
}
changes.changed.push(device);
} else {
let device = match ReadOnlyDevice::try_from(device_keys) {
Ok(d) => d,
Err(e) => {
warn!(
"Failed to create a new device for {} {}: {:?}",
user_id, device_id, e
);
continue;
}
};
info!("Adding a new device to the device store {:?}", device);
changes.new.push(device);
}
}
let current_devices: HashSet<&DeviceIdBox> = device_map.keys().collect();
let stored_devices = self.store.get_readonly_devices(&user_id).await?;
let stored_devices_set: HashSet<&DeviceIdBox> = stored_devices.keys().collect();
let deleted_devices_set = stored_devices_set.difference(&current_devices);
for device_id in deleted_devices_set {
if let Some(device) = stored_devices.get(*device_id) {
device.mark_as_deleted();
changes.deleted.push(device.clone());
}
}
changes.extend(change_fragment);
}
Ok(changes)

View File

@ -63,6 +63,14 @@ impl ToDeviceRequest {
pub fn txn_id_string(&self) -> String {
self.txn_id.to_string()
}
/// Get the number of unique messages this request contains.
///
/// *Note*: A single message may be sent to multiple devices, so this may or
/// may not be the number of devices that will receive the messages as well.
pub fn message_count(&self) -> usize {
self.messages.values().map(|d| d.len()).sum()
}
}
/// Request that will publish a cross signing identity.

View File

@ -17,6 +17,8 @@ use std::{
sync::Arc,
};
use futures::future::join_all;
use dashmap::DashMap;
use matrix_sdk_common::{
api::r0::to_device::DeviceIdOrAllDevices,
@ -24,6 +26,7 @@ use matrix_sdk_common::{
room::{encrypted::EncryptedEventContent, history_visibility::HistoryVisibility},
AnyMessageEventContent, EventType,
},
executor::spawn,
identifiers::{DeviceId, DeviceIdBox, RoomId, UserId},
uuid::Uuid,
};
@ -52,7 +55,7 @@ pub struct GroupSessionManager {
}
impl GroupSessionManager {
const MAX_TO_DEVICE_MESSAGES: usize = 20;
const MAX_TO_DEVICE_MESSAGES: usize = 250;
pub(crate) fn new(account: Account, store: Store) -> Self {
Self {
@ -188,35 +191,57 @@ impl GroupSessionManager {
/// Encrypt the given content for the given devices and create a to-device
/// requests that sends the encrypted content to them.
async fn encrypt_session_for(
&self,
content: Value,
devices: &[Device],
devices: Vec<Device>,
) -> OlmResult<(Uuid, ToDeviceRequest, Vec<Session>)> {
let mut messages = BTreeMap::new();
let mut changed_sessions = Vec::new();
for device in devices {
let encrypt = |device: Device, content: Value| async move {
let mut message = BTreeMap::new();
let encrypted = device.encrypt(EventType::RoomKey, content.clone()).await;
let (used_session, encrypted) = match encrypted {
Ok(c) => c,
let used_session = match encrypted {
Ok((session, encrypted)) => {
message
.entry(device.user_id().clone())
.or_insert_with(BTreeMap::new)
.insert(
DeviceIdOrAllDevices::DeviceId(device.device_id().into()),
serde_json::value::to_raw_value(&encrypted)?,
);
Some(session)
}
// TODO we'll want to create m.room_key.withheld here.
Err(OlmError::MissingSession)
| Err(OlmError::EventError(EventError::MissingSenderKey)) => {
continue;
}
| Err(OlmError::EventError(EventError::MissingSenderKey)) => None,
Err(e) => return Err(e),
};
changed_sessions.push(used_session);
Ok((used_session, message))
};
messages
.entry(device.user_id().clone())
.or_insert_with(BTreeMap::new)
.insert(
DeviceIdOrAllDevices::DeviceId(device.device_id().into()),
serde_json::value::to_raw_value(&encrypted)?,
);
let tasks: Vec<_> = devices
.iter()
.map(|d| spawn(encrypt(d.clone(), content.clone())))
.collect();
let results = join_all(tasks).await;
for result in results {
let (used_session, message) = result.expect("Encryption task panicked")?;
if let Some(session) = used_session {
changed_sessions.push(session);
}
for (user, device_messages) in message.into_iter() {
messages
.entry(user)
.or_insert_with(BTreeMap::new)
.extend(device_messages);
}
}
let id = Uuid::new_v4();
@ -227,6 +252,12 @@ impl GroupSessionManager {
messages,
};
trace!(
recipient_count = request.message_count(),
transaction_id = ?id,
"Created a to-device request carrying a room_key"
);
Ok((id, request, changed_sessions))
}
@ -334,6 +365,24 @@ impl GroupSessionManager {
Ok((should_rotate, devices))
}
pub async fn encrypt_request(
chunk: Vec<Device>,
content: Value,
outbound: OutboundGroupSession,
message_index: u32,
being_shared: Arc<DashMap<Uuid, OutboundGroupSession>>,
) -> OlmResult<Vec<Session>> {
let (id, request, used_sessions) =
Self::encrypt_session_for(content.clone(), chunk).await?;
if !request.messages.is_empty() {
outbound.add_request(id, request.into(), message_index);
being_shared.insert(id, outbound.clone());
}
Ok(used_sessions)
}
/// Get to-device requests to share a group session with users in a room.
///
/// # Arguments
@ -427,18 +476,23 @@ impl GroupSessionManager {
);
}
for device_map_chunk in devices.chunks(Self::MAX_TO_DEVICE_MESSAGES) {
let (id, request, used_sessions) = self
.encrypt_session_for(key_content.clone(), device_map_chunk)
.await?;
let tasks: Vec<_> = devices
.chunks(Self::MAX_TO_DEVICE_MESSAGES)
.map(|chunk| {
spawn(Self::encrypt_request(
chunk.to_vec(),
key_content.clone(),
outbound.clone(),
message_index,
self.outbound_sessions_being_shared.clone(),
))
})
.collect();
if !request.messages.is_empty() {
outbound.add_request(id, request.into(), message_index);
self.outbound_sessions_being_shared
.insert(id, outbound.clone());
}
for result in join_all(tasks).await {
let used_sessions: OlmResult<Vec<Session>> = result.expect("Encryption task paniced");
changes.sessions.extend(used_sessions);
changes.sessions.extend(used_sessions?);
}
let requests = outbound.pending_requests();
@ -474,3 +528,84 @@ impl GroupSessionManager {
Ok(requests)
}
}
#[cfg(test)]
mod test {
use std::convert::TryFrom;
use matrix_sdk_common::{
api::r0::keys::{claim_keys, get_keys},
identifiers::{room_id, user_id, DeviceIdBox, UserId},
uuid::Uuid,
};
use matrix_sdk_test::response_from_file;
use serde_json::Value;
use crate::{EncryptionSettings, OlmMachine};
fn alice_id() -> UserId {
user_id!("@alice:example.org")
}
fn alice_device_id() -> DeviceIdBox {
"JLAFKJWSCS".into()
}
fn keys_query_response() -> get_keys::Response {
let data = include_bytes!("../../benches/keys_query.json");
let data: Value = serde_json::from_slice(data).unwrap();
let data = response_from_file(&data);
get_keys::Response::try_from(data).expect("Can't parse the keys upload response")
}
fn keys_claim_response() -> claim_keys::Response {
let data = include_bytes!("../../benches/keys_claim.json");
let data: Value = serde_json::from_slice(data).unwrap();
let data = response_from_file(&data);
claim_keys::Response::try_from(data).expect("Can't parse the keys upload response")
}
async fn machine() -> OlmMachine {
let keys_query = keys_query_response();
let keys_claim = keys_claim_response();
let uuid = Uuid::new_v4();
let machine = OlmMachine::new(&alice_id(), &alice_device_id());
machine
.mark_request_as_sent(&uuid, &keys_query)
.await
.unwrap();
machine
.mark_request_as_sent(&uuid, &keys_claim)
.await
.unwrap();
machine
}
#[tokio::test]
async fn test_sharing() {
let machine = machine().await;
let room_id = room_id!("!test:localhost");
let keys_claim = keys_claim_response();
let users: Vec<_> = keys_claim.one_time_keys.keys().collect();
let requests = machine
.share_group_session(
&room_id,
users.clone().into_iter(),
EncryptionSettings::default(),
)
.await
.unwrap();
let event_count: usize = requests.iter().map(|r| r.message_count()).sum();
// The keys claim response has a couple of one-time keys with invalid
// signatures, thus only 148 sessions are actually created, we check
// that all 148 valid sessions get an room key.
assert_eq!(event_count, 148);
}
}

View File

@ -195,9 +195,9 @@ impl SessionManager {
// Add the list of devices that the user wishes to establish sessions
// right now.
for user_id in users {
let user_devices = self.store.get_user_devices(user_id).await?;
let user_devices = self.store.get_readonly_devices(user_id).await?;
for device in user_devices.devices() {
for (device_id, device) in user_devices.into_iter() {
let sender_key = if let Some(k) = device.get_key(DeviceKeyAlgorithm::Curve25519) {
k
} else {
@ -216,10 +216,7 @@ impl SessionManager {
missing
.entry(user_id.to_owned())
.or_insert_with(BTreeMap::new)
.insert(
device.device_id().into(),
DeviceKeyAlgorithm::SignedCurve25519,
);
.insert(device_id, DeviceKeyAlgorithm::SignedCurve25519);
}
}
}

View File

@ -126,6 +126,15 @@ pub struct DeviceChanges {
pub deleted: Vec<ReadOnlyDevice>,
}
impl DeviceChanges {
/// Merge the given `DeviceChanges` into this instance of `DeviceChanges`.
pub fn extend(&mut self, other: DeviceChanges) {
self.new.extend(other.new);
self.changed.extend(other.changed);
self.deleted.extend(other.deleted);
}
}
impl Store {
pub fn new(
user_id: Arc<UserId>,

View File

@ -16,11 +16,11 @@ use std::{
collections::{HashMap, HashSet},
convert::TryFrom,
path::{Path, PathBuf},
sync::Arc,
sync::{Arc, RwLock},
};
use dashmap::DashSet;
use olm_rs::PicklingMode;
use olm_rs::{account::IdentityKeys, PicklingMode};
pub use sled::Error;
use sled::{
transaction::{ConflictableTransactionError, TransactionError},
@ -95,9 +95,17 @@ impl EncodeKey for (&str, &str, &str) {
}
}
#[derive(Clone, Debug)]
pub struct AccountInfo {
user_id: Arc<UserId>,
device_id: Arc<DeviceIdBox>,
identity_keys: Arc<IdentityKeys>,
}
/// An in-memory only store that will forget all the E2EE key once it's dropped.
#[derive(Clone)]
pub struct SledStore {
account_info: Arc<RwLock<Option<AccountInfo>>>,
path: Option<PathBuf>,
inner: Db,
pickle_key: Arc<PickleKey>,
@ -159,6 +167,10 @@ impl SledStore {
SledStore::open_helper(db, None, passphrase)
}
fn get_account_info(&self) -> Option<AccountInfo> {
self.account_info.read().unwrap().clone()
}
fn open_helper(db: Db, path: Option<PathBuf>, passphrase: Option<&str>) -> Result<Self> {
let account = db.open_tree("account")?;
let private_identity = db.open_tree("private_identity")?;
@ -184,6 +196,7 @@ impl SledStore {
};
Ok(Self {
account_info: RwLock::new(None).into(),
path,
inner: db,
pickle_key: pickle_key.into(),
@ -249,22 +262,18 @@ impl SledStore {
&self,
room_id: &RoomId,
) -> Result<Option<OutboundGroupSession>> {
let account = self
.load_account()
.await?
let account_info = self
.get_account_info()
.ok_or(CryptoStoreError::AccountUnset)?;
let device_id: Arc<DeviceIdBox> = account.device_id().to_owned().into();
let identity_keys = account.identity_keys;
self.outbound_group_sessions
.get(room_id.encode())?
.map(|p| serde_json::from_slice(&p).map_err(CryptoStoreError::Serialization))
.transpose()?
.map(|p| {
OutboundGroupSession::from_pickle(
device_id,
identity_keys,
account_info.device_id,
account_info.identity_keys,
p,
self.get_pickle_mode(),
)
@ -430,16 +439,31 @@ impl CryptoStore for SledStore {
self.load_tracked_users().await?;
Ok(Some(ReadOnlyAccount::from_pickle(
pickle,
self.get_pickle_mode(),
)?))
let account = ReadOnlyAccount::from_pickle(pickle, self.get_pickle_mode())?;
let account_info = AccountInfo {
user_id: account.user_id.clone(),
device_id: account.device_id.clone(),
identity_keys: account.identity_keys.clone(),
};
*self.account_info.write().unwrap() = Some(account_info);
Ok(Some(account))
} else {
Ok(None)
}
}
async fn save_account(&self, account: ReadOnlyAccount) -> Result<()> {
let account_info = AccountInfo {
user_id: account.user_id.clone(),
device_id: account.device_id.clone(),
identity_keys: account.identity_keys.clone(),
};
*self.account_info.write().unwrap() = Some(account_info);
let changes = Changes {
account: Some(account),
..Default::default()
@ -453,9 +477,8 @@ impl CryptoStore for SledStore {
}
async fn get_sessions(&self, sender_key: &str) -> Result<Option<Arc<Mutex<Vec<Session>>>>> {
let account = self
.load_account()
.await?
let account_info = self
.get_account_info()
.ok_or(CryptoStoreError::AccountUnset)?;
if self.session_cache.get(sender_key).is_none() {
@ -465,9 +488,9 @@ impl CryptoStore for SledStore {
.map(|s| serde_json::from_slice(&s?.1).map_err(CryptoStoreError::Serialization))
.map(|p| {
Session::from_pickle(
account.user_id.clone(),
account.device_id.clone(),
account.identity_keys.clone(),
account_info.user_id.clone(),
account_info.device_id.clone(),
account_info.identity_keys.clone(),
p?,
self.get_pickle_mode(),
)