Merge branch 'multithreaded-crypto'

master
Damir Jelić 2021-03-23 11:34:07 +01:00
commit 15d5b234ed
16 changed files with 40635 additions and 186 deletions

View File

@ -10,15 +10,15 @@ edition = "2018"
crate-type = ["cdylib"] crate-type = ["cdylib"]
[dependencies] [dependencies]
url = "2.1.1" url = "2.2.1"
wasm-bindgen = { version = "0.2.62", features = ["serde-serialize"] } wasm-bindgen = { version = "0.2.72", features = ["serde-serialize"] }
wasm-bindgen-futures = "0.4.12" wasm-bindgen-futures = "0.4.22"
console_error_panic_hook = "*" console_error_panic_hook = "0.1.6"
web-sys = { version = "0.3.39", features = ["console"] } web-sys = { version = "0.3.49", features = ["console"] }
[dependencies.matrix-sdk] [dependencies.matrix-sdk]
path = "../.." path = "../.."
default-features = false default-features = false
features = ["native-tls"] features = ["native-tls", "encryption"]
[workspace] [workspace]

View File

@ -7,6 +7,4 @@ You can build the example locally with:
and then visiting http://localhost:8080 in a browser should run the example! and then visiting http://localhost:8080 in a browser should run the example!
Note: Encryption isn't supported yet This example is loosely based off of [this example](https://github.com/seanmonstar/reqwest/tree/master/examples/wasm_github_fetch), an example usage of `fetch` from `wasm-bindgen`.
This example is loosely based off of [this example](https://github.com/seanmonstar/reqwest/tree/master/examples/wasm_github_fetch), an example usage of `fetch` from `wasm-bindgen`.

View File

@ -5,8 +5,8 @@
}, },
"devDependencies": { "devDependencies": {
"@wasm-tool/wasm-pack-plugin": "1.0.1", "@wasm-tool/wasm-pack-plugin": "1.0.1",
"text-encoding": "^0.7.0",
"html-webpack-plugin": "^3.2.0", "html-webpack-plugin": "^3.2.0",
"text-encoding": "^0.7.0",
"webpack": "^4.29.4", "webpack": "^4.29.4",
"webpack-cli": "^3.1.1", "webpack-cli": "^3.1.1",
"webpack-dev-server": "^3.1.0" "webpack-dev-server": "^3.1.0"

View File

@ -1,7 +1,7 @@
use matrix_sdk::{ use matrix_sdk::{
deserialized_responses::SyncResponse, deserialized_responses::SyncResponse,
events::{ events::{
room::message::{MessageEventContent, TextMessageEventContent}, room::message::{MessageEventContent, MessageType, TextMessageEventContent},
AnyMessageEventContent, AnySyncMessageEvent, AnySyncRoomEvent, SyncMessageEvent, AnyMessageEventContent, AnySyncMessageEvent, AnySyncRoomEvent, SyncMessageEvent,
}, },
identifiers::RoomId, identifiers::RoomId,
@ -17,35 +17,49 @@ impl WasmBot {
async fn on_room_message( async fn on_room_message(
&self, &self,
room_id: &RoomId, room_id: &RoomId,
event: SyncMessageEvent<MessageEventContent>, event: &SyncMessageEvent<MessageEventContent>,
) { ) {
let msg_body = if let SyncMessageEvent { let msg_body = if let SyncMessageEvent {
content: MessageEventContent::Text(TextMessageEventContent { body: msg_body, .. }), content:
MessageEventContent {
msgtype: MessageType::Text(TextMessageEventContent { body: msg_body, .. }),
..
},
.. ..
} = event } = event
{ {
msg_body.clone() msg_body
} else { } else {
return; return;
}; };
console::log_1(&format!("Received message event {:?}", &msg_body).into()); console::log_1(&format!("Received message event {:?}", &msg_body).into());
if msg_body.starts_with("!party") { if msg_body.contains("!party") {
let content = AnyMessageEventContent::RoomMessage(MessageEventContent::Text( let content = AnyMessageEventContent::RoomMessage(MessageEventContent::text_plain(
TextMessageEventContent::plain("🎉🎊🥳 let's PARTY with wasm!! 🥳🎊🎉".to_string()), "🎉🎊🥳 let's PARTY!! 🥳🎊🎉",
)); ));
self.0.room_send(&room_id, content, None).await.unwrap(); println!("sending");
self.0
// send our message to the room we found the "!party" command in
// the last parameter is an optional Uuid which we don't care about.
.room_send(room_id, content, None)
.await
.unwrap();
println!("message sent");
} }
} }
async fn on_sync_response(&self, response: SyncResponse) -> LoopCtrl { async fn on_sync_response(&self, response: SyncResponse) -> LoopCtrl {
console::log_1(&"Synced".to_string().into()); console::log_1(&"Synced".to_string().into());
for (room_id, room) in response.rooms.join { for (room_id, room) in response.rooms.join {
for event in room.timeline.events { for event in room.timeline.events {
if let AnySyncRoomEvent::Message(AnySyncMessageEvent::RoomMessage(ev)) = event { if let AnySyncRoomEvent::Message(AnySyncMessageEvent::RoomMessage(ev)) = event {
self.on_room_message(&room_id, ev).await self.on_room_message(&room_id, &ev).await
} }
} }
} }

View File

@ -34,5 +34,7 @@ default-features = false
features = ["sync"] features = ["sync"]
[target.'cfg(target_arch = "wasm32")'.dependencies] [target.'cfg(target_arch = "wasm32")'.dependencies]
futures = "0.3.12"
futures-locks = { version = "0.6.0", default-features = false } futures-locks = { version = "0.6.0", default-features = false }
wasm-bindgen-futures = "0.4"
uuid = { version = "0.8.2", default-features = false, features = ["v4", "wasm-bindgen"] } uuid = { version = "0.8.2", default-features = false, features = ["v4", "wasm-bindgen"] }

View File

@ -0,0 +1,42 @@
//! Abstraction over an executor so we can spawn tasks under WASM the same way
//! we do usually.
#[cfg(target_arch = "wasm32")]
use std::{
pin::Pin,
task::{Context, Poll},
};
#[cfg(not(target_arch = "wasm32"))]
pub use tokio::spawn;
#[cfg(target_arch = "wasm32")]
use wasm_bindgen_futures::spawn_local;
#[cfg(target_arch = "wasm32")]
use futures::{future::RemoteHandle, Future, FutureExt};
#[cfg(target_arch = "wasm32")]
pub fn spawn<F, T>(future: F) -> JoinHandle<T>
where
F: Future<Output = T> + 'static,
{
let fut = future.unit_error();
let (fut, handle) = fut.remote_handle();
spawn_local(fut);
JoinHandle { handle }
}
#[cfg(target_arch = "wasm32")]
pub struct JoinHandle<T> {
handle: RemoteHandle<Result<T, ()>>,
}
#[cfg(target_arch = "wasm32")]
impl<T: 'static> Future for JoinHandle<T> {
type Output = Result<T, ()>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.handle).poll(cx)
}
}

View File

@ -14,6 +14,7 @@ pub use ruma::{
pub use uuid; pub use uuid;
pub mod deserialized_responses; pub mod deserialized_responses;
pub mod executor;
pub mod locks; pub mod locks;
/// Super trait that is used for our store traits, this trait will differ if /// Super trait that is used for our store traits, this trait will differ if

View File

@ -29,6 +29,7 @@ serde_json = "1.0.61"
zeroize = { version = "1.2.0", features = ["zeroize_derive"] } zeroize = { version = "1.2.0", features = ["zeroize_derive"] }
# Misc dependencies # Misc dependencies
futures = "0.3.12"
sled = { version = "0.34.6", optional = true } sled = { version = "0.34.6", optional = true }
thiserror = "1.0.23" thiserror = "1.0.23"
tracing = "0.1.22" tracing = "0.1.22"
@ -44,14 +45,13 @@ byteorder = "1.4.2"
[dev-dependencies] [dev-dependencies]
tokio = { version = "1.1.0", default-features = false, features = ["rt-multi-thread", "macros"] } tokio = { version = "1.1.0", default-features = false, features = ["rt-multi-thread", "macros"] }
futures = "0.3.12"
proptest = "0.10.1" proptest = "0.10.1"
serde_json = "1.0.61" serde_json = "1.0.61"
tempfile = "3.2.0" tempfile = "3.2.0"
http = "0.2.3" http = "0.2.3"
matrix-sdk-test = { version = "0.2.0", path = "../matrix_sdk_test" } matrix-sdk-test = { version = "0.2.0", path = "../matrix_sdk_test" }
indoc = "1.0.3" indoc = "1.0.3"
criterion = { version = "0.3.4", features = ["async", "async_futures", "html_reports"] } criterion = { version = "0.3.4", features = ["async", "async_tokio", "html_reports"] }
[target.'cfg(target_os = "linux")'.dev-dependencies] [target.'cfg(target_os = "linux")'.dev-dependencies]
pprof = { version = "0.4.2", features = ["flamegraph"] } pprof = { version = "0.4.2", features = ["flamegraph"] }

View File

@ -1,11 +1,10 @@
#[cfg(target_os = "linux")] #[cfg(target_os = "linux")]
mod perf; mod perf;
use std::convert::TryFrom; use std::{convert::TryFrom, sync::Arc};
use criterion::{async_executor::FuturesExecutor, *}; use criterion::*;
use futures::executor::block_on;
use matrix_sdk_common::{ use matrix_sdk_common::{
api::r0::{ api::r0::{
keys::{claim_keys, get_keys}, keys::{claim_keys, get_keys},
@ -17,6 +16,7 @@ use matrix_sdk_common::{
use matrix_sdk_crypto::{EncryptionSettings, OlmMachine}; use matrix_sdk_crypto::{EncryptionSettings, OlmMachine};
use matrix_sdk_test::response_from_file; use matrix_sdk_test::response_from_file;
use serde_json::Value; use serde_json::Value;
use tokio::runtime::Builder;
fn alice_id() -> UserId { fn alice_id() -> UserId {
user_id!("@alice:example.org") user_id!("@alice:example.org")
@ -40,7 +40,17 @@ fn keys_claim_response() -> claim_keys::Response {
claim_keys::Response::try_from(data).expect("Can't parse the keys upload response") claim_keys::Response::try_from(data).expect("Can't parse the keys upload response")
} }
fn huge_keys_query_resopnse() -> get_keys::Response {
let data = include_bytes!("./keys_query_2000_members.json");
let data: Value = serde_json::from_slice(data).unwrap();
let data = response_from_file(&data);
get_keys::Response::try_from(data).expect("Can't parse the keys query response")
}
pub fn keys_query(c: &mut Criterion) { pub fn keys_query(c: &mut Criterion) {
let runtime = Builder::new_multi_thread()
.build()
.expect("Can't create runtime");
let machine = OlmMachine::new(&alice_id(), &alice_device_id()); let machine = OlmMachine::new(&alice_id(), &alice_device_id());
let response = keys_query_response(); let response = keys_query_response();
let uuid = Uuid::new_v4(); let uuid = Uuid::new_v4();
@ -62,25 +72,26 @@ pub fn keys_query(c: &mut Criterion) {
BenchmarkId::new("memory store", &name), BenchmarkId::new("memory store", &name),
&response, &response,
|b, response| { |b, response| {
b.to_async(FuturesExecutor) b.to_async(&runtime)
.iter(|| async { machine.mark_request_as_sent(&uuid, response).await.unwrap() }) .iter(|| async { machine.mark_request_as_sent(&uuid, response).await.unwrap() })
}, },
); );
let dir = tempfile::tempdir().unwrap(); let dir = tempfile::tempdir().unwrap();
let machine = block_on(OlmMachine::new_with_default_store( let machine = runtime
&alice_id(), .block_on(OlmMachine::new_with_default_store(
&alice_device_id(), &alice_id(),
dir.path(), &alice_device_id(),
None, dir.path(),
)) None,
.unwrap(); ))
.unwrap();
group.bench_with_input( group.bench_with_input(
BenchmarkId::new("sled store", &name), BenchmarkId::new("sled store", &name),
&response, &response,
|b, response| { |b, response| {
b.to_async(FuturesExecutor) b.to_async(&runtime)
.iter(|| async { machine.mark_request_as_sent(&uuid, response).await.unwrap() }) .iter(|| async { machine.mark_request_as_sent(&uuid, response).await.unwrap() })
}, },
); );
@ -89,6 +100,12 @@ pub fn keys_query(c: &mut Criterion) {
} }
pub fn keys_claiming(c: &mut Criterion) { pub fn keys_claiming(c: &mut Criterion) {
let runtime = Arc::new(
Builder::new_multi_thread()
.build()
.expect("Can't create runtime"),
);
let keys_query_response = keys_query_response(); let keys_query_response = keys_query_response();
let uuid = Uuid::new_v4(); let uuid = Uuid::new_v4();
@ -111,10 +128,16 @@ pub fn keys_claiming(c: &mut Criterion) {
b.iter_batched( b.iter_batched(
|| { || {
let machine = OlmMachine::new(&alice_id(), &alice_device_id()); let machine = OlmMachine::new(&alice_id(), &alice_device_id());
block_on(machine.mark_request_as_sent(&uuid, &keys_query_response)).unwrap(); runtime
machine .block_on(machine.mark_request_as_sent(&uuid, &keys_query_response))
.unwrap();
(machine, runtime.clone())
},
move |(machine, runtime)| {
runtime
.block_on(machine.mark_request_as_sent(&uuid, response))
.unwrap()
}, },
move |machine| block_on(machine.mark_request_as_sent(&uuid, response)).unwrap(),
BatchSize::SmallInput, BatchSize::SmallInput,
) )
}, },
@ -127,17 +150,24 @@ pub fn keys_claiming(c: &mut Criterion) {
b.iter_batched( b.iter_batched(
|| { || {
let dir = tempfile::tempdir().unwrap(); let dir = tempfile::tempdir().unwrap();
let machine = block_on(OlmMachine::new_with_default_store( let machine = runtime
&alice_id(), .block_on(OlmMachine::new_with_default_store(
&alice_device_id(), &alice_id(),
dir.path(), &alice_device_id(),
None, dir.path(),
)) None,
.unwrap(); ))
block_on(machine.mark_request_as_sent(&uuid, &keys_query_response)).unwrap(); .unwrap();
machine runtime
.block_on(machine.mark_request_as_sent(&uuid, &keys_query_response))
.unwrap();
(machine, runtime.clone())
},
move |(machine, runtime)| {
runtime
.block_on(machine.mark_request_as_sent(&uuid, response))
.unwrap()
}, },
move |machine| block_on(machine.mark_request_as_sent(&uuid, response)).unwrap(),
BatchSize::SmallInput, BatchSize::SmallInput,
) )
}, },
@ -147,6 +177,10 @@ pub fn keys_claiming(c: &mut Criterion) {
} }
pub fn room_key_sharing(c: &mut Criterion) { pub fn room_key_sharing(c: &mut Criterion) {
let runtime = Builder::new_multi_thread()
.build()
.expect("Can't create runtime");
let keys_query_response = keys_query_response(); let keys_query_response = keys_query_response();
let uuid = Uuid::new_v4(); let uuid = Uuid::new_v4();
let response = keys_claim_response(); let response = keys_claim_response();
@ -161,15 +195,19 @@ pub fn room_key_sharing(c: &mut Criterion) {
.fold(0, |acc, d| acc + d.len()); .fold(0, |acc, d| acc + d.len());
let machine = OlmMachine::new(&alice_id(), &alice_device_id()); let machine = OlmMachine::new(&alice_id(), &alice_device_id());
block_on(machine.mark_request_as_sent(&uuid, &keys_query_response)).unwrap(); runtime
block_on(machine.mark_request_as_sent(&uuid, &response)).unwrap(); .block_on(machine.mark_request_as_sent(&uuid, &keys_query_response))
.unwrap();
runtime
.block_on(machine.mark_request_as_sent(&uuid, &response))
.unwrap();
let mut group = c.benchmark_group("Room key sharing"); let mut group = c.benchmark_group("Room key sharing");
group.throughput(Throughput::Elements(count as u64)); group.throughput(Throughput::Elements(count as u64));
let name = format!("{} devices", count); let name = format!("{} devices", count);
group.bench_function(BenchmarkId::new("memory store", &name), |b| { group.bench_function(BenchmarkId::new("memory store", &name), |b| {
b.to_async(FuturesExecutor).iter(|| async { b.to_async(&runtime).iter(|| async {
let requests = machine let requests = machine
.share_group_session(&room_id, users.iter(), EncryptionSettings::default()) .share_group_session(&room_id, users.iter(), EncryptionSettings::default())
.await .await
@ -189,18 +227,23 @@ pub fn room_key_sharing(c: &mut Criterion) {
}); });
let dir = tempfile::tempdir().unwrap(); let dir = tempfile::tempdir().unwrap();
let machine = block_on(OlmMachine::new_with_default_store( let machine = runtime
&alice_id(), .block_on(OlmMachine::new_with_default_store(
&alice_device_id(), &alice_id(),
dir.path(), &alice_device_id(),
None, dir.path(),
)) None,
.unwrap(); ))
block_on(machine.mark_request_as_sent(&uuid, &keys_query_response)).unwrap(); .unwrap();
block_on(machine.mark_request_as_sent(&uuid, &response)).unwrap(); runtime
.block_on(machine.mark_request_as_sent(&uuid, &keys_query_response))
.unwrap();
runtime
.block_on(machine.mark_request_as_sent(&uuid, &response))
.unwrap();
group.bench_function(BenchmarkId::new("sled store", &name), |b| { group.bench_function(BenchmarkId::new("sled store", &name), |b| {
b.to_async(FuturesExecutor).iter(|| async { b.to_async(&runtime).iter(|| async {
let requests = machine let requests = machine
.share_group_session(&room_id, users.iter(), EncryptionSettings::default()) .share_group_session(&room_id, users.iter(), EncryptionSettings::default())
.await .await
@ -222,6 +265,58 @@ pub fn room_key_sharing(c: &mut Criterion) {
group.finish() group.finish()
} }
pub fn devices_missing_sessions_collecting(c: &mut Criterion) {
let runtime = Builder::new_multi_thread()
.build()
.expect("Can't create runtime");
let machine = OlmMachine::new(&alice_id(), &alice_device_id());
let response = huge_keys_query_resopnse();
let uuid = Uuid::new_v4();
let users: Vec<UserId> = response.device_keys.keys().cloned().collect();
let count = response
.device_keys
.values()
.fold(0, |acc, d| acc + d.len());
let mut group = c.benchmark_group("Devices missing sessions collecting");
group.throughput(Throughput::Elements(count as u64));
let name = format!("{} devices", count);
runtime
.block_on(machine.mark_request_as_sent(&uuid, &response))
.unwrap();
group.bench_function(BenchmarkId::new("memory store", &name), |b| {
b.to_async(&runtime).iter_with_large_drop(|| async {
machine.get_missing_sessions(users.iter()).await.unwrap()
})
});
let dir = tempfile::tempdir().unwrap();
let machine = runtime
.block_on(OlmMachine::new_with_default_store(
&alice_id(),
&alice_device_id(),
dir.path(),
None,
))
.unwrap();
runtime
.block_on(machine.mark_request_as_sent(&uuid, &response))
.unwrap();
group.bench_function(BenchmarkId::new("sled store", &name), |b| {
b.to_async(&runtime)
.iter(|| async { machine.get_missing_sessions(users.iter()).await.unwrap() })
});
group.finish()
}
fn criterion() -> Criterion { fn criterion() -> Criterion {
#[cfg(target_os = "linux")] #[cfg(target_os = "linux")]
let criterion = Criterion::default().with_profiler(perf::FlamegraphProfiler::new(100)); let criterion = Criterion::default().with_profiler(perf::FlamegraphProfiler::new(100));
@ -234,6 +329,6 @@ fn criterion() -> Criterion {
criterion_group! { criterion_group! {
name = benches; name = benches;
config = criterion(); config = criterion();
targets = keys_query, keys_claiming, room_key_sharing targets = keys_query, keys_claiming, room_key_sharing, devices_missing_sessions_collecting,
} }
criterion_main!(benches); criterion_main!(benches);

File diff suppressed because it is too large Load Diff

View File

@ -12,17 +12,19 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use futures::future::join_all;
use std::{ use std::{
collections::{BTreeMap, HashSet}, collections::{BTreeMap, HashSet},
convert::TryFrom, convert::TryFrom,
sync::Arc, sync::Arc,
}; };
use tracing::{info, trace, warn}; use tracing::{trace, warn};
use matrix_sdk_common::{ use matrix_sdk_common::{
api::r0::keys::get_keys::Response as KeysQueryResponse, api::r0::keys::get_keys::Response as KeysQueryResponse,
encryption::DeviceKeys, encryption::DeviceKeys,
identifiers::{DeviceId, DeviceIdBox, UserId}, executor::spawn,
identifiers::{DeviceIdBox, UserId},
}; };
use crate::{ use crate::{
@ -35,6 +37,12 @@ use crate::{
store::{Changes, DeviceChanges, IdentityChanges, Result as StoreResult, Store}, store::{Changes, DeviceChanges, IdentityChanges, Result as StoreResult, Store},
}; };
enum DeviceChange {
New(ReadOnlyDevice),
Updated(ReadOnlyDevice),
None,
}
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub(crate) struct IdentityManager { pub(crate) struct IdentityManager {
user_id: Arc<UserId>, user_id: Arc<UserId>,
@ -57,10 +65,6 @@ impl IdentityManager {
&self.user_id &self.user_id
} }
fn device_id(&self) -> &DeviceId {
&self.device_id
}
/// Receive a successful keys query response. /// Receive a successful keys query response.
/// ///
/// Returns a list of devices newly discovered devices and devices that /// Returns a list of devices newly discovered devices and devices that
@ -74,17 +78,8 @@ impl IdentityManager {
&self, &self,
response: &KeysQueryResponse, response: &KeysQueryResponse,
) -> OlmResult<(DeviceChanges, IdentityChanges)> { ) -> OlmResult<(DeviceChanges, IdentityChanges)> {
// TODO create a enum that tells us how the device/identity changed,
// e.g. new/deleted/display name change.
//
// TODO create a struct that will hold the device/identity and the
// change enum and return the struct.
//
// TODO once outbound group sessions hold on to the set of users that
// received the session, invalidate the session if a user device
// got added/deleted.
let changed_devices = self let changed_devices = self
.handle_devices_from_key_query(&response.device_keys) .handle_devices_from_key_query(response.device_keys.clone())
.await?; .await?;
let changed_identities = self.handle_cross_singing_keys(response).await?; let changed_identities = self.handle_cross_singing_keys(response).await?;
@ -94,11 +89,113 @@ impl IdentityManager {
..Default::default() ..Default::default()
}; };
// TODO turn this into a single transaction.
self.store.save_changes(changes).await?; self.store.save_changes(changes).await?;
let updated_users: Vec<&UserId> = response.device_keys.keys().collect();
for user_id in updated_users {
self.store.update_tracked_user(user_id, false).await?;
}
Ok((changed_devices, changed_identities)) Ok((changed_devices, changed_identities))
} }
async fn update_or_create_device(
store: Store,
device_keys: DeviceKeys,
) -> StoreResult<DeviceChange> {
let old_device = store
.get_readonly_device(&device_keys.user_id, &device_keys.device_id)
.await?;
if let Some(mut device) = old_device {
if let Err(e) = device.update_device(&device_keys) {
warn!(
"Failed to update the device keys for {} {}: {:?}",
device.user_id(),
device.device_id(),
e
);
Ok(DeviceChange::None)
} else {
Ok(DeviceChange::Updated(device))
}
} else {
match ReadOnlyDevice::try_from(&device_keys) {
Ok(d) => {
trace!("Adding a new device to the device store {:?}", d);
Ok(DeviceChange::New(d))
}
Err(e) => {
warn!(
"Failed to create a new device for {} {}: {:?}",
device_keys.user_id, device_keys.device_id, e
);
Ok(DeviceChange::None)
}
}
}
}
async fn update_user_devices(
store: Store,
own_user_id: Arc<UserId>,
own_device_id: Arc<DeviceIdBox>,
user_id: UserId,
device_map: BTreeMap<DeviceIdBox, DeviceKeys>,
) -> StoreResult<DeviceChanges> {
let mut changes = DeviceChanges::default();
let current_devices: HashSet<DeviceIdBox> = device_map.keys().cloned().collect();
let tasks = device_map
.into_iter()
.filter_map(|(device_id, device_keys)| {
// We don't need our own device in the device store.
if user_id == *own_user_id && device_id == *own_device_id {
None
} else if user_id != device_keys.user_id || device_id != device_keys.device_id {
warn!(
"Mismatch in device keys payload of device {}|{} from user {}|{}",
device_id, device_keys.device_id, user_id, device_keys.user_id
);
None
} else {
Some(spawn(Self::update_or_create_device(
store.clone(),
device_keys,
)))
}
});
let results = join_all(tasks).await;
for device in results {
let device = device.expect("Creating or updating a device panicked")?;
match device {
DeviceChange::New(d) => changes.new.push(d),
DeviceChange::Updated(d) => changes.changed.push(d),
DeviceChange::None => (),
}
}
let current_devices: HashSet<&DeviceIdBox> = current_devices.iter().collect();
let stored_devices = store.get_readonly_devices(&user_id).await?;
let stored_devices_set: HashSet<&DeviceIdBox> = stored_devices.keys().collect();
let deleted_devices_set = stored_devices_set.difference(&current_devices);
for device_id in deleted_devices_set {
if let Some(device) = stored_devices.get(*device_id) {
device.mark_as_deleted();
changes.deleted.push(device.clone());
}
}
Ok(changes)
}
/// Handle the device keys part of a key query response. /// Handle the device keys part of a key query response.
/// ///
/// # Arguments /// # Arguments
@ -110,69 +207,28 @@ impl IdentityManager {
/// they are new, one of their properties has changed or they got deleted. /// they are new, one of their properties has changed or they got deleted.
async fn handle_devices_from_key_query( async fn handle_devices_from_key_query(
&self, &self,
device_keys_map: &BTreeMap<UserId, BTreeMap<DeviceIdBox, DeviceKeys>>, device_keys_map: BTreeMap<UserId, BTreeMap<DeviceIdBox, DeviceKeys>>,
) -> StoreResult<DeviceChanges> { ) -> StoreResult<DeviceChanges> {
let mut changes = DeviceChanges::default(); let mut changes = DeviceChanges::default();
for (user_id, device_map) in device_keys_map { let tasks = device_keys_map
// TODO move this out into the handle keys query response method .into_iter()
// since we might fail handle the new device at any point here or .map(|(user_id, device_keys_map)| {
// when updating the user identities. spawn(Self::update_user_devices(
self.store.update_tracked_user(user_id, false).await?; self.store.clone(),
self.user_id.clone(),
self.device_id.clone(),
user_id,
device_keys_map,
))
});
for (device_id, device_keys) in device_map.iter() { let results = join_all(tasks).await;
// We don't need our own device in the device store.
if user_id == self.user_id() && &**device_id == self.device_id() {
continue;
}
if user_id != &device_keys.user_id || device_id != &device_keys.device_id { for result in results {
warn!( let change_fragment = result.expect("Panic while updating user devices")?;
"Mismatch in device keys payload of device {}|{} from user {}|{}",
device_id, device_keys.device_id, user_id, device_keys.user_id
);
continue;
}
let device = self.store.get_readonly_device(&user_id, device_id).await?; changes.extend(change_fragment);
if let Some(mut device) = device {
if let Err(e) = device.update_device(device_keys) {
warn!(
"Failed to update the device keys for {} {}: {:?}",
user_id, device_id, e
);
continue;
}
changes.changed.push(device);
} else {
let device = match ReadOnlyDevice::try_from(device_keys) {
Ok(d) => d,
Err(e) => {
warn!(
"Failed to create a new device for {} {}: {:?}",
user_id, device_id, e
);
continue;
}
};
info!("Adding a new device to the device store {:?}", device);
changes.new.push(device);
}
}
let current_devices: HashSet<&DeviceIdBox> = device_map.keys().collect();
let stored_devices = self.store.get_readonly_devices(&user_id).await?;
let stored_devices_set: HashSet<&DeviceIdBox> = stored_devices.keys().collect();
let deleted_devices_set = stored_devices_set.difference(&current_devices);
for device_id in deleted_devices_set {
if let Some(device) = stored_devices.get(*device_id) {
device.mark_as_deleted();
changes.deleted.push(device.clone());
}
}
} }
Ok(changes) Ok(changes)

View File

@ -63,6 +63,14 @@ impl ToDeviceRequest {
pub fn txn_id_string(&self) -> String { pub fn txn_id_string(&self) -> String {
self.txn_id.to_string() self.txn_id.to_string()
} }
/// Get the number of unique messages this request contains.
///
/// *Note*: A single message may be sent to multiple devices, so this may or
/// may not be the number of devices that will receive the messages as well.
pub fn message_count(&self) -> usize {
self.messages.values().map(|d| d.len()).sum()
}
} }
/// Request that will publish a cross signing identity. /// Request that will publish a cross signing identity.

View File

@ -17,6 +17,8 @@ use std::{
sync::Arc, sync::Arc,
}; };
use futures::future::join_all;
use dashmap::DashMap; use dashmap::DashMap;
use matrix_sdk_common::{ use matrix_sdk_common::{
api::r0::to_device::DeviceIdOrAllDevices, api::r0::to_device::DeviceIdOrAllDevices,
@ -24,6 +26,7 @@ use matrix_sdk_common::{
room::{encrypted::EncryptedEventContent, history_visibility::HistoryVisibility}, room::{encrypted::EncryptedEventContent, history_visibility::HistoryVisibility},
AnyMessageEventContent, EventType, AnyMessageEventContent, EventType,
}, },
executor::spawn,
identifiers::{DeviceId, DeviceIdBox, RoomId, UserId}, identifiers::{DeviceId, DeviceIdBox, RoomId, UserId},
uuid::Uuid, uuid::Uuid,
}; };
@ -52,7 +55,7 @@ pub struct GroupSessionManager {
} }
impl GroupSessionManager { impl GroupSessionManager {
const MAX_TO_DEVICE_MESSAGES: usize = 20; const MAX_TO_DEVICE_MESSAGES: usize = 250;
pub(crate) fn new(account: Account, store: Store) -> Self { pub(crate) fn new(account: Account, store: Store) -> Self {
Self { Self {
@ -188,35 +191,57 @@ impl GroupSessionManager {
/// Encrypt the given content for the given devices and create a to-device /// Encrypt the given content for the given devices and create a to-device
/// requests that sends the encrypted content to them. /// requests that sends the encrypted content to them.
async fn encrypt_session_for( async fn encrypt_session_for(
&self,
content: Value, content: Value,
devices: &[Device], devices: Vec<Device>,
) -> OlmResult<(Uuid, ToDeviceRequest, Vec<Session>)> { ) -> OlmResult<(Uuid, ToDeviceRequest, Vec<Session>)> {
let mut messages = BTreeMap::new(); let mut messages = BTreeMap::new();
let mut changed_sessions = Vec::new(); let mut changed_sessions = Vec::new();
for device in devices { let encrypt = |device: Device, content: Value| async move {
let mut message = BTreeMap::new();
let encrypted = device.encrypt(EventType::RoomKey, content.clone()).await; let encrypted = device.encrypt(EventType::RoomKey, content.clone()).await;
let (used_session, encrypted) = match encrypted { let used_session = match encrypted {
Ok(c) => c, Ok((session, encrypted)) => {
message
.entry(device.user_id().clone())
.or_insert_with(BTreeMap::new)
.insert(
DeviceIdOrAllDevices::DeviceId(device.device_id().into()),
serde_json::value::to_raw_value(&encrypted)?,
);
Some(session)
}
// TODO we'll want to create m.room_key.withheld here. // TODO we'll want to create m.room_key.withheld here.
Err(OlmError::MissingSession) Err(OlmError::MissingSession)
| Err(OlmError::EventError(EventError::MissingSenderKey)) => { | Err(OlmError::EventError(EventError::MissingSenderKey)) => None,
continue;
}
Err(e) => return Err(e), Err(e) => return Err(e),
}; };
changed_sessions.push(used_session); Ok((used_session, message))
};
messages let tasks: Vec<_> = devices
.entry(device.user_id().clone()) .iter()
.or_insert_with(BTreeMap::new) .map(|d| spawn(encrypt(d.clone(), content.clone())))
.insert( .collect();
DeviceIdOrAllDevices::DeviceId(device.device_id().into()),
serde_json::value::to_raw_value(&encrypted)?, let results = join_all(tasks).await;
);
for result in results {
let (used_session, message) = result.expect("Encryption task panicked")?;
if let Some(session) = used_session {
changed_sessions.push(session);
}
for (user, device_messages) in message.into_iter() {
messages
.entry(user)
.or_insert_with(BTreeMap::new)
.extend(device_messages);
}
} }
let id = Uuid::new_v4(); let id = Uuid::new_v4();
@ -227,6 +252,12 @@ impl GroupSessionManager {
messages, messages,
}; };
trace!(
recipient_count = request.message_count(),
transaction_id = ?id,
"Created a to-device request carrying a room_key"
);
Ok((id, request, changed_sessions)) Ok((id, request, changed_sessions))
} }
@ -334,6 +365,24 @@ impl GroupSessionManager {
Ok((should_rotate, devices)) Ok((should_rotate, devices))
} }
pub async fn encrypt_request(
chunk: Vec<Device>,
content: Value,
outbound: OutboundGroupSession,
message_index: u32,
being_shared: Arc<DashMap<Uuid, OutboundGroupSession>>,
) -> OlmResult<Vec<Session>> {
let (id, request, used_sessions) =
Self::encrypt_session_for(content.clone(), chunk).await?;
if !request.messages.is_empty() {
outbound.add_request(id, request.into(), message_index);
being_shared.insert(id, outbound.clone());
}
Ok(used_sessions)
}
/// Get to-device requests to share a group session with users in a room. /// Get to-device requests to share a group session with users in a room.
/// ///
/// # Arguments /// # Arguments
@ -427,18 +476,23 @@ impl GroupSessionManager {
); );
} }
for device_map_chunk in devices.chunks(Self::MAX_TO_DEVICE_MESSAGES) { let tasks: Vec<_> = devices
let (id, request, used_sessions) = self .chunks(Self::MAX_TO_DEVICE_MESSAGES)
.encrypt_session_for(key_content.clone(), device_map_chunk) .map(|chunk| {
.await?; spawn(Self::encrypt_request(
chunk.to_vec(),
key_content.clone(),
outbound.clone(),
message_index,
self.outbound_sessions_being_shared.clone(),
))
})
.collect();
if !request.messages.is_empty() { for result in join_all(tasks).await {
outbound.add_request(id, request.into(), message_index); let used_sessions: OlmResult<Vec<Session>> = result.expect("Encryption task paniced");
self.outbound_sessions_being_shared
.insert(id, outbound.clone());
}
changes.sessions.extend(used_sessions); changes.sessions.extend(used_sessions?);
} }
let requests = outbound.pending_requests(); let requests = outbound.pending_requests();
@ -474,3 +528,84 @@ impl GroupSessionManager {
Ok(requests) Ok(requests)
} }
} }
#[cfg(test)]
mod test {
use std::convert::TryFrom;
use matrix_sdk_common::{
api::r0::keys::{claim_keys, get_keys},
identifiers::{room_id, user_id, DeviceIdBox, UserId},
uuid::Uuid,
};
use matrix_sdk_test::response_from_file;
use serde_json::Value;
use crate::{EncryptionSettings, OlmMachine};
fn alice_id() -> UserId {
user_id!("@alice:example.org")
}
fn alice_device_id() -> DeviceIdBox {
"JLAFKJWSCS".into()
}
fn keys_query_response() -> get_keys::Response {
let data = include_bytes!("../../benches/keys_query.json");
let data: Value = serde_json::from_slice(data).unwrap();
let data = response_from_file(&data);
get_keys::Response::try_from(data).expect("Can't parse the keys upload response")
}
fn keys_claim_response() -> claim_keys::Response {
let data = include_bytes!("../../benches/keys_claim.json");
let data: Value = serde_json::from_slice(data).unwrap();
let data = response_from_file(&data);
claim_keys::Response::try_from(data).expect("Can't parse the keys upload response")
}
async fn machine() -> OlmMachine {
let keys_query = keys_query_response();
let keys_claim = keys_claim_response();
let uuid = Uuid::new_v4();
let machine = OlmMachine::new(&alice_id(), &alice_device_id());
machine
.mark_request_as_sent(&uuid, &keys_query)
.await
.unwrap();
machine
.mark_request_as_sent(&uuid, &keys_claim)
.await
.unwrap();
machine
}
#[tokio::test]
async fn test_sharing() {
let machine = machine().await;
let room_id = room_id!("!test:localhost");
let keys_claim = keys_claim_response();
let users: Vec<_> = keys_claim.one_time_keys.keys().collect();
let requests = machine
.share_group_session(
&room_id,
users.clone().into_iter(),
EncryptionSettings::default(),
)
.await
.unwrap();
let event_count: usize = requests.iter().map(|r| r.message_count()).sum();
// The keys claim response has a couple of one-time keys with invalid
// signatures, thus only 148 sessions are actually created, we check
// that all 148 valid sessions get an room key.
assert_eq!(event_count, 148);
}
}

View File

@ -195,9 +195,9 @@ impl SessionManager {
// Add the list of devices that the user wishes to establish sessions // Add the list of devices that the user wishes to establish sessions
// right now. // right now.
for user_id in users { for user_id in users {
let user_devices = self.store.get_user_devices(user_id).await?; let user_devices = self.store.get_readonly_devices(user_id).await?;
for device in user_devices.devices() { for (device_id, device) in user_devices.into_iter() {
let sender_key = if let Some(k) = device.get_key(DeviceKeyAlgorithm::Curve25519) { let sender_key = if let Some(k) = device.get_key(DeviceKeyAlgorithm::Curve25519) {
k k
} else { } else {
@ -216,10 +216,7 @@ impl SessionManager {
missing missing
.entry(user_id.to_owned()) .entry(user_id.to_owned())
.or_insert_with(BTreeMap::new) .or_insert_with(BTreeMap::new)
.insert( .insert(device_id, DeviceKeyAlgorithm::SignedCurve25519);
device.device_id().into(),
DeviceKeyAlgorithm::SignedCurve25519,
);
} }
} }
} }

View File

@ -126,6 +126,15 @@ pub struct DeviceChanges {
pub deleted: Vec<ReadOnlyDevice>, pub deleted: Vec<ReadOnlyDevice>,
} }
impl DeviceChanges {
/// Merge the given `DeviceChanges` into this instance of `DeviceChanges`.
pub fn extend(&mut self, other: DeviceChanges) {
self.new.extend(other.new);
self.changed.extend(other.changed);
self.deleted.extend(other.deleted);
}
}
impl Store { impl Store {
pub fn new( pub fn new(
user_id: Arc<UserId>, user_id: Arc<UserId>,

View File

@ -16,11 +16,11 @@ use std::{
collections::{HashMap, HashSet}, collections::{HashMap, HashSet},
convert::TryFrom, convert::TryFrom,
path::{Path, PathBuf}, path::{Path, PathBuf},
sync::Arc, sync::{Arc, RwLock},
}; };
use dashmap::DashSet; use dashmap::DashSet;
use olm_rs::PicklingMode; use olm_rs::{account::IdentityKeys, PicklingMode};
pub use sled::Error; pub use sled::Error;
use sled::{ use sled::{
transaction::{ConflictableTransactionError, TransactionError}, transaction::{ConflictableTransactionError, TransactionError},
@ -95,9 +95,17 @@ impl EncodeKey for (&str, &str, &str) {
} }
} }
#[derive(Clone, Debug)]
pub struct AccountInfo {
user_id: Arc<UserId>,
device_id: Arc<DeviceIdBox>,
identity_keys: Arc<IdentityKeys>,
}
/// An in-memory only store that will forget all the E2EE key once it's dropped. /// An in-memory only store that will forget all the E2EE key once it's dropped.
#[derive(Clone)] #[derive(Clone)]
pub struct SledStore { pub struct SledStore {
account_info: Arc<RwLock<Option<AccountInfo>>>,
path: Option<PathBuf>, path: Option<PathBuf>,
inner: Db, inner: Db,
pickle_key: Arc<PickleKey>, pickle_key: Arc<PickleKey>,
@ -159,6 +167,10 @@ impl SledStore {
SledStore::open_helper(db, None, passphrase) SledStore::open_helper(db, None, passphrase)
} }
fn get_account_info(&self) -> Option<AccountInfo> {
self.account_info.read().unwrap().clone()
}
fn open_helper(db: Db, path: Option<PathBuf>, passphrase: Option<&str>) -> Result<Self> { fn open_helper(db: Db, path: Option<PathBuf>, passphrase: Option<&str>) -> Result<Self> {
let account = db.open_tree("account")?; let account = db.open_tree("account")?;
let private_identity = db.open_tree("private_identity")?; let private_identity = db.open_tree("private_identity")?;
@ -184,6 +196,7 @@ impl SledStore {
}; };
Ok(Self { Ok(Self {
account_info: RwLock::new(None).into(),
path, path,
inner: db, inner: db,
pickle_key: pickle_key.into(), pickle_key: pickle_key.into(),
@ -249,22 +262,18 @@ impl SledStore {
&self, &self,
room_id: &RoomId, room_id: &RoomId,
) -> Result<Option<OutboundGroupSession>> { ) -> Result<Option<OutboundGroupSession>> {
let account = self let account_info = self
.load_account() .get_account_info()
.await?
.ok_or(CryptoStoreError::AccountUnset)?; .ok_or(CryptoStoreError::AccountUnset)?;
let device_id: Arc<DeviceIdBox> = account.device_id().to_owned().into();
let identity_keys = account.identity_keys;
self.outbound_group_sessions self.outbound_group_sessions
.get(room_id.encode())? .get(room_id.encode())?
.map(|p| serde_json::from_slice(&p).map_err(CryptoStoreError::Serialization)) .map(|p| serde_json::from_slice(&p).map_err(CryptoStoreError::Serialization))
.transpose()? .transpose()?
.map(|p| { .map(|p| {
OutboundGroupSession::from_pickle( OutboundGroupSession::from_pickle(
device_id, account_info.device_id,
identity_keys, account_info.identity_keys,
p, p,
self.get_pickle_mode(), self.get_pickle_mode(),
) )
@ -430,16 +439,31 @@ impl CryptoStore for SledStore {
self.load_tracked_users().await?; self.load_tracked_users().await?;
Ok(Some(ReadOnlyAccount::from_pickle( let account = ReadOnlyAccount::from_pickle(pickle, self.get_pickle_mode())?;
pickle,
self.get_pickle_mode(), let account_info = AccountInfo {
)?)) user_id: account.user_id.clone(),
device_id: account.device_id.clone(),
identity_keys: account.identity_keys.clone(),
};
*self.account_info.write().unwrap() = Some(account_info);
Ok(Some(account))
} else { } else {
Ok(None) Ok(None)
} }
} }
async fn save_account(&self, account: ReadOnlyAccount) -> Result<()> { async fn save_account(&self, account: ReadOnlyAccount) -> Result<()> {
let account_info = AccountInfo {
user_id: account.user_id.clone(),
device_id: account.device_id.clone(),
identity_keys: account.identity_keys.clone(),
};
*self.account_info.write().unwrap() = Some(account_info);
let changes = Changes { let changes = Changes {
account: Some(account), account: Some(account),
..Default::default() ..Default::default()
@ -453,9 +477,8 @@ impl CryptoStore for SledStore {
} }
async fn get_sessions(&self, sender_key: &str) -> Result<Option<Arc<Mutex<Vec<Session>>>>> { async fn get_sessions(&self, sender_key: &str) -> Result<Option<Arc<Mutex<Vec<Session>>>>> {
let account = self let account_info = self
.load_account() .get_account_info()
.await?
.ok_or(CryptoStoreError::AccountUnset)?; .ok_or(CryptoStoreError::AccountUnset)?;
if self.session_cache.get(sender_key).is_none() { if self.session_cache.get(sender_key).is_none() {
@ -465,9 +488,9 @@ impl CryptoStore for SledStore {
.map(|s| serde_json::from_slice(&s?.1).map_err(CryptoStoreError::Serialization)) .map(|s| serde_json::from_slice(&s?.1).map_err(CryptoStoreError::Serialization))
.map(|p| { .map(|p| {
Session::from_pickle( Session::from_pickle(
account.user_id.clone(), account_info.user_id.clone(),
account.device_id.clone(), account_info.device_id.clone(),
account.identity_keys.clone(), account_info.identity_keys.clone(),
p?, p?,
self.get_pickle_mode(), self.get_pickle_mode(),
) )