Merge branch 'multithreaded-crypto'
commit
15d5b234ed
|
@ -10,15 +10,15 @@ edition = "2018"
|
||||||
crate-type = ["cdylib"]
|
crate-type = ["cdylib"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
url = "2.1.1"
|
url = "2.2.1"
|
||||||
wasm-bindgen = { version = "0.2.62", features = ["serde-serialize"] }
|
wasm-bindgen = { version = "0.2.72", features = ["serde-serialize"] }
|
||||||
wasm-bindgen-futures = "0.4.12"
|
wasm-bindgen-futures = "0.4.22"
|
||||||
console_error_panic_hook = "*"
|
console_error_panic_hook = "0.1.6"
|
||||||
web-sys = { version = "0.3.39", features = ["console"] }
|
web-sys = { version = "0.3.49", features = ["console"] }
|
||||||
|
|
||||||
[dependencies.matrix-sdk]
|
[dependencies.matrix-sdk]
|
||||||
path = "../.."
|
path = "../.."
|
||||||
default-features = false
|
default-features = false
|
||||||
features = ["native-tls"]
|
features = ["native-tls", "encryption"]
|
||||||
|
|
||||||
[workspace]
|
[workspace]
|
||||||
|
|
|
@ -7,6 +7,4 @@ You can build the example locally with:
|
||||||
|
|
||||||
and then visiting http://localhost:8080 in a browser should run the example!
|
and then visiting http://localhost:8080 in a browser should run the example!
|
||||||
|
|
||||||
Note: Encryption isn't supported yet
|
This example is loosely based off of [this example](https://github.com/seanmonstar/reqwest/tree/master/examples/wasm_github_fetch), an example usage of `fetch` from `wasm-bindgen`.
|
||||||
|
|
||||||
This example is loosely based off of [this example](https://github.com/seanmonstar/reqwest/tree/master/examples/wasm_github_fetch), an example usage of `fetch` from `wasm-bindgen`.
|
|
||||||
|
|
|
@ -5,8 +5,8 @@
|
||||||
},
|
},
|
||||||
"devDependencies": {
|
"devDependencies": {
|
||||||
"@wasm-tool/wasm-pack-plugin": "1.0.1",
|
"@wasm-tool/wasm-pack-plugin": "1.0.1",
|
||||||
"text-encoding": "^0.7.0",
|
|
||||||
"html-webpack-plugin": "^3.2.0",
|
"html-webpack-plugin": "^3.2.0",
|
||||||
|
"text-encoding": "^0.7.0",
|
||||||
"webpack": "^4.29.4",
|
"webpack": "^4.29.4",
|
||||||
"webpack-cli": "^3.1.1",
|
"webpack-cli": "^3.1.1",
|
||||||
"webpack-dev-server": "^3.1.0"
|
"webpack-dev-server": "^3.1.0"
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
use matrix_sdk::{
|
use matrix_sdk::{
|
||||||
deserialized_responses::SyncResponse,
|
deserialized_responses::SyncResponse,
|
||||||
events::{
|
events::{
|
||||||
room::message::{MessageEventContent, TextMessageEventContent},
|
room::message::{MessageEventContent, MessageType, TextMessageEventContent},
|
||||||
AnyMessageEventContent, AnySyncMessageEvent, AnySyncRoomEvent, SyncMessageEvent,
|
AnyMessageEventContent, AnySyncMessageEvent, AnySyncRoomEvent, SyncMessageEvent,
|
||||||
},
|
},
|
||||||
identifiers::RoomId,
|
identifiers::RoomId,
|
||||||
|
@ -17,35 +17,49 @@ impl WasmBot {
|
||||||
async fn on_room_message(
|
async fn on_room_message(
|
||||||
&self,
|
&self,
|
||||||
room_id: &RoomId,
|
room_id: &RoomId,
|
||||||
event: SyncMessageEvent<MessageEventContent>,
|
event: &SyncMessageEvent<MessageEventContent>,
|
||||||
) {
|
) {
|
||||||
let msg_body = if let SyncMessageEvent {
|
let msg_body = if let SyncMessageEvent {
|
||||||
content: MessageEventContent::Text(TextMessageEventContent { body: msg_body, .. }),
|
content:
|
||||||
|
MessageEventContent {
|
||||||
|
msgtype: MessageType::Text(TextMessageEventContent { body: msg_body, .. }),
|
||||||
|
..
|
||||||
|
},
|
||||||
..
|
..
|
||||||
} = event
|
} = event
|
||||||
{
|
{
|
||||||
msg_body.clone()
|
msg_body
|
||||||
} else {
|
} else {
|
||||||
return;
|
return;
|
||||||
};
|
};
|
||||||
|
|
||||||
console::log_1(&format!("Received message event {:?}", &msg_body).into());
|
console::log_1(&format!("Received message event {:?}", &msg_body).into());
|
||||||
|
|
||||||
if msg_body.starts_with("!party") {
|
if msg_body.contains("!party") {
|
||||||
let content = AnyMessageEventContent::RoomMessage(MessageEventContent::Text(
|
let content = AnyMessageEventContent::RoomMessage(MessageEventContent::text_plain(
|
||||||
TextMessageEventContent::plain("🎉🎊🥳 let's PARTY with wasm!! 🥳🎊🎉".to_string()),
|
"🎉🎊🥳 let's PARTY!! 🥳🎊🎉",
|
||||||
));
|
));
|
||||||
|
|
||||||
self.0.room_send(&room_id, content, None).await.unwrap();
|
println!("sending");
|
||||||
|
|
||||||
|
self.0
|
||||||
|
// send our message to the room we found the "!party" command in
|
||||||
|
// the last parameter is an optional Uuid which we don't care about.
|
||||||
|
.room_send(room_id, content, None)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
println!("message sent");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn on_sync_response(&self, response: SyncResponse) -> LoopCtrl {
|
async fn on_sync_response(&self, response: SyncResponse) -> LoopCtrl {
|
||||||
console::log_1(&"Synced".to_string().into());
|
console::log_1(&"Synced".to_string().into());
|
||||||
|
|
||||||
for (room_id, room) in response.rooms.join {
|
for (room_id, room) in response.rooms.join {
|
||||||
for event in room.timeline.events {
|
for event in room.timeline.events {
|
||||||
if let AnySyncRoomEvent::Message(AnySyncMessageEvent::RoomMessage(ev)) = event {
|
if let AnySyncRoomEvent::Message(AnySyncMessageEvent::RoomMessage(ev)) = event {
|
||||||
self.on_room_message(&room_id, ev).await
|
self.on_room_message(&room_id, &ev).await
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -34,5 +34,7 @@ default-features = false
|
||||||
features = ["sync"]
|
features = ["sync"]
|
||||||
|
|
||||||
[target.'cfg(target_arch = "wasm32")'.dependencies]
|
[target.'cfg(target_arch = "wasm32")'.dependencies]
|
||||||
|
futures = "0.3.12"
|
||||||
futures-locks = { version = "0.6.0", default-features = false }
|
futures-locks = { version = "0.6.0", default-features = false }
|
||||||
|
wasm-bindgen-futures = "0.4"
|
||||||
uuid = { version = "0.8.2", default-features = false, features = ["v4", "wasm-bindgen"] }
|
uuid = { version = "0.8.2", default-features = false, features = ["v4", "wasm-bindgen"] }
|
||||||
|
|
|
@ -0,0 +1,42 @@
|
||||||
|
//! Abstraction over an executor so we can spawn tasks under WASM the same way
|
||||||
|
//! we do usually.
|
||||||
|
#[cfg(target_arch = "wasm32")]
|
||||||
|
use std::{
|
||||||
|
pin::Pin,
|
||||||
|
task::{Context, Poll},
|
||||||
|
};
|
||||||
|
|
||||||
|
#[cfg(not(target_arch = "wasm32"))]
|
||||||
|
pub use tokio::spawn;
|
||||||
|
|
||||||
|
#[cfg(target_arch = "wasm32")]
|
||||||
|
use wasm_bindgen_futures::spawn_local;
|
||||||
|
|
||||||
|
#[cfg(target_arch = "wasm32")]
|
||||||
|
use futures::{future::RemoteHandle, Future, FutureExt};
|
||||||
|
|
||||||
|
#[cfg(target_arch = "wasm32")]
|
||||||
|
pub fn spawn<F, T>(future: F) -> JoinHandle<T>
|
||||||
|
where
|
||||||
|
F: Future<Output = T> + 'static,
|
||||||
|
{
|
||||||
|
let fut = future.unit_error();
|
||||||
|
let (fut, handle) = fut.remote_handle();
|
||||||
|
spawn_local(fut);
|
||||||
|
|
||||||
|
JoinHandle { handle }
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(target_arch = "wasm32")]
|
||||||
|
pub struct JoinHandle<T> {
|
||||||
|
handle: RemoteHandle<Result<T, ()>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(target_arch = "wasm32")]
|
||||||
|
impl<T: 'static> Future for JoinHandle<T> {
|
||||||
|
type Output = Result<T, ()>;
|
||||||
|
|
||||||
|
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||||
|
Pin::new(&mut self.handle).poll(cx)
|
||||||
|
}
|
||||||
|
}
|
|
@ -14,6 +14,7 @@ pub use ruma::{
|
||||||
pub use uuid;
|
pub use uuid;
|
||||||
|
|
||||||
pub mod deserialized_responses;
|
pub mod deserialized_responses;
|
||||||
|
pub mod executor;
|
||||||
pub mod locks;
|
pub mod locks;
|
||||||
|
|
||||||
/// Super trait that is used for our store traits, this trait will differ if
|
/// Super trait that is used for our store traits, this trait will differ if
|
||||||
|
|
|
@ -29,6 +29,7 @@ serde_json = "1.0.61"
|
||||||
zeroize = { version = "1.2.0", features = ["zeroize_derive"] }
|
zeroize = { version = "1.2.0", features = ["zeroize_derive"] }
|
||||||
|
|
||||||
# Misc dependencies
|
# Misc dependencies
|
||||||
|
futures = "0.3.12"
|
||||||
sled = { version = "0.34.6", optional = true }
|
sled = { version = "0.34.6", optional = true }
|
||||||
thiserror = "1.0.23"
|
thiserror = "1.0.23"
|
||||||
tracing = "0.1.22"
|
tracing = "0.1.22"
|
||||||
|
@ -44,14 +45,13 @@ byteorder = "1.4.2"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
tokio = { version = "1.1.0", default-features = false, features = ["rt-multi-thread", "macros"] }
|
tokio = { version = "1.1.0", default-features = false, features = ["rt-multi-thread", "macros"] }
|
||||||
futures = "0.3.12"
|
|
||||||
proptest = "0.10.1"
|
proptest = "0.10.1"
|
||||||
serde_json = "1.0.61"
|
serde_json = "1.0.61"
|
||||||
tempfile = "3.2.0"
|
tempfile = "3.2.0"
|
||||||
http = "0.2.3"
|
http = "0.2.3"
|
||||||
matrix-sdk-test = { version = "0.2.0", path = "../matrix_sdk_test" }
|
matrix-sdk-test = { version = "0.2.0", path = "../matrix_sdk_test" }
|
||||||
indoc = "1.0.3"
|
indoc = "1.0.3"
|
||||||
criterion = { version = "0.3.4", features = ["async", "async_futures", "html_reports"] }
|
criterion = { version = "0.3.4", features = ["async", "async_tokio", "html_reports"] }
|
||||||
|
|
||||||
[target.'cfg(target_os = "linux")'.dev-dependencies]
|
[target.'cfg(target_os = "linux")'.dev-dependencies]
|
||||||
pprof = { version = "0.4.2", features = ["flamegraph"] }
|
pprof = { version = "0.4.2", features = ["flamegraph"] }
|
||||||
|
|
|
@ -1,11 +1,10 @@
|
||||||
#[cfg(target_os = "linux")]
|
#[cfg(target_os = "linux")]
|
||||||
mod perf;
|
mod perf;
|
||||||
|
|
||||||
use std::convert::TryFrom;
|
use std::{convert::TryFrom, sync::Arc};
|
||||||
|
|
||||||
use criterion::{async_executor::FuturesExecutor, *};
|
use criterion::*;
|
||||||
|
|
||||||
use futures::executor::block_on;
|
|
||||||
use matrix_sdk_common::{
|
use matrix_sdk_common::{
|
||||||
api::r0::{
|
api::r0::{
|
||||||
keys::{claim_keys, get_keys},
|
keys::{claim_keys, get_keys},
|
||||||
|
@ -17,6 +16,7 @@ use matrix_sdk_common::{
|
||||||
use matrix_sdk_crypto::{EncryptionSettings, OlmMachine};
|
use matrix_sdk_crypto::{EncryptionSettings, OlmMachine};
|
||||||
use matrix_sdk_test::response_from_file;
|
use matrix_sdk_test::response_from_file;
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
use tokio::runtime::Builder;
|
||||||
|
|
||||||
fn alice_id() -> UserId {
|
fn alice_id() -> UserId {
|
||||||
user_id!("@alice:example.org")
|
user_id!("@alice:example.org")
|
||||||
|
@ -40,7 +40,17 @@ fn keys_claim_response() -> claim_keys::Response {
|
||||||
claim_keys::Response::try_from(data).expect("Can't parse the keys upload response")
|
claim_keys::Response::try_from(data).expect("Can't parse the keys upload response")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn huge_keys_query_resopnse() -> get_keys::Response {
|
||||||
|
let data = include_bytes!("./keys_query_2000_members.json");
|
||||||
|
let data: Value = serde_json::from_slice(data).unwrap();
|
||||||
|
let data = response_from_file(&data);
|
||||||
|
get_keys::Response::try_from(data).expect("Can't parse the keys query response")
|
||||||
|
}
|
||||||
|
|
||||||
pub fn keys_query(c: &mut Criterion) {
|
pub fn keys_query(c: &mut Criterion) {
|
||||||
|
let runtime = Builder::new_multi_thread()
|
||||||
|
.build()
|
||||||
|
.expect("Can't create runtime");
|
||||||
let machine = OlmMachine::new(&alice_id(), &alice_device_id());
|
let machine = OlmMachine::new(&alice_id(), &alice_device_id());
|
||||||
let response = keys_query_response();
|
let response = keys_query_response();
|
||||||
let uuid = Uuid::new_v4();
|
let uuid = Uuid::new_v4();
|
||||||
|
@ -62,25 +72,26 @@ pub fn keys_query(c: &mut Criterion) {
|
||||||
BenchmarkId::new("memory store", &name),
|
BenchmarkId::new("memory store", &name),
|
||||||
&response,
|
&response,
|
||||||
|b, response| {
|
|b, response| {
|
||||||
b.to_async(FuturesExecutor)
|
b.to_async(&runtime)
|
||||||
.iter(|| async { machine.mark_request_as_sent(&uuid, response).await.unwrap() })
|
.iter(|| async { machine.mark_request_as_sent(&uuid, response).await.unwrap() })
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
let dir = tempfile::tempdir().unwrap();
|
let dir = tempfile::tempdir().unwrap();
|
||||||
let machine = block_on(OlmMachine::new_with_default_store(
|
let machine = runtime
|
||||||
&alice_id(),
|
.block_on(OlmMachine::new_with_default_store(
|
||||||
&alice_device_id(),
|
&alice_id(),
|
||||||
dir.path(),
|
&alice_device_id(),
|
||||||
None,
|
dir.path(),
|
||||||
))
|
None,
|
||||||
.unwrap();
|
))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
group.bench_with_input(
|
group.bench_with_input(
|
||||||
BenchmarkId::new("sled store", &name),
|
BenchmarkId::new("sled store", &name),
|
||||||
&response,
|
&response,
|
||||||
|b, response| {
|
|b, response| {
|
||||||
b.to_async(FuturesExecutor)
|
b.to_async(&runtime)
|
||||||
.iter(|| async { machine.mark_request_as_sent(&uuid, response).await.unwrap() })
|
.iter(|| async { machine.mark_request_as_sent(&uuid, response).await.unwrap() })
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
@ -89,6 +100,12 @@ pub fn keys_query(c: &mut Criterion) {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn keys_claiming(c: &mut Criterion) {
|
pub fn keys_claiming(c: &mut Criterion) {
|
||||||
|
let runtime = Arc::new(
|
||||||
|
Builder::new_multi_thread()
|
||||||
|
.build()
|
||||||
|
.expect("Can't create runtime"),
|
||||||
|
);
|
||||||
|
|
||||||
let keys_query_response = keys_query_response();
|
let keys_query_response = keys_query_response();
|
||||||
let uuid = Uuid::new_v4();
|
let uuid = Uuid::new_v4();
|
||||||
|
|
||||||
|
@ -111,10 +128,16 @@ pub fn keys_claiming(c: &mut Criterion) {
|
||||||
b.iter_batched(
|
b.iter_batched(
|
||||||
|| {
|
|| {
|
||||||
let machine = OlmMachine::new(&alice_id(), &alice_device_id());
|
let machine = OlmMachine::new(&alice_id(), &alice_device_id());
|
||||||
block_on(machine.mark_request_as_sent(&uuid, &keys_query_response)).unwrap();
|
runtime
|
||||||
machine
|
.block_on(machine.mark_request_as_sent(&uuid, &keys_query_response))
|
||||||
|
.unwrap();
|
||||||
|
(machine, runtime.clone())
|
||||||
|
},
|
||||||
|
move |(machine, runtime)| {
|
||||||
|
runtime
|
||||||
|
.block_on(machine.mark_request_as_sent(&uuid, response))
|
||||||
|
.unwrap()
|
||||||
},
|
},
|
||||||
move |machine| block_on(machine.mark_request_as_sent(&uuid, response)).unwrap(),
|
|
||||||
BatchSize::SmallInput,
|
BatchSize::SmallInput,
|
||||||
)
|
)
|
||||||
},
|
},
|
||||||
|
@ -127,17 +150,24 @@ pub fn keys_claiming(c: &mut Criterion) {
|
||||||
b.iter_batched(
|
b.iter_batched(
|
||||||
|| {
|
|| {
|
||||||
let dir = tempfile::tempdir().unwrap();
|
let dir = tempfile::tempdir().unwrap();
|
||||||
let machine = block_on(OlmMachine::new_with_default_store(
|
let machine = runtime
|
||||||
&alice_id(),
|
.block_on(OlmMachine::new_with_default_store(
|
||||||
&alice_device_id(),
|
&alice_id(),
|
||||||
dir.path(),
|
&alice_device_id(),
|
||||||
None,
|
dir.path(),
|
||||||
))
|
None,
|
||||||
.unwrap();
|
))
|
||||||
block_on(machine.mark_request_as_sent(&uuid, &keys_query_response)).unwrap();
|
.unwrap();
|
||||||
machine
|
runtime
|
||||||
|
.block_on(machine.mark_request_as_sent(&uuid, &keys_query_response))
|
||||||
|
.unwrap();
|
||||||
|
(machine, runtime.clone())
|
||||||
|
},
|
||||||
|
move |(machine, runtime)| {
|
||||||
|
runtime
|
||||||
|
.block_on(machine.mark_request_as_sent(&uuid, response))
|
||||||
|
.unwrap()
|
||||||
},
|
},
|
||||||
move |machine| block_on(machine.mark_request_as_sent(&uuid, response)).unwrap(),
|
|
||||||
BatchSize::SmallInput,
|
BatchSize::SmallInput,
|
||||||
)
|
)
|
||||||
},
|
},
|
||||||
|
@ -147,6 +177,10 @@ pub fn keys_claiming(c: &mut Criterion) {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn room_key_sharing(c: &mut Criterion) {
|
pub fn room_key_sharing(c: &mut Criterion) {
|
||||||
|
let runtime = Builder::new_multi_thread()
|
||||||
|
.build()
|
||||||
|
.expect("Can't create runtime");
|
||||||
|
|
||||||
let keys_query_response = keys_query_response();
|
let keys_query_response = keys_query_response();
|
||||||
let uuid = Uuid::new_v4();
|
let uuid = Uuid::new_v4();
|
||||||
let response = keys_claim_response();
|
let response = keys_claim_response();
|
||||||
|
@ -161,15 +195,19 @@ pub fn room_key_sharing(c: &mut Criterion) {
|
||||||
.fold(0, |acc, d| acc + d.len());
|
.fold(0, |acc, d| acc + d.len());
|
||||||
|
|
||||||
let machine = OlmMachine::new(&alice_id(), &alice_device_id());
|
let machine = OlmMachine::new(&alice_id(), &alice_device_id());
|
||||||
block_on(machine.mark_request_as_sent(&uuid, &keys_query_response)).unwrap();
|
runtime
|
||||||
block_on(machine.mark_request_as_sent(&uuid, &response)).unwrap();
|
.block_on(machine.mark_request_as_sent(&uuid, &keys_query_response))
|
||||||
|
.unwrap();
|
||||||
|
runtime
|
||||||
|
.block_on(machine.mark_request_as_sent(&uuid, &response))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let mut group = c.benchmark_group("Room key sharing");
|
let mut group = c.benchmark_group("Room key sharing");
|
||||||
group.throughput(Throughput::Elements(count as u64));
|
group.throughput(Throughput::Elements(count as u64));
|
||||||
let name = format!("{} devices", count);
|
let name = format!("{} devices", count);
|
||||||
|
|
||||||
group.bench_function(BenchmarkId::new("memory store", &name), |b| {
|
group.bench_function(BenchmarkId::new("memory store", &name), |b| {
|
||||||
b.to_async(FuturesExecutor).iter(|| async {
|
b.to_async(&runtime).iter(|| async {
|
||||||
let requests = machine
|
let requests = machine
|
||||||
.share_group_session(&room_id, users.iter(), EncryptionSettings::default())
|
.share_group_session(&room_id, users.iter(), EncryptionSettings::default())
|
||||||
.await
|
.await
|
||||||
|
@ -189,18 +227,23 @@ pub fn room_key_sharing(c: &mut Criterion) {
|
||||||
});
|
});
|
||||||
|
|
||||||
let dir = tempfile::tempdir().unwrap();
|
let dir = tempfile::tempdir().unwrap();
|
||||||
let machine = block_on(OlmMachine::new_with_default_store(
|
let machine = runtime
|
||||||
&alice_id(),
|
.block_on(OlmMachine::new_with_default_store(
|
||||||
&alice_device_id(),
|
&alice_id(),
|
||||||
dir.path(),
|
&alice_device_id(),
|
||||||
None,
|
dir.path(),
|
||||||
))
|
None,
|
||||||
.unwrap();
|
))
|
||||||
block_on(machine.mark_request_as_sent(&uuid, &keys_query_response)).unwrap();
|
.unwrap();
|
||||||
block_on(machine.mark_request_as_sent(&uuid, &response)).unwrap();
|
runtime
|
||||||
|
.block_on(machine.mark_request_as_sent(&uuid, &keys_query_response))
|
||||||
|
.unwrap();
|
||||||
|
runtime
|
||||||
|
.block_on(machine.mark_request_as_sent(&uuid, &response))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
group.bench_function(BenchmarkId::new("sled store", &name), |b| {
|
group.bench_function(BenchmarkId::new("sled store", &name), |b| {
|
||||||
b.to_async(FuturesExecutor).iter(|| async {
|
b.to_async(&runtime).iter(|| async {
|
||||||
let requests = machine
|
let requests = machine
|
||||||
.share_group_session(&room_id, users.iter(), EncryptionSettings::default())
|
.share_group_session(&room_id, users.iter(), EncryptionSettings::default())
|
||||||
.await
|
.await
|
||||||
|
@ -222,6 +265,58 @@ pub fn room_key_sharing(c: &mut Criterion) {
|
||||||
group.finish()
|
group.finish()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn devices_missing_sessions_collecting(c: &mut Criterion) {
|
||||||
|
let runtime = Builder::new_multi_thread()
|
||||||
|
.build()
|
||||||
|
.expect("Can't create runtime");
|
||||||
|
|
||||||
|
let machine = OlmMachine::new(&alice_id(), &alice_device_id());
|
||||||
|
let response = huge_keys_query_resopnse();
|
||||||
|
let uuid = Uuid::new_v4();
|
||||||
|
let users: Vec<UserId> = response.device_keys.keys().cloned().collect();
|
||||||
|
|
||||||
|
let count = response
|
||||||
|
.device_keys
|
||||||
|
.values()
|
||||||
|
.fold(0, |acc, d| acc + d.len());
|
||||||
|
|
||||||
|
let mut group = c.benchmark_group("Devices missing sessions collecting");
|
||||||
|
group.throughput(Throughput::Elements(count as u64));
|
||||||
|
|
||||||
|
let name = format!("{} devices", count);
|
||||||
|
|
||||||
|
runtime
|
||||||
|
.block_on(machine.mark_request_as_sent(&uuid, &response))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
group.bench_function(BenchmarkId::new("memory store", &name), |b| {
|
||||||
|
b.to_async(&runtime).iter_with_large_drop(|| async {
|
||||||
|
machine.get_missing_sessions(users.iter()).await.unwrap()
|
||||||
|
})
|
||||||
|
});
|
||||||
|
|
||||||
|
let dir = tempfile::tempdir().unwrap();
|
||||||
|
let machine = runtime
|
||||||
|
.block_on(OlmMachine::new_with_default_store(
|
||||||
|
&alice_id(),
|
||||||
|
&alice_device_id(),
|
||||||
|
dir.path(),
|
||||||
|
None,
|
||||||
|
))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
runtime
|
||||||
|
.block_on(machine.mark_request_as_sent(&uuid, &response))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
group.bench_function(BenchmarkId::new("sled store", &name), |b| {
|
||||||
|
b.to_async(&runtime)
|
||||||
|
.iter(|| async { machine.get_missing_sessions(users.iter()).await.unwrap() })
|
||||||
|
});
|
||||||
|
|
||||||
|
group.finish()
|
||||||
|
}
|
||||||
|
|
||||||
fn criterion() -> Criterion {
|
fn criterion() -> Criterion {
|
||||||
#[cfg(target_os = "linux")]
|
#[cfg(target_os = "linux")]
|
||||||
let criterion = Criterion::default().with_profiler(perf::FlamegraphProfiler::new(100));
|
let criterion = Criterion::default().with_profiler(perf::FlamegraphProfiler::new(100));
|
||||||
|
@ -234,6 +329,6 @@ fn criterion() -> Criterion {
|
||||||
criterion_group! {
|
criterion_group! {
|
||||||
name = benches;
|
name = benches;
|
||||||
config = criterion();
|
config = criterion();
|
||||||
targets = keys_query, keys_claiming, room_key_sharing
|
targets = keys_query, keys_claiming, room_key_sharing, devices_missing_sessions_collecting,
|
||||||
}
|
}
|
||||||
criterion_main!(benches);
|
criterion_main!(benches);
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -12,17 +12,19 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
use futures::future::join_all;
|
||||||
use std::{
|
use std::{
|
||||||
collections::{BTreeMap, HashSet},
|
collections::{BTreeMap, HashSet},
|
||||||
convert::TryFrom,
|
convert::TryFrom,
|
||||||
sync::Arc,
|
sync::Arc,
|
||||||
};
|
};
|
||||||
use tracing::{info, trace, warn};
|
use tracing::{trace, warn};
|
||||||
|
|
||||||
use matrix_sdk_common::{
|
use matrix_sdk_common::{
|
||||||
api::r0::keys::get_keys::Response as KeysQueryResponse,
|
api::r0::keys::get_keys::Response as KeysQueryResponse,
|
||||||
encryption::DeviceKeys,
|
encryption::DeviceKeys,
|
||||||
identifiers::{DeviceId, DeviceIdBox, UserId},
|
executor::spawn,
|
||||||
|
identifiers::{DeviceIdBox, UserId},
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
|
@ -35,6 +37,12 @@ use crate::{
|
||||||
store::{Changes, DeviceChanges, IdentityChanges, Result as StoreResult, Store},
|
store::{Changes, DeviceChanges, IdentityChanges, Result as StoreResult, Store},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
enum DeviceChange {
|
||||||
|
New(ReadOnlyDevice),
|
||||||
|
Updated(ReadOnlyDevice),
|
||||||
|
None,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub(crate) struct IdentityManager {
|
pub(crate) struct IdentityManager {
|
||||||
user_id: Arc<UserId>,
|
user_id: Arc<UserId>,
|
||||||
|
@ -57,10 +65,6 @@ impl IdentityManager {
|
||||||
&self.user_id
|
&self.user_id
|
||||||
}
|
}
|
||||||
|
|
||||||
fn device_id(&self) -> &DeviceId {
|
|
||||||
&self.device_id
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Receive a successful keys query response.
|
/// Receive a successful keys query response.
|
||||||
///
|
///
|
||||||
/// Returns a list of devices newly discovered devices and devices that
|
/// Returns a list of devices newly discovered devices and devices that
|
||||||
|
@ -74,17 +78,8 @@ impl IdentityManager {
|
||||||
&self,
|
&self,
|
||||||
response: &KeysQueryResponse,
|
response: &KeysQueryResponse,
|
||||||
) -> OlmResult<(DeviceChanges, IdentityChanges)> {
|
) -> OlmResult<(DeviceChanges, IdentityChanges)> {
|
||||||
// TODO create a enum that tells us how the device/identity changed,
|
|
||||||
// e.g. new/deleted/display name change.
|
|
||||||
//
|
|
||||||
// TODO create a struct that will hold the device/identity and the
|
|
||||||
// change enum and return the struct.
|
|
||||||
//
|
|
||||||
// TODO once outbound group sessions hold on to the set of users that
|
|
||||||
// received the session, invalidate the session if a user device
|
|
||||||
// got added/deleted.
|
|
||||||
let changed_devices = self
|
let changed_devices = self
|
||||||
.handle_devices_from_key_query(&response.device_keys)
|
.handle_devices_from_key_query(response.device_keys.clone())
|
||||||
.await?;
|
.await?;
|
||||||
let changed_identities = self.handle_cross_singing_keys(response).await?;
|
let changed_identities = self.handle_cross_singing_keys(response).await?;
|
||||||
|
|
||||||
|
@ -94,11 +89,113 @@ impl IdentityManager {
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// TODO turn this into a single transaction.
|
||||||
self.store.save_changes(changes).await?;
|
self.store.save_changes(changes).await?;
|
||||||
|
let updated_users: Vec<&UserId> = response.device_keys.keys().collect();
|
||||||
|
|
||||||
|
for user_id in updated_users {
|
||||||
|
self.store.update_tracked_user(user_id, false).await?;
|
||||||
|
}
|
||||||
|
|
||||||
Ok((changed_devices, changed_identities))
|
Ok((changed_devices, changed_identities))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn update_or_create_device(
|
||||||
|
store: Store,
|
||||||
|
device_keys: DeviceKeys,
|
||||||
|
) -> StoreResult<DeviceChange> {
|
||||||
|
let old_device = store
|
||||||
|
.get_readonly_device(&device_keys.user_id, &device_keys.device_id)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if let Some(mut device) = old_device {
|
||||||
|
if let Err(e) = device.update_device(&device_keys) {
|
||||||
|
warn!(
|
||||||
|
"Failed to update the device keys for {} {}: {:?}",
|
||||||
|
device.user_id(),
|
||||||
|
device.device_id(),
|
||||||
|
e
|
||||||
|
);
|
||||||
|
Ok(DeviceChange::None)
|
||||||
|
} else {
|
||||||
|
Ok(DeviceChange::Updated(device))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
match ReadOnlyDevice::try_from(&device_keys) {
|
||||||
|
Ok(d) => {
|
||||||
|
trace!("Adding a new device to the device store {:?}", d);
|
||||||
|
Ok(DeviceChange::New(d))
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
warn!(
|
||||||
|
"Failed to create a new device for {} {}: {:?}",
|
||||||
|
device_keys.user_id, device_keys.device_id, e
|
||||||
|
);
|
||||||
|
Ok(DeviceChange::None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn update_user_devices(
|
||||||
|
store: Store,
|
||||||
|
own_user_id: Arc<UserId>,
|
||||||
|
own_device_id: Arc<DeviceIdBox>,
|
||||||
|
user_id: UserId,
|
||||||
|
device_map: BTreeMap<DeviceIdBox, DeviceKeys>,
|
||||||
|
) -> StoreResult<DeviceChanges> {
|
||||||
|
let mut changes = DeviceChanges::default();
|
||||||
|
|
||||||
|
let current_devices: HashSet<DeviceIdBox> = device_map.keys().cloned().collect();
|
||||||
|
|
||||||
|
let tasks = device_map
|
||||||
|
.into_iter()
|
||||||
|
.filter_map(|(device_id, device_keys)| {
|
||||||
|
// We don't need our own device in the device store.
|
||||||
|
if user_id == *own_user_id && device_id == *own_device_id {
|
||||||
|
None
|
||||||
|
} else if user_id != device_keys.user_id || device_id != device_keys.device_id {
|
||||||
|
warn!(
|
||||||
|
"Mismatch in device keys payload of device {}|{} from user {}|{}",
|
||||||
|
device_id, device_keys.device_id, user_id, device_keys.user_id
|
||||||
|
);
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(spawn(Self::update_or_create_device(
|
||||||
|
store.clone(),
|
||||||
|
device_keys,
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let results = join_all(tasks).await;
|
||||||
|
|
||||||
|
for device in results {
|
||||||
|
let device = device.expect("Creating or updating a device panicked")?;
|
||||||
|
|
||||||
|
match device {
|
||||||
|
DeviceChange::New(d) => changes.new.push(d),
|
||||||
|
DeviceChange::Updated(d) => changes.changed.push(d),
|
||||||
|
DeviceChange::None => (),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let current_devices: HashSet<&DeviceIdBox> = current_devices.iter().collect();
|
||||||
|
let stored_devices = store.get_readonly_devices(&user_id).await?;
|
||||||
|
let stored_devices_set: HashSet<&DeviceIdBox> = stored_devices.keys().collect();
|
||||||
|
|
||||||
|
let deleted_devices_set = stored_devices_set.difference(¤t_devices);
|
||||||
|
|
||||||
|
for device_id in deleted_devices_set {
|
||||||
|
if let Some(device) = stored_devices.get(*device_id) {
|
||||||
|
device.mark_as_deleted();
|
||||||
|
changes.deleted.push(device.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(changes)
|
||||||
|
}
|
||||||
|
|
||||||
/// Handle the device keys part of a key query response.
|
/// Handle the device keys part of a key query response.
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
|
@ -110,69 +207,28 @@ impl IdentityManager {
|
||||||
/// they are new, one of their properties has changed or they got deleted.
|
/// they are new, one of their properties has changed or they got deleted.
|
||||||
async fn handle_devices_from_key_query(
|
async fn handle_devices_from_key_query(
|
||||||
&self,
|
&self,
|
||||||
device_keys_map: &BTreeMap<UserId, BTreeMap<DeviceIdBox, DeviceKeys>>,
|
device_keys_map: BTreeMap<UserId, BTreeMap<DeviceIdBox, DeviceKeys>>,
|
||||||
) -> StoreResult<DeviceChanges> {
|
) -> StoreResult<DeviceChanges> {
|
||||||
let mut changes = DeviceChanges::default();
|
let mut changes = DeviceChanges::default();
|
||||||
|
|
||||||
for (user_id, device_map) in device_keys_map {
|
let tasks = device_keys_map
|
||||||
// TODO move this out into the handle keys query response method
|
.into_iter()
|
||||||
// since we might fail handle the new device at any point here or
|
.map(|(user_id, device_keys_map)| {
|
||||||
// when updating the user identities.
|
spawn(Self::update_user_devices(
|
||||||
self.store.update_tracked_user(user_id, false).await?;
|
self.store.clone(),
|
||||||
|
self.user_id.clone(),
|
||||||
|
self.device_id.clone(),
|
||||||
|
user_id,
|
||||||
|
device_keys_map,
|
||||||
|
))
|
||||||
|
});
|
||||||
|
|
||||||
for (device_id, device_keys) in device_map.iter() {
|
let results = join_all(tasks).await;
|
||||||
// We don't need our own device in the device store.
|
|
||||||
if user_id == self.user_id() && &**device_id == self.device_id() {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if user_id != &device_keys.user_id || device_id != &device_keys.device_id {
|
for result in results {
|
||||||
warn!(
|
let change_fragment = result.expect("Panic while updating user devices")?;
|
||||||
"Mismatch in device keys payload of device {}|{} from user {}|{}",
|
|
||||||
device_id, device_keys.device_id, user_id, device_keys.user_id
|
|
||||||
);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
let device = self.store.get_readonly_device(&user_id, device_id).await?;
|
changes.extend(change_fragment);
|
||||||
|
|
||||||
if let Some(mut device) = device {
|
|
||||||
if let Err(e) = device.update_device(device_keys) {
|
|
||||||
warn!(
|
|
||||||
"Failed to update the device keys for {} {}: {:?}",
|
|
||||||
user_id, device_id, e
|
|
||||||
);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
changes.changed.push(device);
|
|
||||||
} else {
|
|
||||||
let device = match ReadOnlyDevice::try_from(device_keys) {
|
|
||||||
Ok(d) => d,
|
|
||||||
Err(e) => {
|
|
||||||
warn!(
|
|
||||||
"Failed to create a new device for {} {}: {:?}",
|
|
||||||
user_id, device_id, e
|
|
||||||
);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
info!("Adding a new device to the device store {:?}", device);
|
|
||||||
changes.new.push(device);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let current_devices: HashSet<&DeviceIdBox> = device_map.keys().collect();
|
|
||||||
let stored_devices = self.store.get_readonly_devices(&user_id).await?;
|
|
||||||
let stored_devices_set: HashSet<&DeviceIdBox> = stored_devices.keys().collect();
|
|
||||||
|
|
||||||
let deleted_devices_set = stored_devices_set.difference(¤t_devices);
|
|
||||||
|
|
||||||
for device_id in deleted_devices_set {
|
|
||||||
if let Some(device) = stored_devices.get(*device_id) {
|
|
||||||
device.mark_as_deleted();
|
|
||||||
changes.deleted.push(device.clone());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(changes)
|
Ok(changes)
|
||||||
|
|
|
@ -63,6 +63,14 @@ impl ToDeviceRequest {
|
||||||
pub fn txn_id_string(&self) -> String {
|
pub fn txn_id_string(&self) -> String {
|
||||||
self.txn_id.to_string()
|
self.txn_id.to_string()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Get the number of unique messages this request contains.
|
||||||
|
///
|
||||||
|
/// *Note*: A single message may be sent to multiple devices, so this may or
|
||||||
|
/// may not be the number of devices that will receive the messages as well.
|
||||||
|
pub fn message_count(&self) -> usize {
|
||||||
|
self.messages.values().map(|d| d.len()).sum()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Request that will publish a cross signing identity.
|
/// Request that will publish a cross signing identity.
|
||||||
|
|
|
@ -17,6 +17,8 @@ use std::{
|
||||||
sync::Arc,
|
sync::Arc,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
use futures::future::join_all;
|
||||||
|
|
||||||
use dashmap::DashMap;
|
use dashmap::DashMap;
|
||||||
use matrix_sdk_common::{
|
use matrix_sdk_common::{
|
||||||
api::r0::to_device::DeviceIdOrAllDevices,
|
api::r0::to_device::DeviceIdOrAllDevices,
|
||||||
|
@ -24,6 +26,7 @@ use matrix_sdk_common::{
|
||||||
room::{encrypted::EncryptedEventContent, history_visibility::HistoryVisibility},
|
room::{encrypted::EncryptedEventContent, history_visibility::HistoryVisibility},
|
||||||
AnyMessageEventContent, EventType,
|
AnyMessageEventContent, EventType,
|
||||||
},
|
},
|
||||||
|
executor::spawn,
|
||||||
identifiers::{DeviceId, DeviceIdBox, RoomId, UserId},
|
identifiers::{DeviceId, DeviceIdBox, RoomId, UserId},
|
||||||
uuid::Uuid,
|
uuid::Uuid,
|
||||||
};
|
};
|
||||||
|
@ -52,7 +55,7 @@ pub struct GroupSessionManager {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl GroupSessionManager {
|
impl GroupSessionManager {
|
||||||
const MAX_TO_DEVICE_MESSAGES: usize = 20;
|
const MAX_TO_DEVICE_MESSAGES: usize = 250;
|
||||||
|
|
||||||
pub(crate) fn new(account: Account, store: Store) -> Self {
|
pub(crate) fn new(account: Account, store: Store) -> Self {
|
||||||
Self {
|
Self {
|
||||||
|
@ -188,35 +191,57 @@ impl GroupSessionManager {
|
||||||
/// Encrypt the given content for the given devices and create a to-device
|
/// Encrypt the given content for the given devices and create a to-device
|
||||||
/// requests that sends the encrypted content to them.
|
/// requests that sends the encrypted content to them.
|
||||||
async fn encrypt_session_for(
|
async fn encrypt_session_for(
|
||||||
&self,
|
|
||||||
content: Value,
|
content: Value,
|
||||||
devices: &[Device],
|
devices: Vec<Device>,
|
||||||
) -> OlmResult<(Uuid, ToDeviceRequest, Vec<Session>)> {
|
) -> OlmResult<(Uuid, ToDeviceRequest, Vec<Session>)> {
|
||||||
let mut messages = BTreeMap::new();
|
let mut messages = BTreeMap::new();
|
||||||
let mut changed_sessions = Vec::new();
|
let mut changed_sessions = Vec::new();
|
||||||
|
|
||||||
for device in devices {
|
let encrypt = |device: Device, content: Value| async move {
|
||||||
|
let mut message = BTreeMap::new();
|
||||||
|
|
||||||
let encrypted = device.encrypt(EventType::RoomKey, content.clone()).await;
|
let encrypted = device.encrypt(EventType::RoomKey, content.clone()).await;
|
||||||
|
|
||||||
let (used_session, encrypted) = match encrypted {
|
let used_session = match encrypted {
|
||||||
Ok(c) => c,
|
Ok((session, encrypted)) => {
|
||||||
|
message
|
||||||
|
.entry(device.user_id().clone())
|
||||||
|
.or_insert_with(BTreeMap::new)
|
||||||
|
.insert(
|
||||||
|
DeviceIdOrAllDevices::DeviceId(device.device_id().into()),
|
||||||
|
serde_json::value::to_raw_value(&encrypted)?,
|
||||||
|
);
|
||||||
|
Some(session)
|
||||||
|
}
|
||||||
// TODO we'll want to create m.room_key.withheld here.
|
// TODO we'll want to create m.room_key.withheld here.
|
||||||
Err(OlmError::MissingSession)
|
Err(OlmError::MissingSession)
|
||||||
| Err(OlmError::EventError(EventError::MissingSenderKey)) => {
|
| Err(OlmError::EventError(EventError::MissingSenderKey)) => None,
|
||||||
continue;
|
|
||||||
}
|
|
||||||
Err(e) => return Err(e),
|
Err(e) => return Err(e),
|
||||||
};
|
};
|
||||||
|
|
||||||
changed_sessions.push(used_session);
|
Ok((used_session, message))
|
||||||
|
};
|
||||||
|
|
||||||
messages
|
let tasks: Vec<_> = devices
|
||||||
.entry(device.user_id().clone())
|
.iter()
|
||||||
.or_insert_with(BTreeMap::new)
|
.map(|d| spawn(encrypt(d.clone(), content.clone())))
|
||||||
.insert(
|
.collect();
|
||||||
DeviceIdOrAllDevices::DeviceId(device.device_id().into()),
|
|
||||||
serde_json::value::to_raw_value(&encrypted)?,
|
let results = join_all(tasks).await;
|
||||||
);
|
|
||||||
|
for result in results {
|
||||||
|
let (used_session, message) = result.expect("Encryption task panicked")?;
|
||||||
|
|
||||||
|
if let Some(session) = used_session {
|
||||||
|
changed_sessions.push(session);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (user, device_messages) in message.into_iter() {
|
||||||
|
messages
|
||||||
|
.entry(user)
|
||||||
|
.or_insert_with(BTreeMap::new)
|
||||||
|
.extend(device_messages);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let id = Uuid::new_v4();
|
let id = Uuid::new_v4();
|
||||||
|
@ -227,6 +252,12 @@ impl GroupSessionManager {
|
||||||
messages,
|
messages,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
trace!(
|
||||||
|
recipient_count = request.message_count(),
|
||||||
|
transaction_id = ?id,
|
||||||
|
"Created a to-device request carrying a room_key"
|
||||||
|
);
|
||||||
|
|
||||||
Ok((id, request, changed_sessions))
|
Ok((id, request, changed_sessions))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -334,6 +365,24 @@ impl GroupSessionManager {
|
||||||
Ok((should_rotate, devices))
|
Ok((should_rotate, devices))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn encrypt_request(
|
||||||
|
chunk: Vec<Device>,
|
||||||
|
content: Value,
|
||||||
|
outbound: OutboundGroupSession,
|
||||||
|
message_index: u32,
|
||||||
|
being_shared: Arc<DashMap<Uuid, OutboundGroupSession>>,
|
||||||
|
) -> OlmResult<Vec<Session>> {
|
||||||
|
let (id, request, used_sessions) =
|
||||||
|
Self::encrypt_session_for(content.clone(), chunk).await?;
|
||||||
|
|
||||||
|
if !request.messages.is_empty() {
|
||||||
|
outbound.add_request(id, request.into(), message_index);
|
||||||
|
being_shared.insert(id, outbound.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(used_sessions)
|
||||||
|
}
|
||||||
|
|
||||||
/// Get to-device requests to share a group session with users in a room.
|
/// Get to-device requests to share a group session with users in a room.
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
|
@ -427,18 +476,23 @@ impl GroupSessionManager {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
for device_map_chunk in devices.chunks(Self::MAX_TO_DEVICE_MESSAGES) {
|
let tasks: Vec<_> = devices
|
||||||
let (id, request, used_sessions) = self
|
.chunks(Self::MAX_TO_DEVICE_MESSAGES)
|
||||||
.encrypt_session_for(key_content.clone(), device_map_chunk)
|
.map(|chunk| {
|
||||||
.await?;
|
spawn(Self::encrypt_request(
|
||||||
|
chunk.to_vec(),
|
||||||
|
key_content.clone(),
|
||||||
|
outbound.clone(),
|
||||||
|
message_index,
|
||||||
|
self.outbound_sessions_being_shared.clone(),
|
||||||
|
))
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
if !request.messages.is_empty() {
|
for result in join_all(tasks).await {
|
||||||
outbound.add_request(id, request.into(), message_index);
|
let used_sessions: OlmResult<Vec<Session>> = result.expect("Encryption task paniced");
|
||||||
self.outbound_sessions_being_shared
|
|
||||||
.insert(id, outbound.clone());
|
|
||||||
}
|
|
||||||
|
|
||||||
changes.sessions.extend(used_sessions);
|
changes.sessions.extend(used_sessions?);
|
||||||
}
|
}
|
||||||
|
|
||||||
let requests = outbound.pending_requests();
|
let requests = outbound.pending_requests();
|
||||||
|
@ -474,3 +528,84 @@ impl GroupSessionManager {
|
||||||
Ok(requests)
|
Ok(requests)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod test {
|
||||||
|
use std::convert::TryFrom;
|
||||||
|
|
||||||
|
use matrix_sdk_common::{
|
||||||
|
api::r0::keys::{claim_keys, get_keys},
|
||||||
|
identifiers::{room_id, user_id, DeviceIdBox, UserId},
|
||||||
|
uuid::Uuid,
|
||||||
|
};
|
||||||
|
use matrix_sdk_test::response_from_file;
|
||||||
|
use serde_json::Value;
|
||||||
|
|
||||||
|
use crate::{EncryptionSettings, OlmMachine};
|
||||||
|
|
||||||
|
fn alice_id() -> UserId {
|
||||||
|
user_id!("@alice:example.org")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn alice_device_id() -> DeviceIdBox {
|
||||||
|
"JLAFKJWSCS".into()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn keys_query_response() -> get_keys::Response {
|
||||||
|
let data = include_bytes!("../../benches/keys_query.json");
|
||||||
|
let data: Value = serde_json::from_slice(data).unwrap();
|
||||||
|
let data = response_from_file(&data);
|
||||||
|
get_keys::Response::try_from(data).expect("Can't parse the keys upload response")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn keys_claim_response() -> claim_keys::Response {
|
||||||
|
let data = include_bytes!("../../benches/keys_claim.json");
|
||||||
|
let data: Value = serde_json::from_slice(data).unwrap();
|
||||||
|
let data = response_from_file(&data);
|
||||||
|
claim_keys::Response::try_from(data).expect("Can't parse the keys upload response")
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn machine() -> OlmMachine {
|
||||||
|
let keys_query = keys_query_response();
|
||||||
|
let keys_claim = keys_claim_response();
|
||||||
|
let uuid = Uuid::new_v4();
|
||||||
|
|
||||||
|
let machine = OlmMachine::new(&alice_id(), &alice_device_id());
|
||||||
|
|
||||||
|
machine
|
||||||
|
.mark_request_as_sent(&uuid, &keys_query)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
machine
|
||||||
|
.mark_request_as_sent(&uuid, &keys_claim)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
machine
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_sharing() {
|
||||||
|
let machine = machine().await;
|
||||||
|
let room_id = room_id!("!test:localhost");
|
||||||
|
let keys_claim = keys_claim_response();
|
||||||
|
|
||||||
|
let users: Vec<_> = keys_claim.one_time_keys.keys().collect();
|
||||||
|
|
||||||
|
let requests = machine
|
||||||
|
.share_group_session(
|
||||||
|
&room_id,
|
||||||
|
users.clone().into_iter(),
|
||||||
|
EncryptionSettings::default(),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let event_count: usize = requests.iter().map(|r| r.message_count()).sum();
|
||||||
|
|
||||||
|
// The keys claim response has a couple of one-time keys with invalid
|
||||||
|
// signatures, thus only 148 sessions are actually created, we check
|
||||||
|
// that all 148 valid sessions get an room key.
|
||||||
|
assert_eq!(event_count, 148);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -195,9 +195,9 @@ impl SessionManager {
|
||||||
// Add the list of devices that the user wishes to establish sessions
|
// Add the list of devices that the user wishes to establish sessions
|
||||||
// right now.
|
// right now.
|
||||||
for user_id in users {
|
for user_id in users {
|
||||||
let user_devices = self.store.get_user_devices(user_id).await?;
|
let user_devices = self.store.get_readonly_devices(user_id).await?;
|
||||||
|
|
||||||
for device in user_devices.devices() {
|
for (device_id, device) in user_devices.into_iter() {
|
||||||
let sender_key = if let Some(k) = device.get_key(DeviceKeyAlgorithm::Curve25519) {
|
let sender_key = if let Some(k) = device.get_key(DeviceKeyAlgorithm::Curve25519) {
|
||||||
k
|
k
|
||||||
} else {
|
} else {
|
||||||
|
@ -216,10 +216,7 @@ impl SessionManager {
|
||||||
missing
|
missing
|
||||||
.entry(user_id.to_owned())
|
.entry(user_id.to_owned())
|
||||||
.or_insert_with(BTreeMap::new)
|
.or_insert_with(BTreeMap::new)
|
||||||
.insert(
|
.insert(device_id, DeviceKeyAlgorithm::SignedCurve25519);
|
||||||
device.device_id().into(),
|
|
||||||
DeviceKeyAlgorithm::SignedCurve25519,
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -126,6 +126,15 @@ pub struct DeviceChanges {
|
||||||
pub deleted: Vec<ReadOnlyDevice>,
|
pub deleted: Vec<ReadOnlyDevice>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl DeviceChanges {
|
||||||
|
/// Merge the given `DeviceChanges` into this instance of `DeviceChanges`.
|
||||||
|
pub fn extend(&mut self, other: DeviceChanges) {
|
||||||
|
self.new.extend(other.new);
|
||||||
|
self.changed.extend(other.changed);
|
||||||
|
self.deleted.extend(other.deleted);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Store {
|
impl Store {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
user_id: Arc<UserId>,
|
user_id: Arc<UserId>,
|
||||||
|
|
|
@ -16,11 +16,11 @@ use std::{
|
||||||
collections::{HashMap, HashSet},
|
collections::{HashMap, HashSet},
|
||||||
convert::TryFrom,
|
convert::TryFrom,
|
||||||
path::{Path, PathBuf},
|
path::{Path, PathBuf},
|
||||||
sync::Arc,
|
sync::{Arc, RwLock},
|
||||||
};
|
};
|
||||||
|
|
||||||
use dashmap::DashSet;
|
use dashmap::DashSet;
|
||||||
use olm_rs::PicklingMode;
|
use olm_rs::{account::IdentityKeys, PicklingMode};
|
||||||
pub use sled::Error;
|
pub use sled::Error;
|
||||||
use sled::{
|
use sled::{
|
||||||
transaction::{ConflictableTransactionError, TransactionError},
|
transaction::{ConflictableTransactionError, TransactionError},
|
||||||
|
@ -95,9 +95,17 @@ impl EncodeKey for (&str, &str, &str) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct AccountInfo {
|
||||||
|
user_id: Arc<UserId>,
|
||||||
|
device_id: Arc<DeviceIdBox>,
|
||||||
|
identity_keys: Arc<IdentityKeys>,
|
||||||
|
}
|
||||||
|
|
||||||
/// An in-memory only store that will forget all the E2EE key once it's dropped.
|
/// An in-memory only store that will forget all the E2EE key once it's dropped.
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct SledStore {
|
pub struct SledStore {
|
||||||
|
account_info: Arc<RwLock<Option<AccountInfo>>>,
|
||||||
path: Option<PathBuf>,
|
path: Option<PathBuf>,
|
||||||
inner: Db,
|
inner: Db,
|
||||||
pickle_key: Arc<PickleKey>,
|
pickle_key: Arc<PickleKey>,
|
||||||
|
@ -159,6 +167,10 @@ impl SledStore {
|
||||||
SledStore::open_helper(db, None, passphrase)
|
SledStore::open_helper(db, None, passphrase)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn get_account_info(&self) -> Option<AccountInfo> {
|
||||||
|
self.account_info.read().unwrap().clone()
|
||||||
|
}
|
||||||
|
|
||||||
fn open_helper(db: Db, path: Option<PathBuf>, passphrase: Option<&str>) -> Result<Self> {
|
fn open_helper(db: Db, path: Option<PathBuf>, passphrase: Option<&str>) -> Result<Self> {
|
||||||
let account = db.open_tree("account")?;
|
let account = db.open_tree("account")?;
|
||||||
let private_identity = db.open_tree("private_identity")?;
|
let private_identity = db.open_tree("private_identity")?;
|
||||||
|
@ -184,6 +196,7 @@ impl SledStore {
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
|
account_info: RwLock::new(None).into(),
|
||||||
path,
|
path,
|
||||||
inner: db,
|
inner: db,
|
||||||
pickle_key: pickle_key.into(),
|
pickle_key: pickle_key.into(),
|
||||||
|
@ -249,22 +262,18 @@ impl SledStore {
|
||||||
&self,
|
&self,
|
||||||
room_id: &RoomId,
|
room_id: &RoomId,
|
||||||
) -> Result<Option<OutboundGroupSession>> {
|
) -> Result<Option<OutboundGroupSession>> {
|
||||||
let account = self
|
let account_info = self
|
||||||
.load_account()
|
.get_account_info()
|
||||||
.await?
|
|
||||||
.ok_or(CryptoStoreError::AccountUnset)?;
|
.ok_or(CryptoStoreError::AccountUnset)?;
|
||||||
|
|
||||||
let device_id: Arc<DeviceIdBox> = account.device_id().to_owned().into();
|
|
||||||
let identity_keys = account.identity_keys;
|
|
||||||
|
|
||||||
self.outbound_group_sessions
|
self.outbound_group_sessions
|
||||||
.get(room_id.encode())?
|
.get(room_id.encode())?
|
||||||
.map(|p| serde_json::from_slice(&p).map_err(CryptoStoreError::Serialization))
|
.map(|p| serde_json::from_slice(&p).map_err(CryptoStoreError::Serialization))
|
||||||
.transpose()?
|
.transpose()?
|
||||||
.map(|p| {
|
.map(|p| {
|
||||||
OutboundGroupSession::from_pickle(
|
OutboundGroupSession::from_pickle(
|
||||||
device_id,
|
account_info.device_id,
|
||||||
identity_keys,
|
account_info.identity_keys,
|
||||||
p,
|
p,
|
||||||
self.get_pickle_mode(),
|
self.get_pickle_mode(),
|
||||||
)
|
)
|
||||||
|
@ -430,16 +439,31 @@ impl CryptoStore for SledStore {
|
||||||
|
|
||||||
self.load_tracked_users().await?;
|
self.load_tracked_users().await?;
|
||||||
|
|
||||||
Ok(Some(ReadOnlyAccount::from_pickle(
|
let account = ReadOnlyAccount::from_pickle(pickle, self.get_pickle_mode())?;
|
||||||
pickle,
|
|
||||||
self.get_pickle_mode(),
|
let account_info = AccountInfo {
|
||||||
)?))
|
user_id: account.user_id.clone(),
|
||||||
|
device_id: account.device_id.clone(),
|
||||||
|
identity_keys: account.identity_keys.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
|
*self.account_info.write().unwrap() = Some(account_info);
|
||||||
|
|
||||||
|
Ok(Some(account))
|
||||||
} else {
|
} else {
|
||||||
Ok(None)
|
Ok(None)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn save_account(&self, account: ReadOnlyAccount) -> Result<()> {
|
async fn save_account(&self, account: ReadOnlyAccount) -> Result<()> {
|
||||||
|
let account_info = AccountInfo {
|
||||||
|
user_id: account.user_id.clone(),
|
||||||
|
device_id: account.device_id.clone(),
|
||||||
|
identity_keys: account.identity_keys.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
|
*self.account_info.write().unwrap() = Some(account_info);
|
||||||
|
|
||||||
let changes = Changes {
|
let changes = Changes {
|
||||||
account: Some(account),
|
account: Some(account),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
|
@ -453,9 +477,8 @@ impl CryptoStore for SledStore {
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_sessions(&self, sender_key: &str) -> Result<Option<Arc<Mutex<Vec<Session>>>>> {
|
async fn get_sessions(&self, sender_key: &str) -> Result<Option<Arc<Mutex<Vec<Session>>>>> {
|
||||||
let account = self
|
let account_info = self
|
||||||
.load_account()
|
.get_account_info()
|
||||||
.await?
|
|
||||||
.ok_or(CryptoStoreError::AccountUnset)?;
|
.ok_or(CryptoStoreError::AccountUnset)?;
|
||||||
|
|
||||||
if self.session_cache.get(sender_key).is_none() {
|
if self.session_cache.get(sender_key).is_none() {
|
||||||
|
@ -465,9 +488,9 @@ impl CryptoStore for SledStore {
|
||||||
.map(|s| serde_json::from_slice(&s?.1).map_err(CryptoStoreError::Serialization))
|
.map(|s| serde_json::from_slice(&s?.1).map_err(CryptoStoreError::Serialization))
|
||||||
.map(|p| {
|
.map(|p| {
|
||||||
Session::from_pickle(
|
Session::from_pickle(
|
||||||
account.user_id.clone(),
|
account_info.user_id.clone(),
|
||||||
account.device_id.clone(),
|
account_info.device_id.clone(),
|
||||||
account.identity_keys.clone(),
|
account_info.identity_keys.clone(),
|
||||||
p?,
|
p?,
|
||||||
self.get_pickle_mode(),
|
self.get_pickle_mode(),
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue