Merge branch 'master' into sas-longer-flow
This commit is contained in:
commit
3f57a2a9f2
79 changed files with 2020 additions and 4180 deletions
2
.github/workflows/ci.yml
vendored
2
.github/workflows/ci.yml
vendored
|
@ -20,7 +20,7 @@ jobs:
|
|||
- name: Install rust
|
||||
uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: stable
|
||||
toolchain: nightly
|
||||
components: rustfmt
|
||||
profile: minimal
|
||||
override: true
|
||||
|
|
6
.rustfmt.toml
Normal file
6
.rustfmt.toml
Normal file
|
@ -0,0 +1,6 @@
|
|||
max_width = 100
|
||||
comment_width = 80
|
||||
wrap_comments = true
|
||||
imports_granularity = "Crate"
|
||||
use_small_heuristics = "Max"
|
||||
group_imports = "StdExternalCrate"
|
|
@ -1,5 +1,4 @@
|
|||
use std::{env, process::exit};
|
||||
use tokio::time::{sleep, Duration};
|
||||
|
||||
use matrix_sdk::{
|
||||
self, async_trait,
|
||||
|
@ -7,6 +6,7 @@ use matrix_sdk::{
|
|||
room::Room,
|
||||
Client, ClientConfig, EventHandler, SyncSettings,
|
||||
};
|
||||
use tokio::time::{sleep, Duration};
|
||||
use url::Url;
|
||||
|
||||
struct AutoJoinBot {
|
||||
|
@ -72,15 +72,11 @@ async fn login_and_sync(
|
|||
let homeserver_url = Url::parse(&homeserver_url).expect("Couldn't parse the homeserver URL");
|
||||
let client = Client::new_with_config(homeserver_url, client_config).unwrap();
|
||||
|
||||
client
|
||||
.login(username, password, None, Some("autojoin bot"))
|
||||
.await?;
|
||||
client.login(username, password, None, Some("autojoin bot")).await?;
|
||||
|
||||
println!("logged in as {}", username);
|
||||
|
||||
client
|
||||
.set_event_handler(Box::new(AutoJoinBot::new(client.clone())))
|
||||
.await;
|
||||
client.set_event_handler(Box::new(AutoJoinBot::new(client.clone()))).await;
|
||||
|
||||
client.sync(SyncSettings::default()).await;
|
||||
|
||||
|
|
|
@ -69,24 +69,23 @@ async fn login_and_sync(
|
|||
// create a new Client with the given homeserver url and config
|
||||
let client = Client::new_with_config(homeserver_url, client_config).unwrap();
|
||||
|
||||
client
|
||||
.login(&username, &password, None, Some("command bot"))
|
||||
.await?;
|
||||
client.login(&username, &password, None, Some("command bot")).await?;
|
||||
|
||||
println!("logged in as {}", username);
|
||||
|
||||
// An initial sync to set up state and so our bot doesn't respond to old messages.
|
||||
// If the `StateStore` finds saved state in the location given the initial sync will
|
||||
// be skipped in favor of loading state from the store
|
||||
// An initial sync to set up state and so our bot doesn't respond to old
|
||||
// messages. If the `StateStore` finds saved state in the location given the
|
||||
// initial sync will be skipped in favor of loading state from the store
|
||||
client.sync_once(SyncSettings::default()).await.unwrap();
|
||||
// add our CommandBot to be notified of incoming messages, we do this after the initial
|
||||
// sync to avoid responding to messages before the bot was running.
|
||||
// add our CommandBot to be notified of incoming messages, we do this after the
|
||||
// initial sync to avoid responding to messages before the bot was running.
|
||||
client.set_event_handler(Box::new(CommandBot::new())).await;
|
||||
|
||||
// since we called `sync_once` before we entered our sync loop we must pass
|
||||
// that sync token to `sync`
|
||||
let settings = SyncSettings::default().token(client.sync_token().await.unwrap());
|
||||
// this keeps state from the server streaming in to CommandBot via the EventHandler trait
|
||||
// this keeps state from the server streaming in to CommandBot via the
|
||||
// EventHandler trait
|
||||
client.sync(settings).await;
|
||||
|
||||
Ok(())
|
||||
|
|
|
@ -5,12 +5,11 @@ use std::{
|
|||
sync::atomic::{AtomicBool, Ordering},
|
||||
};
|
||||
|
||||
use serde_json::json;
|
||||
use url::Url;
|
||||
|
||||
use matrix_sdk::{
|
||||
self, api::r0::uiaa::AuthData, identifiers::UserId, Client, LoopCtrl, SyncSettings,
|
||||
};
|
||||
use serde_json::json;
|
||||
use url::Url;
|
||||
|
||||
fn auth_data<'a>(user: &UserId, password: &str, session: Option<&'a str>) -> AuthData<'a> {
|
||||
let mut auth_parameters = BTreeMap::new();
|
||||
|
@ -22,11 +21,7 @@ fn auth_data<'a>(user: &UserId, password: &str, session: Option<&'a str>) -> Aut
|
|||
auth_parameters.insert("identifier".to_owned(), identifier);
|
||||
auth_parameters.insert("password".to_owned(), password.to_owned().into());
|
||||
|
||||
AuthData::DirectRequest {
|
||||
kind: "m.login.password",
|
||||
auth_parameters,
|
||||
session,
|
||||
}
|
||||
AuthData::DirectRequest { kind: "m.login.password", auth_parameters, session }
|
||||
}
|
||||
|
||||
async fn bootstrap(client: Client, user_id: UserId, password: String) {
|
||||
|
@ -34,9 +29,7 @@ async fn bootstrap(client: Client, user_id: UserId, password: String) {
|
|||
|
||||
let mut input = String::new();
|
||||
|
||||
io::stdin()
|
||||
.read_line(&mut input)
|
||||
.expect("error: unable to read user input");
|
||||
io::stdin().read_line(&mut input).expect("error: unable to read user input");
|
||||
|
||||
#[cfg(feature = "encryption")]
|
||||
if let Err(e) = client.bootstrap_cross_signing(None).await {
|
||||
|
@ -63,9 +56,7 @@ async fn login(
|
|||
let homeserver_url = Url::parse(&homeserver_url).expect("Couldn't parse the homeserver URL");
|
||||
let client = Client::new(homeserver_url).unwrap();
|
||||
|
||||
let response = client
|
||||
.login(username, password, None, Some("rust-sdk"))
|
||||
.await?;
|
||||
let response = client.login(username, password, None, Some("rust-sdk")).await?;
|
||||
|
||||
let user_id = &response.user_id;
|
||||
let client_ref = &client;
|
||||
|
|
|
@ -6,7 +6,6 @@ use std::{
|
|||
Arc,
|
||||
},
|
||||
};
|
||||
use url::Url;
|
||||
|
||||
use matrix_sdk::{
|
||||
self,
|
||||
|
@ -14,14 +13,13 @@ use matrix_sdk::{
|
|||
identifiers::UserId,
|
||||
Client, LoopCtrl, Sas, SyncSettings,
|
||||
};
|
||||
use url::Url;
|
||||
|
||||
async fn wait_for_confirmation(client: Client, sas: Sas) {
|
||||
println!("Does the emoji match: {:?}", sas.emoji());
|
||||
|
||||
let mut input = String::new();
|
||||
io::stdin()
|
||||
.read_line(&mut input)
|
||||
.expect("error: unable to read user input");
|
||||
io::stdin().read_line(&mut input).expect("error: unable to read user input");
|
||||
|
||||
match input.trim().to_lowercase().as_ref() {
|
||||
"yes" | "true" | "ok" => {
|
||||
|
@ -68,9 +66,7 @@ async fn login(
|
|||
let homeserver_url = Url::parse(&homeserver_url).expect("Couldn't parse the homeserver URL");
|
||||
let client = Client::new(homeserver_url).unwrap();
|
||||
|
||||
client
|
||||
.login(username, password, None, Some("rust-sdk"))
|
||||
.await?;
|
||||
client.login(username, password, None, Some("rust-sdk")).await?;
|
||||
|
||||
let client_ref = &client;
|
||||
let initial_sync = Arc::new(AtomicBool::from(true));
|
||||
|
@ -81,12 +77,7 @@ async fn login(
|
|||
let client = &client_ref;
|
||||
let initial = &initial_ref;
|
||||
|
||||
for event in response
|
||||
.to_device
|
||||
.events
|
||||
.iter()
|
||||
.filter_map(|e| e.deserialize().ok())
|
||||
{
|
||||
for event in response.to_device.events.iter().filter_map(|e| e.deserialize().ok()) {
|
||||
match event {
|
||||
AnyToDeviceEvent::KeyVerificationStart(e) => {
|
||||
let sas = client
|
||||
|
@ -129,11 +120,8 @@ async fn login(
|
|||
|
||||
if !initial.load(Ordering::SeqCst) {
|
||||
for (_room_id, room_info) in response.rooms.join {
|
||||
for event in room_info
|
||||
.timeline
|
||||
.events
|
||||
.iter()
|
||||
.filter_map(|e| e.event.deserialize().ok())
|
||||
for event in
|
||||
room_info.timeline.events.iter().filter_map(|e| e.event.deserialize().ok())
|
||||
{
|
||||
if let AnySyncRoomEvent::Message(event) = event {
|
||||
match event {
|
||||
|
|
|
@ -1,13 +1,12 @@
|
|||
use std::{convert::TryFrom, env, process::exit};
|
||||
|
||||
use url::Url;
|
||||
|
||||
use matrix_sdk::{
|
||||
self,
|
||||
api::r0::profile,
|
||||
identifiers::{MxcUri, UserId},
|
||||
Client, Result as MatrixResult,
|
||||
};
|
||||
use url::Url;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct UserProfile {
|
||||
|
@ -29,10 +28,7 @@ async fn get_profile(client: Client, mxid: &UserId) -> MatrixResult<UserProfile>
|
|||
// Use the response and construct a UserProfile struct.
|
||||
// See https://docs.rs/ruma-client-api/0.9.0/ruma_client_api/r0/profile/get_profile/struct.Response.html
|
||||
// for details on the Response for this Request
|
||||
let user_profile = UserProfile {
|
||||
avatar_url: resp.avatar_url,
|
||||
displayname: resp.displayname,
|
||||
};
|
||||
let user_profile = UserProfile { avatar_url: resp.avatar_url, displayname: resp.displayname };
|
||||
Ok(user_profile)
|
||||
}
|
||||
|
||||
|
@ -44,9 +40,7 @@ async fn login(
|
|||
let homeserver_url = Url::parse(&homeserver_url).expect("Couldn't parse the homeserver URL");
|
||||
let client = Client::new(homeserver_url).unwrap();
|
||||
|
||||
client
|
||||
.login(username, password, None, Some("rust-sdk"))
|
||||
.await?;
|
||||
client.login(username, password, None, Some("rust-sdk")).await?;
|
||||
|
||||
Ok(client)
|
||||
}
|
||||
|
|
|
@ -6,7 +6,6 @@ use std::{
|
|||
process::exit,
|
||||
sync::Arc,
|
||||
};
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use matrix_sdk::{
|
||||
self, async_trait,
|
||||
|
@ -17,6 +16,7 @@ use matrix_sdk::{
|
|||
room::Room,
|
||||
Client, EventHandler, SyncSettings,
|
||||
};
|
||||
use tokio::sync::Mutex;
|
||||
use url::Url;
|
||||
|
||||
struct ImageBot {
|
||||
|
@ -52,9 +52,7 @@ impl EventHandler for ImageBot {
|
|||
println!("sending image");
|
||||
let mut image = self.image.lock().await;
|
||||
|
||||
room.send_attachment("cat", &mime::IMAGE_JPEG, &mut *image, None)
|
||||
.await
|
||||
.unwrap();
|
||||
room.send_attachment("cat", &mime::IMAGE_JPEG, &mut *image, None).await.unwrap();
|
||||
|
||||
image.seek(SeekFrom::Start(0)).unwrap();
|
||||
|
||||
|
@ -73,14 +71,10 @@ async fn login_and_sync(
|
|||
let homeserver_url = Url::parse(&homeserver_url).expect("Couldn't parse the homeserver URL");
|
||||
let client = Client::new(homeserver_url).unwrap();
|
||||
|
||||
client
|
||||
.login(&username, &password, None, Some("command bot"))
|
||||
.await?;
|
||||
client.login(&username, &password, None, Some("command bot")).await?;
|
||||
|
||||
client.sync_once(SyncSettings::default()).await.unwrap();
|
||||
client
|
||||
.set_event_handler(Box::new(ImageBot::new(image)))
|
||||
.await;
|
||||
client.set_event_handler(Box::new(ImageBot::new(image))).await;
|
||||
|
||||
let settings = SyncSettings::default().token(client.sync_token().await.unwrap());
|
||||
client.sync(settings).await;
|
||||
|
@ -91,26 +85,19 @@ async fn login_and_sync(
|
|||
#[tokio::main]
|
||||
async fn main() -> Result<(), matrix_sdk::Error> {
|
||||
tracing_subscriber::fmt::init();
|
||||
let (homeserver_url, username, password, image_path) = match (
|
||||
env::args().nth(1),
|
||||
env::args().nth(2),
|
||||
env::args().nth(3),
|
||||
env::args().nth(4),
|
||||
) {
|
||||
(Some(a), Some(b), Some(c), Some(d)) => (a, b, c, d),
|
||||
_ => {
|
||||
eprintln!(
|
||||
"Usage: {} <homeserver_url> <username> <password> <image>",
|
||||
env::args().next().unwrap()
|
||||
);
|
||||
exit(1)
|
||||
}
|
||||
};
|
||||
let (homeserver_url, username, password, image_path) =
|
||||
match (env::args().nth(1), env::args().nth(2), env::args().nth(3), env::args().nth(4)) {
|
||||
(Some(a), Some(b), Some(c), Some(d)) => (a, b, c, d),
|
||||
_ => {
|
||||
eprintln!(
|
||||
"Usage: {} <homeserver_url> <username> <password> <image>",
|
||||
env::args().next().unwrap()
|
||||
);
|
||||
exit(1)
|
||||
}
|
||||
};
|
||||
|
||||
println!(
|
||||
"helloooo {} {} {} {:#?}",
|
||||
homeserver_url, username, password, image_path
|
||||
);
|
||||
println!("helloooo {} {} {} {:#?}", homeserver_url, username, password, image_path);
|
||||
let path = PathBuf::from(image_path);
|
||||
let image = File::open(path).expect("Can't open image file.");
|
||||
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
use std::{env, process::exit};
|
||||
use url::Url;
|
||||
|
||||
use matrix_sdk::{
|
||||
self, async_trait,
|
||||
|
@ -10,6 +9,7 @@ use matrix_sdk::{
|
|||
room::Room,
|
||||
Client, EventHandler, SyncSettings,
|
||||
};
|
||||
use url::Url;
|
||||
|
||||
struct EventCallback;
|
||||
|
||||
|
@ -28,9 +28,7 @@ impl EventHandler for EventCallback {
|
|||
} = event
|
||||
{
|
||||
let member = room.get_member(&sender).await.unwrap().unwrap();
|
||||
let name = member
|
||||
.display_name()
|
||||
.unwrap_or_else(|| member.user_id().as_str());
|
||||
let name = member.display_name().unwrap_or_else(|| member.user_id().as_str());
|
||||
println!("{}: {}", name, msg_body);
|
||||
}
|
||||
}
|
||||
|
@ -47,9 +45,7 @@ async fn login(
|
|||
|
||||
client.set_event_handler(Box::new(EventCallback)).await;
|
||||
|
||||
client
|
||||
.login(username, password, None, Some("rust-sdk"))
|
||||
.await?;
|
||||
client.login(username, password, None, Some("rust-sdk")).await?;
|
||||
client.sync(SyncSettings::new()).await;
|
||||
|
||||
Ok(())
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -40,7 +40,8 @@ impl Deref for Device {
|
|||
impl Device {
|
||||
/// Start a interactive verification with this `Device`
|
||||
///
|
||||
/// Returns a `Sas` object that represents the interactive verification flow.
|
||||
/// Returns a `Sas` object that represents the interactive verification
|
||||
/// flow.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
|
@ -65,10 +66,7 @@ impl Device {
|
|||
let (sas, request) = self.inner.start_verification().await?;
|
||||
self.client.send_to_device(&request).await?;
|
||||
|
||||
Ok(Sas {
|
||||
inner: sas,
|
||||
client: self.client.clone(),
|
||||
})
|
||||
Ok(Sas { inner: sas, client: self.client.clone() })
|
||||
}
|
||||
|
||||
/// Is the device trusted.
|
||||
|
@ -102,10 +100,7 @@ pub struct UserDevices {
|
|||
impl UserDevices {
|
||||
/// Get the specific device with the given device id.
|
||||
pub fn get(&self, device_id: &DeviceId) -> Option<Device> {
|
||||
self.inner.get(device_id).map(|d| Device {
|
||||
inner: d,
|
||||
client: self.client.clone(),
|
||||
})
|
||||
self.inner.get(device_id).map(|d| Device { inner: d, client: self.client.clone() })
|
||||
}
|
||||
|
||||
/// Iterator over all the device ids of the user devices.
|
||||
|
@ -117,9 +112,6 @@ impl UserDevices {
|
|||
pub fn devices(&self) -> impl Iterator<Item = Device> + '_ {
|
||||
let client = self.client.clone();
|
||||
|
||||
self.inner.devices().map(move |d| Device {
|
||||
inner: d,
|
||||
client: client.clone(),
|
||||
})
|
||||
self.inner.devices().map(move |d| Device { inner: d, client: client.clone() })
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,7 +14,11 @@
|
|||
|
||||
//! Error conditions.
|
||||
|
||||
use std::io::Error as IoError;
|
||||
|
||||
use http::StatusCode;
|
||||
#[cfg(feature = "encryption")]
|
||||
use matrix_sdk_base::crypto::store::CryptoStoreError;
|
||||
use matrix_sdk_base::{Error as MatrixError, StoreError};
|
||||
use matrix_sdk_common::{
|
||||
api::{
|
||||
|
@ -26,12 +30,8 @@ use matrix_sdk_common::{
|
|||
};
|
||||
use reqwest::Error as ReqwestError;
|
||||
use serde_json::Error as JsonError;
|
||||
use std::io::Error as IoError;
|
||||
use thiserror::Error;
|
||||
|
||||
#[cfg(feature = "encryption")]
|
||||
use matrix_sdk_base::crypto::store::CryptoStoreError;
|
||||
|
||||
/// Result type of the rust-sdk.
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
|
@ -43,11 +43,13 @@ pub enum HttpError {
|
|||
#[error(transparent)]
|
||||
Reqwest(#[from] ReqwestError),
|
||||
|
||||
/// Queried endpoint requires authentication but was called on an anonymous client.
|
||||
/// Queried endpoint requires authentication but was called on an anonymous
|
||||
/// client.
|
||||
#[error("the queried endpoint requires authentication but was called before logging in")]
|
||||
AuthenticationRequired,
|
||||
|
||||
/// Client tried to force authentication but did not provide an access token.
|
||||
/// Client tried to force authentication but did not provide an access
|
||||
/// token.
|
||||
#[error("tried to force authentication but no access token was provided")]
|
||||
ForcedAuthenticationWithoutAccessToken,
|
||||
|
||||
|
@ -69,9 +71,10 @@ pub enum HttpError {
|
|||
|
||||
/// An error occurred while authenticating.
|
||||
///
|
||||
/// When registering or authenticating the Matrix server can send a `UiaaResponse`
|
||||
/// as the error type, this is a User-Interactive Authentication API response. This
|
||||
/// represents an error with information about how to authenticate the user.
|
||||
/// When registering or authenticating the Matrix server can send a
|
||||
/// `UiaaResponse` as the error type, this is a User-Interactive
|
||||
/// Authentication API response. This represents an error with
|
||||
/// information about how to authenticate the user.
|
||||
#[error(transparent)]
|
||||
UiaaError(#[from] FromHttpResponseError<UiaaError>),
|
||||
|
||||
|
@ -96,7 +99,8 @@ pub enum Error {
|
|||
#[error(transparent)]
|
||||
Http(#[from] HttpError),
|
||||
|
||||
/// Queried endpoint requires authentication but was called on an anonymous client.
|
||||
/// Queried endpoint requires authentication but was called on an anonymous
|
||||
/// client.
|
||||
#[error("the queried endpoint requires authentication but was called before logging in")]
|
||||
AuthenticationRequired,
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@ use std::ops::Deref;
|
|||
|
||||
use matrix_sdk_common::{
|
||||
api::r0::push::get_notifications::Notification,
|
||||
async_trait,
|
||||
events::{
|
||||
fully_read::FullyReadEventContent, AnySyncRoomEvent, GlobalAccountDataEvent,
|
||||
RoomAccountDataEvent,
|
||||
|
@ -56,7 +57,6 @@ use crate::{
|
|||
room::Room,
|
||||
Client,
|
||||
};
|
||||
use matrix_sdk_common::async_trait;
|
||||
|
||||
pub(crate) struct Handler {
|
||||
pub(crate) inner: Box<dyn EventHandler>,
|
||||
|
@ -77,50 +77,29 @@ impl Handler {
|
|||
}
|
||||
|
||||
pub(crate) async fn handle_sync(&self, response: &SyncResponse) {
|
||||
for event in response
|
||||
.account_data
|
||||
.events
|
||||
.iter()
|
||||
.filter_map(|e| e.deserialize().ok())
|
||||
{
|
||||
for event in response.account_data.events.iter().filter_map(|e| e.deserialize().ok()) {
|
||||
self.handle_account_data_event(&event).await;
|
||||
}
|
||||
|
||||
for (room_id, room_info) in &response.rooms.join {
|
||||
if let Some(room) = self.get_room(room_id) {
|
||||
for event in room_info
|
||||
.ephemeral
|
||||
.events
|
||||
.iter()
|
||||
.filter_map(|e| e.deserialize().ok())
|
||||
for event in room_info.ephemeral.events.iter().filter_map(|e| e.deserialize().ok())
|
||||
{
|
||||
self.handle_ephemeral_event(room.clone(), &event).await;
|
||||
}
|
||||
|
||||
for event in room_info
|
||||
.account_data
|
||||
.events
|
||||
.iter()
|
||||
.filter_map(|e| e.deserialize().ok())
|
||||
for event in
|
||||
room_info.account_data.events.iter().filter_map(|e| e.deserialize().ok())
|
||||
{
|
||||
self.handle_room_account_data_event(room.clone(), &event)
|
||||
.await;
|
||||
self.handle_room_account_data_event(room.clone(), &event).await;
|
||||
}
|
||||
|
||||
for event in room_info
|
||||
.state
|
||||
.events
|
||||
.iter()
|
||||
.filter_map(|e| e.deserialize().ok())
|
||||
{
|
||||
for event in room_info.state.events.iter().filter_map(|e| e.deserialize().ok()) {
|
||||
self.handle_state_event(room.clone(), &event).await;
|
||||
}
|
||||
|
||||
for event in room_info
|
||||
.timeline
|
||||
.events
|
||||
.iter()
|
||||
.filter_map(|e| e.event.deserialize().ok())
|
||||
for event in
|
||||
room_info.timeline.events.iter().filter_map(|e| e.event.deserialize().ok())
|
||||
{
|
||||
self.handle_timeline_event(room.clone(), &event).await;
|
||||
}
|
||||
|
@ -129,30 +108,18 @@ impl Handler {
|
|||
|
||||
for (room_id, room_info) in &response.rooms.leave {
|
||||
if let Some(room) = self.get_room(room_id) {
|
||||
for event in room_info
|
||||
.account_data
|
||||
.events
|
||||
.iter()
|
||||
.filter_map(|e| e.deserialize().ok())
|
||||
for event in
|
||||
room_info.account_data.events.iter().filter_map(|e| e.deserialize().ok())
|
||||
{
|
||||
self.handle_room_account_data_event(room.clone(), &event)
|
||||
.await;
|
||||
self.handle_room_account_data_event(room.clone(), &event).await;
|
||||
}
|
||||
|
||||
for event in room_info
|
||||
.state
|
||||
.events
|
||||
.iter()
|
||||
.filter_map(|e| e.deserialize().ok())
|
||||
{
|
||||
for event in room_info.state.events.iter().filter_map(|e| e.deserialize().ok()) {
|
||||
self.handle_state_event(room.clone(), &event).await;
|
||||
}
|
||||
|
||||
for event in room_info
|
||||
.timeline
|
||||
.events
|
||||
.iter()
|
||||
.filter_map(|e| e.event.deserialize().ok())
|
||||
for event in
|
||||
room_info.timeline.events.iter().filter_map(|e| e.event.deserialize().ok())
|
||||
{
|
||||
self.handle_timeline_event(room.clone(), &event).await;
|
||||
}
|
||||
|
@ -161,31 +128,22 @@ impl Handler {
|
|||
|
||||
for (room_id, room_info) in &response.rooms.invite {
|
||||
if let Some(room) = self.get_room(room_id) {
|
||||
for event in room_info
|
||||
.invite_state
|
||||
.events
|
||||
.iter()
|
||||
.filter_map(|e| e.deserialize().ok())
|
||||
for event in
|
||||
room_info.invite_state.events.iter().filter_map(|e| e.deserialize().ok())
|
||||
{
|
||||
self.handle_stripped_state_event(room.clone(), &event).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for event in response
|
||||
.presence
|
||||
.events
|
||||
.iter()
|
||||
.filter_map(|e| e.deserialize().ok())
|
||||
{
|
||||
for event in response.presence.events.iter().filter_map(|e| e.deserialize().ok()) {
|
||||
self.on_presence_event(&event).await;
|
||||
}
|
||||
|
||||
for (room_id, notifications) in &response.notifications {
|
||||
if let Some(room) = self.get_room(&room_id) {
|
||||
for notification in notifications {
|
||||
self.on_room_notification(room.clone(), notification.clone())
|
||||
.await;
|
||||
self.on_room_notification(room.clone(), notification.clone()).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -249,8 +207,7 @@ impl Handler {
|
|||
self.on_room_tombstone(room, &tomb).await
|
||||
}
|
||||
AnySyncStateEvent::Custom(custom) => {
|
||||
self.on_custom_event(room, &CustomEvent::State(custom))
|
||||
.await
|
||||
self.on_custom_event(room, &CustomEvent::State(custom)).await
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
@ -268,8 +225,7 @@ impl Handler {
|
|||
}
|
||||
AnyStrippedStateEvent::RoomName(name) => self.on_stripped_state_name(room, &name).await,
|
||||
AnyStrippedStateEvent::RoomCanonicalAlias(canonical) => {
|
||||
self.on_stripped_state_canonical_alias(room, &canonical)
|
||||
.await
|
||||
self.on_stripped_state_canonical_alias(room, &canonical).await
|
||||
}
|
||||
AnyStrippedStateEvent::RoomAliases(aliases) => {
|
||||
self.on_stripped_state_aliases(room, &aliases).await
|
||||
|
@ -341,8 +297,9 @@ pub enum CustomEvent<'c> {
|
|||
StrippedState(&'c StrippedStateEvent<CustomEventContent>),
|
||||
}
|
||||
|
||||
/// This trait allows any type implementing `EventHandler` to specify event callbacks for each event.
|
||||
/// The `Client` calls each method when the corresponding event is received.
|
||||
/// This trait allows any type implementing `EventHandler` to specify event
|
||||
/// callbacks for each event. The `Client` calls each method when the
|
||||
/// corresponding event is received.
|
||||
///
|
||||
/// # Examples
|
||||
/// ```
|
||||
|
@ -427,8 +384,8 @@ pub trait EventHandler: Send + Sync {
|
|||
/// Fires when `Client` receives a `RoomEvent::Tombstone` event.
|
||||
async fn on_room_tombstone(&self, _: Room, _: &SyncStateEvent<TombstoneEventContent>) {}
|
||||
|
||||
/// Fires when `Client` receives room events that trigger notifications according to
|
||||
/// the push rules of the user.
|
||||
/// Fires when `Client` receives room events that trigger notifications
|
||||
/// according to the push rules of the user.
|
||||
async fn on_room_notification(&self, _: Room, _: Notification) {}
|
||||
|
||||
// `RoomEvent`s from `IncomingState`
|
||||
|
@ -453,7 +410,8 @@ pub trait EventHandler: Send + Sync {
|
|||
async fn on_state_join_rules(&self, _: Room, _: &SyncStateEvent<JoinRulesEventContent>) {}
|
||||
|
||||
// `AnyStrippedStateEvent`s
|
||||
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomMember` event.
|
||||
/// Fires when `Client` receives a
|
||||
/// `AnyStrippedStateEvent::StrippedRoomMember` event.
|
||||
async fn on_stripped_state_member(
|
||||
&self,
|
||||
_: Room,
|
||||
|
@ -461,32 +419,38 @@ pub trait EventHandler: Send + Sync {
|
|||
_: Option<MemberEventContent>,
|
||||
) {
|
||||
}
|
||||
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomName` event.
|
||||
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomName`
|
||||
/// event.
|
||||
async fn on_stripped_state_name(&self, _: Room, _: &StrippedStateEvent<NameEventContent>) {}
|
||||
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomCanonicalAlias` event.
|
||||
/// Fires when `Client` receives a
|
||||
/// `AnyStrippedStateEvent::StrippedRoomCanonicalAlias` event.
|
||||
async fn on_stripped_state_canonical_alias(
|
||||
&self,
|
||||
_: Room,
|
||||
_: &StrippedStateEvent<CanonicalAliasEventContent>,
|
||||
) {
|
||||
}
|
||||
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomAliases` event.
|
||||
/// Fires when `Client` receives a
|
||||
/// `AnyStrippedStateEvent::StrippedRoomAliases` event.
|
||||
async fn on_stripped_state_aliases(
|
||||
&self,
|
||||
_: Room,
|
||||
_: &StrippedStateEvent<AliasesEventContent>,
|
||||
) {
|
||||
}
|
||||
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomAvatar` event.
|
||||
/// Fires when `Client` receives a
|
||||
/// `AnyStrippedStateEvent::StrippedRoomAvatar` event.
|
||||
async fn on_stripped_state_avatar(&self, _: Room, _: &StrippedStateEvent<AvatarEventContent>) {}
|
||||
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomPowerLevels` event.
|
||||
/// Fires when `Client` receives a
|
||||
/// `AnyStrippedStateEvent::StrippedRoomPowerLevels` event.
|
||||
async fn on_stripped_state_power_levels(
|
||||
&self,
|
||||
_: Room,
|
||||
_: &StrippedStateEvent<PowerLevelsEventContent>,
|
||||
) {
|
||||
}
|
||||
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomJoinRules` event.
|
||||
/// Fires when `Client` receives a
|
||||
/// `AnyStrippedStateEvent::StrippedRoomJoinRules` event.
|
||||
async fn on_stripped_state_join_rules(
|
||||
&self,
|
||||
_: Room,
|
||||
|
@ -523,31 +487,33 @@ pub trait EventHandler: Send + Sync {
|
|||
/// Fires when `Client` receives a `NonRoomEvent::RoomAliases` event.
|
||||
async fn on_presence_event(&self, _: &PresenceEvent) {}
|
||||
|
||||
/// Fires when `Client` receives a `Event::Custom` event or if deserialization fails
|
||||
/// because the event was unknown to ruma.
|
||||
/// Fires when `Client` receives a `Event::Custom` event or if
|
||||
/// deserialization fails because the event was unknown to ruma.
|
||||
///
|
||||
/// The only guarantee this method can give about the event is that it is valid JSON.
|
||||
/// The only guarantee this method can give about the event is that it is
|
||||
/// valid JSON.
|
||||
async fn on_unrecognized_event(&self, _: Room, _: &RawJsonValue) {}
|
||||
|
||||
/// Fires when `Client` receives a `Event::Custom` event or if deserialization fails
|
||||
/// because the event was unknown to ruma.
|
||||
/// Fires when `Client` receives a `Event::Custom` event or if
|
||||
/// deserialization fails because the event was unknown to ruma.
|
||||
///
|
||||
/// The only guarantee this method can give about the event is that it is in the
|
||||
/// shape of a valid matrix event.
|
||||
/// The only guarantee this method can give about the event is that it is in
|
||||
/// the shape of a valid matrix event.
|
||||
async fn on_custom_event(&self, _: Room, _: &CustomEvent<'_>) {}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
use std::{sync::Arc, time::Duration};
|
||||
|
||||
use matrix_sdk_common::{async_trait, locks::Mutex};
|
||||
use matrix_sdk_test::{async_test, test_json};
|
||||
use mockito::{mock, Matcher};
|
||||
use std::{sync::Arc, time::Duration};
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
pub use wasm_bindgen_test::*;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct EvHandlerTest(Arc<Mutex<Vec<String>>>);
|
||||
|
||||
|
@ -640,56 +606,50 @@ mod test {
|
|||
}
|
||||
|
||||
// `AnyStrippedStateEvent`s
|
||||
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomMember` event.
|
||||
/// Fires when `Client` receives a
|
||||
/// `AnyStrippedStateEvent::StrippedRoomMember` event.
|
||||
async fn on_stripped_state_member(
|
||||
&self,
|
||||
_: Room,
|
||||
_: &StrippedStateEvent<MemberEventContent>,
|
||||
_: Option<MemberEventContent>,
|
||||
) {
|
||||
self.0
|
||||
.lock()
|
||||
.await
|
||||
.push("stripped state member".to_string())
|
||||
self.0.lock().await.push("stripped state member".to_string())
|
||||
}
|
||||
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomName` event.
|
||||
/// Fires when `Client` receives a
|
||||
/// `AnyStrippedStateEvent::StrippedRoomName` event.
|
||||
async fn on_stripped_state_name(&self, _: Room, _: &StrippedStateEvent<NameEventContent>) {
|
||||
self.0.lock().await.push("stripped state name".to_string())
|
||||
}
|
||||
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomCanonicalAlias` event.
|
||||
/// Fires when `Client` receives a
|
||||
/// `AnyStrippedStateEvent::StrippedRoomCanonicalAlias` event.
|
||||
async fn on_stripped_state_canonical_alias(
|
||||
&self,
|
||||
_: Room,
|
||||
_: &StrippedStateEvent<CanonicalAliasEventContent>,
|
||||
) {
|
||||
self.0
|
||||
.lock()
|
||||
.await
|
||||
.push("stripped state canonical".to_string())
|
||||
self.0.lock().await.push("stripped state canonical".to_string())
|
||||
}
|
||||
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomAliases` event.
|
||||
/// Fires when `Client` receives a
|
||||
/// `AnyStrippedStateEvent::StrippedRoomAliases` event.
|
||||
async fn on_stripped_state_aliases(
|
||||
&self,
|
||||
_: Room,
|
||||
_: &StrippedStateEvent<AliasesEventContent>,
|
||||
) {
|
||||
self.0
|
||||
.lock()
|
||||
.await
|
||||
.push("stripped state aliases".to_string())
|
||||
self.0.lock().await.push("stripped state aliases".to_string())
|
||||
}
|
||||
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomAvatar` event.
|
||||
/// Fires when `Client` receives a
|
||||
/// `AnyStrippedStateEvent::StrippedRoomAvatar` event.
|
||||
async fn on_stripped_state_avatar(
|
||||
&self,
|
||||
_: Room,
|
||||
_: &StrippedStateEvent<AvatarEventContent>,
|
||||
) {
|
||||
self.0
|
||||
.lock()
|
||||
.await
|
||||
.push("stripped state avatar".to_string())
|
||||
self.0.lock().await.push("stripped state avatar".to_string())
|
||||
}
|
||||
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomPowerLevels` event.
|
||||
/// Fires when `Client` receives a
|
||||
/// `AnyStrippedStateEvent::StrippedRoomPowerLevels` event.
|
||||
async fn on_stripped_state_power_levels(
|
||||
&self,
|
||||
_: Room,
|
||||
|
@ -697,7 +657,8 @@ mod test {
|
|||
) {
|
||||
self.0.lock().await.push("stripped state power".to_string())
|
||||
}
|
||||
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomJoinRules` event.
|
||||
/// Fires when `Client` receives a
|
||||
/// `AnyStrippedStateEvent::StrippedRoomJoinRules` event.
|
||||
async fn on_stripped_state_join_rules(
|
||||
&self,
|
||||
_: Room,
|
||||
|
@ -768,14 +729,11 @@ mod test {
|
|||
}
|
||||
|
||||
async fn mock_sync(client: &Client, response: String) {
|
||||
let _m = mock(
|
||||
"GET",
|
||||
Matcher::Regex(r"^/_matrix/client/r0/sync\?.*$".to_string()),
|
||||
)
|
||||
.with_status(200)
|
||||
.match_header("authorization", "Bearer 1234")
|
||||
.with_body(response)
|
||||
.create();
|
||||
let _m = mock("GET", Matcher::Regex(r"^/_matrix/client/r0/sync\?.*$".to_string()))
|
||||
.with_status(200)
|
||||
.match_header("authorization", "Bearer 1234")
|
||||
.with_body(response)
|
||||
.create();
|
||||
|
||||
let sync_settings = SyncSettings::new().timeout(Duration::from_millis(3000));
|
||||
let _response = client.sync_once(sync_settings).await.unwrap();
|
||||
|
@ -823,14 +781,7 @@ mod test {
|
|||
mock_sync(&client, test_json::INVITE_SYNC.to_string()).await;
|
||||
|
||||
let v = test_vec.lock().await;
|
||||
assert_eq!(
|
||||
v.as_slice(),
|
||||
[
|
||||
"stripped state name",
|
||||
"stripped state member",
|
||||
"presence event"
|
||||
],
|
||||
)
|
||||
assert_eq!(v.as_slice(), ["stripped state name", "stripped state member", "presence event"],)
|
||||
}
|
||||
|
||||
#[async_test]
|
||||
|
@ -897,15 +848,7 @@ mod test {
|
|||
mock_sync(&client, test_json::VOIP_SYNC.to_string()).await;
|
||||
|
||||
let v = test_vec.lock().await;
|
||||
assert_eq!(
|
||||
v.as_slice(),
|
||||
[
|
||||
"call invite",
|
||||
"call answer",
|
||||
"call candidates",
|
||||
"call hangup",
|
||||
],
|
||||
)
|
||||
assert_eq!(v.as_slice(), ["call invite", "call answer", "call candidates", "call hangup",],)
|
||||
}
|
||||
|
||||
#[async_test]
|
||||
|
|
|
@ -12,6 +12,8 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#[cfg(all(not(target_arch = "wasm32")))]
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::{convert::TryFrom, fmt::Debug, sync::Arc};
|
||||
|
||||
#[cfg(all(not(target_arch = "wasm32")))]
|
||||
|
@ -19,16 +21,13 @@ use backoff::{future::retry, Error as RetryError, ExponentialBackoff};
|
|||
#[cfg(all(not(target_arch = "wasm32")))]
|
||||
use http::StatusCode;
|
||||
use http::{HeaderValue, Response as HttpResponse};
|
||||
use reqwest::{Client, Response};
|
||||
#[cfg(all(not(target_arch = "wasm32")))]
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use tracing::trace;
|
||||
use url::Url;
|
||||
|
||||
use matrix_sdk_common::{
|
||||
api::r0::media::create_content, async_trait, locks::RwLock, AsyncTraitDeps, AuthScheme,
|
||||
FromHttpResponseError, IncomingResponse, SendAccessToken,
|
||||
};
|
||||
use reqwest::{Client, Response};
|
||||
use tracing::trace;
|
||||
use url::Url;
|
||||
|
||||
use crate::{
|
||||
error::HttpError, Bytes, BytesMut, ClientConfig, OutgoingRequest, RequestConfig, Session,
|
||||
|
@ -39,13 +38,16 @@ use crate::{
|
|||
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
|
||||
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
|
||||
pub trait HttpSend: AsyncTraitDeps {
|
||||
/// The method abstracting sending request types and receiving response types.
|
||||
/// The method abstracting sending request types and receiving response
|
||||
/// types.
|
||||
///
|
||||
/// This is called by the client every time it wants to send anything to a homeserver.
|
||||
/// This is called by the client every time it wants to send anything to a
|
||||
/// homeserver.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `request` - The http request that has been converted from a ruma `Request`.
|
||||
/// * `request` - The http request that has been converted from a ruma
|
||||
/// `Request`.
|
||||
///
|
||||
/// * `request_config` - The config used for this request.
|
||||
///
|
||||
|
@ -122,8 +124,7 @@ impl HttpClient {
|
|||
let request = if !self.request_config.assert_identity {
|
||||
self.try_into_http_request(request, session, config).await?
|
||||
} else {
|
||||
self.try_into_http_request_with_identy_assertion(request, session, config)
|
||||
.await?
|
||||
self.try_into_http_request_with_identy_assertion(request, session, config).await?
|
||||
};
|
||||
|
||||
self.inner.send_request(request, config).await
|
||||
|
@ -202,9 +203,7 @@ impl HttpClient {
|
|||
request: create_content::Request<'_>,
|
||||
config: Option<RequestConfig>,
|
||||
) -> Result<create_content::Response, HttpError> {
|
||||
let response = self
|
||||
.send_request(request, self.session.clone(), config)
|
||||
.await?;
|
||||
let response = self.send_request(request, self.session.clone(), config).await?;
|
||||
Ok(create_content::Response::try_from_http_response(response)?)
|
||||
}
|
||||
|
||||
|
@ -217,9 +216,7 @@ impl HttpClient {
|
|||
Request: OutgoingRequest + Debug,
|
||||
HttpError: From<FromHttpResponseError<Request::EndpointError>>,
|
||||
{
|
||||
let response = self
|
||||
.send_request(request, self.session.clone(), config)
|
||||
.await?;
|
||||
let response = self.send_request(request, self.session.clone(), config).await?;
|
||||
|
||||
trace!("Got response: {:?}", response);
|
||||
|
||||
|
@ -256,9 +253,7 @@ pub(crate) fn client_with_config(config: &ClientConfig) -> Result<Client, HttpEr
|
|||
|
||||
headers.insert(reqwest::header::USER_AGENT, user_agent);
|
||||
|
||||
http_client
|
||||
.default_headers(headers)
|
||||
.timeout(config.request_config.timeout)
|
||||
http_client.default_headers(headers).timeout(config.request_config.timeout)
|
||||
};
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
|
@ -274,9 +269,7 @@ async fn response_to_http_response(
|
|||
let status = response.status();
|
||||
|
||||
let mut http_builder = HttpResponse::builder().status(status);
|
||||
let headers = http_builder
|
||||
.headers_mut()
|
||||
.expect("Can't get the response builder headers");
|
||||
let headers = http_builder.headers_mut().expect("Can't get the response builder headers");
|
||||
|
||||
for (k, v) in response.headers_mut().drain() {
|
||||
if let Some(key) = k {
|
||||
|
@ -286,9 +279,7 @@ async fn response_to_http_response(
|
|||
|
||||
let body = response.bytes().await?;
|
||||
|
||||
Ok(http_builder
|
||||
.body(body)
|
||||
.expect("Can't construct a response using the given body"))
|
||||
Ok(http_builder.body(body).expect("Can't construct a response using the given body"))
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "wasm32"))]
|
||||
|
@ -329,18 +320,12 @@ async fn send_request(
|
|||
};
|
||||
|
||||
// Turn errors into permanent errors when the retry limit is reached
|
||||
let error_type = if stop {
|
||||
RetryError::Permanent
|
||||
} else {
|
||||
RetryError::Transient
|
||||
};
|
||||
let error_type = if stop { RetryError::Permanent } else { RetryError::Transient };
|
||||
|
||||
let request = request.try_clone().ok_or(HttpError::UnableToCloneRequest)?;
|
||||
|
||||
let response = client
|
||||
.execute(request)
|
||||
.await
|
||||
.map_err(|e| error_type(HttpError::Reqwest(e)))?;
|
||||
let response =
|
||||
client.execute(request).await.map_err(|e| error_type(HttpError::Reqwest(e)))?;
|
||||
|
||||
let status_code = response.status();
|
||||
// TODO TOO_MANY_REQUESTS will have a retry timeout which we should
|
||||
|
|
|
@ -17,19 +17,21 @@
|
|||
//!
|
||||
//! # Enabling logging
|
||||
//!
|
||||
//! Users of the matrix-sdk crate can enable log output by depending on the `tracing-subscriber`
|
||||
//! crate and including the following line in their application (e.g. at the start of `main`):
|
||||
//! Users of the matrix-sdk crate can enable log output by depending on the
|
||||
//! `tracing-subscriber` crate and including the following line in their
|
||||
//! application (e.g. at the start of `main`):
|
||||
//!
|
||||
//! ```rust
|
||||
//! tracing_subscriber::fmt::init();
|
||||
//! ```
|
||||
//!
|
||||
//! The log output is controlled via the `RUST_LOG` environment variable by setting it to one of
|
||||
//! the `error`, `warn`, `info`, `debug` or `trace` levels. The output is printed to stdout.
|
||||
//! The log output is controlled via the `RUST_LOG` environment variable by
|
||||
//! setting it to one of the `error`, `warn`, `info`, `debug` or `trace` levels.
|
||||
//! The output is printed to stdout.
|
||||
//!
|
||||
//! The `RUST_LOG` variable also supports a more advanced syntax for filtering log output more
|
||||
//! precisely, for instance with crate-level granularity. For more information on this, check out
|
||||
//! the [tracing_subscriber
|
||||
//! The `RUST_LOG` variable also supports a more advanced syntax for filtering
|
||||
//! log output more precisely, for instance with crate-level granularity. For
|
||||
//! more information on this, check out the [tracing_subscriber
|
||||
//! documentation](https://tracing.rs/tracing_subscriber/filter/struct.envfilter).
|
||||
//!
|
||||
//! # Crate Feature Flags
|
||||
|
@ -44,10 +46,13 @@
|
|||
//! * `markdown`: Support for sending markdown formatted messages.
|
||||
//! * `socks`: Enables SOCKS support in reqwest, the default HTTP client.
|
||||
//! * `sso_login`: Enables SSO login with a local http server.
|
||||
//! * `require_auth_for_profile_requests`: Whether to send the access token in the authentication
|
||||
//! header when calling endpoints that retrieve profile data. This matches the synapse
|
||||
//! configuration `require_auth_for_profile_requests`. Enabled by default.
|
||||
//! * `appservice`: Enables low-level appservice functionality. For an high-level API there's the
|
||||
//! * `require_auth_for_profile_requests`: Whether to send the access token in
|
||||
//! the authentication
|
||||
//! header when calling endpoints that retrieve profile data. This matches the
|
||||
//! synapse configuration `require_auth_for_profile_requests`. Enabled by
|
||||
//! default.
|
||||
//! * `appservice`: Enables low-level appservice functionality. For an
|
||||
//! high-level API there's the
|
||||
//! `matrix-sdk-appservice` crate
|
||||
|
||||
#![deny(
|
||||
|
@ -71,6 +76,7 @@ compile_error!("only one of 'native-tls' or 'rustls-tls' features can be enabled
|
|||
#[cfg(all(feature = "sso_login", target_arch = "wasm32"))]
|
||||
compile_error!("'sso_login' cannot be enabled on 'wasm32' arch");
|
||||
|
||||
pub use bytes::{Bytes, BytesMut};
|
||||
#[cfg(feature = "encryption")]
|
||||
#[cfg_attr(feature = "docs", doc(cfg(encryption)))]
|
||||
pub use matrix_sdk_base::crypto::{EncryptionInfo, LocalTrust};
|
||||
|
@ -78,8 +84,6 @@ pub use matrix_sdk_base::{
|
|||
Error as BaseError, Room as BaseRoom, RoomInfo, RoomMember as BaseRoomMember, RoomType,
|
||||
Session, StateChanges, StoreError,
|
||||
};
|
||||
|
||||
pub use bytes::{Bytes, BytesMut};
|
||||
pub use matrix_sdk_common::*;
|
||||
pub use reqwest;
|
||||
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
use std::{ops::Deref, sync::Arc};
|
||||
|
||||
use matrix_sdk_base::{deserialized_responses::MembersResponse, identifiers::UserId};
|
||||
use matrix_sdk_common::{
|
||||
api::r0::{
|
||||
|
@ -8,11 +10,10 @@ use matrix_sdk_common::{
|
|||
locks::Mutex,
|
||||
};
|
||||
|
||||
use std::{ops::Deref, sync::Arc};
|
||||
|
||||
use crate::{BaseRoom, Client, Result, RoomMember};
|
||||
|
||||
/// A struct containing methodes that are common for Joined, Invited and Left Rooms
|
||||
/// A struct containing methodes that are common for Joined, Invited and Left
|
||||
/// Rooms
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Common {
|
||||
inner: BaseRoom,
|
||||
|
@ -36,10 +37,7 @@ impl Common {
|
|||
/// * `room` - The underlaying room.
|
||||
pub fn new(client: Client, room: BaseRoom) -> Self {
|
||||
// TODO: Make this private
|
||||
Self {
|
||||
inner: room,
|
||||
client,
|
||||
}
|
||||
Self { inner: room, client }
|
||||
}
|
||||
|
||||
/// Leave this room.
|
||||
|
@ -111,9 +109,9 @@ impl Common {
|
|||
}
|
||||
}
|
||||
|
||||
/// Sends a request to `/_matrix/client/r0/rooms/{room_id}/messages` and returns
|
||||
/// a `get_message_events::Response` that contains a chunk of room and state events
|
||||
/// (`AnyRoomEvent` and `AnyStateEvent`).
|
||||
/// Sends a request to `/_matrix/client/r0/rooms/{room_id}/messages` and
|
||||
/// returns a `get_message_events::Response` that contains a chunk of
|
||||
/// room and state events (`AnyRoomEvent` and `AnyStateEvent`).
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
|
@ -152,35 +150,25 @@ impl Common {
|
|||
|
||||
pub(crate) async fn request_members(&self) -> Result<Option<MembersResponse>> {
|
||||
#[allow(clippy::map_clone)]
|
||||
if let Some(mutex) = self
|
||||
.client
|
||||
.members_request_locks
|
||||
.get(self.inner.room_id())
|
||||
.map(|m| m.clone())
|
||||
if let Some(mutex) =
|
||||
self.client.members_request_locks.get(self.inner.room_id()).map(|m| m.clone())
|
||||
{
|
||||
mutex.lock().await;
|
||||
|
||||
Ok(None)
|
||||
} else {
|
||||
let mutex = Arc::new(Mutex::new(()));
|
||||
self.client
|
||||
.members_request_locks
|
||||
.insert(self.inner.room_id().clone(), mutex.clone());
|
||||
self.client.members_request_locks.insert(self.inner.room_id().clone(), mutex.clone());
|
||||
|
||||
let _guard = mutex.lock().await;
|
||||
|
||||
let request = get_member_events::Request::new(self.inner.room_id());
|
||||
let response = self.client.send(request, None).await?;
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.base_client
|
||||
.receive_members(self.inner.room_id(), &response)
|
||||
.await?;
|
||||
let response =
|
||||
self.client.base_client.receive_members(self.inner.room_id(), &response).await?;
|
||||
|
||||
self.client
|
||||
.members_request_locks
|
||||
.remove(self.inner.room_id());
|
||||
self.client.members_request_locks.remove(self.inner.room_id());
|
||||
|
||||
Ok(Some(response))
|
||||
}
|
||||
|
@ -248,9 +236,9 @@ impl Common {
|
|||
|
||||
/// Get all the joined members of this room.
|
||||
///
|
||||
/// *Note*: This method will not fetch the members from the homeserver if the
|
||||
/// member list isn't synchronized due to member lazy loading. Thus, members
|
||||
/// could be missing from the list.
|
||||
/// *Note*: This method will not fetch the members from the homeserver if
|
||||
/// the member list isn't synchronized due to member lazy loading. Thus,
|
||||
/// members could be missing from the list.
|
||||
///
|
||||
/// Use [joined_members()](#method.joined_members) if you want to ensure to
|
||||
/// always get the full member list.
|
||||
|
@ -284,9 +272,9 @@ impl Common {
|
|||
|
||||
/// Get a specific member of this room.
|
||||
///
|
||||
/// *Note*: This method will not fetch the members from the homeserver if the
|
||||
/// member list isn't synchronized due to member lazy loading. Thus, members
|
||||
/// could be missing.
|
||||
/// *Note*: This method will not fetch the members from the homeserver if
|
||||
/// the member list isn't synchronized due to member lazy loading. Thus,
|
||||
/// members could be missing.
|
||||
///
|
||||
/// Use [get_member()](#method.get_member) if you want to ensure to always
|
||||
/// have the full member list to chose from.
|
||||
|
@ -295,7 +283,6 @@ impl Common {
|
|||
///
|
||||
/// * `user_id` - The ID of the user that should be fetched out of the
|
||||
/// store.
|
||||
///
|
||||
pub async fn get_member_no_sync(&self, user_id: &UserId) -> Result<Option<RoomMember>> {
|
||||
Ok(self
|
||||
.inner
|
||||
|
@ -304,7 +291,8 @@ impl Common {
|
|||
.map(|member| RoomMember::new(self.client.clone(), member)))
|
||||
}
|
||||
|
||||
/// Get all members for this room, includes invited, joined and left members.
|
||||
/// Get all members for this room, includes invited, joined and left
|
||||
/// members.
|
||||
///
|
||||
/// *Note*: This method will fetch the members from the homeserver if the
|
||||
/// member list isn't synchronized due to member lazy loading. Because of
|
||||
|
@ -317,11 +305,12 @@ impl Common {
|
|||
self.members_no_sync().await
|
||||
}
|
||||
|
||||
/// Get all members for this room, includes invited, joined and left members.
|
||||
/// Get all members for this room, includes invited, joined and left
|
||||
/// members.
|
||||
///
|
||||
/// *Note*: This method will not fetch the members from the homeserver if the
|
||||
/// member list isn't synchronized due to member lazy loading. Thus, members
|
||||
/// could be missing.
|
||||
/// *Note*: This method will not fetch the members from the homeserver if
|
||||
/// the member list isn't synchronized due to member lazy loading. Thus,
|
||||
/// members could be missing.
|
||||
///
|
||||
/// Use [members()](#method.members) if you want to ensure to always get
|
||||
/// the full member list.
|
||||
|
|
|
@ -1,17 +1,20 @@
|
|||
use crate::{room::Common, BaseRoom, Client, Result, RoomType};
|
||||
use std::ops::Deref;
|
||||
|
||||
use crate::{room::Common, BaseRoom, Client, Result, RoomType};
|
||||
|
||||
/// A room in the invited state.
|
||||
///
|
||||
/// This struct contains all methodes specific to a `Room` with type `RoomType::Invited`.
|
||||
/// Operations may fail once the underlaying `Room` changes `RoomType`.
|
||||
/// This struct contains all methodes specific to a `Room` with type
|
||||
/// `RoomType::Invited`. Operations may fail once the underlaying `Room` changes
|
||||
/// `RoomType`.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Invited {
|
||||
pub(crate) inner: Common,
|
||||
}
|
||||
|
||||
impl Invited {
|
||||
/// Create a new `room::Invited` if the underlaying `Room` has type `RoomType::Invited`.
|
||||
/// Create a new `room::Invited` if the underlaying `Room` has type
|
||||
/// `RoomType::Invited`.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `client` - The client used to make requests.
|
||||
|
@ -20,9 +23,7 @@ impl Invited {
|
|||
pub fn new(client: Client, room: BaseRoom) -> Option<Self> {
|
||||
// TODO: Make this private
|
||||
if room.room_type() == RoomType::Invited {
|
||||
Some(Self {
|
||||
inner: Common::new(client, room),
|
||||
})
|
||||
Some(Self { inner: Common::new(client, room) })
|
||||
} else {
|
||||
None
|
||||
}
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
use crate::{room::Common, BaseRoom, Client, Result, RoomType};
|
||||
#[cfg(feature = "encryption")]
|
||||
use std::sync::Arc;
|
||||
use std::{io::Read, ops::Deref};
|
||||
|
||||
#[cfg(feature = "encryption")]
|
||||
use std::sync::Arc;
|
||||
|
||||
use matrix_sdk_base::crypto::AttachmentEncryptor;
|
||||
#[cfg(feature = "encryption")]
|
||||
use matrix_sdk_common::locks::Mutex;
|
||||
use matrix_sdk_common::{
|
||||
api::r0::{
|
||||
membership::{
|
||||
|
@ -34,25 +36,20 @@ use matrix_sdk_common::{
|
|||
receipt::ReceiptType,
|
||||
uuid::Uuid,
|
||||
};
|
||||
|
||||
use mime::{self, Mime};
|
||||
|
||||
#[cfg(feature = "encryption")]
|
||||
use matrix_sdk_common::locks::Mutex;
|
||||
|
||||
#[cfg(feature = "encryption")]
|
||||
use matrix_sdk_base::crypto::AttachmentEncryptor;
|
||||
|
||||
#[cfg(feature = "encryption")]
|
||||
use tracing::instrument;
|
||||
|
||||
use crate::{room::Common, BaseRoom, Client, Result, RoomType};
|
||||
|
||||
const TYPING_NOTICE_TIMEOUT: Duration = Duration::from_secs(4);
|
||||
const TYPING_NOTICE_RESEND_TIMEOUT: Duration = Duration::from_secs(3);
|
||||
|
||||
/// A room in the joined state.
|
||||
///
|
||||
/// The `JoinedRoom` contains all methodes specific to a `Room` with type `RoomType::Joined`.
|
||||
/// Operations may fail once the underlaying `Room` changes `RoomType`.
|
||||
/// The `JoinedRoom` contains all methodes specific to a `Room` with type
|
||||
/// `RoomType::Joined`. Operations may fail once the underlaying `Room` changes
|
||||
/// `RoomType`.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Joined {
|
||||
pub(crate) inner: Common,
|
||||
|
@ -67,7 +64,8 @@ impl Deref for Joined {
|
|||
}
|
||||
|
||||
impl Joined {
|
||||
/// Create a new `room::Joined` if the underlaying `BaseRoom` has type `RoomType::Joined`.
|
||||
/// Create a new `room::Joined` if the underlaying `BaseRoom` has type
|
||||
/// `RoomType::Joined`.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `client` - The client used to make requests.
|
||||
|
@ -76,9 +74,7 @@ impl Joined {
|
|||
pub fn new(client: Client, room: BaseRoom) -> Option<Self> {
|
||||
// TODO: Make this private
|
||||
if room.room_type() == RoomType::Joined {
|
||||
Some(Self {
|
||||
inner: Common::new(client, room),
|
||||
})
|
||||
Some(Self { inner: Common::new(client, room) })
|
||||
} else {
|
||||
None
|
||||
}
|
||||
|
@ -97,9 +93,7 @@ impl Joined {
|
|||
///
|
||||
/// * `reason` - The reason for banning this user.
|
||||
pub async fn ban_user(&self, user_id: &UserId, reason: Option<&str>) -> Result<()> {
|
||||
let request = assign!(ban_user::Request::new(self.inner.room_id(), user_id), {
|
||||
reason
|
||||
});
|
||||
let request = assign!(ban_user::Request::new(self.inner.room_id(), user_id), { reason });
|
||||
self.client.send(request, None).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
@ -108,13 +102,12 @@ impl Joined {
|
|||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `user_id` - The `UserId` of the user that should be kicked out of the room.
|
||||
/// * `user_id` - The `UserId` of the user that should be kicked out of the
|
||||
/// room.
|
||||
///
|
||||
/// * `reason` - Optional reason why the room member is being kicked out.
|
||||
pub async fn kick_user(&self, user_id: &UserId, reason: Option<&str>) -> Result<()> {
|
||||
let request = assign!(kick_user::Request::new(self.inner.room_id(), user_id), {
|
||||
reason
|
||||
});
|
||||
let request = assign!(kick_user::Request::new(self.inner.room_id(), user_id), { reason });
|
||||
self.client.send(request, None).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
@ -148,9 +141,10 @@ impl Joined {
|
|||
|
||||
/// Activate typing notice for this room.
|
||||
///
|
||||
/// The typing notice remains active for 4s. It can be deactivate at any point by setting
|
||||
/// typing to `false`. If this method is called while the typing notice is active nothing will happen.
|
||||
/// This method can be called on every key stroke, since it will do nothing while typing is
|
||||
/// The typing notice remains active for 4s. It can be deactivate at any
|
||||
/// point by setting typing to `false`. If this method is called while
|
||||
/// the typing notice is active nothing will happen. This method can be
|
||||
/// called on every key stroke, since it will do nothing while typing is
|
||||
/// active.
|
||||
///
|
||||
/// # Arguments
|
||||
|
@ -183,21 +177,23 @@ impl Joined {
|
|||
/// # });
|
||||
/// ```
|
||||
pub async fn typing_notice(&self, typing: bool) -> Result<()> {
|
||||
// Only send a request to the homeserver if the old timeout has elapsed or the typing
|
||||
// notice changed state within the TYPING_NOTICE_TIMEOUT
|
||||
// Only send a request to the homeserver if the old timeout has elapsed
|
||||
// or the typing notice changed state within the
|
||||
// TYPING_NOTICE_TIMEOUT
|
||||
let send =
|
||||
if let Some(typing_time) = self.client.typing_notice_times.get(self.inner.room_id()) {
|
||||
if typing_time.elapsed() > TYPING_NOTICE_RESEND_TIMEOUT {
|
||||
// We always reactivate the typing notice if typing is true or we may need to
|
||||
// deactivate it if it's currently active if typing is false
|
||||
// We always reactivate the typing notice if typing is true or
|
||||
// we may need to deactivate it if it's
|
||||
// currently active if typing is false
|
||||
typing || typing_time.elapsed() <= TYPING_NOTICE_TIMEOUT
|
||||
} else {
|
||||
// Only send a request when we need to deactivate typing
|
||||
!typing
|
||||
}
|
||||
} else {
|
||||
// Typing notice is currently deactivated, therefore, send a request only when it's
|
||||
// about to be activated
|
||||
// Typing notice is currently deactivated, therefore, send a request
|
||||
// only when it's about to be activated
|
||||
typing
|
||||
};
|
||||
|
||||
|
@ -220,11 +216,13 @@ impl Joined {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
/// Send a request to notify this room that the user has read specific event.
|
||||
/// Send a request to notify this room that the user has read specific
|
||||
/// event.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `event_id` - The `EventId` specifies the event to set the read receipt on.
|
||||
/// * `event_id` - The `EventId` specifies the event to set the read receipt
|
||||
/// on.
|
||||
pub async fn read_receipt(&self, event_id: &EventId) -> Result<()> {
|
||||
let request =
|
||||
create_receipt::Request::new(self.inner.room_id(), ReceiptType::Read, event_id);
|
||||
|
@ -233,22 +231,23 @@ impl Joined {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
/// Send a request to notify this room that the user has read up to specific event.
|
||||
/// Send a request to notify this room that the user has read up to specific
|
||||
/// event.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * fully_read - The `EventId` of the event the user has read to.
|
||||
///
|
||||
/// * read_receipt - An `EventId` to specify the event to set the read receipt on.
|
||||
/// * read_receipt - An `EventId` to specify the event to set the read
|
||||
/// receipt on.
|
||||
pub async fn read_marker(
|
||||
&self,
|
||||
fully_read: &EventId,
|
||||
read_receipt: Option<&EventId>,
|
||||
) -> Result<()> {
|
||||
let request = assign!(
|
||||
set_read_marker::Request::new(self.inner.room_id(), fully_read),
|
||||
{ read_receipt }
|
||||
);
|
||||
let request = assign!(set_read_marker::Request::new(self.inner.room_id(), fully_read), {
|
||||
read_receipt
|
||||
});
|
||||
|
||||
self.client.send(request, None).await?;
|
||||
Ok(())
|
||||
|
@ -266,11 +265,8 @@ impl Joined {
|
|||
// TODO expose this publicly so people can pre-share a group session if
|
||||
// e.g. a user starts to type a message for a room.
|
||||
#[allow(clippy::map_clone)]
|
||||
if let Some(mutex) = self
|
||||
.client
|
||||
.group_session_locks
|
||||
.get(self.inner.room_id())
|
||||
.map(|m| m.clone())
|
||||
if let Some(mutex) =
|
||||
self.client.group_session_locks.get(self.inner.room_id()).map(|m| m.clone())
|
||||
{
|
||||
// If a group session share request is already going on,
|
||||
// await the release of the lock.
|
||||
|
@ -279,23 +275,14 @@ impl Joined {
|
|||
// Otherwise create a new lock and share the group
|
||||
// session.
|
||||
let mutex = Arc::new(Mutex::new(()));
|
||||
self.client
|
||||
.group_session_locks
|
||||
.insert(self.inner.room_id().clone(), mutex.clone());
|
||||
self.client.group_session_locks.insert(self.inner.room_id().clone(), mutex.clone());
|
||||
|
||||
let _guard = mutex.lock().await;
|
||||
|
||||
{
|
||||
let joined = self
|
||||
.client
|
||||
.store()
|
||||
.get_joined_user_ids(self.inner.room_id())
|
||||
.await?;
|
||||
let invited = self
|
||||
.client
|
||||
.store()
|
||||
.get_invited_user_ids(self.inner.room_id())
|
||||
.await?;
|
||||
let joined = self.client.store().get_joined_user_ids(self.inner.room_id()).await?;
|
||||
let invited =
|
||||
self.client.store().get_invited_user_ids(self.inner.room_id()).await?;
|
||||
let members = joined.iter().chain(&invited);
|
||||
self.client.claim_one_time_keys(members).await?;
|
||||
};
|
||||
|
@ -308,10 +295,7 @@ impl Joined {
|
|||
// session as using it would end up in undecryptable
|
||||
// messages.
|
||||
if let Err(r) = response {
|
||||
self.client
|
||||
.base_client
|
||||
.invalidate_group_session(self.inner.room_id())
|
||||
.await?;
|
||||
self.client.base_client.invalidate_group_session(self.inner.room_id()).await?;
|
||||
return Err(r);
|
||||
}
|
||||
}
|
||||
|
@ -328,19 +312,13 @@ impl Joined {
|
|||
#[cfg_attr(feature = "docs", doc(cfg(encryption)))]
|
||||
#[instrument]
|
||||
async fn share_group_session(&self) -> Result<()> {
|
||||
let mut requests = self
|
||||
.client
|
||||
.base_client
|
||||
.share_group_session(self.inner.room_id())
|
||||
.await?;
|
||||
let mut requests =
|
||||
self.client.base_client.share_group_session(self.inner.room_id()).await?;
|
||||
|
||||
for request in requests.drain(..) {
|
||||
let response = self.client.send_to_device(&request).await?;
|
||||
|
||||
self.client
|
||||
.base_client
|
||||
.mark_request_as_sent(&request.txn_id, &response)
|
||||
.await?;
|
||||
self.client.base_client.mark_request_as_sent(&request.txn_id, &response).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
@ -407,10 +385,7 @@ impl Joined {
|
|||
|
||||
self.preshare_group_session().await?;
|
||||
AnyMessageEventContent::RoomEncrypted(
|
||||
self.client
|
||||
.base_client
|
||||
.encrypt(self.inner.room_id(), content)
|
||||
.await?,
|
||||
self.client.base_client.encrypt(self.inner.room_id(), content).await?,
|
||||
)
|
||||
} else {
|
||||
content.into()
|
||||
|
@ -430,8 +405,9 @@ impl Joined {
|
|||
/// If the room is encrypted and the encryption feature is enabled the
|
||||
/// upload will be encrypted.
|
||||
///
|
||||
/// This is a convenience method that calls the [`Client::upload()`](#Client::method.upload)
|
||||
/// and afterwards the [`send()`](#method.send).
|
||||
/// This is a convenience method that calls the
|
||||
/// [`Client::upload()`](#Client::method.upload) and afterwards the
|
||||
/// [`send()`](#method.send).
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `body` - A textual representation of the media that is going to be
|
||||
|
@ -538,11 +514,8 @@ impl Joined {
|
|||
}),
|
||||
};
|
||||
|
||||
self.send(
|
||||
AnyMessageEventContent::RoomMessage(MessageEventContent::new(content)),
|
||||
txn_id,
|
||||
)
|
||||
.await
|
||||
self.send(AnyMessageEventContent::RoomMessage(MessageEventContent::new(content)), txn_id)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Send a room state event to the homeserver.
|
||||
|
@ -639,10 +612,10 @@ impl Joined {
|
|||
txn_id: Option<Uuid>,
|
||||
) -> Result<redact_event::Response> {
|
||||
let txn_id = txn_id.unwrap_or_else(Uuid::new_v4).to_string();
|
||||
let request = assign!(
|
||||
redact_event::Request::new(self.inner.room_id(), event_id, &txn_id),
|
||||
{ reason }
|
||||
);
|
||||
let request =
|
||||
assign!(redact_event::Request::new(self.inner.room_id(), event_id, &txn_id), {
|
||||
reason
|
||||
});
|
||||
|
||||
self.client.send(request, None).await
|
||||
}
|
||||
|
|
|
@ -1,19 +1,22 @@
|
|||
use crate::{room::Common, BaseRoom, Client, Result, RoomType};
|
||||
use std::ops::Deref;
|
||||
|
||||
use matrix_sdk_common::api::r0::membership::forget_room;
|
||||
|
||||
use crate::{room::Common, BaseRoom, Client, Result, RoomType};
|
||||
|
||||
/// A room in the left state.
|
||||
///
|
||||
/// This struct contains all methodes specific to a `Room` with type `RoomType::Left`.
|
||||
/// Operations may fail once the underlaying `Room` changes `RoomType`.
|
||||
/// This struct contains all methodes specific to a `Room` with type
|
||||
/// `RoomType::Left`. Operations may fail once the underlaying `Room` changes
|
||||
/// `RoomType`.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Left {
|
||||
pub(crate) inner: Common,
|
||||
}
|
||||
|
||||
impl Left {
|
||||
/// Create a new `room::Left` if the underlaying `Room` has type `RoomType::Left`.
|
||||
/// Create a new `room::Left` if the underlaying `Room` has type
|
||||
/// `RoomType::Left`.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `client` - The client used to make requests.
|
||||
|
@ -22,9 +25,7 @@ impl Left {
|
|||
pub fn new(client: Client, room: BaseRoom) -> Option<Self> {
|
||||
// TODO: Make this private
|
||||
if room.room_type() == RoomType::Left {
|
||||
Some(Self {
|
||||
inner: Common::new(client, room),
|
||||
})
|
||||
Some(Self { inner: Common::new(client, room) })
|
||||
} else {
|
||||
None
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use matrix_sdk_common::api::r0::media::{get_content, get_content_thumbnail};
|
||||
|
||||
use std::ops::Deref;
|
||||
|
||||
use matrix_sdk_common::api::r0::media::{get_content, get_content_thumbnail};
|
||||
|
||||
use crate::{BaseRoomMember, Client, Result};
|
||||
|
||||
/// The high-level `RoomMember` representation
|
||||
|
@ -21,10 +21,7 @@ impl Deref for RoomMember {
|
|||
|
||||
impl RoomMember {
|
||||
pub(crate) fn new(client: Client, member: BaseRoomMember) -> Self {
|
||||
Self {
|
||||
inner: member,
|
||||
client,
|
||||
}
|
||||
Self { inner: member, client }
|
||||
}
|
||||
|
||||
/// Gets the avatar of this member, if set.
|
||||
|
|
|
@ -26,11 +26,7 @@ impl AppserviceEventHandler {
|
|||
#[async_trait]
|
||||
impl EventHandler for AppserviceEventHandler {
|
||||
async fn on_room_member(&self, room: Room, event: &SyncStateEvent<MemberEventContent>) {
|
||||
if !self
|
||||
.appservice
|
||||
.user_id_is_in_namespace(&event.state_key)
|
||||
.unwrap()
|
||||
{
|
||||
if !self.appservice.user_id_is_in_namespace(&event.state_key).unwrap() {
|
||||
dbg!("not an appservice user");
|
||||
return;
|
||||
}
|
||||
|
@ -38,11 +34,7 @@ impl EventHandler for AppserviceEventHandler {
|
|||
if let MembershipState::Invite = event.content.membership {
|
||||
let user_id = UserId::try_from(event.state_key.clone()).unwrap();
|
||||
|
||||
let client = self
|
||||
.appservice
|
||||
.client_with_localpart(user_id.localpart())
|
||||
.await
|
||||
.unwrap();
|
||||
let client = self.appservice.client_with_localpart(user_id.localpart()).await.unwrap();
|
||||
|
||||
client.join_room_by_id(room.room_id()).await.unwrap();
|
||||
}
|
||||
|
@ -51,10 +43,7 @@ impl EventHandler for AppserviceEventHandler {
|
|||
|
||||
#[actix_web::main]
|
||||
pub async fn main() -> std::io::Result<()> {
|
||||
env::set_var(
|
||||
"RUST_LOG",
|
||||
"actix_web=debug,actix_server=info,matrix_sdk=debug",
|
||||
);
|
||||
env::set_var("RUST_LOG", "actix_web=debug,actix_server=info,matrix_sdk=debug");
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
let homeserver_url = "http://localhost:8008";
|
||||
|
@ -62,16 +51,11 @@ pub async fn main() -> std::io::Result<()> {
|
|||
let registration =
|
||||
AppserviceRegistration::try_from_yaml_file("./tests/registration.yaml").unwrap();
|
||||
|
||||
let appservice = Appservice::new(homeserver_url, server_name, registration)
|
||||
.await
|
||||
.unwrap();
|
||||
let appservice = Appservice::new(homeserver_url, server_name, registration).await.unwrap();
|
||||
|
||||
let event_handler = AppserviceEventHandler::new(appservice.clone());
|
||||
|
||||
appservice
|
||||
.client()
|
||||
.set_event_handler(Box::new(event_handler))
|
||||
.await;
|
||||
appservice.client().set_event_handler(Box::new(event_handler)).await;
|
||||
|
||||
HttpServer::new(move || App::new().service(appservice.actix_service()))
|
||||
.bind(("0.0.0.0", 8090))?
|
||||
|
|
|
@ -17,6 +17,7 @@ use std::{
|
|||
pin::Pin,
|
||||
};
|
||||
|
||||
pub use actix_web::Scope;
|
||||
use actix_web::{
|
||||
dev::Payload,
|
||||
error::PayloadError,
|
||||
|
@ -30,8 +31,6 @@ use futures::Future;
|
|||
use futures_util::{TryFutureExt, TryStreamExt};
|
||||
use matrix_sdk::api_appservice as api;
|
||||
|
||||
pub use actix_web::Scope;
|
||||
|
||||
use crate::{error::Error, Appservice};
|
||||
|
||||
pub async fn run_server(
|
||||
|
@ -53,10 +52,7 @@ pub fn get_scope() -> Scope {
|
|||
}
|
||||
|
||||
fn gen_scope(scope: &str) -> Scope {
|
||||
web::scope(scope)
|
||||
.service(push_transactions)
|
||||
.service(query_user_id)
|
||||
.service(query_room_alias)
|
||||
web::scope(scope).service(push_transactions).service(query_user_id).service(query_room_alias)
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
|
@ -69,11 +65,7 @@ async fn push_transactions(
|
|||
return Ok(HttpResponse::Unauthorized().finish());
|
||||
}
|
||||
|
||||
appservice
|
||||
.client()
|
||||
.receive_transaction(request.incoming)
|
||||
.await
|
||||
.unwrap();
|
||||
appservice.client().receive_transaction(request.incoming).await.unwrap();
|
||||
|
||||
Ok(HttpResponse::Ok().json("{}"))
|
||||
}
|
||||
|
@ -136,13 +128,9 @@ impl<T: matrix_sdk::IncomingRequest> FromRequest for IncomingRequest<T> {
|
|||
uri
|
||||
};
|
||||
|
||||
let mut builder = http::request::Builder::new()
|
||||
.method(request.method())
|
||||
.uri(uri);
|
||||
let mut builder = http::request::Builder::new().method(request.method()).uri(uri);
|
||||
|
||||
let headers = builder
|
||||
.headers_mut()
|
||||
.ok_or(Error::UnknownHttpRequestBuilder)?;
|
||||
let headers = builder.headers_mut().ok_or(Error::UnknownHttpRequestBuilder)?;
|
||||
for (key, value) in request.headers().iter() {
|
||||
headers.append(key, value.to_owned());
|
||||
}
|
||||
|
@ -158,10 +146,7 @@ impl<T: matrix_sdk::IncomingRequest> FromRequest for IncomingRequest<T> {
|
|||
let access_token = match request.uri().query() {
|
||||
Some(query) => {
|
||||
let query: Vec<(String, String)> = matrix_sdk::urlencoded::from_str(query)?;
|
||||
query
|
||||
.into_iter()
|
||||
.find(|(key, _)| key == "access_token")
|
||||
.map(|(_, value)| value)
|
||||
query.into_iter().find(|(key, _)| key == "access_token").map(|(_, value)| value)
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
|
|
|
@ -14,11 +14,14 @@
|
|||
|
||||
//! Matrix [Application Service] library
|
||||
//!
|
||||
//! The appservice crate aims to provide a batteries-included experience. That means that we
|
||||
//! * ship with functionality to configure your webserver crate or simply run the webserver for you
|
||||
//! The appservice crate aims to provide a batteries-included experience. That
|
||||
//! means that we
|
||||
//! * ship with functionality to configure your webserver crate or simply run
|
||||
//! the webserver for you
|
||||
//! * receive and validate requests from the homeserver correctly
|
||||
//! * allow calling the homeserver with proper virtual user identity assertion
|
||||
//! * have the goal to have a consistent room state available by leveraging the stores that the matrix-sdk provides
|
||||
//! * have the goal to have a consistent room state available by leveraging the
|
||||
//! stores that the matrix-sdk provides
|
||||
//!
|
||||
//! # Quickstart
|
||||
//!
|
||||
|
@ -62,6 +65,8 @@ use std::{
|
|||
};
|
||||
|
||||
use http::Uri;
|
||||
#[doc(inline)]
|
||||
pub use matrix_sdk::api_appservice as api;
|
||||
use matrix_sdk::{
|
||||
api::{
|
||||
error::ErrorKind,
|
||||
|
@ -81,9 +86,6 @@ use regex::Regex;
|
|||
use tracing::error;
|
||||
use tracing::warn;
|
||||
|
||||
#[doc(inline)]
|
||||
pub use matrix_sdk::api_appservice as api;
|
||||
|
||||
#[cfg(feature = "actix")]
|
||||
mod actix;
|
||||
mod error;
|
||||
|
@ -104,9 +106,7 @@ impl AppserviceRegistration {
|
|||
///
|
||||
/// See the fields of [`Registration`] for the required format
|
||||
pub fn try_from_yaml_str(value: impl AsRef<str>) -> Result<Self> {
|
||||
Ok(Self {
|
||||
inner: serde_yaml::from_str(value.as_ref())?,
|
||||
})
|
||||
Ok(Self { inner: serde_yaml::from_str(value.as_ref())? })
|
||||
}
|
||||
|
||||
/// Try to load registration from yaml file
|
||||
|
@ -115,9 +115,7 @@ impl AppserviceRegistration {
|
|||
pub fn try_from_yaml_file(path: impl Into<PathBuf>) -> Result<Self> {
|
||||
let file = File::open(path.into())?;
|
||||
|
||||
Ok(Self {
|
||||
inner: serde_yaml::from_reader(file)?,
|
||||
})
|
||||
Ok(Self { inner: serde_yaml::from_reader(file)? })
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -177,8 +175,10 @@ impl Appservice {
|
|||
/// # Arguments
|
||||
///
|
||||
/// * `homeserver_url` - The homeserver that the client should connect to.
|
||||
/// * `server_name` - The server name to use when constructing user ids from the localpart.
|
||||
/// * `registration` - The [Appservice Registration] to use when interacting with the homserver.
|
||||
/// * `server_name` - The server name to use when constructing user ids from
|
||||
/// the localpart.
|
||||
/// * `registration` - The [Appservice Registration] to use when interacting
|
||||
/// with the homserver.
|
||||
///
|
||||
/// [Appservice Registration]: https://matrix.org/docs/spec/application_service/r0.1.2#registration
|
||||
pub async fn new(
|
||||
|
@ -209,8 +209,9 @@ impl Appservice {
|
|||
|
||||
/// Get `Client` for the given `localpart`
|
||||
///
|
||||
/// If the `localpart` is covered by the `namespaces` in the [registration] all requests to the
|
||||
/// homeserver will [assert the identity] to the according virtual user.
|
||||
/// If the `localpart` is covered by the `namespaces` in the [registration]
|
||||
/// all requests to the homeserver will [assert the identity] to the
|
||||
/// according virtual user.
|
||||
///
|
||||
/// [registration]: https://matrix.org/docs/spec/application_service/r0.1.2#registration
|
||||
/// [assert the identity]:
|
||||
|
@ -291,7 +292,8 @@ impl Appservice {
|
|||
|
||||
/// Get the host and port from the registration URL
|
||||
///
|
||||
/// If no port is found it falls back to scheme defaults: 80 for http and 443 for https
|
||||
/// If no port is found it falls back to scheme defaults: 80 for http and
|
||||
/// 443 for https
|
||||
pub fn get_host_and_port_from_registration(&self) -> Result<(Host, Port)> {
|
||||
let uri = Uri::try_from(&self.registration.url)?;
|
||||
|
||||
|
@ -315,9 +317,11 @@ impl Appservice {
|
|||
actix::get_scope().data(self.clone())
|
||||
}
|
||||
|
||||
/// Convenience method that runs an http server depending on the selected server feature
|
||||
/// Convenience method that runs an http server depending on the selected
|
||||
/// server feature
|
||||
///
|
||||
/// This is a blocking call that tries to listen on the provided host and port
|
||||
/// This is a blocking call that tries to listen on the provided host and
|
||||
/// port
|
||||
pub async fn run(&self, host: impl AsRef<str>, port: impl Into<u16>) -> Result<()> {
|
||||
#[cfg(feature = "actix")]
|
||||
{
|
||||
|
|
|
@ -1,14 +1,12 @@
|
|||
#[cfg(feature = "actix")]
|
||||
mod actix {
|
||||
use actix_web::{test, App};
|
||||
use matrix_sdk_appservice::*;
|
||||
use std::env;
|
||||
|
||||
use actix_web::{test, App};
|
||||
use matrix_sdk_appservice::*;
|
||||
|
||||
async fn appservice() -> Appservice {
|
||||
env::set_var(
|
||||
"RUST_LOG",
|
||||
"mockito=debug,matrix_sdk=debug,ruma=debug,actix_web=debug",
|
||||
);
|
||||
env::set_var("RUST_LOG", "mockito=debug,matrix_sdk=debug,ruma=debug,actix_web=debug");
|
||||
let _ = tracing_subscriber::fmt::try_init();
|
||||
|
||||
Appservice::new(
|
||||
|
@ -109,7 +107,8 @@ mod actix {
|
|||
|
||||
let resp = test::call_service(&app, req).await;
|
||||
|
||||
// TODO: this should actually return a 401 but is 500 because something in the extractor fails
|
||||
// TODO: this should actually return a 401 but is 500 because something in the
|
||||
// extractor fails
|
||||
assert_eq!(resp.status(), 500);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -76,10 +76,7 @@ async fn test_event_handler() -> Result<()> {
|
|||
}
|
||||
}
|
||||
|
||||
appservice
|
||||
.client()
|
||||
.set_event_handler(Box::new(Example::new()))
|
||||
.await;
|
||||
appservice.client().set_event_handler(Box::new(Example::new())).await;
|
||||
|
||||
let event = serde_json::from_value::<AnyStateEvent>(member_json()).unwrap();
|
||||
let event: Raw<AnyRoomEvent> = AnyRoomEvent::State(event).into();
|
||||
|
|
|
@ -1,12 +1,15 @@
|
|||
use std::{convert::TryFrom, fmt::Debug, sync::Arc};
|
||||
|
||||
use futures::executor::block_on;
|
||||
use serde::Serialize;
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use atty::Stream;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use clap::{App as Argparse, AppSettings as ArgParseSettings, Arg, ArgMatches, SubCommand};
|
||||
use futures::executor::block_on;
|
||||
use matrix_sdk_base::{
|
||||
events::EventType,
|
||||
identifiers::{RoomId, UserId},
|
||||
RoomInfo, Store,
|
||||
};
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use rustyline::{
|
||||
completion::{Completer, Pair},
|
||||
|
@ -18,7 +21,7 @@ use rustyline::{
|
|||
};
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use rustyline_derive::Helper;
|
||||
|
||||
use serde::Serialize;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use syntect::{
|
||||
dumps::from_binary,
|
||||
|
@ -28,12 +31,6 @@ use syntect::{
|
|||
util::{as_24_bit_terminal_escaped, LinesWithEndings},
|
||||
};
|
||||
|
||||
use matrix_sdk_base::{
|
||||
events::EventType,
|
||||
identifiers::{RoomId, UserId},
|
||||
RoomInfo, Store,
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
struct Inspector {
|
||||
|
@ -79,17 +76,8 @@ impl InspectorHelper {
|
|||
fn complete_event_types(&self, arg: Option<&&str>) -> Vec<Pair> {
|
||||
Self::EVENT_TYPES
|
||||
.iter()
|
||||
.map(|t| Pair {
|
||||
display: t.to_string(),
|
||||
replacement: format!("{} ", t),
|
||||
})
|
||||
.filter(|r| {
|
||||
if let Some(arg) = arg {
|
||||
r.replacement.starts_with(arg)
|
||||
} else {
|
||||
true
|
||||
}
|
||||
})
|
||||
.map(|t| Pair { display: t.to_string(), replacement: format!("{} ", t) })
|
||||
.filter(|r| if let Some(arg) = arg { r.replacement.starts_with(arg) } else { true })
|
||||
.collect()
|
||||
}
|
||||
|
||||
|
@ -102,13 +90,7 @@ impl InspectorHelper {
|
|||
display: r.room_id.to_string(),
|
||||
replacement: format!("{} ", r.room_id.to_string()),
|
||||
})
|
||||
.filter(|r| {
|
||||
if let Some(arg) = arg {
|
||||
r.replacement.starts_with(arg)
|
||||
} else {
|
||||
true
|
||||
}
|
||||
})
|
||||
.filter(|r| if let Some(arg) = arg { r.replacement.starts_with(arg) } else { true })
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
@ -127,15 +109,9 @@ impl Completer for InspectorHelper {
|
|||
|
||||
let commands = vec![
|
||||
("get-state", "get a state event in the given room"),
|
||||
(
|
||||
"get-profiles",
|
||||
"get all the stored profiles in the given room",
|
||||
),
|
||||
("get-profiles", "get all the stored profiles in the given room"),
|
||||
("list-rooms", "list all rooms"),
|
||||
(
|
||||
"get-members",
|
||||
"get all the membership events in the given room",
|
||||
),
|
||||
("get-members", "get all the membership events in the given room"),
|
||||
]
|
||||
.iter()
|
||||
.map(|(r, d)| Pair {
|
||||
|
@ -154,19 +130,13 @@ impl Completer for InspectorHelper {
|
|||
} else {
|
||||
Ok((
|
||||
0,
|
||||
commands
|
||||
.into_iter()
|
||||
.filter(|c| c.replacement.starts_with(args[0]))
|
||||
.collect(),
|
||||
commands.into_iter().filter(|c| c.replacement.starts_with(args[0])).collect(),
|
||||
))
|
||||
}
|
||||
} else if args.len() == 2 {
|
||||
if args[0] == "get-state" {
|
||||
if line.ends_with(' ') {
|
||||
Ok((
|
||||
args[0].len() + args[1].len() + 2,
|
||||
self.complete_event_types(args.get(2)),
|
||||
))
|
||||
Ok((args[0].len() + args[1].len() + 2, self.complete_event_types(args.get(2))))
|
||||
} else {
|
||||
Ok((args[0].len() + 1, self.complete_rooms(args.get(1))))
|
||||
}
|
||||
|
@ -177,10 +147,7 @@ impl Completer for InspectorHelper {
|
|||
}
|
||||
} else if args.len() == 3 {
|
||||
if args[0] == "get-state" {
|
||||
Ok((
|
||||
args[0].len() + args[1].len() + 2,
|
||||
self.complete_event_types(args.get(2)),
|
||||
))
|
||||
Ok((args[0].len() + args[1].len() + 2, self.complete_event_types(args.get(2))))
|
||||
} else {
|
||||
Ok((pos, vec![]))
|
||||
}
|
||||
|
@ -216,12 +183,7 @@ impl Printer {
|
|||
let syntax_set: SyntaxSet = from_binary(include_bytes!("./syntaxes.bin"));
|
||||
let themes: ThemeSet = from_binary(include_bytes!("./themes.bin"));
|
||||
|
||||
Self {
|
||||
ps: syntax_set.into(),
|
||||
ts: themes.into(),
|
||||
json,
|
||||
color,
|
||||
}
|
||||
Self { ps: syntax_set.into(), ts: themes.into(), json, color }
|
||||
}
|
||||
|
||||
fn pretty_print_struct<T: Debug + Serialize>(&self, data: &T) {
|
||||
|
@ -232,13 +194,9 @@ impl Printer {
|
|||
};
|
||||
|
||||
let syntax = if self.json {
|
||||
self.ps
|
||||
.find_syntax_by_extension("rs")
|
||||
.expect("Can't find rust syntax extension")
|
||||
self.ps.find_syntax_by_extension("rs").expect("Can't find rust syntax extension")
|
||||
} else {
|
||||
self.ps
|
||||
.find_syntax_by_extension("json")
|
||||
.expect("Can't find json syntax extension")
|
||||
self.ps.find_syntax_by_extension("json").expect("Can't find json syntax extension")
|
||||
};
|
||||
|
||||
if self.color {
|
||||
|
@ -305,11 +263,7 @@ impl Inspector {
|
|||
}
|
||||
|
||||
async fn get_display_name_owners(&self, room_id: RoomId, display_name: String) {
|
||||
let users = self
|
||||
.store
|
||||
.get_users_with_display_name(&room_id, &display_name)
|
||||
.await
|
||||
.unwrap();
|
||||
let users = self.store.get_users_with_display_name(&room_id, &display_name).await.unwrap();
|
||||
self.printer.pretty_print_struct(&users);
|
||||
}
|
||||
|
||||
|
@ -326,22 +280,14 @@ impl Inspector {
|
|||
let joined: Vec<UserId> = self.store.get_joined_user_ids(&room_id).await.unwrap();
|
||||
|
||||
for member in joined {
|
||||
let event = self
|
||||
.store
|
||||
.get_member_event(&room_id, &member)
|
||||
.await
|
||||
.unwrap();
|
||||
let event = self.store.get_member_event(&room_id, &member).await.unwrap();
|
||||
self.printer.pretty_print_struct(&event);
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_state(&self, room_id: RoomId, event_type: EventType) {
|
||||
self.printer.pretty_print_struct(
|
||||
&self
|
||||
.store
|
||||
.get_state_event(&room_id, event_type, "")
|
||||
.await
|
||||
.unwrap(),
|
||||
&self.store.get_state_event(&room_id, event_type, "").await.unwrap(),
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -350,35 +296,25 @@ impl Inspector {
|
|||
SubCommand::with_name("list-rooms"),
|
||||
SubCommand::with_name("get-members").arg(
|
||||
Arg::with_name("room-id").required(true).validator(|r| {
|
||||
RoomId::try_from(r)
|
||||
.map(|_| ())
|
||||
.map_err(|_| "Invalid room id given".to_owned())
|
||||
RoomId::try_from(r).map(|_| ()).map_err(|_| "Invalid room id given".to_owned())
|
||||
}),
|
||||
),
|
||||
SubCommand::with_name("get-profiles").arg(
|
||||
Arg::with_name("room-id").required(true).validator(|r| {
|
||||
RoomId::try_from(r)
|
||||
.map(|_| ())
|
||||
.map_err(|_| "Invalid room id given".to_owned())
|
||||
RoomId::try_from(r).map(|_| ()).map_err(|_| "Invalid room id given".to_owned())
|
||||
}),
|
||||
),
|
||||
SubCommand::with_name("get-display-names")
|
||||
.arg(Arg::with_name("room-id").required(true).validator(|r| {
|
||||
RoomId::try_from(r)
|
||||
.map(|_| ())
|
||||
.map_err(|_| "Invalid room id given".to_owned())
|
||||
RoomId::try_from(r).map(|_| ()).map_err(|_| "Invalid room id given".to_owned())
|
||||
}))
|
||||
.arg(Arg::with_name("display-name").required(true)),
|
||||
SubCommand::with_name("get-state")
|
||||
.arg(Arg::with_name("room-id").required(true).validator(|r| {
|
||||
RoomId::try_from(r)
|
||||
.map(|_| ())
|
||||
.map_err(|_| "Invalid room id given".to_owned())
|
||||
RoomId::try_from(r).map(|_| ()).map_err(|_| "Invalid room id given".to_owned())
|
||||
}))
|
||||
.arg(Arg::with_name("event-type").required(true).validator(|e| {
|
||||
EventType::try_from(e)
|
||||
.map(|_| ())
|
||||
.map_err(|_| "Invalid event type".to_string())
|
||||
EventType::try_from(e).map(|_| ()).map_err(|_| "Invalid event type".to_string())
|
||||
})),
|
||||
]
|
||||
}
|
||||
|
|
|
@ -87,20 +87,21 @@ pub struct AdditionalUnsignedData {
|
|||
pub prev_content: Option<Raw<MemberEventContent>>,
|
||||
}
|
||||
|
||||
/// Transform state event by hoisting `prev_content` field from `unsigned` to the top level.
|
||||
/// Transform state event by hoisting `prev_content` field from `unsigned` to
|
||||
/// the top level.
|
||||
///
|
||||
/// Due to a [bug in synapse][synapse-bug], `prev_content` often ends up in `unsigned` contrary to
|
||||
/// the C2S spec. Some more discussion can be found [here][discussion]. Until this is fixed in
|
||||
/// synapse or handled in Ruma, we use this to hoist up `prev_content` to the top level.
|
||||
/// Due to a [bug in synapse][synapse-bug], `prev_content` often ends up in
|
||||
/// `unsigned` contrary to the C2S spec. Some more discussion can be found
|
||||
/// [here][discussion]. Until this is fixed in synapse or handled in Ruma, we
|
||||
/// use this to hoist up `prev_content` to the top level.
|
||||
///
|
||||
/// [synapse-bug]: <https://github.com/matrix-org/matrix-doc/issues/684#issuecomment-641182668>
|
||||
/// [discussion]: <https://github.com/matrix-org/matrix-doc/issues/684#issuecomment-641182668>
|
||||
pub fn hoist_and_deserialize_state_event(
|
||||
event: &Raw<AnySyncStateEvent>,
|
||||
) -> StdResult<AnySyncStateEvent, serde_json::Error> {
|
||||
let prev_content = serde_json::from_str::<AdditionalEventData>(event.json().get())?
|
||||
.unsigned
|
||||
.prev_content;
|
||||
let prev_content =
|
||||
serde_json::from_str::<AdditionalEventData>(event.json().get())?.unsigned.prev_content;
|
||||
|
||||
let mut ev = event.deserialize()?;
|
||||
|
||||
|
@ -116,9 +117,8 @@ pub fn hoist_and_deserialize_state_event(
|
|||
fn hoist_member_event(
|
||||
event: &Raw<StateEvent<MemberEventContent>>,
|
||||
) -> StdResult<StateEvent<MemberEventContent>, serde_json::Error> {
|
||||
let prev_content = serde_json::from_str::<AdditionalEventData>(event.json().get())?
|
||||
.unsigned
|
||||
.prev_content;
|
||||
let prev_content =
|
||||
serde_json::from_str::<AdditionalEventData>(event.json().get())?.unsigned.prev_content;
|
||||
|
||||
let mut e = event.deserialize()?;
|
||||
|
||||
|
@ -340,7 +340,8 @@ impl BaseClient {
|
|||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `response` - A successful login response that contains our access token
|
||||
/// * `response` - A successful login response that contains our access
|
||||
/// token
|
||||
/// and device id.
|
||||
pub async fn receive_login_response(
|
||||
&self,
|
||||
|
@ -440,9 +441,7 @@ impl BaseClient {
|
|||
AnySyncRoomEvent::State(s) => match s {
|
||||
AnySyncStateEvent::RoomMember(member) => {
|
||||
if let Ok(member) = MemberEvent::try_from(member.clone()) {
|
||||
ambiguity_cache
|
||||
.handle_event(changes, room_id, &member)
|
||||
.await?;
|
||||
ambiguity_cache.handle_event(changes, room_id, &member).await?;
|
||||
|
||||
match member.content.membership {
|
||||
MembershipState::Join | MembershipState::Invite => {
|
||||
|
@ -500,8 +499,7 @@ impl BaseClient {
|
|||
}
|
||||
|
||||
if let Some(context) = &mut push_context {
|
||||
self.update_push_room_context(context, user_id, room_info, changes)
|
||||
.await;
|
||||
self.update_push_room_context(context, user_id, room_info, changes).await;
|
||||
} else {
|
||||
push_context = self.get_push_room_context(room, room_info, changes).await?;
|
||||
}
|
||||
|
@ -521,10 +519,13 @@ impl BaseClient {
|
|||
),
|
||||
);
|
||||
}
|
||||
// TODO if there is an Action::SetTweak(Tweak::Highlight) we need to store
|
||||
// its value with the event so a client can show if the event is highlighted
|
||||
// TODO if there is an
|
||||
// Action::SetTweak(Tweak::Highlight) we need to store
|
||||
// its value with the event so a client can show if the
|
||||
// event is highlighted
|
||||
// in the UI.
|
||||
// Requires the possibility to associate custom data with events and to
|
||||
// Requires the possibility to associate custom data
|
||||
// with events and to
|
||||
// store them.
|
||||
}
|
||||
}
|
||||
|
@ -762,18 +763,14 @@ impl BaseClient {
|
|||
let mut changes = StateChanges::new(next_batch.clone());
|
||||
let mut ambiguity_cache = AmbiguityCache::new(self.store.clone());
|
||||
|
||||
self.handle_account_data(&account_data.events, &mut changes)
|
||||
.await;
|
||||
self.handle_account_data(&account_data.events, &mut changes).await;
|
||||
|
||||
let push_rules = self.get_push_rules(&changes).await?;
|
||||
|
||||
let mut new_rooms = Rooms::default();
|
||||
|
||||
for (room_id, new_info) in rooms.join {
|
||||
let room = self
|
||||
.store
|
||||
.get_or_create_room(&room_id, RoomType::Joined)
|
||||
.await;
|
||||
let room = self.store.get_or_create_room(&room_id, RoomType::Joined).await;
|
||||
let mut room_info = room.clone_info();
|
||||
room_info.mark_as_joined();
|
||||
|
||||
|
@ -844,10 +841,7 @@ impl BaseClient {
|
|||
}
|
||||
|
||||
for (room_id, new_info) in rooms.leave {
|
||||
let room = self
|
||||
.store
|
||||
.get_or_create_room(&room_id, RoomType::Left)
|
||||
.await;
|
||||
let room = self.store.get_or_create_room(&room_id, RoomType::Left).await;
|
||||
let mut room_info = room.clone_info();
|
||||
room_info.mark_as_left();
|
||||
|
||||
|
@ -876,18 +870,14 @@ impl BaseClient {
|
|||
.await;
|
||||
|
||||
changes.add_room(room_info);
|
||||
new_rooms.leave.insert(
|
||||
room_id,
|
||||
LeftRoom::new(timeline, new_info.state, new_info.account_data),
|
||||
);
|
||||
new_rooms
|
||||
.leave
|
||||
.insert(room_id, LeftRoom::new(timeline, new_info.state, new_info.account_data));
|
||||
}
|
||||
|
||||
for (room_id, new_info) in rooms.invite {
|
||||
{
|
||||
let room = self
|
||||
.store
|
||||
.get_or_create_room(&room_id, RoomType::Invited)
|
||||
.await;
|
||||
let room = self.store.get_or_create_room(&room_id, RoomType::Invited).await;
|
||||
let mut room_info = room.clone_info();
|
||||
room_info.mark_as_invited();
|
||||
changes.add_room(room_info);
|
||||
|
@ -934,9 +924,7 @@ impl BaseClient {
|
|||
.into_iter()
|
||||
.map(|(k, v)| (k, v.into()))
|
||||
.collect(),
|
||||
ambiguity_changes: AmbiguityChanges {
|
||||
changes: ambiguity_cache.changes,
|
||||
},
|
||||
ambiguity_changes: AmbiguityChanges { changes: ambiguity_cache.changes },
|
||||
notifications: changes.notifications,
|
||||
};
|
||||
|
||||
|
@ -968,11 +956,7 @@ impl BaseClient {
|
|||
let members: Vec<MemberEvent> = response
|
||||
.chunk
|
||||
.iter()
|
||||
.filter_map(|e| {
|
||||
hoist_member_event(e)
|
||||
.ok()
|
||||
.and_then(|e| MemberEvent::try_from(e).ok())
|
||||
})
|
||||
.filter_map(|e| hoist_member_event(e).ok().and_then(|e| MemberEvent::try_from(e).ok()))
|
||||
.collect();
|
||||
let mut ambiguity_cache = AmbiguityCache::new(self.store.clone());
|
||||
|
||||
|
@ -986,12 +970,7 @@ impl BaseClient {
|
|||
let mut user_ids = BTreeSet::new();
|
||||
|
||||
for member in &members {
|
||||
if self
|
||||
.store
|
||||
.get_member_event(&room_id, &member.state_key)
|
||||
.await?
|
||||
.is_none()
|
||||
{
|
||||
if self.store.get_member_event(&room_id, &member.state_key).await?.is_none() {
|
||||
#[cfg(feature = "encryption")]
|
||||
match member.content.membership {
|
||||
MembershipState::Join | MembershipState::Invite => {
|
||||
|
@ -1000,9 +979,7 @@ impl BaseClient {
|
|||
_ => (),
|
||||
}
|
||||
|
||||
ambiguity_cache
|
||||
.handle_event(&changes, room_id, &member)
|
||||
.await?;
|
||||
ambiguity_cache.handle_event(&changes, room_id, &member).await?;
|
||||
|
||||
if member.state_key == member.sender {
|
||||
changes
|
||||
|
@ -1036,9 +1013,7 @@ impl BaseClient {
|
|||
|
||||
Ok(MembersResponse {
|
||||
chunk: members,
|
||||
ambiguity_changes: AmbiguityChanges {
|
||||
changes: ambiguity_cache.changes,
|
||||
},
|
||||
ambiguity_changes: AmbiguityChanges { changes: ambiguity_cache.changes },
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -1050,7 +1025,8 @@ impl BaseClient {
|
|||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `filter_name` - The name that should be used to persist the filter id in
|
||||
/// * `filter_name` - The name that should be used to persist the filter id
|
||||
/// in
|
||||
/// the store.
|
||||
///
|
||||
/// * `response` - The successful filter upload response containing the
|
||||
|
@ -1062,10 +1038,7 @@ impl BaseClient {
|
|||
filter_name: &str,
|
||||
response: &api::filter::create_filter::Response,
|
||||
) -> Result<()> {
|
||||
Ok(self
|
||||
.store
|
||||
.save_filter(filter_name, &response.filter_id)
|
||||
.await?)
|
||||
Ok(self.store.save_filter(filter_name, &response.filter_id).await?)
|
||||
}
|
||||
|
||||
/// Get the filter id of a previously uploaded filter.
|
||||
|
@ -1224,18 +1197,14 @@ impl BaseClient {
|
|||
/// # Arguments
|
||||
///
|
||||
/// * `flow_id` - The unique id that identifies a interactive verification
|
||||
/// flow. For in-room verifications this will be the event id of the
|
||||
/// *m.key.verification.request* event that started the flow, for the
|
||||
/// to-device verification flows this will be the transaction id of the
|
||||
/// *m.key.verification.start* event.
|
||||
/// flow. For in-room verifications this will be the event id of the
|
||||
/// *m.key.verification.request* event that started the flow, for the
|
||||
/// to-device verification flows this will be the transaction id of the
|
||||
/// *m.key.verification.start* event.
|
||||
#[cfg(feature = "encryption")]
|
||||
#[cfg_attr(feature = "docs", doc(cfg(encryption)))]
|
||||
pub async fn get_verification(&self, flow_id: &str) -> Option<Sas> {
|
||||
self.olm
|
||||
.lock()
|
||||
.await
|
||||
.as_ref()
|
||||
.and_then(|o| o.get_verification(flow_id))
|
||||
self.olm.lock().await.as_ref().and_then(|o| o.get_verification(flow_id))
|
||||
}
|
||||
|
||||
/// Get a specific device of a user.
|
||||
|
@ -1284,10 +1253,12 @@ impl BaseClient {
|
|||
|
||||
/// Get the user login session.
|
||||
///
|
||||
/// If the client is currently logged in, this will return a `matrix_sdk::Session` object which
|
||||
/// can later be given to `restore_login`.
|
||||
/// If the client is currently logged in, this will return a
|
||||
/// `matrix_sdk::Session` object which can later be given to
|
||||
/// `restore_login`.
|
||||
///
|
||||
/// Returns a session object if the client is logged in. Otherwise returns `None`.
|
||||
/// Returns a session object if the client is logged in. Otherwise returns
|
||||
/// `None`.
|
||||
pub async fn get_session(&self) -> Option<Session> {
|
||||
self.session.read().await.clone()
|
||||
}
|
||||
|
@ -1349,8 +1320,9 @@ impl BaseClient {
|
|||
|
||||
/// Get the push rules.
|
||||
///
|
||||
/// Gets the push rules from `changes` if they have been updated, otherwise get them from the
|
||||
/// store. As a fallback, uses `Ruleset::server_default` if the user is logged in.
|
||||
/// Gets the push rules from `changes` if they have been updated, otherwise
|
||||
/// get them from the store. As a fallback, uses
|
||||
/// `Ruleset::server_default` if the user is logged in.
|
||||
pub async fn get_push_rules(&self, changes: &StateChanges) -> Result<Ruleset> {
|
||||
if let Some(AnyGlobalAccountDataEvent::PushRules(event)) = changes
|
||||
.account_data
|
||||
|
@ -1374,11 +1346,11 @@ impl BaseClient {
|
|||
|
||||
/// Get the push context for the given room.
|
||||
///
|
||||
/// Tries to get the data from `changes` or the up to date `room_info`. Loads the data from the
|
||||
/// store otherwise.
|
||||
/// Tries to get the data from `changes` or the up to date `room_info`.
|
||||
/// Loads the data from the store otherwise.
|
||||
///
|
||||
/// Returns `None` if some data couldn't be found. This should only happen in brand new rooms,
|
||||
/// while we process its state.
|
||||
/// Returns `None` if some data couldn't be found. This should only happen
|
||||
/// in brand new rooms, while we process its state.
|
||||
pub async fn get_push_room_context(
|
||||
&self,
|
||||
room: &Room,
|
||||
|
@ -1390,16 +1362,10 @@ impl BaseClient {
|
|||
|
||||
let member_count = room_info.active_members_count();
|
||||
|
||||
let user_display_name = if let Some(member) = changes
|
||||
.members
|
||||
.get(room_id)
|
||||
.and_then(|members| members.get(user_id))
|
||||
let user_display_name = if let Some(member) =
|
||||
changes.members.get(room_id).and_then(|members| members.get(user_id))
|
||||
{
|
||||
member
|
||||
.content
|
||||
.displayname
|
||||
.clone()
|
||||
.unwrap_or_else(|| user_id.localpart().to_owned())
|
||||
member.content.displayname.clone().unwrap_or_else(|| user_id.localpart().to_owned())
|
||||
} else if let Some(member) = room.get_member(user_id).await? {
|
||||
member.name().to_owned()
|
||||
} else {
|
||||
|
@ -1449,16 +1415,10 @@ impl BaseClient {
|
|||
|
||||
push_rules.member_count = UInt::new(room_info.active_members_count()).unwrap_or(UInt::MAX);
|
||||
|
||||
if let Some(member) = changes
|
||||
.members
|
||||
.get(room_id)
|
||||
.and_then(|members| members.get(user_id))
|
||||
if let Some(member) = changes.members.get(room_id).and_then(|members| members.get(user_id))
|
||||
{
|
||||
push_rules.user_display_name = member
|
||||
.content
|
||||
.displayname
|
||||
.clone()
|
||||
.unwrap_or_else(|| user_id.localpart().to_owned())
|
||||
push_rules.user_display_name =
|
||||
member.content.displayname.clone().unwrap_or_else(|| user_id.localpart().to_owned())
|
||||
}
|
||||
|
||||
if let Some(AnySyncStateEvent::RoomPowerLevels(event)) = changes
|
||||
|
|
|
@ -15,12 +15,12 @@
|
|||
|
||||
//! Error conditions.
|
||||
|
||||
use serde_json::Error as JsonError;
|
||||
use std::io::Error as IoError;
|
||||
use thiserror::Error;
|
||||
|
||||
#[cfg(feature = "encryption")]
|
||||
use matrix_sdk_crypto::{CryptoStoreError, MegolmError, OlmError};
|
||||
use serde_json::Error as JsonError;
|
||||
use thiserror::Error;
|
||||
|
||||
/// Result type of the rust-sdk.
|
||||
pub type Result<T, E = Error> = std::result::Result<T, E>;
|
||||
|
@ -28,7 +28,8 @@ pub type Result<T, E = Error> = std::result::Result<T, E>;
|
|||
/// Internal representation of errors.
|
||||
#[derive(Error, Debug)]
|
||||
pub enum Error {
|
||||
/// Queried endpoint requires authentication but was called on an anonymous client.
|
||||
/// Queried endpoint requires authentication but was called on an anonymous
|
||||
/// client.
|
||||
#[error("the queried endpoint requires authentication but was called before logging in")]
|
||||
AuthenticationRequired,
|
||||
|
||||
|
|
|
@ -36,11 +36,12 @@
|
|||
)]
|
||||
#![cfg_attr(feature = "docs", feature(doc_cfg))]
|
||||
|
||||
pub use matrix_sdk_common::*;
|
||||
|
||||
pub use crate::{
|
||||
error::{Error, Result},
|
||||
session::Session,
|
||||
};
|
||||
pub use matrix_sdk_common::*;
|
||||
|
||||
mod client;
|
||||
mod error;
|
||||
|
@ -48,11 +49,9 @@ mod rooms;
|
|||
mod session;
|
||||
mod store;
|
||||
|
||||
pub use rooms::{Room, RoomInfo, RoomMember, RoomType};
|
||||
pub use store::{StateChanges, StateStore, Store, StoreError};
|
||||
|
||||
pub use client::{BaseClient, BaseClientConfig};
|
||||
|
||||
#[cfg(feature = "encryption")]
|
||||
#[cfg_attr(feature = "docs", doc(cfg(encryption)))]
|
||||
pub use matrix_sdk_crypto as crypto;
|
||||
pub use rooms::{Room, RoomInfo, RoomMember, RoomType};
|
||||
pub use store::{StateChanges, StateStore, Store, StoreError};
|
||||
|
|
|
@ -1,27 +1,22 @@
|
|||
mod members;
|
||||
mod normal;
|
||||
|
||||
use matrix_sdk_common::{
|
||||
events::room::{
|
||||
create::CreateEventContent, guest_access::GuestAccess,
|
||||
history_visibility::HistoryVisibility, join_rules::JoinRule,
|
||||
},
|
||||
identifiers::{MxcUri, UserId},
|
||||
};
|
||||
pub use normal::{Room, RoomInfo, RoomType};
|
||||
|
||||
pub use members::RoomMember;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::cmp::max;
|
||||
|
||||
use matrix_sdk_common::{
|
||||
events::{
|
||||
room::{encryption::EncryptionEventContent, tombstone::TombstoneEventContent},
|
||||
room::{
|
||||
create::CreateEventContent, encryption::EncryptionEventContent,
|
||||
guest_access::GuestAccess, history_visibility::HistoryVisibility, join_rules::JoinRule,
|
||||
tombstone::TombstoneEventContent,
|
||||
},
|
||||
AnyStateEventContent,
|
||||
},
|
||||
identifiers::RoomAliasId,
|
||||
identifiers::{MxcUri, RoomAliasId, UserId},
|
||||
};
|
||||
pub use members::RoomMember;
|
||||
pub use normal::{Room, RoomInfo, RoomType};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// A base room info struct that is the backbone of normal as well as stripped
|
||||
/// rooms. Holds all the state events that are important to present a room to
|
||||
|
@ -71,20 +66,12 @@ impl BaseRoomInfo {
|
|||
let invited_joined = (invited_member_count + joined_member_count).saturating_sub(1);
|
||||
|
||||
if heroes_count >= invited_joined {
|
||||
let mut names = heroes
|
||||
.iter()
|
||||
.take(3)
|
||||
.map(|mem| mem.name())
|
||||
.collect::<Vec<&str>>();
|
||||
let mut names = heroes.iter().take(3).map(|mem| mem.name()).collect::<Vec<&str>>();
|
||||
// stabilize ordering
|
||||
names.sort_unstable();
|
||||
names.join(", ")
|
||||
} else if heroes_count < invited_joined && invited_joined > 1 {
|
||||
let mut names = heroes
|
||||
.iter()
|
||||
.take(3)
|
||||
.map(|mem| mem.name())
|
||||
.collect::<Vec<&str>>();
|
||||
let mut names = heroes.iter().take(3).map(|mem| mem.name()).collect::<Vec<&str>>();
|
||||
names.sort_unstable();
|
||||
|
||||
// TODO: What length does the spec want us to use here and in
|
||||
|
@ -149,10 +136,8 @@ impl BaseRoomInfo {
|
|||
true
|
||||
}
|
||||
AnyStateEventContent::RoomPowerLevels(p) => {
|
||||
let max_power_level = p
|
||||
.users
|
||||
.values()
|
||||
.fold(self.max_power_level, |acc, p| max(acc, (*p).into()));
|
||||
let max_power_level =
|
||||
p.users.values().fold(self.max_power_level, |acc, p| max(acc, (*p).into()));
|
||||
self.max_power_level = max_power_level;
|
||||
true
|
||||
}
|
||||
|
|
|
@ -37,14 +37,14 @@ use matrix_sdk_common::{
|
|||
use serde::{Deserialize, Serialize};
|
||||
use tracing::info;
|
||||
|
||||
use super::{BaseRoomInfo, RoomMember};
|
||||
use crate::{
|
||||
deserialized_responses::UnreadNotificationsCount,
|
||||
store::{Result as StoreResult, StateStore},
|
||||
};
|
||||
|
||||
use super::{BaseRoomInfo, RoomMember};
|
||||
|
||||
/// The underlying room data structure collecting state for joined, left and invtied rooms.
|
||||
/// The underlying room data structure collecting state for joined, left and
|
||||
/// invtied rooms.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Room {
|
||||
room_id: Arc<RoomId>,
|
||||
|
@ -135,7 +135,8 @@ impl Room {
|
|||
|
||||
/// Check if the room has it's members fully synced.
|
||||
///
|
||||
/// Members might be missing if lazy member loading was enabled for the sync.
|
||||
/// Members might be missing if lazy member loading was enabled for the
|
||||
/// sync.
|
||||
///
|
||||
/// Returns true if no members are missing, false otherwise.
|
||||
pub fn are_members_synced(&self) -> bool {
|
||||
|
@ -200,12 +201,7 @@ impl Room {
|
|||
|
||||
/// Get the history visibility policy of this room.
|
||||
pub fn history_visibility(&self) -> HistoryVisibility {
|
||||
self.inner
|
||||
.read()
|
||||
.unwrap()
|
||||
.base_info
|
||||
.history_visibility
|
||||
.clone()
|
||||
self.inner.read().unwrap().base_info.history_visibility.clone()
|
||||
}
|
||||
|
||||
/// Is the room considered to be public.
|
||||
|
@ -367,9 +363,7 @@ impl Room {
|
|||
);
|
||||
|
||||
let inner = self.inner.read().unwrap();
|
||||
Ok(inner
|
||||
.base_info
|
||||
.calculate_room_name(joined, invited, members))
|
||||
Ok(inner.base_info.calculate_room_name(joined, invited, members))
|
||||
}
|
||||
|
||||
pub(crate) fn clone_info(&self) -> RoomInfo {
|
||||
|
@ -394,11 +388,8 @@ impl Room {
|
|||
return Ok(None);
|
||||
};
|
||||
|
||||
let presence = self
|
||||
.store
|
||||
.get_presence_event(user_id)
|
||||
.await?
|
||||
.and_then(|e| e.deserialize().ok());
|
||||
let presence =
|
||||
self.store.get_presence_event(user_id).await?.and_then(|e| e.deserialize().ok());
|
||||
let profile = self.store.get_profile(self.room_id(), user_id).await?;
|
||||
let max_power_level = self.max_power_level();
|
||||
let is_room_creator = self
|
||||
|
@ -411,28 +402,24 @@ impl Room {
|
|||
.map(|c| &c.creator == user_id)
|
||||
.unwrap_or(false);
|
||||
|
||||
let power = self
|
||||
.store
|
||||
.get_state_event(self.room_id(), EventType::RoomPowerLevels, "")
|
||||
.await?
|
||||
.and_then(|e| e.deserialize().ok())
|
||||
.and_then(|e| {
|
||||
if let AnySyncStateEvent::RoomPowerLevels(e) = e {
|
||||
Some(e)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
});
|
||||
let power =
|
||||
self.store
|
||||
.get_state_event(self.room_id(), EventType::RoomPowerLevels, "")
|
||||
.await?
|
||||
.and_then(|e| e.deserialize().ok())
|
||||
.and_then(|e| {
|
||||
if let AnySyncStateEvent::RoomPowerLevels(e) = e {
|
||||
Some(e)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
});
|
||||
|
||||
let ambiguous = self
|
||||
.store
|
||||
.get_users_with_display_name(
|
||||
self.room_id(),
|
||||
member_event
|
||||
.content
|
||||
.displayname
|
||||
.as_deref()
|
||||
.unwrap_or_else(|| user_id.localpart()),
|
||||
member_event.content.displayname.as_deref().unwrap_or_else(|| user_id.localpart()),
|
||||
)
|
||||
.await?
|
||||
.len()
|
||||
|
@ -558,8 +545,6 @@ impl RoomInfo {
|
|||
///
|
||||
/// The return value is saturated at `u64::MAX`.
|
||||
pub fn active_members_count(&self) -> u64 {
|
||||
self.summary
|
||||
.joined_member_count
|
||||
.saturating_add(self.summary.invited_member_count)
|
||||
self.summary.joined_member_count.saturating_add(self.summary.invited_member_count)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,9 +15,8 @@
|
|||
|
||||
//! User sessions.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use matrix_sdk_common::identifiers::{DeviceId, UserId};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// A user session, containing an access token and information about the
|
||||
/// associated user account.
|
||||
|
|
|
@ -19,12 +19,10 @@ use matrix_sdk_common::{
|
|||
events::room::member::MembershipState,
|
||||
identifiers::{EventId, RoomId, UserId},
|
||||
};
|
||||
|
||||
use tracing::trace;
|
||||
|
||||
use crate::Store;
|
||||
|
||||
use super::{Result, StateChanges};
|
||||
use crate::Store;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct AmbiguityCache {
|
||||
|
@ -51,11 +49,8 @@ impl AmbiguityMap {
|
|||
}
|
||||
|
||||
fn add(&mut self, user_id: UserId) -> Option<UserId> {
|
||||
let ambiguous_user = if self.user_count() == 1 {
|
||||
self.users.iter().next().cloned()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let ambiguous_user =
|
||||
if self.user_count() == 1 { self.users.iter().next().cloned() } else { None };
|
||||
|
||||
self.users.insert(user_id);
|
||||
|
||||
|
@ -73,11 +68,7 @@ impl AmbiguityMap {
|
|||
|
||||
impl AmbiguityCache {
|
||||
pub fn new(store: Store) -> Self {
|
||||
Self {
|
||||
store,
|
||||
cache: BTreeMap::new(),
|
||||
changes: BTreeMap::new(),
|
||||
}
|
||||
Self { store, cache: BTreeMap::new(), changes: BTreeMap::new() }
|
||||
}
|
||||
|
||||
pub async fn handle_event(
|
||||
|
@ -115,12 +106,9 @@ impl AmbiguityCache {
|
|||
return Ok(());
|
||||
}
|
||||
|
||||
let disambiguated_member = old_map
|
||||
.as_mut()
|
||||
.and_then(|o| o.remove(&member_event.state_key));
|
||||
let ambiguated_member = new_map
|
||||
.as_mut()
|
||||
.and_then(|n| n.add(member_event.state_key.clone()));
|
||||
let disambiguated_member = old_map.as_mut().and_then(|o| o.remove(&member_event.state_key));
|
||||
let ambiguated_member =
|
||||
new_map.as_mut().and_then(|n| n.add(member_event.state_key.clone()));
|
||||
let ambiguous = new_map.as_ref().map(|n| n.is_ambiguous()).unwrap_or(false);
|
||||
|
||||
self.update(room_id, old_map, new_map);
|
||||
|
@ -131,11 +119,7 @@ impl AmbiguityCache {
|
|||
member_ambiguous: ambiguous,
|
||||
};
|
||||
|
||||
trace!(
|
||||
"Handling display name ambiguity for {}: {:#?}",
|
||||
member_event.state_key,
|
||||
change
|
||||
);
|
||||
trace!("Handling display name ambiguity for {}: {:#?}", member_event.state_key, change);
|
||||
|
||||
self.add_change(room_id, member_event.event_id.clone(), change);
|
||||
|
||||
|
@ -148,10 +132,7 @@ impl AmbiguityCache {
|
|||
old_map: Option<AmbiguityMap>,
|
||||
new_map: Option<AmbiguityMap>,
|
||||
) {
|
||||
let entry = self
|
||||
.cache
|
||||
.entry(room_id.clone())
|
||||
.or_insert_with(BTreeMap::new);
|
||||
let entry = self.cache.entry(room_id.clone()).or_insert_with(BTreeMap::new);
|
||||
|
||||
if let Some(old) = old_map {
|
||||
entry.insert(old.display_name, old.users);
|
||||
|
@ -163,10 +144,7 @@ impl AmbiguityCache {
|
|||
}
|
||||
|
||||
fn add_change(&mut self, room_id: &RoomId, event_id: EventId, change: AmbiguityChange) {
|
||||
self.changes
|
||||
.entry(room_id.clone())
|
||||
.or_insert_with(BTreeMap::new)
|
||||
.insert(event_id, change);
|
||||
self.changes.entry(room_id.clone()).or_insert_with(BTreeMap::new).insert(event_id, change);
|
||||
}
|
||||
|
||||
async fn get(
|
||||
|
@ -177,16 +155,12 @@ impl AmbiguityCache {
|
|||
) -> Result<(Option<AmbiguityMap>, Option<AmbiguityMap>)> {
|
||||
use MembershipState::*;
|
||||
|
||||
let old_event = if let Some(m) = changes
|
||||
.members
|
||||
.get(room_id)
|
||||
.and_then(|m| m.get(&member_event.state_key))
|
||||
let old_event = if let Some(m) =
|
||||
changes.members.get(room_id).and_then(|m| m.get(&member_event.state_key))
|
||||
{
|
||||
Some(m.clone())
|
||||
} else {
|
||||
self.store
|
||||
.get_member_event(room_id, &member_event.state_key)
|
||||
.await?
|
||||
self.store.get_member_event(room_id, &member_event.state_key).await?
|
||||
};
|
||||
|
||||
let old_display_name = if let Some(event) = old_event {
|
||||
|
@ -218,23 +192,15 @@ impl AmbiguityCache {
|
|||
};
|
||||
|
||||
let old_map = if let Some(old_name) = old_display_name.as_deref() {
|
||||
let old_display_name_map = if let Some(u) = self
|
||||
.cache
|
||||
.entry(room_id.clone())
|
||||
.or_insert_with(BTreeMap::new)
|
||||
.get(old_name)
|
||||
let old_display_name_map = if let Some(u) =
|
||||
self.cache.entry(room_id.clone()).or_insert_with(BTreeMap::new).get(old_name)
|
||||
{
|
||||
u.clone()
|
||||
} else {
|
||||
self.store
|
||||
.get_users_with_display_name(&room_id, &old_name)
|
||||
.await?
|
||||
self.store.get_users_with_display_name(&room_id, &old_name).await?
|
||||
};
|
||||
|
||||
Some(AmbiguityMap {
|
||||
display_name: old_name.to_string(),
|
||||
users: old_display_name_map,
|
||||
})
|
||||
Some(AmbiguityMap { display_name: old_name.to_string(), users: old_display_name_map })
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
@ -246,8 +212,9 @@ impl AmbiguityCache {
|
|||
.as_deref()
|
||||
.unwrap_or_else(|| member_event.state_key.localpart());
|
||||
|
||||
// We don't allow other users to set the display name, so if we have
|
||||
// a more trusted version of the display name use that.
|
||||
// We don't allow other users to set the display name, so if we
|
||||
// have a more trusted version of the display
|
||||
// name use that.
|
||||
let new_display_name = if member_event.sender.as_str() == member_event.state_key {
|
||||
new
|
||||
} else if let Some(old) = old_display_name.as_deref() {
|
||||
|
@ -264,9 +231,7 @@ impl AmbiguityCache {
|
|||
{
|
||||
u.clone()
|
||||
} else {
|
||||
self.store
|
||||
.get_users_with_display_name(&room_id, &new_display_name)
|
||||
.await?
|
||||
self.store.get_users_with_display_name(&room_id, &new_display_name).await?
|
||||
};
|
||||
|
||||
Some(AmbiguityMap {
|
||||
|
|
|
@ -30,12 +30,10 @@ use matrix_sdk_common::{
|
|||
instant::Instant,
|
||||
Raw,
|
||||
};
|
||||
|
||||
use tracing::info;
|
||||
|
||||
use crate::deserialized_responses::{MemberEvent, StrippedMemberEvent};
|
||||
|
||||
use super::{Result, RoomInfo, StateChanges, StateStore};
|
||||
use crate::deserialized_responses::{MemberEvent, StrippedMemberEvent};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MemoryStore {
|
||||
|
@ -82,8 +80,7 @@ impl MemoryStore {
|
|||
}
|
||||
|
||||
async fn save_filter(&self, filter_name: &str, filter_id: &str) -> Result<()> {
|
||||
self.filters
|
||||
.insert(filter_name.to_string(), filter_id.to_string());
|
||||
self.filters.insert(filter_name.to_string(), filter_id.to_string());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -164,8 +161,7 @@ impl MemoryStore {
|
|||
}
|
||||
|
||||
for (event_type, event) in &changes.account_data {
|
||||
self.account_data
|
||||
.insert(event_type.to_string(), event.clone());
|
||||
self.account_data.insert(event_type.to_string(), event.clone());
|
||||
}
|
||||
|
||||
for (room, events) in &changes.room_account_data {
|
||||
|
@ -199,8 +195,7 @@ impl MemoryStore {
|
|||
}
|
||||
|
||||
for (room_id, info) in &changes.invited_room_info {
|
||||
self.stripped_room_info
|
||||
.insert(room_id.clone(), info.clone());
|
||||
self.stripped_room_info.insert(room_id.clone(), info.clone());
|
||||
}
|
||||
|
||||
for (room, events) in &changes.stripped_members {
|
||||
|
@ -243,8 +238,7 @@ impl MemoryStore {
|
|||
) -> Result<Option<Raw<AnySyncStateEvent>>> {
|
||||
#[allow(clippy::map_clone)]
|
||||
Ok(self.room_state.get(room_id).and_then(|e| {
|
||||
e.get(event_type.as_ref())
|
||||
.and_then(|s| s.get(state_key).map(|e| e.clone()))
|
||||
e.get(event_type.as_ref()).and_then(|s| s.get(state_key).map(|e| e.clone()))
|
||||
}))
|
||||
}
|
||||
|
||||
|
@ -254,10 +248,7 @@ impl MemoryStore {
|
|||
user_id: &UserId,
|
||||
) -> Result<Option<MemberEventContent>> {
|
||||
#[allow(clippy::map_clone)]
|
||||
Ok(self
|
||||
.profiles
|
||||
.get(room_id)
|
||||
.and_then(|p| p.get(user_id).map(|p| p.clone())))
|
||||
Ok(self.profiles.get(room_id).and_then(|p| p.get(user_id).map(|p| p.clone())))
|
||||
}
|
||||
|
||||
async fn get_member_event(
|
||||
|
@ -266,10 +257,7 @@ impl MemoryStore {
|
|||
state_key: &UserId,
|
||||
) -> Result<Option<MemberEvent>> {
|
||||
#[allow(clippy::map_clone)]
|
||||
Ok(self
|
||||
.members
|
||||
.get(room_id)
|
||||
.and_then(|m| m.get(state_key).map(|m| m.clone())))
|
||||
Ok(self.members.get(room_id).and_then(|m| m.get(state_key).map(|m| m.clone())))
|
||||
}
|
||||
|
||||
fn get_user_ids(&self, room_id: &RoomId) -> Vec<UserId> {
|
||||
|
@ -310,10 +298,7 @@ impl MemoryStore {
|
|||
&self,
|
||||
event_type: EventType,
|
||||
) -> Result<Option<Raw<AnyGlobalAccountDataEvent>>> {
|
||||
Ok(self
|
||||
.account_data
|
||||
.get(event_type.as_ref())
|
||||
.map(|e| e.clone()))
|
||||
Ok(self.account_data.get(event_type.as_ref()).map(|e| e.clone()))
|
||||
}
|
||||
|
||||
async fn get_room_account_data_event(
|
||||
|
|
|
@ -12,15 +12,14 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#[cfg(feature = "sled_state_store")]
|
||||
use std::path::Path;
|
||||
use std::{
|
||||
collections::{BTreeMap, BTreeSet},
|
||||
ops::Deref,
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
#[cfg(feature = "sled_state_store")]
|
||||
use std::path::Path;
|
||||
|
||||
use dashmap::DashMap;
|
||||
use matrix_sdk_common::{
|
||||
api::r0::push::get_notifications::Notification,
|
||||
|
@ -201,7 +200,8 @@ pub trait StateStore: AsyncTraitDeps {
|
|||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `room_id` - The id of the room for which the room account data event should
|
||||
/// * `room_id` - The id of the room for which the room account data event
|
||||
/// should
|
||||
/// be fetched.
|
||||
///
|
||||
/// * `event_type` - The event type of the room account data event.
|
||||
|
@ -298,20 +298,16 @@ impl Store {
|
|||
|
||||
/// Get all the rooms this store knows about.
|
||||
pub fn get_rooms(&self) -> Vec<Room> {
|
||||
self.rooms
|
||||
.iter()
|
||||
.filter_map(|r| self.get_room(r.key()))
|
||||
.collect()
|
||||
self.rooms.iter().filter_map(|r| self.get_room(r.key())).collect()
|
||||
}
|
||||
|
||||
/// Get the room with the given room id.
|
||||
pub fn get_room(&self, room_id: &RoomId) -> Option<Room> {
|
||||
self.get_bare_room(room_id)
|
||||
.and_then(|r| match r.room_type() {
|
||||
RoomType::Joined => Some(r),
|
||||
RoomType::Left => Some(r),
|
||||
RoomType::Invited => self.get_stripped_room(room_id),
|
||||
})
|
||||
self.get_bare_room(room_id).and_then(|r| match r.room_type() {
|
||||
RoomType::Joined => Some(r),
|
||||
RoomType::Left => Some(r),
|
||||
RoomType::Invited => self.get_stripped_room(room_id),
|
||||
})
|
||||
}
|
||||
|
||||
fn get_stripped_room(&self, room_id: &RoomId) -> Option<Room> {
|
||||
|
@ -321,10 +317,7 @@ impl Store {
|
|||
|
||||
pub(crate) async fn get_or_create_stripped_room(&self, room_id: &RoomId) -> Room {
|
||||
let session = self.session.read().await;
|
||||
let user_id = &session
|
||||
.as_ref()
|
||||
.expect("Creating room while not being logged in")
|
||||
.user_id;
|
||||
let user_id = &session.as_ref().expect("Creating room while not being logged in").user_id;
|
||||
|
||||
self.stripped_rooms
|
||||
.entry(room_id.clone())
|
||||
|
@ -334,10 +327,7 @@ impl Store {
|
|||
|
||||
pub(crate) async fn get_or_create_room(&self, room_id: &RoomId, room_type: RoomType) -> Room {
|
||||
let session = self.session.read().await;
|
||||
let user_id = &session
|
||||
.as_ref()
|
||||
.expect("Creating room while not being logged in")
|
||||
.user_id;
|
||||
let user_id = &session.as_ref().expect("Creating room while not being logged in").user_id;
|
||||
|
||||
self.rooms
|
||||
.entry(room_id.clone())
|
||||
|
@ -359,7 +349,8 @@ impl Deref for Store {
|
|||
pub struct StateChanges {
|
||||
/// The sync token that relates to this update.
|
||||
pub sync_token: Option<String>,
|
||||
/// A user session, containing an access token and information about the associated user account.
|
||||
/// A user session, containing an access token and information about the
|
||||
/// associated user account.
|
||||
pub session: Option<Session>,
|
||||
/// A mapping of event type string to `AnyBasicEvent`.
|
||||
pub account_data: BTreeMap<String, Raw<AnyGlobalAccountDataEvent>>,
|
||||
|
@ -371,14 +362,16 @@ pub struct StateChanges {
|
|||
/// A mapping of `RoomId` to a map of users and their `MemberEventContent`.
|
||||
pub profiles: BTreeMap<RoomId, BTreeMap<UserId, MemberEventContent>>,
|
||||
|
||||
/// A mapping of `RoomId` to a map of event type string to a state key and `AnySyncStateEvent`.
|
||||
/// A mapping of `RoomId` to a map of event type string to a state key and
|
||||
/// `AnySyncStateEvent`.
|
||||
pub state: BTreeMap<RoomId, BTreeMap<String, BTreeMap<String, Raw<AnySyncStateEvent>>>>,
|
||||
/// A mapping of `RoomId` to a map of event type string to `AnyBasicEvent`.
|
||||
pub room_account_data: BTreeMap<RoomId, BTreeMap<String, Raw<AnyRoomAccountDataEvent>>>,
|
||||
/// A map of `RoomId` to `RoomInfo`.
|
||||
pub room_infos: BTreeMap<RoomId, RoomInfo>,
|
||||
|
||||
/// A mapping of `RoomId` to a map of event type to a map of state key to `AnyStrippedStateEvent`.
|
||||
/// A mapping of `RoomId` to a map of event type to a map of state key to
|
||||
/// `AnyStrippedStateEvent`.
|
||||
pub stripped_state:
|
||||
BTreeMap<RoomId, BTreeMap<String, BTreeMap<String, Raw<AnyStrippedStateEvent>>>>,
|
||||
/// A mapping of `RoomId` to a map of users and their `StrippedMemberEvent`.
|
||||
|
@ -396,10 +389,7 @@ pub struct StateChanges {
|
|||
impl StateChanges {
|
||||
/// Create a new `StateChanges` struct with the given sync_token.
|
||||
pub fn new(sync_token: String) -> Self {
|
||||
Self {
|
||||
sync_token: Some(sync_token),
|
||||
..Default::default()
|
||||
}
|
||||
Self { sync_token: Some(sync_token), ..Default::default() }
|
||||
}
|
||||
|
||||
/// Update the `StateChanges` struct with the given `PresenceEvent`.
|
||||
|
@ -409,14 +399,12 @@ impl StateChanges {
|
|||
|
||||
/// Update the `StateChanges` struct with the given `RoomInfo`.
|
||||
pub fn add_room(&mut self, room: RoomInfo) {
|
||||
self.room_infos
|
||||
.insert(room.room_id.as_ref().to_owned(), room);
|
||||
self.room_infos.insert(room.room_id.as_ref().to_owned(), room);
|
||||
}
|
||||
|
||||
/// Update the `StateChanges` struct with the given `RoomInfo`.
|
||||
pub fn add_stripped_room(&mut self, room: RoomInfo) {
|
||||
self.invited_room_info
|
||||
.insert(room.room_id.as_ref().to_owned(), room);
|
||||
self.invited_room_info.insert(room.room_id.as_ref().to_owned(), room);
|
||||
}
|
||||
|
||||
/// Update the `StateChanges` struct with the given `AnyBasicEvent`.
|
||||
|
@ -425,11 +413,11 @@ impl StateChanges {
|
|||
event: AnyGlobalAccountDataEvent,
|
||||
raw_event: Raw<AnyGlobalAccountDataEvent>,
|
||||
) {
|
||||
self.account_data
|
||||
.insert(event.content().event_type().to_owned(), raw_event);
|
||||
self.account_data.insert(event.content().event_type().to_owned(), raw_event);
|
||||
}
|
||||
|
||||
/// Update the `StateChanges` struct with the given room with a new `AnyBasicEvent`.
|
||||
/// Update the `StateChanges` struct with the given room with a new
|
||||
/// `AnyBasicEvent`.
|
||||
pub fn add_room_account_data(
|
||||
&mut self,
|
||||
room_id: &RoomId,
|
||||
|
@ -442,7 +430,8 @@ impl StateChanges {
|
|||
.insert(event.content().event_type().to_owned(), raw_event);
|
||||
}
|
||||
|
||||
/// Update the `StateChanges` struct with the given room with a new `StrippedMemberEvent`.
|
||||
/// Update the `StateChanges` struct with the given room with a new
|
||||
/// `StrippedMemberEvent`.
|
||||
pub fn add_stripped_member(&mut self, room_id: &RoomId, event: StrippedMemberEvent) {
|
||||
let user_id = event.state_key.clone();
|
||||
|
||||
|
@ -452,7 +441,8 @@ impl StateChanges {
|
|||
.insert(user_id, event);
|
||||
}
|
||||
|
||||
/// Update the `StateChanges` struct with the given room with a new `AnySyncStateEvent`.
|
||||
/// Update the `StateChanges` struct with the given room with a new
|
||||
/// `AnySyncStateEvent`.
|
||||
pub fn add_state_event(
|
||||
&mut self,
|
||||
room_id: &RoomId,
|
||||
|
@ -467,11 +457,9 @@ impl StateChanges {
|
|||
.insert(event.state_key().to_string(), raw_event);
|
||||
}
|
||||
|
||||
/// Update the `StateChanges` struct with the given room with a new `Notification`.
|
||||
/// Update the `StateChanges` struct with the given room with a new
|
||||
/// `Notification`.
|
||||
pub fn add_notification(&mut self, room_id: &RoomId, notification: Notification) {
|
||||
self.notifications
|
||||
.entry(room_id.to_owned())
|
||||
.or_insert_with(Vec::new)
|
||||
.push(notification);
|
||||
self.notifications.entry(room_id.to_owned()).or_insert_with(Vec::new).push(notification);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -37,18 +37,15 @@ use matrix_sdk_common::{
|
|||
Raw,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use sled::{
|
||||
transaction::{ConflictableTransactionError, TransactionError},
|
||||
Config, Db, Transactional, Tree,
|
||||
};
|
||||
use tracing::info;
|
||||
|
||||
use crate::deserialized_responses::MemberEvent;
|
||||
|
||||
use self::store_key::{EncryptedEvent, StoreKey};
|
||||
|
||||
use super::{Result, RoomInfo, StateChanges, StateStore, StoreError};
|
||||
use crate::deserialized_responses::MemberEvent;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub enum DatabaseType {
|
||||
|
@ -111,13 +108,7 @@ impl EncodeKey for &str {
|
|||
|
||||
impl EncodeKey for (&str, &str) {
|
||||
fn encode(&self) -> Vec<u8> {
|
||||
[
|
||||
self.0.as_bytes(),
|
||||
&[Self::SEPARATOR],
|
||||
self.1.as_bytes(),
|
||||
&[Self::SEPARATOR],
|
||||
]
|
||||
.concat()
|
||||
[self.0.as_bytes(), &[Self::SEPARATOR], self.1.as_bytes(), &[Self::SEPARATOR]].concat()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -167,9 +158,7 @@ impl std::fmt::Debug for SledStore {
|
|||
if let Some(path) = &self.path {
|
||||
f.debug_struct("SledStore").field("path", &path).finish()
|
||||
} else {
|
||||
f.debug_struct("SledStore")
|
||||
.field("path", &"memory store")
|
||||
.finish()
|
||||
f.debug_struct("SledStore").field("path", &"memory store").finish()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -239,8 +228,7 @@ impl SledStore {
|
|||
} else {
|
||||
let key = StoreKey::new().map_err::<StoreError, _>(|e| e.into())?;
|
||||
let encrypted_key = DatabaseType::Encrypted(
|
||||
key.export(passphrase)
|
||||
.map_err::<StoreError, _>(|e| e.into())?,
|
||||
key.export(passphrase).map_err::<StoreError, _>(|e| e.into())?,
|
||||
);
|
||||
db.insert("store_key".encode(), serde_json::to_vec(&encrypted_key)?)?;
|
||||
key
|
||||
|
@ -278,8 +266,7 @@ impl SledStore {
|
|||
}
|
||||
|
||||
pub async fn save_filter(&self, filter_name: &str, filter_id: &str) -> Result<()> {
|
||||
self.session
|
||||
.insert(("filter", filter_name).encode(), filter_id)?;
|
||||
self.session.insert(("filter", filter_name).encode(), filter_id)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -479,11 +466,7 @@ impl SledStore {
|
|||
}
|
||||
|
||||
pub async fn get_presence_event(&self, user_id: &UserId) -> Result<Option<Raw<PresenceEvent>>> {
|
||||
Ok(self
|
||||
.presence
|
||||
.get(user_id.encode())?
|
||||
.map(|e| self.deserialize_event(&e))
|
||||
.transpose()?)
|
||||
Ok(self.presence.get(user_id.encode())?.map(|e| self.deserialize_event(&e)).transpose()?)
|
||||
}
|
||||
|
||||
pub async fn get_state_event(
|
||||
|
@ -534,14 +517,10 @@ impl SledStore {
|
|||
&self,
|
||||
room_id: &RoomId,
|
||||
) -> impl Stream<Item = Result<UserId>> {
|
||||
stream::iter(
|
||||
self.invited_user_ids
|
||||
.scan_prefix(room_id.encode())
|
||||
.map(|u| {
|
||||
UserId::try_from(String::from_utf8_lossy(&u?.1).to_string())
|
||||
.map_err(StoreError::Identifier)
|
||||
}),
|
||||
)
|
||||
stream::iter(self.invited_user_ids.scan_prefix(room_id.encode()).map(|u| {
|
||||
UserId::try_from(String::from_utf8_lossy(&u?.1).to_string())
|
||||
.map_err(StoreError::Identifier)
|
||||
}))
|
||||
}
|
||||
|
||||
pub async fn get_joined_user_ids(
|
||||
|
@ -557,9 +536,7 @@ impl SledStore {
|
|||
pub async fn get_room_infos(&self) -> impl Stream<Item = Result<RoomInfo>> {
|
||||
let db = self.clone();
|
||||
stream::iter(
|
||||
self.room_info
|
||||
.iter()
|
||||
.map(move |r| db.deserialize_event(&r?.1).map_err(|e| e.into())),
|
||||
self.room_info.iter().map(move |r| db.deserialize_event(&r?.1).map_err(|e| e.into())),
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -683,8 +660,7 @@ impl StateStore for SledStore {
|
|||
room_id: &RoomId,
|
||||
display_name: &str,
|
||||
) -> Result<BTreeSet<UserId>> {
|
||||
self.get_users_with_display_name(room_id, display_name)
|
||||
.await
|
||||
self.get_users_with_display_name(room_id, display_name).await
|
||||
}
|
||||
|
||||
async fn get_account_data_event(
|
||||
|
@ -770,11 +746,7 @@ mod test {
|
|||
let room_id = room_id!("!test:localhost");
|
||||
let user_id = user_id();
|
||||
|
||||
assert!(store
|
||||
.get_member_event(&room_id, &user_id)
|
||||
.await
|
||||
.unwrap()
|
||||
.is_none());
|
||||
assert!(store.get_member_event(&room_id, &user_id).await.unwrap().is_none());
|
||||
let mut changes = StateChanges::default();
|
||||
changes
|
||||
.members
|
||||
|
@ -783,11 +755,7 @@ mod test {
|
|||
.insert(user_id.clone(), membership_event());
|
||||
|
||||
store.save_changes(&changes).await.unwrap();
|
||||
assert!(store
|
||||
.get_member_event(&room_id, &user_id)
|
||||
.await
|
||||
.unwrap()
|
||||
.is_some());
|
||||
assert!(store.get_member_event(&room_id, &user_id).await.unwrap().is_some());
|
||||
}
|
||||
|
||||
#[async_test]
|
||||
|
|
|
@ -21,11 +21,10 @@ use chacha20poly1305::{
|
|||
use hmac::Hmac;
|
||||
use pbkdf2::pbkdf2;
|
||||
use rand::{thread_rng, Error as RngError, Fill};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sha2::Sha256;
|
||||
use zeroize::{Zeroize, Zeroizing};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::StoreError;
|
||||
|
||||
const VERSION: u8 = 1;
|
||||
|
@ -76,9 +75,11 @@ pub struct EncryptedEvent {
|
|||
#[derive(Debug, Serialize, Deserialize, PartialEq)]
|
||||
pub enum KdfInfo {
|
||||
Pbkdf2ToChaCha20Poly1305 {
|
||||
/// The number of PBKDF rounds that were used when deriving the store key.
|
||||
/// The number of PBKDF rounds that were used when deriving the store
|
||||
/// key.
|
||||
rounds: u32,
|
||||
/// The salt that was used when the passphrase was expanded into a store key.
|
||||
/// The salt that was used when the passphrase was expanded into a store
|
||||
/// key.
|
||||
kdf_salt: Vec<u8>,
|
||||
},
|
||||
}
|
||||
|
@ -170,10 +171,7 @@ impl StoreKey {
|
|||
cipher.encrypt(Nonce::from_slice(nonce.as_ref()), self.inner.as_slice())?;
|
||||
|
||||
Ok(EncryptedStoreKey {
|
||||
kdf_info: KdfInfo::Pbkdf2ToChaCha20Poly1305 {
|
||||
rounds: KDF_ROUNDS,
|
||||
kdf_salt: salt,
|
||||
},
|
||||
kdf_info: KdfInfo::Pbkdf2ToChaCha20Poly1305 { rounds: KDF_ROUNDS, kdf_salt: salt },
|
||||
ciphertext_info: CipherTextInfo::ChaCha20Poly1305 { nonce, ciphertext },
|
||||
})
|
||||
}
|
||||
|
@ -196,11 +194,7 @@ impl StoreKey {
|
|||
|
||||
let ciphertext = cipher.encrypt(xnonce, event.as_ref())?;
|
||||
|
||||
Ok(EncryptedEvent {
|
||||
version: VERSION,
|
||||
ciphertext,
|
||||
nonce,
|
||||
})
|
||||
Ok(EncryptedEvent { version: VERSION, ciphertext, nonce })
|
||||
}
|
||||
|
||||
pub fn decrypt<T: for<'b> Deserialize<'b>>(&self, event: EncryptedEvent) -> Result<T, Error> {
|
||||
|
@ -248,9 +242,10 @@ impl StoreKey {
|
|||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::StoreKey;
|
||||
use serde_json::{json, Value};
|
||||
|
||||
use super::StoreKey;
|
||||
|
||||
#[test]
|
||||
fn generating() {
|
||||
StoreKey::new().unwrap();
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
use std::{collections::BTreeMap, convert::TryFrom, time::SystemTime};
|
||||
|
||||
use ruma::{
|
||||
api::client::r0::sync::sync_events::{
|
||||
Ephemeral, InvitedRoom, Presence, RoomAccountData, State, ToDevice,
|
||||
|
@ -6,7 +8,6 @@ use ruma::{
|
|||
DeviceIdBox,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{collections::BTreeMap, convert::TryFrom, time::SystemTime};
|
||||
|
||||
use super::{
|
||||
api::r0::{
|
||||
|
@ -103,16 +104,14 @@ pub struct SyncRoomEvent {
|
|||
|
||||
impl From<Raw<AnySyncRoomEvent>> for SyncRoomEvent {
|
||||
fn from(inner: Raw<AnySyncRoomEvent>) -> Self {
|
||||
Self {
|
||||
encryption_info: None,
|
||||
event: inner,
|
||||
}
|
||||
Self { encryption_info: None, event: inner }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
|
||||
pub struct SyncResponse {
|
||||
/// The batch token to supply in the `since` param of the next `/sync` request.
|
||||
/// The batch token to supply in the `since` param of the next `/sync`
|
||||
/// request.
|
||||
pub next_batch: String,
|
||||
/// Updates to rooms.
|
||||
pub rooms: Rooms,
|
||||
|
@ -137,10 +136,7 @@ pub struct SyncResponse {
|
|||
|
||||
impl SyncResponse {
|
||||
pub fn new(next_batch: String) -> Self {
|
||||
Self {
|
||||
next_batch,
|
||||
..Default::default()
|
||||
}
|
||||
Self { next_batch, ..Default::default() }
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -161,14 +157,15 @@ pub struct JoinedRoom {
|
|||
pub unread_notifications: UnreadNotificationsCount,
|
||||
/// The timeline of messages and state changes in the room.
|
||||
pub timeline: Timeline,
|
||||
/// Updates to the state, between the time indicated by the `since` parameter, and the start
|
||||
/// of the `timeline` (or all state up to the start of the `timeline`, if `since` is not
|
||||
/// given, or `full_state` is true).
|
||||
/// Updates to the state, between the time indicated by the `since`
|
||||
/// parameter, and the start of the `timeline` (or all state up to the
|
||||
/// start of the `timeline`, if `since` is not given, or `full_state` is
|
||||
/// true).
|
||||
pub state: State,
|
||||
/// The private data that this user has attached to this room.
|
||||
pub account_data: RoomAccountData,
|
||||
/// The ephemeral events in the room that aren't recorded in the timeline or state of the
|
||||
/// room. e.g. typing.
|
||||
/// The ephemeral events in the room that aren't recorded in the timeline or
|
||||
/// state of the room. e.g. typing.
|
||||
pub ephemeral: Ephemeral,
|
||||
}
|
||||
|
||||
|
@ -180,20 +177,15 @@ impl JoinedRoom {
|
|||
ephemeral: Ephemeral,
|
||||
unread_notifications: UnreadNotificationsCount,
|
||||
) -> Self {
|
||||
Self {
|
||||
unread_notifications,
|
||||
timeline,
|
||||
state,
|
||||
account_data,
|
||||
ephemeral,
|
||||
}
|
||||
Self { unread_notifications, timeline, state, account_data, ephemeral }
|
||||
}
|
||||
}
|
||||
|
||||
/// Counts of unread notifications for a room.
|
||||
#[derive(Copy, Clone, Debug, Default, Deserialize, Serialize)]
|
||||
pub struct UnreadNotificationsCount {
|
||||
/// The number of unread notifications for this room with the highlight flag set.
|
||||
/// The number of unread notifications for this room with the highlight flag
|
||||
/// set.
|
||||
pub highlight_count: u64,
|
||||
/// The total number of unread notifications for this room.
|
||||
pub notification_count: u64,
|
||||
|
@ -203,10 +195,7 @@ impl From<RumaUnreadNotificationsCount> for UnreadNotificationsCount {
|
|||
fn from(notifications: RumaUnreadNotificationsCount) -> Self {
|
||||
Self {
|
||||
highlight_count: notifications.highlight_count.map(|c| c.into()).unwrap_or(0),
|
||||
notification_count: notifications
|
||||
.notification_count
|
||||
.map(|c| c.into())
|
||||
.unwrap_or(0),
|
||||
notification_count: notifications.notification_count.map(|c| c.into()).unwrap_or(0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -216,9 +205,10 @@ pub struct LeftRoom {
|
|||
/// The timeline of messages and state changes in the room up to the point
|
||||
/// when the user left.
|
||||
pub timeline: Timeline,
|
||||
/// Updates to the state, between the time indicated by the `since` parameter, and the start
|
||||
/// of the `timeline` (or all state up to the start of the `timeline`, if `since` is not
|
||||
/// given, or `full_state` is true).
|
||||
/// Updates to the state, between the time indicated by the `since`
|
||||
/// parameter, and the start of the `timeline` (or all state up to the
|
||||
/// start of the `timeline`, if `since` is not given, or `full_state` is
|
||||
/// true).
|
||||
pub state: State,
|
||||
/// The private data that this user has attached to this room.
|
||||
pub account_data: RoomAccountData,
|
||||
|
@ -226,18 +216,15 @@ pub struct LeftRoom {
|
|||
|
||||
impl LeftRoom {
|
||||
pub fn new(timeline: Timeline, state: State, account_data: RoomAccountData) -> Self {
|
||||
Self {
|
||||
timeline,
|
||||
state,
|
||||
account_data,
|
||||
}
|
||||
Self { timeline, state, account_data }
|
||||
}
|
||||
}
|
||||
|
||||
/// Events in the room.
|
||||
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
|
||||
pub struct Timeline {
|
||||
/// True if the number of events returned was limited by the `limit` on the filter.
|
||||
/// True if the number of events returned was limited by the `limit` on the
|
||||
/// filter.
|
||||
pub limited: bool,
|
||||
|
||||
/// A token that can be supplied to to the `from` parameter of the
|
||||
|
@ -250,11 +237,7 @@ pub struct Timeline {
|
|||
|
||||
impl Timeline {
|
||||
pub fn new(limited: bool, prev_batch: Option<String>) -> Self {
|
||||
Self {
|
||||
limited,
|
||||
prev_batch,
|
||||
..Default::default()
|
||||
}
|
||||
Self { limited, prev_batch, ..Default::default() }
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -6,14 +6,12 @@ use std::{
|
|||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub use tokio::spawn;
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
use wasm_bindgen_futures::spawn_local;
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
use futures::{future::RemoteHandle, Future, FutureExt};
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub use tokio::spawn;
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
use wasm_bindgen_futures::spawn_local;
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
pub fn spawn<F, T>(future: F) -> JoinHandle<T>
|
||||
|
|
|
@ -17,7 +17,6 @@ pub use ruma::{
|
|||
serde::{CanonicalJsonValue, Raw},
|
||||
thirdparty, uint, Int, Outgoing, UInt,
|
||||
};
|
||||
|
||||
pub use uuid;
|
||||
|
||||
pub mod deserialized_responses;
|
||||
|
|
|
@ -4,6 +4,5 @@
|
|||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
pub use futures_locks::{Mutex, MutexGuard, RwLock, RwLockReadGuard, RwLockWriteGuard};
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub use tokio::sync::{Mutex, MutexGuard, RwLock, RwLockReadGuard, RwLockWriteGuard};
|
||||
|
|
|
@ -4,7 +4,6 @@ mod perf;
|
|||
use std::sync::Arc;
|
||||
|
||||
use criterion::*;
|
||||
|
||||
use matrix_sdk_common::{
|
||||
api::r0::{
|
||||
keys::{claim_keys, get_keys},
|
||||
|
@ -50,17 +49,12 @@ fn huge_keys_query_resopnse() -> get_keys::Response {
|
|||
}
|
||||
|
||||
pub fn keys_query(c: &mut Criterion) {
|
||||
let runtime = Builder::new_multi_thread()
|
||||
.build()
|
||||
.expect("Can't create runtime");
|
||||
let runtime = Builder::new_multi_thread().build().expect("Can't create runtime");
|
||||
let machine = OlmMachine::new(&alice_id(), &alice_device_id());
|
||||
let response = keys_query_response();
|
||||
let uuid = Uuid::new_v4();
|
||||
|
||||
let count = response
|
||||
.device_keys
|
||||
.values()
|
||||
.fold(0, |acc, d| acc + d.len())
|
||||
let count = response.device_keys.values().fold(0, |acc, d| acc + d.len())
|
||||
+ response.master_keys.len()
|
||||
+ response.self_signing_keys.len()
|
||||
+ response.user_signing_keys.len();
|
||||
|
@ -70,14 +64,10 @@ pub fn keys_query(c: &mut Criterion) {
|
|||
|
||||
let name = format!("{} device and cross signing keys", count);
|
||||
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("memory store", &name),
|
||||
&response,
|
||||
|b, response| {
|
||||
b.to_async(&runtime)
|
||||
.iter(|| async { machine.mark_request_as_sent(&uuid, response).await.unwrap() })
|
||||
},
|
||||
);
|
||||
group.bench_with_input(BenchmarkId::new("memory store", &name), &response, |b, response| {
|
||||
b.to_async(&runtime)
|
||||
.iter(|| async { machine.mark_request_as_sent(&uuid, response).await.unwrap() })
|
||||
});
|
||||
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let machine = runtime
|
||||
|
@ -89,99 +79,74 @@ pub fn keys_query(c: &mut Criterion) {
|
|||
))
|
||||
.unwrap();
|
||||
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("sled store", &name),
|
||||
&response,
|
||||
|b, response| {
|
||||
b.to_async(&runtime)
|
||||
.iter(|| async { machine.mark_request_as_sent(&uuid, response).await.unwrap() })
|
||||
},
|
||||
);
|
||||
group.bench_with_input(BenchmarkId::new("sled store", &name), &response, |b, response| {
|
||||
b.to_async(&runtime)
|
||||
.iter(|| async { machine.mark_request_as_sent(&uuid, response).await.unwrap() })
|
||||
});
|
||||
|
||||
group.finish()
|
||||
}
|
||||
|
||||
pub fn keys_claiming(c: &mut Criterion) {
|
||||
let runtime = Arc::new(
|
||||
Builder::new_multi_thread()
|
||||
.build()
|
||||
.expect("Can't create runtime"),
|
||||
);
|
||||
let runtime = Arc::new(Builder::new_multi_thread().build().expect("Can't create runtime"));
|
||||
|
||||
let keys_query_response = keys_query_response();
|
||||
let uuid = Uuid::new_v4();
|
||||
|
||||
let response = keys_claim_response();
|
||||
|
||||
let count = response
|
||||
.one_time_keys
|
||||
.values()
|
||||
.fold(0, |acc, d| acc + d.len());
|
||||
let count = response.one_time_keys.values().fold(0, |acc, d| acc + d.len());
|
||||
|
||||
let mut group = c.benchmark_group("Olm session creation");
|
||||
group.throughput(Throughput::Elements(count as u64));
|
||||
|
||||
let name = format!("{} one-time keys", count);
|
||||
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("memory store", &name),
|
||||
&response,
|
||||
|b, response| {
|
||||
b.iter_batched(
|
||||
|| {
|
||||
let machine = OlmMachine::new(&alice_id(), &alice_device_id());
|
||||
runtime
|
||||
.block_on(machine.mark_request_as_sent(&uuid, &keys_query_response))
|
||||
.unwrap();
|
||||
(machine, runtime.clone())
|
||||
},
|
||||
move |(machine, runtime)| {
|
||||
runtime
|
||||
.block_on(machine.mark_request_as_sent(&uuid, response))
|
||||
.unwrap()
|
||||
},
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
},
|
||||
);
|
||||
group.bench_with_input(BenchmarkId::new("memory store", &name), &response, |b, response| {
|
||||
b.iter_batched(
|
||||
|| {
|
||||
let machine = OlmMachine::new(&alice_id(), &alice_device_id());
|
||||
runtime
|
||||
.block_on(machine.mark_request_as_sent(&uuid, &keys_query_response))
|
||||
.unwrap();
|
||||
(machine, runtime.clone())
|
||||
},
|
||||
move |(machine, runtime)| {
|
||||
runtime.block_on(machine.mark_request_as_sent(&uuid, response)).unwrap()
|
||||
},
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
});
|
||||
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("sled store", &name),
|
||||
&response,
|
||||
|b, response| {
|
||||
b.iter_batched(
|
||||
|| {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let machine = runtime
|
||||
.block_on(OlmMachine::new_with_default_store(
|
||||
&alice_id(),
|
||||
&alice_device_id(),
|
||||
dir.path(),
|
||||
None,
|
||||
))
|
||||
.unwrap();
|
||||
runtime
|
||||
.block_on(machine.mark_request_as_sent(&uuid, &keys_query_response))
|
||||
.unwrap();
|
||||
(machine, runtime.clone())
|
||||
},
|
||||
move |(machine, runtime)| {
|
||||
runtime
|
||||
.block_on(machine.mark_request_as_sent(&uuid, response))
|
||||
.unwrap()
|
||||
},
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
},
|
||||
);
|
||||
group.bench_with_input(BenchmarkId::new("sled store", &name), &response, |b, response| {
|
||||
b.iter_batched(
|
||||
|| {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let machine = runtime
|
||||
.block_on(OlmMachine::new_with_default_store(
|
||||
&alice_id(),
|
||||
&alice_device_id(),
|
||||
dir.path(),
|
||||
None,
|
||||
))
|
||||
.unwrap();
|
||||
runtime
|
||||
.block_on(machine.mark_request_as_sent(&uuid, &keys_query_response))
|
||||
.unwrap();
|
||||
(machine, runtime.clone())
|
||||
},
|
||||
move |(machine, runtime)| {
|
||||
runtime.block_on(machine.mark_request_as_sent(&uuid, response)).unwrap()
|
||||
},
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
});
|
||||
|
||||
group.finish()
|
||||
}
|
||||
|
||||
pub fn room_key_sharing(c: &mut Criterion) {
|
||||
let runtime = Builder::new_multi_thread()
|
||||
.build()
|
||||
.expect("Can't create runtime");
|
||||
let runtime = Builder::new_multi_thread().build().expect("Can't create runtime");
|
||||
|
||||
let keys_query_response = keys_query_response();
|
||||
let uuid = Uuid::new_v4();
|
||||
|
@ -191,18 +156,11 @@ pub fn room_key_sharing(c: &mut Criterion) {
|
|||
let to_device_response = ToDeviceResponse::new();
|
||||
let users: Vec<UserId> = keys_query_response.device_keys.keys().cloned().collect();
|
||||
|
||||
let count = response
|
||||
.one_time_keys
|
||||
.values()
|
||||
.fold(0, |acc, d| acc + d.len());
|
||||
let count = response.one_time_keys.values().fold(0, |acc, d| acc + d.len());
|
||||
|
||||
let machine = OlmMachine::new(&alice_id(), &alice_device_id());
|
||||
runtime
|
||||
.block_on(machine.mark_request_as_sent(&uuid, &keys_query_response))
|
||||
.unwrap();
|
||||
runtime
|
||||
.block_on(machine.mark_request_as_sent(&uuid, &response))
|
||||
.unwrap();
|
||||
runtime.block_on(machine.mark_request_as_sent(&uuid, &keys_query_response)).unwrap();
|
||||
runtime.block_on(machine.mark_request_as_sent(&uuid, &response)).unwrap();
|
||||
|
||||
let mut group = c.benchmark_group("Room key sharing");
|
||||
group.throughput(Throughput::Elements(count as u64));
|
||||
|
@ -218,10 +176,7 @@ pub fn room_key_sharing(c: &mut Criterion) {
|
|||
assert!(!requests.is_empty());
|
||||
|
||||
for request in requests {
|
||||
machine
|
||||
.mark_request_as_sent(&request.txn_id, &to_device_response)
|
||||
.await
|
||||
.unwrap();
|
||||
machine.mark_request_as_sent(&request.txn_id, &to_device_response).await.unwrap();
|
||||
}
|
||||
|
||||
machine.invalidate_group_session(&room_id).await.unwrap();
|
||||
|
@ -237,12 +192,8 @@ pub fn room_key_sharing(c: &mut Criterion) {
|
|||
None,
|
||||
))
|
||||
.unwrap();
|
||||
runtime
|
||||
.block_on(machine.mark_request_as_sent(&uuid, &keys_query_response))
|
||||
.unwrap();
|
||||
runtime
|
||||
.block_on(machine.mark_request_as_sent(&uuid, &response))
|
||||
.unwrap();
|
||||
runtime.block_on(machine.mark_request_as_sent(&uuid, &keys_query_response)).unwrap();
|
||||
runtime.block_on(machine.mark_request_as_sent(&uuid, &response)).unwrap();
|
||||
|
||||
group.bench_function(BenchmarkId::new("sled store", &name), |b| {
|
||||
b.to_async(&runtime).iter(|| async {
|
||||
|
@ -254,10 +205,7 @@ pub fn room_key_sharing(c: &mut Criterion) {
|
|||
assert!(!requests.is_empty());
|
||||
|
||||
for request in requests {
|
||||
machine
|
||||
.mark_request_as_sent(&request.txn_id, &to_device_response)
|
||||
.await
|
||||
.unwrap();
|
||||
machine.mark_request_as_sent(&request.txn_id, &to_device_response).await.unwrap();
|
||||
}
|
||||
|
||||
machine.invalidate_group_session(&room_id).await.unwrap();
|
||||
|
@ -268,28 +216,21 @@ pub fn room_key_sharing(c: &mut Criterion) {
|
|||
}
|
||||
|
||||
pub fn devices_missing_sessions_collecting(c: &mut Criterion) {
|
||||
let runtime = Builder::new_multi_thread()
|
||||
.build()
|
||||
.expect("Can't create runtime");
|
||||
let runtime = Builder::new_multi_thread().build().expect("Can't create runtime");
|
||||
|
||||
let machine = OlmMachine::new(&alice_id(), &alice_device_id());
|
||||
let response = huge_keys_query_resopnse();
|
||||
let uuid = Uuid::new_v4();
|
||||
let users: Vec<UserId> = response.device_keys.keys().cloned().collect();
|
||||
|
||||
let count = response
|
||||
.device_keys
|
||||
.values()
|
||||
.fold(0, |acc, d| acc + d.len());
|
||||
let count = response.device_keys.values().fold(0, |acc, d| acc + d.len());
|
||||
|
||||
let mut group = c.benchmark_group("Devices missing sessions collecting");
|
||||
group.throughput(Throughput::Elements(count as u64));
|
||||
|
||||
let name = format!("{} devices", count);
|
||||
|
||||
runtime
|
||||
.block_on(machine.mark_request_as_sent(&uuid, &response))
|
||||
.unwrap();
|
||||
runtime.block_on(machine.mark_request_as_sent(&uuid, &response)).unwrap();
|
||||
|
||||
group.bench_function(BenchmarkId::new("memory store", &name), |b| {
|
||||
b.to_async(&runtime).iter_with_large_drop(|| async {
|
||||
|
@ -307,9 +248,7 @@ pub fn devices_missing_sessions_collecting(c: &mut Criterion) {
|
|||
))
|
||||
.unwrap();
|
||||
|
||||
runtime
|
||||
.block_on(machine.mark_request_as_sent(&uuid, &response))
|
||||
.unwrap();
|
||||
runtime.block_on(machine.mark_request_as_sent(&uuid, &response)).unwrap();
|
||||
|
||||
group.bench_function(BenchmarkId::new("sled store", &name), |b| {
|
||||
b.to_async(&runtime)
|
||||
|
|
|
@ -6,8 +6,9 @@ use std::{fs::File, os::raw::c_int, path::Path};
|
|||
use criterion::profiler::Profiler;
|
||||
use pprof::ProfilerGuard;
|
||||
|
||||
/// Small custom profiler that can be used with Criterion to create a flamegraph for benchmarks.
|
||||
/// Also see [the Criterion documentation on this][custom-profiler].
|
||||
/// Small custom profiler that can be used with Criterion to create a flamegraph
|
||||
/// for benchmarks. Also see [the Criterion documentation on
|
||||
/// this][custom-profiler].
|
||||
///
|
||||
/// ## Example on how to enable the custom profiler:
|
||||
///
|
||||
|
@ -30,12 +31,12 @@ use pprof::ProfilerGuard;
|
|||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// The neat thing about this is that it will sample _only_ the benchmark, and not other stuff like
|
||||
/// the setup process.
|
||||
/// The neat thing about this is that it will sample _only_ the benchmark, and
|
||||
/// not other stuff like the setup process.
|
||||
///
|
||||
/// Further, it will only kick in if `--profile-time <time>` is passed to the benchmark binary.
|
||||
/// A flamegraph will be created for each individual benchmark in its report directory under
|
||||
/// `profile/flamegraph.svg`.
|
||||
/// Further, it will only kick in if `--profile-time <time>` is passed to the
|
||||
/// benchmark binary. A flamegraph will be created for each individual benchmark
|
||||
/// in its report directory under `profile/flamegraph.svg`.
|
||||
///
|
||||
/// [custom-profiler]: https://bheisler.github.io/criterion.rs/book/user_guide/profiling.html#implementing-in-process-profiling-hooks
|
||||
pub struct FlamegraphProfiler<'a> {
|
||||
|
@ -45,10 +46,7 @@ pub struct FlamegraphProfiler<'a> {
|
|||
|
||||
impl<'a> FlamegraphProfiler<'a> {
|
||||
pub fn new(frequency: c_int) -> Self {
|
||||
FlamegraphProfiler {
|
||||
frequency,
|
||||
active_profiler: None,
|
||||
}
|
||||
FlamegraphProfiler { frequency, active_profiler: None }
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -17,21 +17,17 @@ use std::{
|
|||
io::{Error as IoError, ErrorKind, Read},
|
||||
};
|
||||
|
||||
use thiserror::Error;
|
||||
use zeroize::Zeroizing;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use matrix_sdk_common::events::room::JsonWebKey;
|
||||
|
||||
use getrandom::getrandom;
|
||||
|
||||
use aes_ctr::{
|
||||
cipher::{NewStreamCipher, SyncStreamCipher},
|
||||
Aes256Ctr,
|
||||
};
|
||||
use base64::DecodeError;
|
||||
use getrandom::getrandom;
|
||||
use matrix_sdk_common::events::room::JsonWebKey;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sha2::{Digest, Sha256};
|
||||
use thiserror::Error;
|
||||
use zeroize::Zeroizing;
|
||||
|
||||
use crate::utilities::{decode, decode_url_safe, encode, encode_url_safe};
|
||||
|
||||
|
@ -59,10 +55,7 @@ impl<'a, R: Read> Read for AttachmentDecryptor<'a, R> {
|
|||
if hash.as_slice() == self.expected_hash.as_slice() {
|
||||
Ok(0)
|
||||
} else {
|
||||
Err(IoError::new(
|
||||
ErrorKind::Other,
|
||||
"Hash missmatch while decrypting",
|
||||
))
|
||||
Err(IoError::new(ErrorKind::Other, "Hash missmatch while decrypting"))
|
||||
}
|
||||
} else {
|
||||
self.sha.update(&buf[0..read_bytes]);
|
||||
|
@ -130,23 +123,14 @@ impl<'a, R: Read + 'a> AttachmentDecryptor<'a, R> {
|
|||
return Err(DecryptorError::UnknownVersion);
|
||||
}
|
||||
|
||||
let hash = decode(
|
||||
info.hashes
|
||||
.get("sha256")
|
||||
.ok_or(DecryptorError::MissingHash)?,
|
||||
)?;
|
||||
let hash = decode(info.hashes.get("sha256").ok_or(DecryptorError::MissingHash)?)?;
|
||||
let key = Zeroizing::from(decode_url_safe(info.web_key.k)?);
|
||||
let iv = decode(info.iv)?;
|
||||
|
||||
let sha = Sha256::default();
|
||||
let aes = Aes256Ctr::new_var(&key, &iv).map_err(|_| DecryptorError::KeyNonceLength)?;
|
||||
|
||||
Ok(AttachmentDecryptor {
|
||||
inner_reader: input,
|
||||
expected_hash: hash,
|
||||
sha,
|
||||
aes,
|
||||
})
|
||||
Ok(AttachmentDecryptor { inner_reader: input, expected_hash: hash, sha, aes })
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -168,9 +152,7 @@ impl<'a, R: Read + 'a> Read for AttachmentEncryptor<'a, R> {
|
|||
|
||||
if read_bytes == 0 {
|
||||
let hash = self.sha.finalize_reset();
|
||||
self.hashes
|
||||
.entry("sha256".to_owned())
|
||||
.or_insert_with(|| encode(hash));
|
||||
self.hashes.entry("sha256".to_owned()).or_insert_with(|| encode(hash));
|
||||
Ok(0)
|
||||
} else {
|
||||
self.aes.apply_keystream(&mut buf[0..read_bytes]);
|
||||
|
@ -244,9 +226,7 @@ impl<'a, R: Read + 'a> AttachmentEncryptor<'a, R> {
|
|||
/// Consume the encryptor and get the encryption key.
|
||||
pub fn finish(mut self) -> EncryptionInfo {
|
||||
let hash = self.sha.finalize();
|
||||
self.hashes
|
||||
.entry("sha256".to_owned())
|
||||
.or_insert_with(|| encode(hash));
|
||||
self.hashes.entry("sha256".to_owned()).or_insert_with(|| encode(hash));
|
||||
|
||||
EncryptionInfo {
|
||||
version: VERSION.to_string(),
|
||||
|
@ -274,10 +254,12 @@ pub struct EncryptionInfo {
|
|||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::{AttachmentDecryptor, AttachmentEncryptor, EncryptionInfo};
|
||||
use serde_json::json;
|
||||
use std::io::{Cursor, Read};
|
||||
|
||||
use serde_json::json;
|
||||
|
||||
use super::{AttachmentDecryptor, AttachmentEncryptor, EncryptionInfo};
|
||||
|
||||
const EXAMPLE_DATA: &[u8] = &[
|
||||
179, 154, 118, 127, 186, 127, 110, 33, 203, 33, 33, 134, 67, 100, 173, 46, 235, 27, 215,
|
||||
172, 36, 26, 75, 47, 33, 160,
|
||||
|
|
|
@ -12,20 +12,19 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use serde_json::Error as SerdeError;
|
||||
use std::io::{Cursor, Read, Seek, SeekFrom};
|
||||
use thiserror::Error;
|
||||
|
||||
use byteorder::{BigEndian, ReadBytesExt};
|
||||
use getrandom::getrandom;
|
||||
|
||||
use aes_ctr::{
|
||||
cipher::{NewStreamCipher, SyncStreamCipher},
|
||||
Aes256Ctr,
|
||||
};
|
||||
use byteorder::{BigEndian, ReadBytesExt};
|
||||
use getrandom::getrandom;
|
||||
use hmac::{Hmac, Mac, NewMac};
|
||||
use pbkdf2::pbkdf2;
|
||||
use serde_json::Error as SerdeError;
|
||||
use sha2::{Sha256, Sha512};
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::{
|
||||
olm::ExportedRoomKey,
|
||||
|
@ -99,14 +98,10 @@ pub fn decrypt_key_export(
|
|||
return Err(KeyExportError::InvalidHeaders);
|
||||
}
|
||||
|
||||
let payload: String = x
|
||||
.lines()
|
||||
.filter(|l| !(l.starts_with(HEADER) || l.starts_with(FOOTER)))
|
||||
.collect();
|
||||
let payload: String =
|
||||
x.lines().filter(|l| !(l.starts_with(HEADER) || l.starts_with(FOOTER))).collect();
|
||||
|
||||
Ok(serde_json::from_str(&decrypt_helper(
|
||||
&payload, passphrase,
|
||||
)?)?)
|
||||
Ok(serde_json::from_str(&decrypt_helper(&payload, passphrase)?)?)
|
||||
}
|
||||
|
||||
/// Encrypt the list of exported room keys using the given passphrase.
|
||||
|
@ -231,12 +226,12 @@ fn decrypt_helper(ciphertext: &str, passphrase: &str) -> Result<String, KeyExpor
|
|||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use indoc::indoc;
|
||||
use proptest::prelude::*;
|
||||
use std::io::Cursor;
|
||||
|
||||
use indoc::indoc;
|
||||
use matrix_sdk_common::identifiers::room_id;
|
||||
use matrix_sdk_test::async_test;
|
||||
use proptest::prelude::*;
|
||||
|
||||
use super::{decode, decrypt_helper, decrypt_key_export, encrypt_helper, encrypt_key_export};
|
||||
use crate::machine::test::get_prepared_machine;
|
||||
|
@ -261,10 +256,7 @@ mod test {
|
|||
"};
|
||||
|
||||
fn export_wihtout_headers() -> String {
|
||||
TEST_EXPORT
|
||||
.lines()
|
||||
.filter(|l| !l.starts_with("-----"))
|
||||
.collect()
|
||||
TEST_EXPORT.lines().filter(|l| !l.starts_with("-----")).collect()
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -301,14 +293,8 @@ mod test {
|
|||
let (machine, _) = get_prepared_machine().await;
|
||||
let room_id = room_id!("!test:localhost");
|
||||
|
||||
machine
|
||||
.create_outbound_group_session_with_defaults(&room_id)
|
||||
.await
|
||||
.unwrap();
|
||||
let export = machine
|
||||
.export_keys(|s| s.room_id() == &room_id)
|
||||
.await
|
||||
.unwrap();
|
||||
machine.create_outbound_group_session_with_defaults(&room_id).await.unwrap();
|
||||
let export = machine.export_keys(|s| s.room_id() == &room_id).await.unwrap();
|
||||
|
||||
assert!(!export.is_empty());
|
||||
|
||||
|
@ -316,10 +302,7 @@ mod test {
|
|||
let decrypted = decrypt_key_export(Cursor::new(encrypted), "1234").unwrap();
|
||||
|
||||
assert_eq!(export, decrypted);
|
||||
assert_eq!(
|
||||
machine.import_keys(decrypted, |_, _| {}).await.unwrap(),
|
||||
(0, 1)
|
||||
);
|
||||
assert_eq!(machine.import_keys(decrypted, |_, _| {}).await.unwrap(), (0, 1));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
@ -39,24 +39,17 @@ use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
|||
use serde_json::{json, Value};
|
||||
use tracing::warn;
|
||||
|
||||
use crate::{
|
||||
olm::{InboundGroupSession, PrivateCrossSigningIdentity, Session},
|
||||
store::{Changes, DeviceChanges},
|
||||
OutgoingVerificationRequest,
|
||||
};
|
||||
#[cfg(test)]
|
||||
use crate::{OlmMachine, ReadOnlyAccount};
|
||||
|
||||
use super::{atomic_bool_deserializer, atomic_bool_serializer};
|
||||
use crate::{
|
||||
error::{EventError, OlmError, OlmResult, SignatureError},
|
||||
identities::{OwnUserIdentity, UserIdentities},
|
||||
olm::Utility,
|
||||
store::{CryptoStore, Result as StoreResult},
|
||||
olm::{InboundGroupSession, PrivateCrossSigningIdentity, Session, Utility},
|
||||
store::{Changes, CryptoStore, DeviceChanges, Result as StoreResult},
|
||||
verification::VerificationMachine,
|
||||
Sas, ToDeviceRequest,
|
||||
OutgoingVerificationRequest, Sas, ToDeviceRequest,
|
||||
};
|
||||
|
||||
use super::{atomic_bool_deserializer, atomic_bool_serializer};
|
||||
#[cfg(test)]
|
||||
use crate::{OlmMachine, ReadOnlyAccount};
|
||||
|
||||
/// A read-only version of a `Device`.
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
|
@ -120,9 +113,7 @@ pub struct Device {
|
|||
|
||||
impl std::fmt::Debug for Device {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("Device")
|
||||
.field("device", &self.inner)
|
||||
.finish()
|
||||
f.debug_struct("Device").field("device", &self.inner).finish()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -139,10 +130,7 @@ impl Device {
|
|||
///
|
||||
/// Returns a `Sas` object and to-device request that needs to be sent out.
|
||||
pub async fn start_verification(&self) -> StoreResult<(Sas, ToDeviceRequest)> {
|
||||
let (sas, request) = self
|
||||
.verification_machine
|
||||
.start_sas(self.inner.clone())
|
||||
.await?;
|
||||
let (sas, request) = self.verification_machine.start_sas(self.inner.clone()).await?;
|
||||
|
||||
if let OutgoingVerificationRequest::ToDevice(r) = request {
|
||||
Ok((sas, r))
|
||||
|
@ -162,8 +150,7 @@ impl Device {
|
|||
|
||||
/// Get the trust state of the device.
|
||||
pub fn trust_state(&self) -> bool {
|
||||
self.inner
|
||||
.trust_state(&self.own_identity, &self.device_owner_identity)
|
||||
self.inner.trust_state(&self.own_identity, &self.device_owner_identity)
|
||||
}
|
||||
|
||||
/// Set the local trust state of the device to the given state.
|
||||
|
@ -178,10 +165,7 @@ impl Device {
|
|||
self.inner.set_trust_state(trust_state);
|
||||
|
||||
let changes = Changes {
|
||||
devices: DeviceChanges {
|
||||
changed: vec![self.inner.clone()],
|
||||
..Default::default()
|
||||
},
|
||||
devices: DeviceChanges { changed: vec![self.inner.clone()], ..Default::default() },
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
|
@ -200,9 +184,7 @@ impl Device {
|
|||
event_type: EventType,
|
||||
content: Value,
|
||||
) -> OlmResult<(Session, EncryptedEventContent)> {
|
||||
self.inner
|
||||
.encrypt(&**self.verification_machine.store, event_type, content)
|
||||
.await
|
||||
self.inner.encrypt(&**self.verification_machine.store, event_type, content).await
|
||||
}
|
||||
|
||||
/// Encrypt the given inbound group session as a forwarded room key for this
|
||||
|
@ -261,9 +243,7 @@ impl UserDevices {
|
|||
/// Returns true if there is at least one devices of this user that is
|
||||
/// considered to be verified, false otherwise.
|
||||
pub fn is_any_verified(&self) -> bool {
|
||||
self.inner
|
||||
.values()
|
||||
.any(|d| d.trust_state(&self.own_identity, &self.device_owner_identity))
|
||||
self.inner.values().any(|d| d.trust_state(&self.own_identity, &self.device_owner_identity))
|
||||
}
|
||||
|
||||
/// Iterator over all the device ids of the user devices.
|
||||
|
@ -348,8 +328,7 @@ impl ReadOnlyDevice {
|
|||
|
||||
/// Get the key of the given key algorithm belonging to this device.
|
||||
pub fn get_key(&self, algorithm: DeviceKeyAlgorithm) -> Option<&String> {
|
||||
self.keys
|
||||
.get(&DeviceKeyId::from_parts(algorithm, &self.device_id))
|
||||
self.keys.get(&DeviceKeyId::from_parts(algorithm, &self.device_id))
|
||||
}
|
||||
|
||||
/// Get a map containing all the device keys.
|
||||
|
@ -496,9 +475,8 @@ impl ReadOnlyDevice {
|
|||
}
|
||||
|
||||
fn is_signed_by_device(&self, json: &mut Value) -> Result<(), SignatureError> {
|
||||
let signing_key = self
|
||||
.get_key(DeviceKeyAlgorithm::Ed25519)
|
||||
.ok_or(SignatureError::MissingSigningKey)?;
|
||||
let signing_key =
|
||||
self.get_key(DeviceKeyAlgorithm::Ed25519).ok_or(SignatureError::MissingSigningKey)?;
|
||||
|
||||
let utility = Utility::new();
|
||||
|
||||
|
@ -590,14 +568,15 @@ impl PartialEq for ReadOnlyDevice {
|
|||
|
||||
#[cfg(test)]
|
||||
pub(crate) mod test {
|
||||
use serde_json::json;
|
||||
use std::convert::TryFrom;
|
||||
|
||||
use crate::identities::{LocalTrust, ReadOnlyDevice};
|
||||
use matrix_sdk_common::{
|
||||
encryption::DeviceKeys,
|
||||
identifiers::{user_id, DeviceKeyAlgorithm},
|
||||
};
|
||||
use serde_json::json;
|
||||
|
||||
use crate::identities::{LocalTrust, ReadOnlyDevice};
|
||||
|
||||
fn device_keys() -> DeviceKeys {
|
||||
let device_keys = json!({
|
||||
|
@ -640,10 +619,7 @@ pub(crate) mod test {
|
|||
assert_eq!(device_id, device.device_id());
|
||||
assert_eq!(device.algorithms.len(), 2);
|
||||
assert_eq!(LocalTrust::Unset, device.local_trust_state());
|
||||
assert_eq!(
|
||||
"Alice's mobile phone",
|
||||
device.display_name().as_ref().unwrap()
|
||||
);
|
||||
assert_eq!("Alice's mobile phone", device.display_name().as_ref().unwrap());
|
||||
assert_eq!(
|
||||
device.get_key(DeviceKeyAlgorithm::Curve25519).unwrap(),
|
||||
"xfgbLIC5WAl1OIkpOzoxpCe8FsRDT6nch7NQsOb15nc"
|
||||
|
@ -658,10 +634,7 @@ pub(crate) mod test {
|
|||
fn update_a_device() {
|
||||
let mut device = get_device();
|
||||
|
||||
assert_eq!(
|
||||
"Alice's mobile phone",
|
||||
device.display_name().as_ref().unwrap()
|
||||
);
|
||||
assert_eq!("Alice's mobile phone", device.display_name().as_ref().unwrap());
|
||||
|
||||
let display_name = "Alice's work computer".to_owned();
|
||||
|
||||
|
|
|
@ -12,20 +12,20 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use futures::future::join_all;
|
||||
use std::{
|
||||
collections::{BTreeMap, HashSet},
|
||||
convert::TryFrom,
|
||||
sync::Arc,
|
||||
};
|
||||
use tracing::{trace, warn};
|
||||
|
||||
use futures::future::join_all;
|
||||
use matrix_sdk_common::{
|
||||
api::r0::keys::get_keys::Response as KeysQueryResponse,
|
||||
encryption::DeviceKeys,
|
||||
executor::spawn,
|
||||
identifiers::{DeviceIdBox, UserId},
|
||||
};
|
||||
use tracing::{trace, warn};
|
||||
|
||||
use crate::{
|
||||
error::OlmResult,
|
||||
|
@ -54,11 +54,7 @@ impl IdentityManager {
|
|||
const MAX_KEY_QUERY_USERS: usize = 250;
|
||||
|
||||
pub fn new(user_id: Arc<UserId>, device_id: Arc<DeviceIdBox>, store: Store) -> Self {
|
||||
IdentityManager {
|
||||
user_id,
|
||||
device_id,
|
||||
store,
|
||||
}
|
||||
IdentityManager { user_id, device_id, store }
|
||||
}
|
||||
|
||||
fn user_id(&self) -> &UserId {
|
||||
|
@ -78,9 +74,8 @@ impl IdentityManager {
|
|||
&self,
|
||||
response: &KeysQueryResponse,
|
||||
) -> OlmResult<(DeviceChanges, IdentityChanges)> {
|
||||
let changed_devices = self
|
||||
.handle_devices_from_key_query(response.device_keys.clone())
|
||||
.await?;
|
||||
let changed_devices =
|
||||
self.handle_devices_from_key_query(response.device_keys.clone()).await?;
|
||||
let changed_identities = self.handle_cross_singing_keys(response).await?;
|
||||
|
||||
let changes = Changes {
|
||||
|
@ -104,9 +99,8 @@ impl IdentityManager {
|
|||
store: Store,
|
||||
device_keys: DeviceKeys,
|
||||
) -> StoreResult<DeviceChange> {
|
||||
let old_device = store
|
||||
.get_readonly_device(&device_keys.user_id, &device_keys.device_id)
|
||||
.await?;
|
||||
let old_device =
|
||||
store.get_readonly_device(&device_keys.user_id, &device_keys.device_id).await?;
|
||||
|
||||
if let Some(mut device) = old_device {
|
||||
if let Err(e) = device.update_device(&device_keys) {
|
||||
|
@ -148,25 +142,20 @@ impl IdentityManager {
|
|||
|
||||
let current_devices: HashSet<DeviceIdBox> = device_map.keys().cloned().collect();
|
||||
|
||||
let tasks = device_map
|
||||
.into_iter()
|
||||
.filter_map(|(device_id, device_keys)| {
|
||||
// We don't need our own device in the device store.
|
||||
if user_id == *own_user_id && device_id == *own_device_id {
|
||||
None
|
||||
} else if user_id != device_keys.user_id || device_id != device_keys.device_id {
|
||||
warn!(
|
||||
"Mismatch in device keys payload of device {}|{} from user {}|{}",
|
||||
device_id, device_keys.device_id, user_id, device_keys.user_id
|
||||
);
|
||||
None
|
||||
} else {
|
||||
Some(spawn(Self::update_or_create_device(
|
||||
store.clone(),
|
||||
device_keys,
|
||||
)))
|
||||
}
|
||||
});
|
||||
let tasks = device_map.into_iter().filter_map(|(device_id, device_keys)| {
|
||||
// We don't need our own device in the device store.
|
||||
if user_id == *own_user_id && device_id == *own_device_id {
|
||||
None
|
||||
} else if user_id != device_keys.user_id || device_id != device_keys.device_id {
|
||||
warn!(
|
||||
"Mismatch in device keys payload of device {}|{} from user {}|{}",
|
||||
device_id, device_keys.device_id, user_id, device_keys.user_id
|
||||
);
|
||||
None
|
||||
} else {
|
||||
Some(spawn(Self::update_or_create_device(store.clone(), device_keys)))
|
||||
}
|
||||
});
|
||||
|
||||
let results = join_all(tasks).await;
|
||||
|
||||
|
@ -211,17 +200,15 @@ impl IdentityManager {
|
|||
) -> StoreResult<DeviceChanges> {
|
||||
let mut changes = DeviceChanges::default();
|
||||
|
||||
let tasks = device_keys_map
|
||||
.into_iter()
|
||||
.map(|(user_id, device_keys_map)| {
|
||||
spawn(Self::update_user_devices(
|
||||
self.store.clone(),
|
||||
self.user_id.clone(),
|
||||
self.device_id.clone(),
|
||||
user_id,
|
||||
device_keys_map,
|
||||
))
|
||||
});
|
||||
let tasks = device_keys_map.into_iter().map(|(user_id, device_keys_map)| {
|
||||
spawn(Self::update_user_devices(
|
||||
self.store.clone(),
|
||||
self.user_id.clone(),
|
||||
self.device_id.clone(),
|
||||
user_id,
|
||||
device_keys_map,
|
||||
))
|
||||
});
|
||||
|
||||
let results = join_all(tasks).await;
|
||||
|
||||
|
@ -254,10 +241,7 @@ impl IdentityManager {
|
|||
let self_signing = if let Some(s) = response.self_signing_keys.get(user_id) {
|
||||
SelfSigningPubkey::from(s)
|
||||
} else {
|
||||
warn!(
|
||||
"User identity for user {} didn't contain a self signing pubkey",
|
||||
user_id
|
||||
);
|
||||
warn!("User identity for user {} didn't contain a self signing pubkey", user_id);
|
||||
continue;
|
||||
};
|
||||
|
||||
|
@ -276,13 +260,11 @@ impl IdentityManager {
|
|||
continue;
|
||||
};
|
||||
|
||||
identity
|
||||
.update(master_key, self_signing, user_signing)
|
||||
.map(|_| (i, false))
|
||||
identity.update(master_key, self_signing, user_signing).map(|_| (i, false))
|
||||
}
|
||||
UserIdentities::Other(ref mut identity) => {
|
||||
identity.update(master_key, self_signing).map(|_| (i, false))
|
||||
}
|
||||
UserIdentities::Other(ref mut identity) => identity
|
||||
.update(master_key, self_signing)
|
||||
.map(|_| (i, false)),
|
||||
}
|
||||
} else if user_id == self.user_id() {
|
||||
if let Some(s) = response.user_signing_keys.get(user_id) {
|
||||
|
@ -310,10 +292,7 @@ impl IdentityManager {
|
|||
continue;
|
||||
}
|
||||
} else if master_key.user_id() != user_id || self_signing.user_id() != user_id {
|
||||
warn!(
|
||||
"User id mismatch in one of the cross signing keys for user {}",
|
||||
user_id
|
||||
);
|
||||
warn!("User id mismatch in one of the cross signing keys for user {}", user_id);
|
||||
continue;
|
||||
} else {
|
||||
UserIdentity::new(master_key, self_signing)
|
||||
|
@ -322,11 +301,7 @@ impl IdentityManager {
|
|||
|
||||
match result {
|
||||
Ok((i, new)) => {
|
||||
trace!(
|
||||
"Updated or created new user identity for {}: {:?}",
|
||||
user_id,
|
||||
i
|
||||
);
|
||||
trace!("Updated or created new user identity for {}: {:?}", user_id, i);
|
||||
if new {
|
||||
changes.new.push(i);
|
||||
} else {
|
||||
|
@ -334,10 +309,7 @@ impl IdentityManager {
|
|||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"Couldn't update or create new user identity for {}: {:?}",
|
||||
user_id, e
|
||||
);
|
||||
warn!("Couldn't update or create new user identity for {}: {:?}", user_id, e);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
@ -424,9 +396,7 @@ pub(crate) mod test {
|
|||
locks::Mutex,
|
||||
IncomingResponse,
|
||||
};
|
||||
|
||||
use matrix_sdk_test::async_test;
|
||||
|
||||
use serde_json::json;
|
||||
|
||||
use crate::{
|
||||
|
@ -637,10 +607,7 @@ pub(crate) mod test {
|
|||
let devices = manager.store.get_user_devices(&other_user).await.unwrap();
|
||||
assert_eq!(devices.devices().count(), 0);
|
||||
|
||||
manager
|
||||
.receive_keys_query_response(&other_key_query())
|
||||
.await
|
||||
.unwrap();
|
||||
manager.receive_keys_query_response(&other_key_query()).await.unwrap();
|
||||
|
||||
let devices = manager.store.get_user_devices(&other_user).await.unwrap();
|
||||
assert_eq!(devices.devices().count(), 1);
|
||||
|
@ -651,12 +618,7 @@ pub(crate) mod test {
|
|||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
let identity = manager
|
||||
.store
|
||||
.get_user_identity(&other_user)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
let identity = manager.store.get_user_identity(&other_user).await.unwrap().unwrap();
|
||||
let identity = identity.other().unwrap();
|
||||
|
||||
assert!(identity.is_device_signed(&device).is_ok())
|
||||
|
@ -669,10 +631,7 @@ pub(crate) mod test {
|
|||
let devices = manager.store.get_user_devices(&other_user).await.unwrap();
|
||||
assert_eq!(devices.devices().count(), 0);
|
||||
|
||||
manager
|
||||
.receive_keys_query_response(&other_key_query())
|
||||
.await
|
||||
.unwrap();
|
||||
manager.receive_keys_query_response(&other_key_query()).await.unwrap();
|
||||
|
||||
let devices = manager.store.get_user_devices(&other_user).await.unwrap();
|
||||
assert_eq!(devices.devices().count(), 1);
|
||||
|
@ -683,12 +642,7 @@ pub(crate) mod test {
|
|||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
let identity = manager
|
||||
.store
|
||||
.get_user_identity(&other_user)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
let identity = manager.store.get_user_identity(&other_user).await.unwrap().unwrap();
|
||||
let identity = identity.other().unwrap();
|
||||
|
||||
assert!(identity.is_device_signed(&device).is_ok())
|
||||
|
|
|
@ -29,10 +29,10 @@
|
|||
//!
|
||||
//! ## User
|
||||
//!
|
||||
//! Cross-signing capable devices will upload 3 additional (master, self-signing,
|
||||
//! user-signing) public keys which represent the user identity owning all the
|
||||
//! devices. This is represented in two ways, as a `UserIdentity` for other
|
||||
//! users and as `OwnUserIdentity` for our own user.
|
||||
//! Cross-signing capable devices will upload 3 additional (master,
|
||||
//! self-signing, user-signing) public keys which represent the user identity
|
||||
//! owning all the devices. This is represented in two ways, as a `UserIdentity`
|
||||
//! for other users and as `OwnUserIdentity` for our own user.
|
||||
//!
|
||||
//! This is done because the server will only give us access to 2 of the 3
|
||||
//! additional public keys for other users, while it will give us access to all
|
||||
|
@ -44,19 +44,19 @@ pub(crate) mod device;
|
|||
mod manager;
|
||||
pub(crate) mod user;
|
||||
|
||||
pub use device::{Device, LocalTrust, ReadOnlyDevice, UserDevices};
|
||||
pub(crate) use manager::IdentityManager;
|
||||
pub use user::{
|
||||
MasterPubkey, OwnUserIdentity, SelfSigningPubkey, UserIdentities, UserIdentity,
|
||||
UserSigningPubkey,
|
||||
};
|
||||
|
||||
use serde::{Deserialize, Deserializer, Serializer};
|
||||
use std::sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
};
|
||||
|
||||
pub use device::{Device, LocalTrust, ReadOnlyDevice, UserDevices};
|
||||
pub(crate) use manager::IdentityManager;
|
||||
use serde::{Deserialize, Deserializer, Serializer};
|
||||
pub use user::{
|
||||
MasterPubkey, OwnUserIdentity, SelfSigningPubkey, UserIdentities, UserIdentity,
|
||||
UserSigningPubkey,
|
||||
};
|
||||
|
||||
// These methods are only here because Serialize and Deserialize don't seem to
|
||||
// be implemented for WASM.
|
||||
fn atomic_bool_serializer<S>(x: &AtomicBool, s: S) -> Result<S::Ok, S::Error>
|
||||
|
|
|
@ -21,20 +21,18 @@ use std::{
|
|||
},
|
||||
};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::to_value;
|
||||
|
||||
use matrix_sdk_common::{
|
||||
api::r0::keys::{CrossSigningKey, KeyUsage},
|
||||
identifiers::{DeviceKeyId, UserId},
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::to_value;
|
||||
|
||||
use super::{atomic_bool_deserializer, atomic_bool_serializer};
|
||||
#[cfg(test)]
|
||||
use crate::olm::PrivateCrossSigningIdentity;
|
||||
use crate::{error::SignatureError, olm::Utility, ReadOnlyDevice};
|
||||
|
||||
use super::{atomic_bool_deserializer, atomic_bool_serializer};
|
||||
|
||||
/// Wrapper for a cross signing key marking it as the master key.
|
||||
///
|
||||
/// Master keys are used to sign other cross signing keys, the self signing and
|
||||
|
@ -227,12 +225,7 @@ impl MasterPubkey {
|
|||
&self,
|
||||
subkey: impl Into<CrossSigningSubKeys<'a>>,
|
||||
) -> Result<(), SignatureError> {
|
||||
let (key_id, key) = self
|
||||
.0
|
||||
.keys
|
||||
.iter()
|
||||
.next()
|
||||
.ok_or(SignatureError::MissingSigningKey)?;
|
||||
let (key_id, key) = self.0.keys.iter().next().ok_or(SignatureError::MissingSigningKey)?;
|
||||
|
||||
let key_id = DeviceKeyId::try_from(key_id.as_str())?;
|
||||
|
||||
|
@ -289,12 +282,7 @@ impl UserSigningPubkey {
|
|||
&self,
|
||||
master_key: &MasterPubkey,
|
||||
) -> Result<(), SignatureError> {
|
||||
let (key_id, key) = self
|
||||
.0
|
||||
.keys
|
||||
.iter()
|
||||
.next()
|
||||
.ok_or(SignatureError::MissingSigningKey)?;
|
||||
let (key_id, key) = self.0.keys.iter().next().ok_or(SignatureError::MissingSigningKey)?;
|
||||
|
||||
// TODO check that the usage is OK.
|
||||
|
||||
|
@ -337,12 +325,7 @@ impl SelfSigningPubkey {
|
|||
/// Returns an empty result if the signature check succeeded, otherwise a
|
||||
/// SignatureError indicating why the check failed.
|
||||
pub(crate) fn verify_device(&self, device: &ReadOnlyDevice) -> Result<(), SignatureError> {
|
||||
let (key_id, key) = self
|
||||
.0
|
||||
.keys
|
||||
.iter()
|
||||
.next()
|
||||
.ok_or(SignatureError::MissingSigningKey)?;
|
||||
let (key_id, key) = self.0.keys.iter().next().ok_or(SignatureError::MissingSigningKey)?;
|
||||
|
||||
// TODO check that the usage is OK.
|
||||
|
||||
|
@ -474,37 +457,16 @@ impl UserIdentity {
|
|||
) -> Result<Self, SignatureError> {
|
||||
master_key.verify_subkey(&self_signing_key)?;
|
||||
|
||||
Ok(Self {
|
||||
user_id: Arc::new(master_key.0.user_id.clone()),
|
||||
master_key,
|
||||
self_signing_key,
|
||||
})
|
||||
Ok(Self { user_id: Arc::new(master_key.0.user_id.clone()), master_key, self_signing_key })
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub async fn from_private(identity: &PrivateCrossSigningIdentity) -> Self {
|
||||
let master_key = identity
|
||||
.master_key
|
||||
.lock()
|
||||
.await
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.public_key
|
||||
.clone();
|
||||
let self_signing_key = identity
|
||||
.self_signing_key
|
||||
.lock()
|
||||
.await
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.public_key
|
||||
.clone();
|
||||
let master_key = identity.master_key.lock().await.as_ref().unwrap().public_key.clone();
|
||||
let self_signing_key =
|
||||
identity.self_signing_key.lock().await.as_ref().unwrap().public_key.clone();
|
||||
|
||||
Self {
|
||||
user_id: Arc::new(identity.user_id().clone()),
|
||||
master_key,
|
||||
self_signing_key,
|
||||
}
|
||||
Self { user_id: Arc::new(identity.user_id().clone()), master_key, self_signing_key }
|
||||
}
|
||||
|
||||
/// Get the user id of this identity.
|
||||
|
@ -646,8 +608,7 @@ impl OwnUserIdentity {
|
|||
/// Returns an empty result if the signature check succeeded, otherwise a
|
||||
/// SignatureError indicating why the check failed.
|
||||
pub fn is_identity_signed(&self, identity: &UserIdentity) -> Result<(), SignatureError> {
|
||||
self.user_signing_key
|
||||
.verify_master_key(&identity.master_key)
|
||||
self.user_signing_key.verify_master_key(&identity.master_key)
|
||||
}
|
||||
|
||||
/// Check if the given device has been signed by this identity.
|
||||
|
@ -719,6 +680,12 @@ impl OwnUserIdentity {
|
|||
pub(crate) mod test {
|
||||
use std::{convert::TryFrom, sync::Arc};
|
||||
|
||||
use matrix_sdk_common::{
|
||||
api::r0::keys::get_keys::Response as KeyQueryResponse, identifiers::user_id, locks::Mutex,
|
||||
};
|
||||
use matrix_sdk_test::async_test;
|
||||
|
||||
use super::{OwnUserIdentity, UserIdentities, UserIdentity};
|
||||
use crate::{
|
||||
identities::{
|
||||
manager::test::{other_key_query, own_key_query},
|
||||
|
@ -729,13 +696,6 @@ pub(crate) mod test {
|
|||
verification::VerificationMachine,
|
||||
};
|
||||
|
||||
use matrix_sdk_common::{
|
||||
api::r0::keys::get_keys::Response as KeyQueryResponse, identifiers::user_id, locks::Mutex,
|
||||
};
|
||||
use matrix_sdk_test::async_test;
|
||||
|
||||
use super::{OwnUserIdentity, UserIdentities, UserIdentity};
|
||||
|
||||
fn device(response: &KeyQueryResponse) -> (ReadOnlyDevice, ReadOnlyDevice) {
|
||||
let mut devices = response.device_keys.values().next().unwrap().values();
|
||||
let first = ReadOnlyDevice::try_from(devices.next().unwrap()).unwrap();
|
||||
|
@ -793,9 +753,8 @@ pub(crate) mod test {
|
|||
assert!(identity.is_device_signed(&first).is_err());
|
||||
assert!(identity.is_device_signed(&second).is_ok());
|
||||
|
||||
let private_identity = Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty(
|
||||
second.user_id().clone(),
|
||||
)));
|
||||
let private_identity =
|
||||
Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty(second.user_id().clone())));
|
||||
let verification_machine = VerificationMachine::new(
|
||||
ReadOnlyAccount::new(second.user_id(), second.device_id()),
|
||||
private_identity.clone(),
|
||||
|
|
|
@ -20,13 +20,9 @@
|
|||
// If we don't trust the device store an object that remembers the request and
|
||||
// let the users introspect that object.
|
||||
|
||||
use dashmap::{mapref::entry::Entry, DashMap, DashSet};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::value::to_raw_value;
|
||||
use std::{collections::BTreeMap, sync::Arc};
|
||||
use thiserror::Error;
|
||||
use tracing::{error, info, trace, warn};
|
||||
|
||||
use dashmap::{mapref::entry::Entry, DashMap, DashSet};
|
||||
use matrix_sdk_common::{
|
||||
api::r0::to_device::DeviceIdOrAllDevices,
|
||||
events::{
|
||||
|
@ -37,6 +33,10 @@ use matrix_sdk_common::{
|
|||
identifiers::{DeviceId, DeviceIdBox, EventEncryptionAlgorithm, RoomId, UserId},
|
||||
uuid::Uuid,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::value::to_raw_value;
|
||||
use thiserror::Error;
|
||||
use tracing::{error, info, trace, warn};
|
||||
|
||||
use crate::{
|
||||
error::{OlmError, OlmResult},
|
||||
|
@ -105,10 +105,8 @@ impl WaitQueue {
|
|||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
) -> Vec<(
|
||||
(UserId, DeviceIdBox, String),
|
||||
ToDeviceEvent<RoomKeyRequestToDeviceEventContent>,
|
||||
)> {
|
||||
) -> Vec<((UserId, DeviceIdBox, String), ToDeviceEvent<RoomKeyRequestToDeviceEventContent>)>
|
||||
{
|
||||
self.requests_ids_waiting
|
||||
.remove(&(user_id.to_owned(), device_id.into()))
|
||||
.map(|(_, request_ids)| {
|
||||
|
@ -204,12 +202,7 @@ fn wrap_key_request_content(
|
|||
Ok(OutgoingRequest {
|
||||
request_id: id,
|
||||
request: Arc::new(
|
||||
ToDeviceRequest {
|
||||
event_type: EventType::RoomKeyRequest,
|
||||
txn_id: id,
|
||||
messages,
|
||||
}
|
||||
.into(),
|
||||
ToDeviceRequest { event_type: EventType::RoomKeyRequest, txn_id: id, messages }.into(),
|
||||
),
|
||||
})
|
||||
}
|
||||
|
@ -241,10 +234,7 @@ impl KeyRequestMachine {
|
|||
.await?
|
||||
.into_iter()
|
||||
.filter(|i| !i.sent_out)
|
||||
.map(|info| {
|
||||
info.to_request(self.device_id())
|
||||
.map_err(CryptoStoreError::from)
|
||||
})
|
||||
.map(|info| info.to_request(self.device_id()).map_err(CryptoStoreError::from))
|
||||
.collect()
|
||||
}
|
||||
|
||||
|
@ -262,11 +252,8 @@ impl KeyRequestMachine {
|
|||
&self,
|
||||
) -> Result<Vec<OutgoingRequest>, CryptoStoreError> {
|
||||
let mut key_requests = self.load_outgoing_requests().await?;
|
||||
let key_forwards: Vec<OutgoingRequest> = self
|
||||
.outgoing_to_device_requests
|
||||
.iter()
|
||||
.map(|i| i.value().clone())
|
||||
.collect();
|
||||
let key_forwards: Vec<OutgoingRequest> =
|
||||
self.outgoing_to_device_requests.iter().map(|i| i.value().clone()).collect();
|
||||
key_requests.extend(key_forwards);
|
||||
|
||||
Ok(key_requests)
|
||||
|
@ -281,8 +268,7 @@ impl KeyRequestMachine {
|
|||
let device_id = event.content.requesting_device_id.clone();
|
||||
let request_id = event.content.request_id.clone();
|
||||
|
||||
self.incoming_key_requests
|
||||
.insert((sender, device_id, request_id), event.clone());
|
||||
self.incoming_key_requests.insert((sender, device_id, request_id), event.clone());
|
||||
}
|
||||
|
||||
/// Handle all the incoming key requests that are queued up and empty our
|
||||
|
@ -401,10 +387,8 @@ impl KeyRequestMachine {
|
|||
return Ok(None);
|
||||
};
|
||||
|
||||
let device = self
|
||||
.store
|
||||
.get_device(&event.sender, &event.content.requesting_device_id)
|
||||
.await?;
|
||||
let device =
|
||||
self.store.get_device(&event.sender, &event.content.requesting_device_id).await?;
|
||||
|
||||
if let Some(device) = device {
|
||||
match self.should_share_key(&device, &session).await {
|
||||
|
@ -461,30 +445,22 @@ impl KeyRequestMachine {
|
|||
device: &Device,
|
||||
message_index: Option<u32>,
|
||||
) -> OlmResult<Session> {
|
||||
let (used_session, content) = device
|
||||
.encrypt_session(session.clone(), message_index)
|
||||
.await?;
|
||||
let (used_session, content) =
|
||||
device.encrypt_session(session.clone(), message_index).await?;
|
||||
|
||||
let id = Uuid::new_v4();
|
||||
let mut messages = BTreeMap::new();
|
||||
|
||||
messages
|
||||
.entry(device.user_id().to_owned())
|
||||
.or_insert_with(BTreeMap::new)
|
||||
.insert(
|
||||
DeviceIdOrAllDevices::DeviceId(device.device_id().into()),
|
||||
to_raw_value(&content)?,
|
||||
);
|
||||
messages.entry(device.user_id().to_owned()).or_insert_with(BTreeMap::new).insert(
|
||||
DeviceIdOrAllDevices::DeviceId(device.device_id().into()),
|
||||
to_raw_value(&content)?,
|
||||
);
|
||||
|
||||
let request = OutgoingRequest {
|
||||
request_id: id,
|
||||
request: Arc::new(
|
||||
ToDeviceRequest {
|
||||
event_type: EventType::RoomEncrypted,
|
||||
txn_id: id,
|
||||
messages,
|
||||
}
|
||||
.into(),
|
||||
ToDeviceRequest { event_type: EventType::RoomEncrypted, txn_id: id, messages }
|
||||
.into(),
|
||||
),
|
||||
};
|
||||
|
||||
|
@ -542,8 +518,8 @@ impl KeyRequestMachine {
|
|||
} else {
|
||||
Err(KeyshareDecision::OutboundSessionNotShared)
|
||||
}
|
||||
// Else just check if it's one of our own devices that requested the key and
|
||||
// check if the device is trusted.
|
||||
// Else just check if it's one of our own devices that requested the key
|
||||
// and check if the device is trusted.
|
||||
} else if device.user_id() == self.user_id() {
|
||||
own_device_check()
|
||||
// Otherwise, there's not enough info to decide if we can safely share
|
||||
|
@ -711,9 +687,7 @@ impl KeyRequestMachine {
|
|||
|
||||
/// Delete the given outgoing key info.
|
||||
async fn delete_key_info(&self, info: &OutgoingKeyRequest) -> Result<(), CryptoStoreError> {
|
||||
self.store
|
||||
.delete_outgoing_key_request(info.request_id)
|
||||
.await
|
||||
self.store.delete_outgoing_key_request(info.request_id).await
|
||||
}
|
||||
|
||||
/// Mark the outgoing request as sent.
|
||||
|
@ -736,20 +710,15 @@ impl KeyRequestMachine {
|
|||
/// This will queue up a request cancelation.
|
||||
async fn mark_as_done(&self, key_info: OutgoingKeyRequest) -> Result<(), CryptoStoreError> {
|
||||
// TODO perhaps only remove the key info if the first known index is 0.
|
||||
trace!(
|
||||
"Successfully received a forwarded room key for {:#?}",
|
||||
key_info
|
||||
);
|
||||
trace!("Successfully received a forwarded room key for {:#?}", key_info);
|
||||
|
||||
self.outgoing_to_device_requests
|
||||
.remove(&key_info.request_id);
|
||||
self.outgoing_to_device_requests.remove(&key_info.request_id);
|
||||
// TODO return the key info instead of deleting it so the sync handler
|
||||
// can delete it in one transaction.
|
||||
self.delete_key_info(&key_info).await?;
|
||||
|
||||
let request = key_info.to_cancelation(self.device_id())?;
|
||||
self.outgoing_to_device_requests
|
||||
.insert(request.request_id, request);
|
||||
self.outgoing_to_device_requests.insert(request.request_id, request);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -801,10 +770,7 @@ impl KeyRequestMachine {
|
|||
);
|
||||
}
|
||||
|
||||
Ok((
|
||||
Some(AnyToDeviceEvent::ForwardedRoomKey(event.clone())),
|
||||
session,
|
||||
))
|
||||
Ok((Some(AnyToDeviceEvent::ForwardedRoomKey(event.clone())), session))
|
||||
} else {
|
||||
info!(
|
||||
"Received a forwarded room key from {}, but no key info was found.",
|
||||
|
@ -817,6 +783,8 @@ impl KeyRequestMachine {
|
|||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use std::{convert::TryInto, sync::Arc};
|
||||
|
||||
use dashmap::DashMap;
|
||||
use matrix_sdk_common::{
|
||||
api::r0::to_device::DeviceIdOrAllDevices,
|
||||
|
@ -829,8 +797,8 @@ mod test {
|
|||
locks::Mutex,
|
||||
};
|
||||
use matrix_sdk_test::async_test;
|
||||
use std::{convert::TryInto, sync::Arc};
|
||||
|
||||
use super::{KeyRequestMachine, KeyshareDecision};
|
||||
use crate::{
|
||||
identities::{LocalTrust, ReadOnlyDevice},
|
||||
olm::{Account, PrivateCrossSigningIdentity, ReadOnlyAccount},
|
||||
|
@ -839,8 +807,6 @@ mod test {
|
|||
verification::VerificationMachine,
|
||||
};
|
||||
|
||||
use super::{KeyRequestMachine, KeyshareDecision};
|
||||
|
||||
fn alice_id() -> UserId {
|
||||
user_id!("@alice:example.org")
|
||||
}
|
||||
|
@ -919,11 +885,7 @@ mod test {
|
|||
async fn create_machine() {
|
||||
let machine = get_machine().await;
|
||||
|
||||
assert!(machine
|
||||
.outgoing_to_device_requests()
|
||||
.await
|
||||
.unwrap()
|
||||
.is_empty());
|
||||
assert!(machine.outgoing_to_device_requests().await.unwrap().is_empty());
|
||||
}
|
||||
|
||||
#[async_test]
|
||||
|
@ -931,16 +893,10 @@ mod test {
|
|||
let machine = get_machine().await;
|
||||
let account = account();
|
||||
|
||||
let (_, session) = account
|
||||
.create_group_session_pair_with_defaults(&room_id())
|
||||
.await
|
||||
.unwrap();
|
||||
let (_, session) =
|
||||
account.create_group_session_pair_with_defaults(&room_id()).await.unwrap();
|
||||
|
||||
assert!(machine
|
||||
.outgoing_to_device_requests()
|
||||
.await
|
||||
.unwrap()
|
||||
.is_empty());
|
||||
assert!(machine.outgoing_to_device_requests().await.unwrap().is_empty());
|
||||
let (cancel, request) = machine
|
||||
.request_key(session.room_id(), &session.sender_key, session.session_id())
|
||||
.await
|
||||
|
@ -948,10 +904,7 @@ mod test {
|
|||
|
||||
assert!(cancel.is_none());
|
||||
|
||||
machine
|
||||
.mark_outgoing_request_as_sent(request.request_id)
|
||||
.await
|
||||
.unwrap();
|
||||
machine.mark_outgoing_request_as_sent(request.request_id).await.unwrap();
|
||||
|
||||
let (cancel, _) = machine
|
||||
.request_key(session.room_id(), &session.sender_key, session.session_id())
|
||||
|
@ -972,16 +925,10 @@ mod test {
|
|||
alice_device.set_trust_state(LocalTrust::Verified);
|
||||
machine.store.save_devices(&[alice_device]).await.unwrap();
|
||||
|
||||
let (_, session) = account
|
||||
.create_group_session_pair_with_defaults(&room_id())
|
||||
.await
|
||||
.unwrap();
|
||||
let (_, session) =
|
||||
account.create_group_session_pair_with_defaults(&room_id()).await.unwrap();
|
||||
|
||||
assert!(machine
|
||||
.outgoing_to_device_requests()
|
||||
.await
|
||||
.unwrap()
|
||||
.is_empty());
|
||||
assert!(machine.outgoing_to_device_requests().await.unwrap().is_empty());
|
||||
machine
|
||||
.create_outgoing_key_request(
|
||||
session.room_id(),
|
||||
|
@ -990,15 +937,8 @@ mod test {
|
|||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!machine
|
||||
.outgoing_to_device_requests()
|
||||
.await
|
||||
.unwrap()
|
||||
.is_empty());
|
||||
assert_eq!(
|
||||
machine.outgoing_to_device_requests().await.unwrap().len(),
|
||||
1
|
||||
);
|
||||
assert!(!machine.outgoing_to_device_requests().await.unwrap().is_empty());
|
||||
assert_eq!(machine.outgoing_to_device_requests().await.unwrap().len(), 1);
|
||||
|
||||
machine
|
||||
.create_outgoing_key_request(
|
||||
|
@ -1014,15 +954,8 @@ mod test {
|
|||
|
||||
let request = requests.get(0).unwrap();
|
||||
|
||||
machine
|
||||
.mark_outgoing_request_as_sent(request.request_id)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(machine
|
||||
.outgoing_to_device_requests()
|
||||
.await
|
||||
.unwrap()
|
||||
.is_empty());
|
||||
machine.mark_outgoing_request_as_sent(request.request_id).await.unwrap();
|
||||
assert!(machine.outgoing_to_device_requests().await.unwrap().is_empty());
|
||||
}
|
||||
|
||||
#[async_test]
|
||||
|
@ -1037,10 +970,8 @@ mod test {
|
|||
alice_device.set_trust_state(LocalTrust::Verified);
|
||||
machine.store.save_devices(&[alice_device]).await.unwrap();
|
||||
|
||||
let (_, session) = account
|
||||
.create_group_session_pair_with_defaults(&room_id())
|
||||
.await
|
||||
.unwrap();
|
||||
let (_, session) =
|
||||
account.create_group_session_pair_with_defaults(&room_id()).await.unwrap();
|
||||
machine
|
||||
.create_outgoing_key_request(
|
||||
session.room_id(),
|
||||
|
@ -1060,10 +991,7 @@ mod test {
|
|||
|
||||
let content: ForwardedRoomKeyToDeviceEventContent = export.try_into().unwrap();
|
||||
|
||||
let mut event = ToDeviceEvent {
|
||||
sender: alice_id(),
|
||||
content,
|
||||
};
|
||||
let mut event = ToDeviceEvent { sender: alice_id(), content };
|
||||
|
||||
assert!(
|
||||
machine
|
||||
|
@ -1078,19 +1006,13 @@ mod test {
|
|||
.is_none()
|
||||
);
|
||||
|
||||
let (_, first_session) = machine
|
||||
.receive_forwarded_room_key(&session.sender_key, &mut event)
|
||||
.await
|
||||
.unwrap();
|
||||
let (_, first_session) =
|
||||
machine.receive_forwarded_room_key(&session.sender_key, &mut event).await.unwrap();
|
||||
let first_session = first_session.unwrap();
|
||||
|
||||
assert_eq!(first_session.first_known_index(), 10);
|
||||
|
||||
machine
|
||||
.store
|
||||
.save_inbound_group_sessions(&[first_session.clone()])
|
||||
.await
|
||||
.unwrap();
|
||||
machine.store.save_inbound_group_sessions(&[first_session.clone()]).await.unwrap();
|
||||
|
||||
// Get the cancel request.
|
||||
let request = machine.outgoing_to_device_requests.iter().next().unwrap();
|
||||
|
@ -1110,24 +1032,16 @@ mod test {
|
|||
let requests = machine.outgoing_to_device_requests().await.unwrap();
|
||||
let request = &requests[0];
|
||||
|
||||
machine
|
||||
.mark_outgoing_request_as_sent(request.request_id)
|
||||
.await
|
||||
.unwrap();
|
||||
machine.mark_outgoing_request_as_sent(request.request_id).await.unwrap();
|
||||
|
||||
let export = session.export_at_index(15).await;
|
||||
|
||||
let content: ForwardedRoomKeyToDeviceEventContent = export.try_into().unwrap();
|
||||
|
||||
let mut event = ToDeviceEvent {
|
||||
sender: alice_id(),
|
||||
content,
|
||||
};
|
||||
let mut event = ToDeviceEvent { sender: alice_id(), content };
|
||||
|
||||
let (_, second_session) = machine
|
||||
.receive_forwarded_room_key(&session.sender_key, &mut event)
|
||||
.await
|
||||
.unwrap();
|
||||
let (_, second_session) =
|
||||
machine.receive_forwarded_room_key(&session.sender_key, &mut event).await.unwrap();
|
||||
|
||||
assert!(second_session.is_none());
|
||||
|
||||
|
@ -1135,15 +1049,10 @@ mod test {
|
|||
|
||||
let content: ForwardedRoomKeyToDeviceEventContent = export.try_into().unwrap();
|
||||
|
||||
let mut event = ToDeviceEvent {
|
||||
sender: alice_id(),
|
||||
content,
|
||||
};
|
||||
let mut event = ToDeviceEvent { sender: alice_id(), content };
|
||||
|
||||
let (_, second_session) = machine
|
||||
.receive_forwarded_room_key(&session.sender_key, &mut event)
|
||||
.await
|
||||
.unwrap();
|
||||
let (_, second_session) =
|
||||
machine.receive_forwarded_room_key(&session.sender_key, &mut event).await.unwrap();
|
||||
|
||||
assert_eq!(second_session.unwrap().first_known_index(), 0);
|
||||
}
|
||||
|
@ -1153,17 +1062,11 @@ mod test {
|
|||
let machine = get_machine().await;
|
||||
let account = account();
|
||||
|
||||
let own_device = machine
|
||||
.store
|
||||
.get_device(&alice_id(), &alice_device_id())
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
let own_device =
|
||||
machine.store.get_device(&alice_id(), &alice_device_id()).await.unwrap().unwrap();
|
||||
|
||||
let (outbound, inbound) = account
|
||||
.create_group_session_pair_with_defaults(&room_id())
|
||||
.await
|
||||
.unwrap();
|
||||
let (outbound, inbound) =
|
||||
account.create_group_session_pair_with_defaults(&room_id()).await.unwrap();
|
||||
|
||||
// We don't share keys with untrusted devices.
|
||||
assert_eq!(
|
||||
|
@ -1175,20 +1078,13 @@ mod test {
|
|||
);
|
||||
own_device.set_trust_state(LocalTrust::Verified);
|
||||
// Now we do want to share the keys.
|
||||
assert!(machine
|
||||
.should_share_key(&own_device, &inbound)
|
||||
.await
|
||||
.is_ok());
|
||||
assert!(machine.should_share_key(&own_device, &inbound).await.is_ok());
|
||||
|
||||
let bob_device = ReadOnlyDevice::from_account(&bob_account()).await;
|
||||
machine.store.save_devices(&[bob_device]).await.unwrap();
|
||||
|
||||
let bob_device = machine
|
||||
.store
|
||||
.get_device(&bob_id(), &bob_device_id())
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
let bob_device =
|
||||
machine.store.get_device(&bob_id(), &bob_device_id()).await.unwrap().unwrap();
|
||||
|
||||
// We don't share sessions with other user's devices if no outbound
|
||||
// session was provided.
|
||||
|
@ -1231,17 +1127,12 @@ mod test {
|
|||
|
||||
// We now share the session, since it was shared before.
|
||||
outbound.mark_shared_with(bob_device.user_id(), bob_device.device_id());
|
||||
assert!(machine
|
||||
.should_share_key(&bob_device, &inbound)
|
||||
.await
|
||||
.is_ok());
|
||||
assert!(machine.should_share_key(&bob_device, &inbound).await.is_ok());
|
||||
|
||||
// But we don't share some other session that doesn't match our outbound
|
||||
// session
|
||||
let (_, other_inbound) = account
|
||||
.create_group_session_pair_with_defaults(&room_id())
|
||||
.await
|
||||
.unwrap();
|
||||
let (_, other_inbound) =
|
||||
account.create_group_session_pair_with_defaults(&room_id()).await.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
machine
|
||||
|
@ -1255,10 +1146,7 @@ mod test {
|
|||
#[async_test]
|
||||
async fn key_share_cycle() {
|
||||
let alice_machine = get_machine().await;
|
||||
let alice_account = Account {
|
||||
inner: account(),
|
||||
store: alice_machine.store.clone(),
|
||||
};
|
||||
let alice_account = Account { inner: account(), store: alice_machine.store.clone() };
|
||||
|
||||
let bob_machine = bob_machine();
|
||||
let bob_account = bob_account();
|
||||
|
@ -1268,11 +1156,7 @@ mod test {
|
|||
|
||||
// We need a trusted device, otherwise we won't request keys
|
||||
alice_device.set_trust_state(LocalTrust::Verified);
|
||||
alice_machine
|
||||
.store
|
||||
.save_devices(&[alice_device])
|
||||
.await
|
||||
.unwrap();
|
||||
alice_machine.store.save_devices(&[alice_device]).await.unwrap();
|
||||
|
||||
// Create Olm sessions for our two accounts.
|
||||
let (alice_session, bob_session) = alice_account.create_session_for(&bob_account).await;
|
||||
|
@ -1282,37 +1166,15 @@ mod test {
|
|||
|
||||
// Populate our stores with Olm sessions and a Megolm session.
|
||||
|
||||
alice_machine
|
||||
.store
|
||||
.save_sessions(&[alice_session])
|
||||
.await
|
||||
.unwrap();
|
||||
alice_machine
|
||||
.store
|
||||
.save_devices(&[bob_device])
|
||||
.await
|
||||
.unwrap();
|
||||
bob_machine
|
||||
.store
|
||||
.save_sessions(&[bob_session])
|
||||
.await
|
||||
.unwrap();
|
||||
bob_machine
|
||||
.store
|
||||
.save_devices(&[alice_device])
|
||||
.await
|
||||
.unwrap();
|
||||
alice_machine.store.save_sessions(&[alice_session]).await.unwrap();
|
||||
alice_machine.store.save_devices(&[bob_device]).await.unwrap();
|
||||
bob_machine.store.save_sessions(&[bob_session]).await.unwrap();
|
||||
bob_machine.store.save_devices(&[alice_device]).await.unwrap();
|
||||
|
||||
let (group_session, inbound_group_session) = bob_account
|
||||
.create_group_session_pair_with_defaults(&room_id())
|
||||
.await
|
||||
.unwrap();
|
||||
let (group_session, inbound_group_session) =
|
||||
bob_account.create_group_session_pair_with_defaults(&room_id()).await.unwrap();
|
||||
|
||||
bob_machine
|
||||
.store
|
||||
.save_inbound_group_sessions(&[inbound_group_session])
|
||||
.await
|
||||
.unwrap();
|
||||
bob_machine.store.save_inbound_group_sessions(&[inbound_group_session]).await.unwrap();
|
||||
|
||||
// Alice wants to request the outbound group session from bob.
|
||||
alice_machine
|
||||
|
@ -1326,9 +1188,7 @@ mod test {
|
|||
group_session.mark_shared_with(&alice_id(), &alice_device_id());
|
||||
|
||||
// Put the outbound session into bobs store.
|
||||
bob_machine
|
||||
.outbound_group_sessions
|
||||
.insert(group_session.clone());
|
||||
bob_machine.outbound_group_sessions.insert(group_session.clone());
|
||||
|
||||
// Get the request and convert it into a event.
|
||||
let requests = alice_machine.outgoing_to_device_requests().await.unwrap();
|
||||
|
@ -1346,15 +1206,9 @@ mod test {
|
|||
let content: RoomKeyRequestToDeviceEventContent =
|
||||
serde_json::from_str(content.get()).unwrap();
|
||||
|
||||
alice_machine
|
||||
.mark_outgoing_request_as_sent(id)
|
||||
.await
|
||||
.unwrap();
|
||||
alice_machine.mark_outgoing_request_as_sent(id).await.unwrap();
|
||||
|
||||
let event = ToDeviceEvent {
|
||||
sender: alice_id(),
|
||||
content,
|
||||
};
|
||||
let event = ToDeviceEvent { sender: alice_id(), content };
|
||||
|
||||
// Bob doesn't have any outgoing requests.
|
||||
assert!(bob_machine.outgoing_to_device_requests.is_empty());
|
||||
|
@ -1383,10 +1237,7 @@ mod test {
|
|||
|
||||
bob_machine.mark_outgoing_request_as_sent(id).await.unwrap();
|
||||
|
||||
let event = ToDeviceEvent {
|
||||
sender: bob_id(),
|
||||
content,
|
||||
};
|
||||
let event = ToDeviceEvent { sender: bob_id(), content };
|
||||
|
||||
// Check that alice doesn't have the session.
|
||||
assert!(alice_machine
|
||||
|
@ -1407,11 +1258,7 @@ mod test {
|
|||
.receive_forwarded_room_key(&decrypted.sender_key, &mut e)
|
||||
.await
|
||||
.unwrap();
|
||||
alice_machine
|
||||
.store
|
||||
.save_inbound_group_sessions(&[session.unwrap()])
|
||||
.await
|
||||
.unwrap();
|
||||
alice_machine.store.save_inbound_group_sessions(&[session.unwrap()]).await.unwrap();
|
||||
} else {
|
||||
panic!("Invalid decrypted event type");
|
||||
}
|
||||
|
@ -1434,10 +1281,7 @@ mod test {
|
|||
#[async_test]
|
||||
async fn key_share_cycle_without_session() {
|
||||
let alice_machine = get_machine().await;
|
||||
let alice_account = Account {
|
||||
inner: account(),
|
||||
store: alice_machine.store.clone(),
|
||||
};
|
||||
let alice_account = Account { inner: account(), store: alice_machine.store.clone() };
|
||||
|
||||
let bob_machine = bob_machine();
|
||||
let bob_account = bob_account();
|
||||
|
@ -1447,11 +1291,7 @@ mod test {
|
|||
|
||||
// We need a trusted device, otherwise we won't request keys
|
||||
alice_device.set_trust_state(LocalTrust::Verified);
|
||||
alice_machine
|
||||
.store
|
||||
.save_devices(&[alice_device])
|
||||
.await
|
||||
.unwrap();
|
||||
alice_machine.store.save_devices(&[alice_device]).await.unwrap();
|
||||
|
||||
// Create Olm sessions for our two accounts.
|
||||
let (alice_session, bob_session) = alice_account.create_session_for(&bob_account).await;
|
||||
|
@ -1461,27 +1301,13 @@ mod test {
|
|||
|
||||
// Populate our stores with Olm sessions and a Megolm session.
|
||||
|
||||
alice_machine
|
||||
.store
|
||||
.save_devices(&[bob_device])
|
||||
.await
|
||||
.unwrap();
|
||||
bob_machine
|
||||
.store
|
||||
.save_devices(&[alice_device])
|
||||
.await
|
||||
.unwrap();
|
||||
alice_machine.store.save_devices(&[bob_device]).await.unwrap();
|
||||
bob_machine.store.save_devices(&[alice_device]).await.unwrap();
|
||||
|
||||
let (group_session, inbound_group_session) = bob_account
|
||||
.create_group_session_pair_with_defaults(&room_id())
|
||||
.await
|
||||
.unwrap();
|
||||
let (group_session, inbound_group_session) =
|
||||
bob_account.create_group_session_pair_with_defaults(&room_id()).await.unwrap();
|
||||
|
||||
bob_machine
|
||||
.store
|
||||
.save_inbound_group_sessions(&[inbound_group_session])
|
||||
.await
|
||||
.unwrap();
|
||||
bob_machine.store.save_inbound_group_sessions(&[inbound_group_session]).await.unwrap();
|
||||
|
||||
// Alice wants to request the outbound group session from bob.
|
||||
alice_machine
|
||||
|
@ -1495,9 +1321,7 @@ mod test {
|
|||
group_session.mark_shared_with(&alice_id(), &alice_device_id());
|
||||
|
||||
// Put the outbound session into bobs store.
|
||||
bob_machine
|
||||
.outbound_group_sessions
|
||||
.insert(group_session.clone());
|
||||
bob_machine.outbound_group_sessions.insert(group_session.clone());
|
||||
|
||||
// Get the request and convert it into a event.
|
||||
let requests = alice_machine.outgoing_to_device_requests().await.unwrap();
|
||||
|
@ -1515,22 +1339,12 @@ mod test {
|
|||
let content: RoomKeyRequestToDeviceEventContent =
|
||||
serde_json::from_str(content.get()).unwrap();
|
||||
|
||||
alice_machine
|
||||
.mark_outgoing_request_as_sent(id)
|
||||
.await
|
||||
.unwrap();
|
||||
alice_machine.mark_outgoing_request_as_sent(id).await.unwrap();
|
||||
|
||||
let event = ToDeviceEvent {
|
||||
sender: alice_id(),
|
||||
content,
|
||||
};
|
||||
let event = ToDeviceEvent { sender: alice_id(), content };
|
||||
|
||||
// Bob doesn't have any outgoing requests.
|
||||
assert!(bob_machine
|
||||
.outgoing_to_device_requests()
|
||||
.await
|
||||
.unwrap()
|
||||
.is_empty());
|
||||
assert!(bob_machine.outgoing_to_device_requests().await.unwrap().is_empty());
|
||||
assert!(bob_machine.users_for_key_claim.is_empty());
|
||||
assert!(bob_machine.wait_queue.is_empty());
|
||||
|
||||
|
@ -1538,35 +1352,19 @@ mod test {
|
|||
bob_machine.receive_incoming_key_request(&event);
|
||||
bob_machine.collect_incoming_key_requests().await.unwrap();
|
||||
// Bob doens't have an outgoing requests since we're lacking a session.
|
||||
assert!(bob_machine
|
||||
.outgoing_to_device_requests()
|
||||
.await
|
||||
.unwrap()
|
||||
.is_empty());
|
||||
assert!(bob_machine.outgoing_to_device_requests().await.unwrap().is_empty());
|
||||
assert!(!bob_machine.users_for_key_claim.is_empty());
|
||||
assert!(!bob_machine.wait_queue.is_empty());
|
||||
|
||||
// We create a session now.
|
||||
alice_machine
|
||||
.store
|
||||
.save_sessions(&[alice_session])
|
||||
.await
|
||||
.unwrap();
|
||||
bob_machine
|
||||
.store
|
||||
.save_sessions(&[bob_session])
|
||||
.await
|
||||
.unwrap();
|
||||
alice_machine.store.save_sessions(&[alice_session]).await.unwrap();
|
||||
bob_machine.store.save_sessions(&[bob_session]).await.unwrap();
|
||||
|
||||
bob_machine.retry_keyshare(&alice_id(), &alice_device_id());
|
||||
assert!(bob_machine.users_for_key_claim.is_empty());
|
||||
bob_machine.collect_incoming_key_requests().await.unwrap();
|
||||
// Bob now has an outgoing requests.
|
||||
assert!(!bob_machine
|
||||
.outgoing_to_device_requests()
|
||||
.await
|
||||
.unwrap()
|
||||
.is_empty());
|
||||
assert!(!bob_machine.outgoing_to_device_requests().await.unwrap().is_empty());
|
||||
assert!(bob_machine.wait_queue.is_empty());
|
||||
|
||||
// Get the request and convert it to a encrypted to-device event.
|
||||
|
@ -1588,10 +1386,7 @@ mod test {
|
|||
|
||||
bob_machine.mark_outgoing_request_as_sent(id).await.unwrap();
|
||||
|
||||
let event = ToDeviceEvent {
|
||||
sender: bob_id(),
|
||||
content,
|
||||
};
|
||||
let event = ToDeviceEvent { sender: bob_id(), content };
|
||||
|
||||
// Check that alice doesn't have the session.
|
||||
assert!(alice_machine
|
||||
|
@ -1612,11 +1407,7 @@ mod test {
|
|||
.receive_forwarded_room_key(&decrypted.sender_key, &mut e)
|
||||
.await
|
||||
.unwrap();
|
||||
alice_machine
|
||||
.store
|
||||
.save_inbound_group_sessions(&[session.unwrap()])
|
||||
.await
|
||||
.unwrap();
|
||||
alice_machine.store.save_inbound_group_sessions(&[session.unwrap()]).await.unwrap();
|
||||
} else {
|
||||
panic!("Invalid decrypted event type");
|
||||
}
|
||||
|
|
|
@ -17,8 +17,6 @@ use std::path::Path;
|
|||
use std::{collections::BTreeMap, mem, sync::Arc};
|
||||
|
||||
use dashmap::DashMap;
|
||||
use tracing::{debug, error, info, trace, warn};
|
||||
|
||||
use matrix_sdk_common::{
|
||||
api::r0::{
|
||||
keys::{
|
||||
|
@ -43,6 +41,7 @@ use matrix_sdk_common::{
|
|||
uuid::Uuid,
|
||||
UInt,
|
||||
};
|
||||
use tracing::{debug, error, info, trace, warn};
|
||||
|
||||
#[cfg(feature = "sled_cryptostore")]
|
||||
use crate::store::sled::SledStore;
|
||||
|
@ -148,19 +147,12 @@ impl OlmMachine {
|
|||
let store = Arc::new(store);
|
||||
let verification_machine =
|
||||
VerificationMachine::new(account.clone(), user_identity.clone(), store.clone());
|
||||
let store = Store::new(
|
||||
user_id.clone(),
|
||||
user_identity.clone(),
|
||||
store,
|
||||
verification_machine.clone(),
|
||||
);
|
||||
let store =
|
||||
Store::new(user_id.clone(), user_identity.clone(), store, verification_machine.clone());
|
||||
let device_id: Arc<DeviceIdBox> = Arc::new(device_id);
|
||||
let users_for_key_claim = Arc::new(DashMap::new());
|
||||
|
||||
let account = Account {
|
||||
inner: account,
|
||||
store: store.clone(),
|
||||
};
|
||||
let account = Account { inner: account, store: store.clone() };
|
||||
|
||||
let group_session_manager = GroupSessionManager::new(account.clone(), store.clone());
|
||||
|
||||
|
@ -244,9 +236,7 @@ impl OlmMachine {
|
|||
}
|
||||
};
|
||||
|
||||
Ok(OlmMachine::new_helper(
|
||||
&user_id, device_id, store, account, identity,
|
||||
))
|
||||
Ok(OlmMachine::new_helper(&user_id, device_id, store, account, identity))
|
||||
}
|
||||
|
||||
/// Create a new machine with the default crypto store.
|
||||
|
@ -296,19 +286,16 @@ impl OlmMachine {
|
|||
pub async fn outgoing_requests(&self) -> StoreResult<Vec<OutgoingRequest>> {
|
||||
let mut requests = Vec::new();
|
||||
|
||||
if let Some(r) = self.keys_for_upload().await.map(|r| OutgoingRequest {
|
||||
request_id: Uuid::new_v4(),
|
||||
request: Arc::new(r.into()),
|
||||
}) {
|
||||
if let Some(r) = self
|
||||
.keys_for_upload()
|
||||
.await
|
||||
.map(|r| OutgoingRequest { request_id: Uuid::new_v4(), request: Arc::new(r.into()) })
|
||||
{
|
||||
requests.push(r);
|
||||
}
|
||||
|
||||
for request in self
|
||||
.identity_manager
|
||||
.users_for_key_query()
|
||||
.await
|
||||
.into_iter()
|
||||
.map(|r| OutgoingRequest {
|
||||
for request in
|
||||
self.identity_manager.users_for_key_query().await.into_iter().map(|r| OutgoingRequest {
|
||||
request_id: Uuid::new_v4(),
|
||||
request: Arc::new(r.into()),
|
||||
})
|
||||
|
@ -317,12 +304,7 @@ impl OlmMachine {
|
|||
}
|
||||
|
||||
requests.append(&mut self.verification_machine.outgoing_messages());
|
||||
requests.append(
|
||||
&mut self
|
||||
.key_request_machine
|
||||
.outgoing_to_device_requests()
|
||||
.await?,
|
||||
);
|
||||
requests.append(&mut self.key_request_machine.outgoing_to_device_requests().await?);
|
||||
|
||||
Ok(requests)
|
||||
}
|
||||
|
@ -373,10 +355,7 @@ impl OlmMachine {
|
|||
let identity = self.user_identity.lock().await;
|
||||
identity.mark_as_shared();
|
||||
|
||||
let changes = Changes {
|
||||
private_identity: Some(identity.clone()),
|
||||
..Default::default()
|
||||
};
|
||||
let changes = Changes { private_identity: Some(identity.clone()), ..Default::default() };
|
||||
|
||||
self.store.save_changes(changes).await
|
||||
}
|
||||
|
@ -406,10 +385,7 @@ impl OlmMachine {
|
|||
);
|
||||
|
||||
let changes = Changes {
|
||||
identities: IdentityChanges {
|
||||
new: vec![public.into()],
|
||||
..Default::default()
|
||||
},
|
||||
identities: IdentityChanges { new: vec![public.into()], ..Default::default() },
|
||||
private_identity: Some(identity.clone()),
|
||||
..Default::default()
|
||||
};
|
||||
|
@ -421,10 +397,8 @@ impl OlmMachine {
|
|||
info!("Trying to upload the existing cross signing identity");
|
||||
let request = identity.as_upload_request().await;
|
||||
// TODO remove this expect.
|
||||
let signature_request = identity
|
||||
.sign_account(&self.account)
|
||||
.await
|
||||
.expect("Can't sign device keys");
|
||||
let signature_request =
|
||||
identity.sign_account(&self.account).await.expect("Can't sign device keys");
|
||||
Ok((request, signature_request))
|
||||
}
|
||||
}
|
||||
|
@ -518,9 +492,7 @@ impl OlmMachine {
|
|||
///
|
||||
/// * `response` - The response containing the claimed one-time keys.
|
||||
async fn receive_keys_claim_response(&self, response: &KeysClaimResponse) -> OlmResult<()> {
|
||||
self.session_manager
|
||||
.receive_keys_claim_response(response)
|
||||
.await
|
||||
self.session_manager.receive_keys_claim_response(response).await
|
||||
}
|
||||
|
||||
/// Receive a successful keys query response.
|
||||
|
@ -536,9 +508,7 @@ impl OlmMachine {
|
|||
&self,
|
||||
response: &KeysQueryResponse,
|
||||
) -> OlmResult<(DeviceChanges, IdentityChanges)> {
|
||||
self.identity_manager
|
||||
.receive_keys_query_response(response)
|
||||
.await
|
||||
self.identity_manager.receive_keys_query_response(response).await
|
||||
}
|
||||
|
||||
/// Get a request to upload E2EE keys to the server.
|
||||
|
@ -676,9 +646,7 @@ impl OlmMachine {
|
|||
/// Returns true if a session was invalidated, false if there was no session
|
||||
/// to invalidate.
|
||||
pub async fn invalidate_group_session(&self, room_id: &RoomId) -> StoreResult<bool> {
|
||||
self.group_session_manager
|
||||
.invalidate_group_session(room_id)
|
||||
.await
|
||||
self.group_session_manager.invalidate_group_session(room_id).await
|
||||
}
|
||||
|
||||
/// Get to-device requests to share a group session with users in a room.
|
||||
|
@ -695,9 +663,7 @@ impl OlmMachine {
|
|||
users: impl Iterator<Item = &UserId>,
|
||||
encryption_settings: impl Into<EncryptionSettings>,
|
||||
) -> OlmResult<Vec<Arc<ToDeviceRequest>>> {
|
||||
self.group_session_manager
|
||||
.share_group_session(room_id, users, encryption_settings)
|
||||
.await
|
||||
self.group_session_manager.share_group_session(room_id, users, encryption_settings).await
|
||||
}
|
||||
|
||||
/// Receive and properly handle a decrypted to-device event.
|
||||
|
@ -716,18 +682,15 @@ impl OlmMachine {
|
|||
let event = match decrypted.event.deserialize() {
|
||||
Ok(e) => e,
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"Decrypted to-device event failed to be parsed correctly {:?}",
|
||||
e
|
||||
);
|
||||
warn!("Decrypted to-device event failed to be parsed correctly {:?}", e);
|
||||
return Ok((None, None));
|
||||
}
|
||||
};
|
||||
|
||||
match event {
|
||||
AnyToDeviceEvent::RoomKey(mut e) => Ok(self
|
||||
.add_room_key(&decrypted.sender_key, &decrypted.signing_key, &mut e)
|
||||
.await?),
|
||||
AnyToDeviceEvent::RoomKey(mut e) => {
|
||||
Ok(self.add_room_key(&decrypted.sender_key, &decrypted.signing_key, &mut e).await?)
|
||||
}
|
||||
AnyToDeviceEvent::ForwardedRoomKey(mut e) => Ok(self
|
||||
.key_request_machine
|
||||
.receive_forwarded_room_key(&decrypted.sender_key, &mut e)
|
||||
|
@ -748,14 +711,9 @@ impl OlmMachine {
|
|||
/// Mark an outgoing to-device requests as sent.
|
||||
async fn mark_to_device_request_as_sent(&self, request_id: &Uuid) -> StoreResult<()> {
|
||||
self.verification_machine.mark_request_as_sent(request_id);
|
||||
self.key_request_machine
|
||||
.mark_outgoing_request_as_sent(*request_id)
|
||||
.await?;
|
||||
self.group_session_manager
|
||||
.mark_request_as_sent(request_id)
|
||||
.await?;
|
||||
self.session_manager
|
||||
.mark_outgoing_request_as_sent(request_id);
|
||||
self.key_request_machine.mark_outgoing_request_as_sent(*request_id).await?;
|
||||
self.group_session_manager.mark_request_as_sent(request_id).await?;
|
||||
self.session_manager.mark_outgoing_request_as_sent(request_id);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -830,10 +788,8 @@ impl OlmMachine {
|
|||
|
||||
// Always save the account, a new session might get created which also
|
||||
// touches the account.
|
||||
let mut changes = Changes {
|
||||
account: Some(self.account.inner.clone()),
|
||||
..Default::default()
|
||||
};
|
||||
let mut changes =
|
||||
Changes { account: Some(self.account.inner.clone()), ..Default::default() };
|
||||
|
||||
self.update_one_time_key_count(one_time_keys_counts).await;
|
||||
|
||||
|
@ -850,10 +806,7 @@ impl OlmMachine {
|
|||
Ok(e) => e,
|
||||
Err(e) => {
|
||||
// Skip invalid events.
|
||||
warn!(
|
||||
"Received an invalid to-device event {:?} {:?}",
|
||||
e, raw_event
|
||||
);
|
||||
warn!("Received an invalid to-device event {:?} {:?}", e, raw_event);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
@ -865,10 +818,7 @@ impl OlmMachine {
|
|||
let decrypted = match self.decrypt_to_device_event(&e).await {
|
||||
Ok(e) => e,
|
||||
Err(err) => {
|
||||
warn!(
|
||||
"Failed to decrypt to-device event from {} {}",
|
||||
e.sender, err
|
||||
);
|
||||
warn!("Failed to decrypt to-device event from {} {}", e.sender, err);
|
||||
|
||||
if let OlmError::SessionWedged(sender, curve_key) = err {
|
||||
if let Err(e) = self
|
||||
|
@ -916,10 +866,7 @@ impl OlmMachine {
|
|||
events.push(raw_event);
|
||||
}
|
||||
|
||||
let changed_sessions = self
|
||||
.key_request_machine
|
||||
.collect_incoming_key_requests()
|
||||
.await?;
|
||||
let changed_sessions = self.key_request_machine.collect_incoming_key_requests().await?;
|
||||
|
||||
changes.sessions.extend(changed_sessions);
|
||||
|
||||
|
@ -1036,25 +983,16 @@ impl OlmMachine {
|
|||
// TODO check if this is from a verified device.
|
||||
let (decrypted_event, _) = session.decrypt(event).await?;
|
||||
|
||||
trace!(
|
||||
"Successfully decrypted a Megolm event {:?}",
|
||||
decrypted_event
|
||||
);
|
||||
trace!("Successfully decrypted a Megolm event {:?}", decrypted_event);
|
||||
|
||||
if let Ok(e) = decrypted_event.deserialize() {
|
||||
self.verification_machine
|
||||
.receive_room_event(room_id, &e)
|
||||
.await?;
|
||||
self.verification_machine.receive_room_event(room_id, &e).await?;
|
||||
}
|
||||
|
||||
let encryption_info = self
|
||||
.get_encryption_info(&session, &event.sender, &content.device_id)
|
||||
.await?;
|
||||
let encryption_info =
|
||||
self.get_encryption_info(&session, &event.sender, &content.device_id).await?;
|
||||
|
||||
Ok(SyncRoomEvent {
|
||||
encryption_info: Some(encryption_info),
|
||||
event: decrypted_event,
|
||||
})
|
||||
Ok(SyncRoomEvent { encryption_info: Some(encryption_info), event: decrypted_event })
|
||||
}
|
||||
|
||||
/// Update the tracked users.
|
||||
|
@ -1210,17 +1148,11 @@ impl OlmMachine {
|
|||
|
||||
let num_sessions = sessions.len();
|
||||
|
||||
let changes = Changes {
|
||||
inbound_group_sessions: sessions,
|
||||
..Default::default()
|
||||
};
|
||||
let changes = Changes { inbound_group_sessions: sessions, ..Default::default() };
|
||||
|
||||
self.store.save_changes(changes).await?;
|
||||
|
||||
info!(
|
||||
"Successfully imported {} inbound group sessions",
|
||||
num_sessions
|
||||
);
|
||||
info!("Successfully imported {} inbound group sessions", num_sessions);
|
||||
|
||||
Ok((num_sessions, total_sessions))
|
||||
}
|
||||
|
@ -1288,15 +1220,6 @@ pub(crate) mod test {
|
|||
};
|
||||
|
||||
use http::Response;
|
||||
use serde_json::json;
|
||||
|
||||
use crate::{
|
||||
machine::OlmMachine,
|
||||
olm::Utility,
|
||||
verification::test::{outgoing_request_to_event, request_to_event},
|
||||
EncryptionSettings, ReadOnlyDevice, ToDeviceRequest,
|
||||
};
|
||||
|
||||
use matrix_sdk_common::{
|
||||
api::r0::keys::{claim_keys, get_keys, upload_keys, OneTimeKey},
|
||||
events::{
|
||||
|
@ -1313,6 +1236,14 @@ pub(crate) mod test {
|
|||
IncomingResponse, Raw,
|
||||
};
|
||||
use matrix_sdk_test::test_json;
|
||||
use serde_json::json;
|
||||
|
||||
use crate::{
|
||||
machine::OlmMachine,
|
||||
olm::Utility,
|
||||
verification::test::{outgoing_request_to_event, request_to_event},
|
||||
EncryptionSettings, ReadOnlyDevice, ToDeviceRequest,
|
||||
};
|
||||
|
||||
/// These keys need to be periodically uploaded to the server.
|
||||
type OneTimeKeys = BTreeMap<DeviceKeyId, OneTimeKey>;
|
||||
|
@ -1332,10 +1263,7 @@ pub(crate) mod test {
|
|||
}
|
||||
|
||||
pub fn response_from_file(json: &serde_json::Value) -> Response<Vec<u8>> {
|
||||
Response::builder()
|
||||
.status(200)
|
||||
.body(json.to_string().as_bytes().to_vec())
|
||||
.unwrap()
|
||||
Response::builder().status(200).body(json.to_string().as_bytes().to_vec()).unwrap()
|
||||
}
|
||||
|
||||
fn keys_upload_response() -> upload_keys::Response {
|
||||
|
@ -1354,15 +1282,7 @@ pub(crate) mod test {
|
|||
let to_device_request = &requests[0];
|
||||
|
||||
let content: Raw<EncryptedEventContent> = serde_json::from_str(
|
||||
to_device_request
|
||||
.messages
|
||||
.values()
|
||||
.next()
|
||||
.unwrap()
|
||||
.values()
|
||||
.next()
|
||||
.unwrap()
|
||||
.get(),
|
||||
to_device_request.messages.values().next().unwrap().values().next().unwrap().get(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
|
@ -1372,15 +1292,9 @@ pub(crate) mod test {
|
|||
pub(crate) async fn get_prepared_machine() -> (OlmMachine, OneTimeKeys) {
|
||||
let machine = OlmMachine::new(&user_id(), &alice_device_id());
|
||||
machine.account.inner.update_uploaded_key_count(0);
|
||||
let request = machine
|
||||
.keys_for_upload()
|
||||
.await
|
||||
.expect("Can't prepare initial key upload");
|
||||
let request = machine.keys_for_upload().await.expect("Can't prepare initial key upload");
|
||||
let response = keys_upload_response();
|
||||
machine
|
||||
.receive_keys_upload_response(&response)
|
||||
.await
|
||||
.unwrap();
|
||||
machine.receive_keys_upload_response(&response).await.unwrap();
|
||||
|
||||
(machine, request.one_time_keys.unwrap())
|
||||
}
|
||||
|
@ -1389,10 +1303,7 @@ pub(crate) mod test {
|
|||
let (machine, otk) = get_prepared_machine().await;
|
||||
let response = keys_query_response();
|
||||
|
||||
machine
|
||||
.receive_keys_query_response(&response)
|
||||
.await
|
||||
.unwrap();
|
||||
machine.receive_keys_query_response(&response).await.unwrap();
|
||||
|
||||
(machine, otk)
|
||||
}
|
||||
|
@ -1435,28 +1346,15 @@ pub(crate) mod test {
|
|||
async fn get_machine_pair_with_setup_sessions() -> (OlmMachine, OlmMachine) {
|
||||
let (alice, bob) = get_machine_pair_with_session().await;
|
||||
|
||||
let bob_device = alice
|
||||
.get_device(&bob.user_id, &bob.device_id)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
let bob_device = alice.get_device(&bob.user_id, &bob.device_id).await.unwrap().unwrap();
|
||||
|
||||
let (session, content) = bob_device
|
||||
.encrypt(EventType::Dummy, json!({}))
|
||||
.await
|
||||
.unwrap();
|
||||
let (session, content) = bob_device.encrypt(EventType::Dummy, json!({})).await.unwrap();
|
||||
alice.store.save_sessions(&[session]).await.unwrap();
|
||||
|
||||
let event = ToDeviceEvent {
|
||||
sender: alice.user_id().clone(),
|
||||
content,
|
||||
};
|
||||
let event = ToDeviceEvent { sender: alice.user_id().clone(), content };
|
||||
|
||||
let decrypted = bob.decrypt_to_device_event(&event).await.unwrap();
|
||||
bob.store
|
||||
.save_sessions(&[decrypted.session.session()])
|
||||
.await
|
||||
.unwrap();
|
||||
bob.store.save_sessions(&[decrypted.session.session()]).await.unwrap();
|
||||
|
||||
(alice, bob)
|
||||
}
|
||||
|
@ -1472,34 +1370,18 @@ pub(crate) mod test {
|
|||
let machine = OlmMachine::new(&user_id(), &alice_device_id());
|
||||
let mut response = keys_upload_response();
|
||||
|
||||
response
|
||||
.one_time_key_counts
|
||||
.remove(&DeviceKeyAlgorithm::SignedCurve25519)
|
||||
.unwrap();
|
||||
response.one_time_key_counts.remove(&DeviceKeyAlgorithm::SignedCurve25519).unwrap();
|
||||
|
||||
assert!(machine.should_upload_keys().await);
|
||||
machine
|
||||
.receive_keys_upload_response(&response)
|
||||
.await
|
||||
.unwrap();
|
||||
machine.receive_keys_upload_response(&response).await.unwrap();
|
||||
assert!(machine.should_upload_keys().await);
|
||||
|
||||
response
|
||||
.one_time_key_counts
|
||||
.insert(DeviceKeyAlgorithm::SignedCurve25519, uint!(10));
|
||||
machine
|
||||
.receive_keys_upload_response(&response)
|
||||
.await
|
||||
.unwrap();
|
||||
response.one_time_key_counts.insert(DeviceKeyAlgorithm::SignedCurve25519, uint!(10));
|
||||
machine.receive_keys_upload_response(&response).await.unwrap();
|
||||
assert!(machine.should_upload_keys().await);
|
||||
|
||||
response
|
||||
.one_time_key_counts
|
||||
.insert(DeviceKeyAlgorithm::SignedCurve25519, uint!(50));
|
||||
machine
|
||||
.receive_keys_upload_response(&response)
|
||||
.await
|
||||
.unwrap();
|
||||
response.one_time_key_counts.insert(DeviceKeyAlgorithm::SignedCurve25519, uint!(50));
|
||||
machine.receive_keys_upload_response(&response).await.unwrap();
|
||||
assert!(!machine.should_upload_keys().await);
|
||||
}
|
||||
|
||||
|
@ -1511,20 +1393,12 @@ pub(crate) mod test {
|
|||
|
||||
assert!(machine.should_upload_keys().await);
|
||||
|
||||
machine
|
||||
.receive_keys_upload_response(&response)
|
||||
.await
|
||||
.unwrap();
|
||||
machine.receive_keys_upload_response(&response).await.unwrap();
|
||||
assert!(machine.should_upload_keys().await);
|
||||
assert!(machine.account.generate_one_time_keys().await.is_ok());
|
||||
|
||||
response
|
||||
.one_time_key_counts
|
||||
.insert(DeviceKeyAlgorithm::SignedCurve25519, uint!(50));
|
||||
machine
|
||||
.receive_keys_upload_response(&response)
|
||||
.await
|
||||
.unwrap();
|
||||
response.one_time_key_counts.insert(DeviceKeyAlgorithm::SignedCurve25519, uint!(50));
|
||||
machine.receive_keys_upload_response(&response).await.unwrap();
|
||||
assert!(machine.account.generate_one_time_keys().await.is_err());
|
||||
}
|
||||
|
||||
|
@ -1551,14 +1425,8 @@ pub(crate) mod test {
|
|||
let machine = OlmMachine::new(&user_id(), &alice_device_id());
|
||||
let room_id = room_id!("!test:example.org");
|
||||
|
||||
machine
|
||||
.create_outbound_group_session_with_defaults(&room_id)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(machine
|
||||
.group_session_manager
|
||||
.get_outbound_group_session(&room_id)
|
||||
.is_some());
|
||||
machine.create_outbound_group_session_with_defaults(&room_id).await.unwrap();
|
||||
assert!(machine.group_session_manager.get_outbound_group_session(&room_id).is_some());
|
||||
|
||||
machine.invalidate_group_session(&room_id).await.unwrap();
|
||||
|
||||
|
@ -1614,10 +1482,8 @@ pub(crate) mod test {
|
|||
let identity_keys = machine.account.identity_keys();
|
||||
let ed25519_key = identity_keys.ed25519();
|
||||
|
||||
let mut request = machine
|
||||
.keys_for_upload()
|
||||
.await
|
||||
.expect("Can't prepare initial key upload");
|
||||
let mut request =
|
||||
machine.keys_for_upload().await.expect("Can't prepare initial key upload");
|
||||
|
||||
let utility = Utility::new();
|
||||
let ret = utility.verify_json(
|
||||
|
@ -1640,15 +1506,10 @@ pub(crate) mod test {
|
|||
let mut response = keys_upload_response();
|
||||
response.one_time_key_counts.insert(
|
||||
DeviceKeyAlgorithm::SignedCurve25519,
|
||||
(request.one_time_keys.unwrap().len() as u64)
|
||||
.try_into()
|
||||
.unwrap(),
|
||||
(request.one_time_keys.unwrap().len() as u64).try_into().unwrap(),
|
||||
);
|
||||
|
||||
machine
|
||||
.receive_keys_upload_response(&response)
|
||||
.await
|
||||
.unwrap();
|
||||
machine.receive_keys_upload_response(&response).await.unwrap();
|
||||
|
||||
let ret = machine.keys_for_upload().await;
|
||||
assert!(ret.is_none());
|
||||
|
@ -1664,17 +1525,9 @@ pub(crate) mod test {
|
|||
let alice_devices = machine.store.get_user_devices(&alice_id).await.unwrap();
|
||||
assert!(alice_devices.devices().peekable().peek().is_none());
|
||||
|
||||
machine
|
||||
.receive_keys_query_response(&response)
|
||||
.await
|
||||
.unwrap();
|
||||
machine.receive_keys_query_response(&response).await.unwrap();
|
||||
|
||||
let device = machine
|
||||
.store
|
||||
.get_device(&alice_id, alice_device_id)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
let device = machine.store.get_device(&alice_id, alice_device_id).await.unwrap().unwrap();
|
||||
assert_eq!(device.user_id(), &alice_id);
|
||||
assert_eq!(device.device_id(), alice_device_id);
|
||||
}
|
||||
|
@ -1686,11 +1539,8 @@ pub(crate) mod test {
|
|||
let alice = alice_id();
|
||||
let alice_device = alice_device_id();
|
||||
|
||||
let (_, missing_sessions) = machine
|
||||
.get_missing_sessions(&mut [alice.clone()].iter())
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
let (_, missing_sessions) =
|
||||
machine.get_missing_sessions(&mut [alice.clone()].iter()).await.unwrap().unwrap();
|
||||
|
||||
assert!(missing_sessions.one_time_keys.contains_key(&alice));
|
||||
let user_sessions = missing_sessions.one_time_keys.get(&alice).unwrap();
|
||||
|
@ -1713,10 +1563,7 @@ pub(crate) mod test {
|
|||
|
||||
let response = claim_keys::Response::new(one_time_keys);
|
||||
|
||||
alice_machine
|
||||
.receive_keys_claim_response(&response)
|
||||
.await
|
||||
.unwrap();
|
||||
alice_machine.receive_keys_claim_response(&response).await.unwrap();
|
||||
|
||||
let session = alice_machine
|
||||
.store
|
||||
|
@ -1732,28 +1579,14 @@ pub(crate) mod test {
|
|||
async fn test_olm_encryption() {
|
||||
let (alice, bob) = get_machine_pair_with_session().await;
|
||||
|
||||
let bob_device = alice
|
||||
.get_device(&bob.user_id, &bob.device_id)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
let bob_device = alice.get_device(&bob.user_id, &bob.device_id).await.unwrap().unwrap();
|
||||
|
||||
let event = ToDeviceEvent {
|
||||
sender: alice.user_id().clone(),
|
||||
content: bob_device
|
||||
.encrypt(EventType::Dummy, json!({}))
|
||||
.await
|
||||
.unwrap()
|
||||
.1,
|
||||
content: bob_device.encrypt(EventType::Dummy, json!({})).await.unwrap().1,
|
||||
};
|
||||
|
||||
let event = bob
|
||||
.decrypt_to_device_event(&event)
|
||||
.await
|
||||
.unwrap()
|
||||
.event
|
||||
.deserialize()
|
||||
.unwrap();
|
||||
let event = bob.decrypt_to_device_event(&event).await.unwrap().event.deserialize().unwrap();
|
||||
|
||||
if let AnyToDeviceEvent::Dummy(e) = event {
|
||||
assert_eq!(&e.sender, alice.user_id());
|
||||
|
@ -1782,17 +1615,12 @@ pub(crate) mod test {
|
|||
content: to_device_requests_to_content(to_device_requests),
|
||||
};
|
||||
|
||||
let alice_session = alice
|
||||
.group_session_manager
|
||||
.get_outbound_group_session(&room_id)
|
||||
.unwrap();
|
||||
let alice_session =
|
||||
alice.group_session_manager.get_outbound_group_session(&room_id).unwrap();
|
||||
|
||||
let decrypted = bob.decrypt_to_device_event(&event).await.unwrap();
|
||||
|
||||
bob.store
|
||||
.save_sessions(&[decrypted.session.session()])
|
||||
.await
|
||||
.unwrap();
|
||||
bob.store.save_sessions(&[decrypted.session.session()]).await.unwrap();
|
||||
bob.store
|
||||
.save_inbound_group_sessions(&[decrypted.inbound_group_session.unwrap()])
|
||||
.await
|
||||
|
@ -1837,25 +1665,16 @@ pub(crate) mod test {
|
|||
content: to_device_requests_to_content(to_device_requests),
|
||||
};
|
||||
|
||||
let group_session = bob
|
||||
.decrypt_to_device_event(&event)
|
||||
.await
|
||||
.unwrap()
|
||||
.inbound_group_session;
|
||||
bob.store
|
||||
.save_inbound_group_sessions(&[group_session.unwrap()])
|
||||
.await
|
||||
.unwrap();
|
||||
let group_session =
|
||||
bob.decrypt_to_device_event(&event).await.unwrap().inbound_group_session;
|
||||
bob.store.save_inbound_group_sessions(&[group_session.unwrap()]).await.unwrap();
|
||||
|
||||
let plaintext = "It is a secret to everybody";
|
||||
|
||||
let content = MessageEventContent::text_plain(plaintext);
|
||||
|
||||
let encrypted_content = alice
|
||||
.encrypt(
|
||||
&room_id,
|
||||
AnyMessageEventContent::RoomMessage(content.clone()),
|
||||
)
|
||||
.encrypt(&room_id, AnyMessageEventContent::RoomMessage(content.clone()))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
@ -1867,13 +1686,8 @@ pub(crate) mod test {
|
|||
unsigned: Unsigned::default(),
|
||||
};
|
||||
|
||||
let decrypted_event = bob
|
||||
.decrypt_room_event(&event, &room_id)
|
||||
.await
|
||||
.unwrap()
|
||||
.event
|
||||
.deserialize()
|
||||
.unwrap();
|
||||
let decrypted_event =
|
||||
bob.decrypt_room_event(&event, &room_id).await.unwrap().event.deserialize().unwrap();
|
||||
|
||||
if let AnySyncRoomEvent::Message(AnySyncMessageEvent::RoomMessage(SyncMessageEvent {
|
||||
sender,
|
||||
|
@ -1912,10 +1726,7 @@ pub(crate) mod test {
|
|||
let device_id = machine.device_id().to_owned();
|
||||
let ed25519_key = machine.identity_keys().ed25519().to_owned();
|
||||
|
||||
machine
|
||||
.receive_keys_upload_response(&keys_upload_response())
|
||||
.await
|
||||
.unwrap();
|
||||
machine.receive_keys_upload_response(&keys_upload_response()).await.unwrap();
|
||||
|
||||
drop(machine);
|
||||
|
||||
|
@ -1937,11 +1748,7 @@ pub(crate) mod test {
|
|||
async fn interactive_verification() {
|
||||
let (alice, bob) = get_machine_pair_with_setup_sessions().await;
|
||||
|
||||
let bob_device = alice
|
||||
.get_device(bob.user_id(), bob.device_id())
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
let bob_device = alice.get_device(bob.user_id(), bob.device_id()).await.unwrap().unwrap();
|
||||
|
||||
assert!(!bob_device.is_trusted());
|
||||
|
||||
|
@ -1955,10 +1762,7 @@ pub(crate) mod test {
|
|||
assert!(alice_sas.emoji().is_none());
|
||||
assert!(bob_sas.emoji().is_none());
|
||||
|
||||
let event = bob_sas
|
||||
.accept()
|
||||
.map(|r| request_to_event(bob.user_id(), &r))
|
||||
.unwrap();
|
||||
let event = bob_sas.accept().map(|r| request_to_event(bob.user_id(), &r)).unwrap();
|
||||
|
||||
alice.handle_verification_event(&event).await;
|
||||
|
||||
|
@ -2007,11 +1811,8 @@ pub(crate) mod test {
|
|||
assert!(alice_sas.is_done());
|
||||
assert!(bob_device.is_trusted());
|
||||
|
||||
let alice_device = bob
|
||||
.get_device(alice.user_id(), alice.device_id())
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
let alice_device =
|
||||
bob.get_device(alice.user_id(), alice.device_id()).await.unwrap().unwrap();
|
||||
|
||||
assert!(!alice_device.is_trusted());
|
||||
bob.handle_verification_event(&event).await;
|
||||
|
|
|
@ -12,10 +12,6 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use matrix_sdk_common::events::ToDeviceEvent;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
use sha2::{Digest, Sha256};
|
||||
use std::{
|
||||
collections::BTreeMap,
|
||||
convert::{TryFrom, TryInto},
|
||||
|
@ -26,7 +22,6 @@ use std::{
|
|||
Arc,
|
||||
},
|
||||
};
|
||||
use tracing::{debug, trace, warn};
|
||||
|
||||
#[cfg(test)]
|
||||
use matrix_sdk_common::events::EventType;
|
||||
|
@ -37,7 +32,7 @@ use matrix_sdk_common::{
|
|||
encryption::DeviceKeys,
|
||||
events::{
|
||||
room::encrypted::{EncryptedEventContent, EncryptedEventScheme},
|
||||
AnyToDeviceEvent,
|
||||
AnyToDeviceEvent, ToDeviceEvent,
|
||||
},
|
||||
identifiers::{
|
||||
DeviceId, DeviceIdBox, DeviceKeyAlgorithm, DeviceKeyId, EventEncryptionAlgorithm, RoomId,
|
||||
|
@ -53,7 +48,15 @@ use olm_rs::{
|
|||
session::{OlmMessage, PreKeyMessage},
|
||||
PicklingMode,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
use sha2::{Digest, Sha256};
|
||||
use tracing::{debug, trace, warn};
|
||||
|
||||
use super::{
|
||||
EncryptionSettings, InboundGroupSession, OutboundGroupSession, PrivateCrossSigningIdentity,
|
||||
Session,
|
||||
};
|
||||
use crate::{
|
||||
error::{EventError, OlmResult, SessionCreationError},
|
||||
identities::ReadOnlyDevice,
|
||||
|
@ -63,11 +66,6 @@ use crate::{
|
|||
OlmError,
|
||||
};
|
||||
|
||||
use super::{
|
||||
EncryptionSettings, InboundGroupSession, OutboundGroupSession, PrivateCrossSigningIdentity,
|
||||
Session,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Account {
|
||||
pub(crate) inner: ReadOnlyAccount,
|
||||
|
@ -141,10 +139,8 @@ impl Account {
|
|||
|
||||
// Try to find a ciphertext that was meant for our device.
|
||||
if let Some(ciphertext) = own_ciphertext {
|
||||
let message_type: u8 = ciphertext
|
||||
.message_type
|
||||
.try_into()
|
||||
.map_err(|_| EventError::UnsupportedOlmType)?;
|
||||
let message_type: u8 =
|
||||
ciphertext.message_type.try_into().map_err(|_| EventError::UnsupportedOlmType)?;
|
||||
|
||||
let sha = Sha256::new()
|
||||
.chain(&content.sender_key)
|
||||
|
@ -162,20 +158,18 @@ impl Account {
|
|||
.map_err(|_| EventError::UnsupportedOlmType)?;
|
||||
|
||||
// Decrypt the OlmMessage and get a Ruma event out of it.
|
||||
let (session, event, signing_key) = match self
|
||||
.decrypt_olm_message(&event.sender, &content.sender_key, message)
|
||||
.await
|
||||
{
|
||||
Ok(d) => d,
|
||||
Err(OlmError::SessionWedged(user_id, sender_key)) => {
|
||||
if self.store.is_message_known(&message_hash).await? {
|
||||
return Err(OlmError::ReplayedMessage(user_id, sender_key));
|
||||
} else {
|
||||
return Err(OlmError::SessionWedged(user_id, sender_key));
|
||||
let (session, event, signing_key) =
|
||||
match self.decrypt_olm_message(&event.sender, &content.sender_key, message).await {
|
||||
Ok(d) => d,
|
||||
Err(OlmError::SessionWedged(user_id, sender_key)) => {
|
||||
if self.store.is_message_known(&message_hash).await? {
|
||||
return Err(OlmError::ReplayedMessage(user_id, sender_key));
|
||||
} else {
|
||||
return Err(OlmError::SessionWedged(user_id, sender_key));
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => return Err(e),
|
||||
};
|
||||
Err(e) => return Err(e),
|
||||
};
|
||||
|
||||
debug!("Decrypted a to-device event {:?}", event);
|
||||
|
||||
|
@ -210,9 +204,8 @@ impl Account {
|
|||
}
|
||||
self.inner.mark_as_shared();
|
||||
|
||||
let one_time_key_count = response
|
||||
.one_time_key_counts
|
||||
.get(&DeviceKeyAlgorithm::SignedCurve25519);
|
||||
let one_time_key_count =
|
||||
response.one_time_key_counts.get(&DeviceKeyAlgorithm::SignedCurve25519);
|
||||
|
||||
let count: u64 = one_time_key_count.map_or(0, |c| (*c).into());
|
||||
debug!(
|
||||
|
@ -297,9 +290,8 @@ impl Account {
|
|||
message: OlmMessage,
|
||||
) -> OlmResult<(SessionType, Raw<AnyToDeviceEvent>, String)> {
|
||||
// First try to decrypt using an existing session.
|
||||
let (session, plaintext) = if let Some(d) = self
|
||||
.try_decrypt_olm_message(sender, sender_key, &message)
|
||||
.await?
|
||||
let (session, plaintext) = if let Some(d) =
|
||||
self.try_decrypt_olm_message(sender, sender_key, &message).await?
|
||||
{
|
||||
// Decryption succeeded, de-structure the session/plaintext out of
|
||||
// the Option.
|
||||
|
@ -316,32 +308,26 @@ impl Account {
|
|||
available sessions {} {}",
|
||||
sender, sender_key
|
||||
);
|
||||
return Err(OlmError::SessionWedged(
|
||||
sender.to_owned(),
|
||||
sender_key.to_owned(),
|
||||
));
|
||||
return Err(OlmError::SessionWedged(sender.to_owned(), sender_key.to_owned()));
|
||||
}
|
||||
|
||||
OlmMessage::PreKey(m) => {
|
||||
// Create the new session.
|
||||
let session = match self
|
||||
.inner
|
||||
.create_inbound_session(sender_key, m.clone())
|
||||
.await
|
||||
{
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"Failed to create a new Olm session for {} {}
|
||||
let session =
|
||||
match self.inner.create_inbound_session(sender_key, m.clone()).await {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"Failed to create a new Olm session for {} {}
|
||||
from a prekey message: {}",
|
||||
sender, sender_key, e
|
||||
);
|
||||
return Err(OlmError::SessionWedged(
|
||||
sender.to_owned(),
|
||||
sender_key.to_owned(),
|
||||
));
|
||||
}
|
||||
};
|
||||
sender, sender_key, e
|
||||
);
|
||||
return Err(OlmError::SessionWedged(
|
||||
sender.to_owned(),
|
||||
sender_key.to_owned(),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
session
|
||||
}
|
||||
|
@ -428,9 +414,8 @@ impl Account {
|
|||
return Err(EventError::MissmatchedKeys.into());
|
||||
}
|
||||
|
||||
let signing_key = keys
|
||||
.get(&DeviceKeyAlgorithm::Ed25519)
|
||||
.ok_or(EventError::MissingSigningKey)?;
|
||||
let signing_key =
|
||||
keys.get(&DeviceKeyAlgorithm::Ed25519).ok_or(EventError::MissingSigningKey)?;
|
||||
|
||||
Ok((
|
||||
Raw::from(serde_json::from_value::<AnyToDeviceEvent>(decrypted_json)?),
|
||||
|
@ -547,8 +532,7 @@ impl ReadOnlyAccount {
|
|||
/// * `new_count` - The new count that was reported by the server.
|
||||
pub(crate) fn update_uploaded_key_count(&self, new_count: u64) {
|
||||
let key_count = i64::try_from(new_count).unwrap_or(i64::MAX);
|
||||
self.uploaded_signed_key_count
|
||||
.store(key_count, Ordering::Relaxed);
|
||||
self.uploaded_signed_key_count.store(key_count, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Get the currently known uploaded key count.
|
||||
|
@ -631,19 +615,12 @@ impl ReadOnlyAccount {
|
|||
/// Returns None if no keys need to be uploaded.
|
||||
pub(crate) async fn keys_for_upload(
|
||||
&self,
|
||||
) -> Option<(
|
||||
Option<DeviceKeys>,
|
||||
Option<BTreeMap<DeviceKeyId, OneTimeKey>>,
|
||||
)> {
|
||||
) -> Option<(Option<DeviceKeys>, Option<BTreeMap<DeviceKeyId, OneTimeKey>>)> {
|
||||
if !self.should_upload_keys().await {
|
||||
return None;
|
||||
}
|
||||
|
||||
let device_keys = if !self.shared() {
|
||||
Some(self.device_keys().await)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let device_keys = if !self.shared() { Some(self.device_keys().await) } else { None };
|
||||
|
||||
let one_time_keys = self.signed_one_time_keys().await.ok();
|
||||
|
||||
|
@ -666,7 +643,8 @@ impl ReadOnlyAccount {
|
|||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `pickle_mode` - The mode that was used to pickle the account, either an
|
||||
/// * `pickle_mode` - The mode that was used to pickle the account, either
|
||||
/// an
|
||||
/// unencrypted mode or an encrypted using passphrase.
|
||||
pub async fn pickle(&self, pickle_mode: PicklingMode) -> PickledAccount {
|
||||
let pickle = AccountPickle(self.inner.lock().await.pickle(pickle_mode));
|
||||
|
@ -686,7 +664,8 @@ impl ReadOnlyAccount {
|
|||
///
|
||||
/// * `pickle` - The pickled version of the Account.
|
||||
///
|
||||
/// * `pickle_mode` - The mode that was used to pickle the account, either an
|
||||
/// * `pickle_mode` - The mode that was used to pickle the account, either
|
||||
/// an
|
||||
/// unencrypted mode or an encrypted using passphrase.
|
||||
pub fn from_pickle(
|
||||
pickle: PickledAccount,
|
||||
|
@ -742,25 +721,17 @@ impl ReadOnlyAccount {
|
|||
"keys": device_keys.keys,
|
||||
});
|
||||
|
||||
device_keys
|
||||
.signatures
|
||||
.entry(self.user_id().clone())
|
||||
.or_insert_with(BTreeMap::new)
|
||||
.insert(
|
||||
DeviceKeyId::from_parts(DeviceKeyAlgorithm::Ed25519, &self.device_id),
|
||||
self.sign_json(json_device_keys).await,
|
||||
);
|
||||
device_keys.signatures.entry(self.user_id().clone()).or_insert_with(BTreeMap::new).insert(
|
||||
DeviceKeyId::from_parts(DeviceKeyAlgorithm::Ed25519, &self.device_id),
|
||||
self.sign_json(json_device_keys).await,
|
||||
);
|
||||
|
||||
device_keys
|
||||
}
|
||||
|
||||
pub(crate) async fn bootstrap_cross_signing(
|
||||
&self,
|
||||
) -> (
|
||||
PrivateCrossSigningIdentity,
|
||||
UploadSigningKeysRequest,
|
||||
SignatureUploadRequest,
|
||||
) {
|
||||
) -> (PrivateCrossSigningIdentity, UploadSigningKeysRequest, SignatureUploadRequest) {
|
||||
PrivateCrossSigningIdentity::new_with_account(self).await
|
||||
}
|
||||
|
||||
|
@ -873,8 +844,8 @@ impl ReadOnlyAccount {
|
|||
/// # Arguments
|
||||
/// * `device` - The other account's device.
|
||||
///
|
||||
/// * `key_map` - A map from the algorithm and device id to the one-time
|
||||
/// key that the other account created and shared with us.
|
||||
/// * `key_map` - A map from the algorithm and device id to the one-time key
|
||||
/// that the other account created and shared with us.
|
||||
pub(crate) async fn create_outbound_session(
|
||||
&self,
|
||||
device: ReadOnlyDevice,
|
||||
|
@ -911,24 +882,20 @@ impl ReadOnlyAccount {
|
|||
)
|
||||
})?;
|
||||
|
||||
let curve_key = device
|
||||
.get_key(DeviceKeyAlgorithm::Curve25519)
|
||||
.ok_or_else(|| {
|
||||
SessionCreationError::DeviceMissingCurveKey(
|
||||
device.user_id().to_owned(),
|
||||
device.device_id().into(),
|
||||
)
|
||||
})?;
|
||||
let curve_key = device.get_key(DeviceKeyAlgorithm::Curve25519).ok_or_else(|| {
|
||||
SessionCreationError::DeviceMissingCurveKey(
|
||||
device.user_id().to_owned(),
|
||||
device.device_id().into(),
|
||||
)
|
||||
})?;
|
||||
|
||||
self.create_outbound_session_helper(curve_key, &one_time_key)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
SessionCreationError::OlmError(
|
||||
device.user_id().to_owned(),
|
||||
device.device_id().into(),
|
||||
e,
|
||||
)
|
||||
})
|
||||
self.create_outbound_session_helper(curve_key, &one_time_key).await.map_err(|e| {
|
||||
SessionCreationError::OlmError(
|
||||
device.user_id().to_owned(),
|
||||
device.device_id().into(),
|
||||
e,
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
/// Create a new session with another account given a pre-key Olm message.
|
||||
|
@ -946,17 +913,10 @@ impl ReadOnlyAccount {
|
|||
their_identity_key: &str,
|
||||
message: PreKeyMessage,
|
||||
) -> Result<Session, OlmSessionError> {
|
||||
let session = self
|
||||
.inner
|
||||
.lock()
|
||||
.await
|
||||
.create_inbound_session_from(their_identity_key, message)?;
|
||||
let session =
|
||||
self.inner.lock().await.create_inbound_session_from(their_identity_key, message)?;
|
||||
|
||||
self.inner
|
||||
.lock()
|
||||
.await
|
||||
.remove_one_time_keys(&session)
|
||||
.expect(
|
||||
self.inner.lock().await.remove_one_time_keys(&session).expect(
|
||||
"Session was successfully created but the account doesn't hold a matching one-time key",
|
||||
);
|
||||
|
||||
|
@ -1028,8 +988,7 @@ impl ReadOnlyAccount {
|
|||
&self,
|
||||
room_id: &RoomId,
|
||||
) -> Result<(OutboundGroupSession, InboundGroupSession), ()> {
|
||||
self.create_group_session_pair(room_id, EncryptionSettings::default())
|
||||
.await
|
||||
self.create_group_session_pair(room_id, EncryptionSettings::default()).await
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
@ -1039,27 +998,19 @@ impl ReadOnlyAccount {
|
|||
|
||||
let device = ReadOnlyDevice::from_account(other).await;
|
||||
|
||||
let mut our_session = self
|
||||
.create_outbound_session(device.clone(), &one_time)
|
||||
.await
|
||||
.unwrap();
|
||||
let mut our_session =
|
||||
self.create_outbound_session(device.clone(), &one_time).await.unwrap();
|
||||
|
||||
other.mark_keys_as_published().await;
|
||||
|
||||
let message = our_session
|
||||
.encrypt(&device, EventType::Dummy, json!({}))
|
||||
.await
|
||||
.unwrap();
|
||||
let message = our_session.encrypt(&device, EventType::Dummy, json!({})).await.unwrap();
|
||||
let content = if let EncryptedEventScheme::OlmV1Curve25519AesSha2(c) = message.scheme {
|
||||
c
|
||||
} else {
|
||||
panic!("Invalid encrypted event algorithm");
|
||||
};
|
||||
|
||||
let own_ciphertext = content
|
||||
.ciphertext
|
||||
.get(other.identity_keys.curve25519())
|
||||
.unwrap();
|
||||
let own_ciphertext = content.ciphertext.get(other.identity_keys.curve25519()).unwrap();
|
||||
let message_type: u8 = own_ciphertext.message_type.try_into().unwrap();
|
||||
|
||||
let message =
|
||||
|
|
|
@ -19,19 +19,6 @@ use std::{
|
|||
sync::Arc,
|
||||
};
|
||||
|
||||
use olm_rs::{
|
||||
errors::OlmGroupSessionError, inbound_group_session::OlmInboundGroupSession, PicklingMode,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use zeroize::Zeroizing;
|
||||
|
||||
pub use olm_rs::{
|
||||
account::IdentityKeys,
|
||||
session::{OlmMessage, PreKeyMessage},
|
||||
utility::OlmUtility,
|
||||
};
|
||||
|
||||
use matrix_sdk_common::{
|
||||
events::{
|
||||
forwarded_room_key::ForwardedRoomKeyToDeviceEventContent,
|
||||
|
@ -45,6 +32,17 @@ use matrix_sdk_common::{
|
|||
locks::Mutex,
|
||||
Raw,
|
||||
};
|
||||
pub use olm_rs::{
|
||||
account::IdentityKeys,
|
||||
session::{OlmMessage, PreKeyMessage},
|
||||
utility::OlmUtility,
|
||||
};
|
||||
use olm_rs::{
|
||||
errors::OlmGroupSessionError, inbound_group_session::OlmInboundGroupSession, PicklingMode,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use zeroize::Zeroizing;
|
||||
|
||||
use super::{ExportedGroupSessionKey, ExportedRoomKey, GroupSessionKey};
|
||||
use crate::error::{EventError, MegolmResult};
|
||||
|
@ -149,10 +147,8 @@ impl InboundGroupSession {
|
|||
forwarding_chains.push(sender_key.to_owned());
|
||||
|
||||
let mut sender_claimed_key = BTreeMap::new();
|
||||
sender_claimed_key.insert(
|
||||
DeviceKeyAlgorithm::Ed25519,
|
||||
content.sender_claimed_ed25519_key.to_owned(),
|
||||
);
|
||||
sender_claimed_key
|
||||
.insert(DeviceKeyAlgorithm::Ed25519, content.sender_claimed_ed25519_key.to_owned());
|
||||
|
||||
Ok(InboundGroupSession {
|
||||
inner: Mutex::new(session).into(),
|
||||
|
@ -219,11 +215,7 @@ impl InboundGroupSession {
|
|||
let message_index = std::cmp::max(self.first_known_index(), message_index);
|
||||
|
||||
let session_key = ExportedGroupSessionKey(
|
||||
self.inner
|
||||
.lock()
|
||||
.await
|
||||
.export(message_index)
|
||||
.expect("Can't export session"),
|
||||
self.inner.lock().await.export(message_index).expect("Can't export session"),
|
||||
);
|
||||
|
||||
ExportedRoomKey {
|
||||
|
@ -316,9 +308,7 @@ impl InboundGroupSession {
|
|||
let (plaintext, message_index) = self.decrypt_helper(content.ciphertext.clone()).await?;
|
||||
|
||||
let mut decrypted_value = serde_json::from_str::<Value>(&plaintext)?;
|
||||
let decrypted_object = decrypted_value
|
||||
.as_object_mut()
|
||||
.ok_or(EventError::NotAnObject)?;
|
||||
let decrypted_object = decrypted_value.as_object_mut().ok_or(EventError::NotAnObject)?;
|
||||
|
||||
// TODO better number conversion here.
|
||||
let server_ts = event
|
||||
|
@ -337,10 +327,8 @@ impl InboundGroupSession {
|
|||
serde_json::to_value(&event.unsigned).unwrap_or_default(),
|
||||
);
|
||||
|
||||
if let Some(decrypted_content) = decrypted_object
|
||||
.get_mut("content")
|
||||
.map(|c| c.as_object_mut())
|
||||
.flatten()
|
||||
if let Some(decrypted_content) =
|
||||
decrypted_object.get_mut("content").map(|c| c.as_object_mut()).flatten()
|
||||
{
|
||||
if !decrypted_content.contains_key("m.relates_to") {
|
||||
if let Some(relation) = &event.content.relates_to {
|
||||
|
@ -352,19 +340,14 @@ impl InboundGroupSession {
|
|||
}
|
||||
}
|
||||
|
||||
Ok((
|
||||
serde_json::from_value::<Raw<AnySyncRoomEvent>>(decrypted_value)?,
|
||||
message_index,
|
||||
))
|
||||
Ok((serde_json::from_value::<Raw<AnySyncRoomEvent>>(decrypted_value)?, message_index))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(tarpaulin_include))]
|
||||
impl fmt::Debug for InboundGroupSession {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("InboundGroupSession")
|
||||
.field("session_id", &self.session_id())
|
||||
.finish()
|
||||
f.debug_struct("InboundGroupSession").field("session_id", &self.session_id()).finish()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -399,7 +382,8 @@ pub struct PickledInboundGroupSession {
|
|||
pub history_visibility: Option<HistoryVisibility>,
|
||||
}
|
||||
|
||||
/// The typed representation of a base64 encoded string of the GroupSession pickle.
|
||||
/// The typed representation of a base64 encoded string of the GroupSession
|
||||
/// pickle.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct InboundGroupSessionPickle(String);
|
||||
|
||||
|
|
|
@ -12,6 +12,8 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::{collections::BTreeMap, convert::TryInto};
|
||||
|
||||
use matrix_sdk_common::{
|
||||
events::forwarded_room_key::{
|
||||
ForwardedRoomKeyToDeviceEventContent, ForwardedRoomKeyToDeviceEventContentInit,
|
||||
|
@ -19,7 +21,6 @@ use matrix_sdk_common::{
|
|||
identifiers::{DeviceKeyAlgorithm, EventEncryptionAlgorithm, RoomId},
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{collections::BTreeMap, convert::TryInto};
|
||||
use zeroize::Zeroize;
|
||||
|
||||
mod inbound;
|
||||
|
@ -107,10 +108,8 @@ impl From<ForwardedRoomKeyToDeviceEventContent> for ExportedRoomKey {
|
|||
/// Convert the content of a forwarded room key into a exported room key.
|
||||
fn from(forwarded_key: ForwardedRoomKeyToDeviceEventContent) -> Self {
|
||||
let mut sender_claimed_keys: BTreeMap<DeviceKeyAlgorithm, String> = BTreeMap::new();
|
||||
sender_claimed_keys.insert(
|
||||
DeviceKeyAlgorithm::Ed25519,
|
||||
forwarded_key.sender_claimed_ed25519_key,
|
||||
);
|
||||
sender_claimed_keys
|
||||
.insert(DeviceKeyAlgorithm::Ed25519, forwarded_key.sender_claimed_ed25519_key);
|
||||
|
||||
Self {
|
||||
algorithm: forwarded_key.algorithm,
|
||||
|
@ -142,10 +141,7 @@ mod test {
|
|||
#[tokio::test]
|
||||
#[cfg(target_os = "linux")]
|
||||
async fn expiration() {
|
||||
let settings = EncryptionSettings {
|
||||
rotation_period_msgs: 1,
|
||||
..Default::default()
|
||||
};
|
||||
let settings = EncryptionSettings { rotation_period_msgs: 1, ..Default::default() };
|
||||
|
||||
let account = ReadOnlyAccount::new(&user_id!("@alice:example.org"), "DEVICEID".into());
|
||||
let (session, _) = account
|
||||
|
@ -155,9 +151,9 @@ mod test {
|
|||
|
||||
assert!(!session.expired());
|
||||
let _ = session
|
||||
.encrypt(AnyMessageEventContent::RoomMessage(
|
||||
MessageEventContent::text_plain("Test message"),
|
||||
))
|
||||
.encrypt(AnyMessageEventContent::RoomMessage(MessageEventContent::text_plain(
|
||||
"Test message",
|
||||
)))
|
||||
.await;
|
||||
assert!(session.expired());
|
||||
|
||||
|
|
|
@ -12,15 +12,6 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use dashmap::DashMap;
|
||||
use matrix_sdk_common::{
|
||||
api::r0::to_device::DeviceIdOrAllDevices,
|
||||
events::room::{
|
||||
encrypted::MegolmV1AesSha2ContentInit, history_visibility::HistoryVisibility,
|
||||
message::Relation,
|
||||
},
|
||||
uuid::Uuid,
|
||||
};
|
||||
use std::{
|
||||
cmp::max,
|
||||
collections::BTreeMap,
|
||||
|
@ -31,23 +22,24 @@ use std::{
|
|||
},
|
||||
time::Duration,
|
||||
};
|
||||
use tracing::{debug, error, trace};
|
||||
|
||||
use dashmap::DashMap;
|
||||
use matrix_sdk_common::{
|
||||
api::r0::to_device::DeviceIdOrAllDevices,
|
||||
events::{
|
||||
room::{
|
||||
encrypted::{EncryptedEventContent, EncryptedEventScheme},
|
||||
encrypted::{EncryptedEventContent, EncryptedEventScheme, MegolmV1AesSha2ContentInit},
|
||||
encryption::EncryptionEventContent,
|
||||
history_visibility::HistoryVisibility,
|
||||
message::Relation,
|
||||
},
|
||||
AnyMessageEventContent, EventContent,
|
||||
},
|
||||
identifiers::{DeviceId, DeviceIdBox, EventEncryptionAlgorithm, RoomId, UserId},
|
||||
instant::Instant,
|
||||
locks::Mutex,
|
||||
uuid::Uuid,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
|
||||
pub use olm_rs::{
|
||||
account::IdentityKeys,
|
||||
session::{OlmMessage, PreKeyMessage},
|
||||
|
@ -56,13 +48,15 @@ pub use olm_rs::{
|
|||
use olm_rs::{
|
||||
errors::OlmGroupSessionError, outbound_group_session::OlmOutboundGroupSession, PicklingMode,
|
||||
};
|
||||
|
||||
use crate::ToDeviceRequest;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
use tracing::{debug, error, trace};
|
||||
|
||||
use super::{
|
||||
super::{deserialize_instant, serialize_instant},
|
||||
GroupSessionKey,
|
||||
};
|
||||
use crate::ToDeviceRequest;
|
||||
|
||||
const ROTATION_PERIOD: Duration = Duration::from_millis(604800000);
|
||||
const ROTATION_MESSAGES: u64 = 100;
|
||||
|
@ -102,12 +96,10 @@ impl EncryptionSettings {
|
|||
/// Create new encryption settings using an `EncryptionEventContent` and a
|
||||
/// history visibility.
|
||||
pub fn new(content: EncryptionEventContent, history_visibility: HistoryVisibility) -> Self {
|
||||
let rotation_period: Duration = content
|
||||
.rotation_period_ms
|
||||
.map_or(ROTATION_PERIOD, |r| Duration::from_millis(r.into()));
|
||||
let rotation_period_msgs: u64 = content
|
||||
.rotation_period_msgs
|
||||
.map_or(ROTATION_MESSAGES, Into::into);
|
||||
let rotation_period: Duration =
|
||||
content.rotation_period_ms.map_or(ROTATION_PERIOD, |r| Duration::from_millis(r.into()));
|
||||
let rotation_period_msgs: u64 =
|
||||
content.rotation_period_msgs.map_or(ROTATION_MESSAGES, Into::into);
|
||||
|
||||
Self {
|
||||
algorithm: content.algorithm,
|
||||
|
@ -186,8 +178,7 @@ impl OutboundGroupSession {
|
|||
request: Arc<ToDeviceRequest>,
|
||||
message_index: u32,
|
||||
) {
|
||||
self.to_share_with_set
|
||||
.insert(request_id, (request, message_index));
|
||||
self.to_share_with_set.insert(request_id, (request, message_index));
|
||||
}
|
||||
|
||||
/// This should be called if an the user wishes to rotate this session.
|
||||
|
@ -225,10 +216,7 @@ impl OutboundGroupSession {
|
|||
});
|
||||
|
||||
user_pairs.for_each(|(u, d)| {
|
||||
self.shared_with_set
|
||||
.entry(u)
|
||||
.or_insert_with(DashMap::new)
|
||||
.extend(d);
|
||||
self.shared_with_set.entry(u).or_insert_with(DashMap::new).extend(d);
|
||||
});
|
||||
|
||||
if self.to_share_with_set.is_empty() {
|
||||
|
@ -241,11 +229,8 @@ impl OutboundGroupSession {
|
|||
self.mark_as_shared();
|
||||
}
|
||||
} else {
|
||||
let request_ids: Vec<String> = self
|
||||
.to_share_with_set
|
||||
.iter()
|
||||
.map(|e| e.key().to_string())
|
||||
.collect();
|
||||
let request_ids: Vec<String> =
|
||||
self.to_share_with_set.iter().map(|e| e.key().to_string()).collect();
|
||||
|
||||
error!(
|
||||
all_request_ids = ?request_ids,
|
||||
|
@ -296,11 +281,7 @@ impl OutboundGroupSession {
|
|||
|
||||
let relates_to: Option<Relation> = json_content
|
||||
.get("content")
|
||||
.map(|c| {
|
||||
c.get("m.relates_to")
|
||||
.cloned()
|
||||
.map(|r| serde_json::from_value(r).ok())
|
||||
})
|
||||
.map(|c| c.get("m.relates_to").cloned().map(|r| serde_json::from_value(r).ok()))
|
||||
.flatten()
|
||||
.flatten();
|
||||
|
||||
|
@ -443,10 +424,7 @@ impl OutboundGroupSession {
|
|||
/// Get the list of requests that need to be sent out for this session to be
|
||||
/// marked as shared.
|
||||
pub(crate) fn pending_requests(&self) -> Vec<Arc<ToDeviceRequest>> {
|
||||
self.to_share_with_set
|
||||
.iter()
|
||||
.map(|i| i.value().0.clone())
|
||||
.collect()
|
||||
self.to_share_with_set.iter().map(|i| i.value().0.clone()).collect()
|
||||
}
|
||||
|
||||
/// Get the list of request ids this session is waiting for to be sent out.
|
||||
|
@ -462,10 +440,10 @@ impl OutboundGroupSession {
|
|||
/// # Arguments
|
||||
///
|
||||
/// * `device_id` - The device id of the device that created this session.
|
||||
/// Put differently, our own device id.
|
||||
/// Put differently, our own device id.
|
||||
///
|
||||
/// * `identity_keys` - The identity keys of the device that created this
|
||||
/// session, our own identity keys.
|
||||
/// session, our own identity keys.
|
||||
///
|
||||
/// * `pickle` - The pickled version of the `OutboundGroupSession`.
|
||||
///
|
||||
|
@ -507,7 +485,8 @@ impl OutboundGroupSession {
|
|||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `pickle_mode` - The mode that should be used to pickle the group session,
|
||||
/// * `pickle_mode` - The mode that should be used to pickle the group
|
||||
/// session,
|
||||
/// either an unencrypted mode or an encrypted using passphrase.
|
||||
pub async fn pickle(&self, pickling_mode: PicklingMode) -> PickledOutboundGroupSession {
|
||||
let pickle: OutboundGroupSessionPickle =
|
||||
|
@ -528,10 +507,7 @@ impl OutboundGroupSession {
|
|||
(
|
||||
u.key().clone(),
|
||||
#[allow(clippy::map_clone)]
|
||||
u.value()
|
||||
.iter()
|
||||
.map(|d| (d.key().clone(), *d.value()))
|
||||
.collect(),
|
||||
u.value().iter().map(|d| (d.key().clone(), *d.value())).collect(),
|
||||
)
|
||||
})
|
||||
.collect(),
|
||||
|
@ -578,10 +554,7 @@ pub struct PickledOutboundGroupSession {
|
|||
/// The room id this session is used for.
|
||||
pub room_id: Arc<RoomId>,
|
||||
/// The timestamp when this session was created.
|
||||
#[serde(
|
||||
deserialize_with = "deserialize_instant",
|
||||
serialize_with = "serialize_instant"
|
||||
)]
|
||||
#[serde(deserialize_with = "deserialize_instant", serialize_with = "serialize_instant")]
|
||||
pub creation_time: Instant,
|
||||
/// The number of messages this session has already encrypted.
|
||||
pub message_count: u64,
|
||||
|
|
|
@ -30,14 +30,13 @@ pub use group_sessions::{
|
|||
OutboundGroupSession, PickledInboundGroupSession, PickledOutboundGroupSession,
|
||||
};
|
||||
pub(crate) use group_sessions::{GroupSessionKey, ShareState};
|
||||
use matrix_sdk_common::instant::{Duration, Instant};
|
||||
pub use olm_rs::{account::IdentityKeys, PicklingMode};
|
||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
pub use session::{PickledSession, Session, SessionPickle};
|
||||
pub use signing::{PickledCrossSigningIdentity, PrivateCrossSigningIdentity};
|
||||
pub(crate) use utility::Utility;
|
||||
|
||||
use matrix_sdk_common::instant::{Duration, Instant};
|
||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
|
||||
pub(crate) fn serialize_instant<S>(instant: &Instant, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
|
@ -60,14 +59,16 @@ where
|
|||
|
||||
#[cfg(test)]
|
||||
pub(crate) mod test {
|
||||
use crate::olm::{InboundGroupSession, ReadOnlyAccount, Session};
|
||||
use std::{collections::BTreeMap, convert::TryInto};
|
||||
|
||||
use matrix_sdk_common::{
|
||||
api::r0::keys::SignedKey,
|
||||
events::forwarded_room_key::ForwardedRoomKeyToDeviceEventContent,
|
||||
identifiers::{room_id, user_id, DeviceId, UserId},
|
||||
};
|
||||
use olm_rs::session::OlmMessage;
|
||||
use std::{collections::BTreeMap, convert::TryInto};
|
||||
|
||||
use crate::olm::{InboundGroupSession, ReadOnlyAccount, Session};
|
||||
|
||||
fn alice_id() -> UserId {
|
||||
user_id!("@alice:example.org")
|
||||
|
@ -90,21 +91,12 @@ pub(crate) mod test {
|
|||
let bob = ReadOnlyAccount::new(&bob_id(), &bob_device_id());
|
||||
|
||||
bob.generate_one_time_keys_helper(1).await;
|
||||
let one_time_key = bob
|
||||
.one_time_keys()
|
||||
.await
|
||||
.curve25519()
|
||||
.iter()
|
||||
.next()
|
||||
.unwrap()
|
||||
.1
|
||||
.to_owned();
|
||||
let one_time_key =
|
||||
bob.one_time_keys().await.curve25519().iter().next().unwrap().1.to_owned();
|
||||
let one_time_key = SignedKey::new(one_time_key, BTreeMap::new());
|
||||
let sender_key = bob.identity_keys().curve25519().to_owned();
|
||||
let session = alice
|
||||
.create_outbound_session_helper(&sender_key, &one_time_key)
|
||||
.await
|
||||
.unwrap();
|
||||
let session =
|
||||
alice.create_outbound_session_helper(&sender_key, &one_time_key).await.unwrap();
|
||||
|
||||
(alice, session)
|
||||
}
|
||||
|
@ -120,10 +112,7 @@ pub(crate) mod test {
|
|||
assert_ne!(identity_keys.keys().len(), 0);
|
||||
assert_ne!(identity_keys.iter().len(), 0);
|
||||
assert!(identity_keys.contains_key("ed25519"));
|
||||
assert_eq!(
|
||||
identity_keys.ed25519(),
|
||||
identity_keys.get("ed25519").unwrap()
|
||||
);
|
||||
assert_eq!(identity_keys.ed25519(), identity_keys.get("ed25519").unwrap());
|
||||
assert!(!identity_keys.curve25519().is_empty());
|
||||
|
||||
account.mark_as_shared();
|
||||
|
@ -147,10 +136,7 @@ pub(crate) mod test {
|
|||
assert_ne!(one_time_keys.iter().len(), 0);
|
||||
assert!(one_time_keys.contains_key("curve25519"));
|
||||
assert_eq!(one_time_keys.curve25519().keys().len(), 10);
|
||||
assert_eq!(
|
||||
one_time_keys.curve25519(),
|
||||
one_time_keys.get("curve25519").unwrap()
|
||||
);
|
||||
assert_eq!(one_time_keys.curve25519(), one_time_keys.get("curve25519").unwrap());
|
||||
|
||||
account.mark_keys_as_published().await;
|
||||
let one_time_keys = account.one_time_keys().await;
|
||||
|
@ -166,13 +152,7 @@ pub(crate) mod test {
|
|||
let one_time_keys = alice.one_time_keys().await;
|
||||
alice.mark_keys_as_published().await;
|
||||
|
||||
let one_time_key = one_time_keys
|
||||
.curve25519()
|
||||
.iter()
|
||||
.next()
|
||||
.unwrap()
|
||||
.1
|
||||
.to_owned();
|
||||
let one_time_key = one_time_keys.curve25519().iter().next().unwrap().1.to_owned();
|
||||
|
||||
let one_time_key = SignedKey::new(one_time_key, BTreeMap::new());
|
||||
|
||||
|
@ -196,10 +176,7 @@ pub(crate) mod test {
|
|||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(alice_session
|
||||
.matches(bob_keys.curve25519(), prekey_message)
|
||||
.await
|
||||
.unwrap());
|
||||
assert!(alice_session.matches(bob_keys.curve25519(), prekey_message).await.unwrap());
|
||||
|
||||
assert_eq!(bob_session.session_id(), alice_session.session_id());
|
||||
|
||||
|
@ -212,10 +189,7 @@ pub(crate) mod test {
|
|||
let alice = ReadOnlyAccount::new(&alice_id(), &alice_device_id());
|
||||
let room_id = room_id!("!test:localhost");
|
||||
|
||||
let (outbound, _) = alice
|
||||
.create_group_session_pair_with_defaults(&room_id)
|
||||
.await
|
||||
.unwrap();
|
||||
let (outbound, _) = alice.create_group_session_pair_with_defaults(&room_id).await.unwrap();
|
||||
|
||||
assert_eq!(0, outbound.message_index().await);
|
||||
assert!(!outbound.shared());
|
||||
|
@ -238,10 +212,7 @@ pub(crate) mod test {
|
|||
let plaintext = "This is a secret to everybody".to_owned();
|
||||
let ciphertext = outbound.encrypt_helper(plaintext.clone()).await;
|
||||
|
||||
assert_eq!(
|
||||
plaintext,
|
||||
inbound.decrypt_helper(ciphertext).await.unwrap().0
|
||||
);
|
||||
assert_eq!(plaintext, inbound.decrypt_helper(ciphertext).await.unwrap().0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
@ -249,10 +220,7 @@ pub(crate) mod test {
|
|||
let alice = ReadOnlyAccount::new(&alice_id(), &alice_device_id());
|
||||
let room_id = room_id!("!test:localhost");
|
||||
|
||||
let (_, inbound) = alice
|
||||
.create_group_session_pair_with_defaults(&room_id)
|
||||
.await
|
||||
.unwrap();
|
||||
let (_, inbound) = alice.create_group_session_pair_with_defaults(&room_id).await.unwrap();
|
||||
|
||||
let export = inbound.export().await;
|
||||
let export: ForwardedRoomKeyToDeviceEventContent = export.try_into().unwrap();
|
||||
|
|
|
@ -27,20 +27,18 @@ use matrix_sdk_common::{
|
|||
locks::Mutex,
|
||||
};
|
||||
use olm_rs::{errors::OlmSessionError, session::OlmSession, PicklingMode};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
|
||||
use crate::{
|
||||
error::{EventError, OlmResult, SessionUnpicklingError},
|
||||
ReadOnlyDevice,
|
||||
};
|
||||
|
||||
pub use olm_rs::{
|
||||
session::{OlmMessage, PreKeyMessage},
|
||||
utility::OlmUtility,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
|
||||
use super::{deserialize_instant, serialize_instant, IdentityKeys};
|
||||
use crate::{
|
||||
error::{EventError, OlmResult, SessionUnpicklingError},
|
||||
ReadOnlyDevice,
|
||||
};
|
||||
|
||||
/// Cryptographic session that enables secure communication between two
|
||||
/// `Account`s
|
||||
|
@ -105,8 +103,8 @@ impl Session {
|
|||
/// # Arguments
|
||||
///
|
||||
/// * `recipient_device` - The device for which this message is going to be
|
||||
/// encrypted, this needs to be the device that was used to create this
|
||||
/// session with.
|
||||
/// encrypted, this needs to be the device that was used to create this
|
||||
/// session with.
|
||||
///
|
||||
/// * `event_type` - The type of the event.
|
||||
///
|
||||
|
@ -121,10 +119,8 @@ impl Session {
|
|||
.get_key(DeviceKeyAlgorithm::Ed25519)
|
||||
.ok_or(EventError::MissingSigningKey)?;
|
||||
|
||||
let relates_to = content
|
||||
.get("m.relates_to")
|
||||
.cloned()
|
||||
.and_then(|v| serde_json::from_value(v).ok());
|
||||
let relates_to =
|
||||
content.get("m.relates_to").cloned().and_then(|v| serde_json::from_value(v).ok());
|
||||
|
||||
let payload = json!({
|
||||
"sender": self.user_id.as_str(),
|
||||
|
@ -174,10 +170,7 @@ impl Session {
|
|||
their_identity_key: &str,
|
||||
message: PreKeyMessage,
|
||||
) -> Result<bool, OlmSessionError> {
|
||||
self.inner
|
||||
.lock()
|
||||
.await
|
||||
.matches_inbound_session_from(their_identity_key, message)
|
||||
self.inner.lock().await.matches_inbound_session_from(their_identity_key, message)
|
||||
}
|
||||
|
||||
/// Returns the unique identifier for this session.
|
||||
|
@ -259,20 +252,15 @@ pub struct PickledSession {
|
|||
/// The curve25519 key of the other user that we share this session with.
|
||||
pub sender_key: String,
|
||||
/// The relative time elapsed since the session was created.
|
||||
#[serde(
|
||||
deserialize_with = "deserialize_instant",
|
||||
serialize_with = "serialize_instant"
|
||||
)]
|
||||
#[serde(deserialize_with = "deserialize_instant", serialize_with = "serialize_instant")]
|
||||
pub creation_time: Instant,
|
||||
/// The relative time elapsed since the session was last used.
|
||||
#[serde(
|
||||
deserialize_with = "deserialize_instant",
|
||||
serialize_with = "serialize_instant"
|
||||
)]
|
||||
#[serde(deserialize_with = "deserialize_instant", serialize_with = "serialize_instant")]
|
||||
pub last_use_time: Instant,
|
||||
}
|
||||
|
||||
/// The typed representation of a base64 encoded string of the Olm Session pickle.
|
||||
/// The typed representation of a base64 encoded string of the Olm Session
|
||||
/// pickle.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SessionPickle(String);
|
||||
|
||||
|
|
|
@ -14,8 +14,6 @@
|
|||
|
||||
mod pk_signing;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Error as JsonError;
|
||||
use std::{
|
||||
collections::BTreeMap,
|
||||
sync::{
|
||||
|
@ -30,14 +28,15 @@ use matrix_sdk_common::{
|
|||
identifiers::{DeviceKeyAlgorithm, DeviceKeyId, UserId},
|
||||
locks::Mutex,
|
||||
};
|
||||
use pk_signing::{MasterSigning, PickledSignings, SelfSigning, Signing, SigningError, UserSigning};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Error as JsonError;
|
||||
|
||||
use crate::{
|
||||
error::SignatureError, requests::UploadSigningKeysRequest, OwnUserIdentity, ReadOnlyAccount,
|
||||
ReadOnlyDevice, UserIdentity,
|
||||
};
|
||||
|
||||
use pk_signing::{MasterSigning, PickledSignings, SelfSigning, Signing, SigningError, UserSigning};
|
||||
|
||||
/// Private cross signing identity.
|
||||
///
|
||||
/// This object holds the private and public ed25519 key triplet that is used
|
||||
|
@ -186,10 +185,7 @@ impl PrivateCrossSigningIdentity {
|
|||
signed_keys
|
||||
.entry((&*self.user_id).to_owned())
|
||||
.or_insert_with(BTreeMap::new)
|
||||
.insert(
|
||||
device_keys.device_id.to_string(),
|
||||
serde_json::to_value(device_keys)?,
|
||||
);
|
||||
.insert(device_keys.device_id.to_string(), serde_json::to_value(device_keys)?);
|
||||
|
||||
Ok(SignatureUploadRequest::new(signed_keys))
|
||||
}
|
||||
|
@ -229,10 +225,7 @@ impl PrivateCrossSigningIdentity {
|
|||
signature,
|
||||
);
|
||||
|
||||
let master = MasterSigning {
|
||||
inner: master,
|
||||
public_key: public_key.into(),
|
||||
};
|
||||
let master = MasterSigning { inner: master, public_key: public_key.into() };
|
||||
|
||||
let identity = Self::new_helper(account.user_id(), master).await;
|
||||
let signature_request = identity
|
||||
|
@ -250,20 +243,14 @@ impl PrivateCrossSigningIdentity {
|
|||
let mut public_key = user.cross_signing_key(user_id.to_owned(), KeyUsage::UserSigning);
|
||||
master.sign_subkey(&mut public_key).await;
|
||||
|
||||
let user = UserSigning {
|
||||
inner: user,
|
||||
public_key: public_key.into(),
|
||||
};
|
||||
let user = UserSigning { inner: user, public_key: public_key.into() };
|
||||
|
||||
let self_signing = Signing::new();
|
||||
let mut public_key =
|
||||
self_signing.cross_signing_key(user_id.to_owned(), KeyUsage::SelfSigning);
|
||||
master.sign_subkey(&mut public_key).await;
|
||||
|
||||
let self_signing = SelfSigning {
|
||||
inner: self_signing,
|
||||
public_key: public_key.into(),
|
||||
};
|
||||
let self_signing = SelfSigning { inner: self_signing, public_key: public_key.into() };
|
||||
|
||||
Self {
|
||||
user_id: Arc::new(user_id.to_owned()),
|
||||
|
@ -281,10 +268,7 @@ impl PrivateCrossSigningIdentity {
|
|||
let master = Signing::new();
|
||||
|
||||
let public_key = master.cross_signing_key(user_id.clone(), KeyUsage::Master);
|
||||
let master = MasterSigning {
|
||||
inner: master,
|
||||
public_key: public_key.into(),
|
||||
};
|
||||
let master = MasterSigning { inner: master, public_key: public_key.into() };
|
||||
|
||||
Self::new_helper(&user_id, master).await
|
||||
}
|
||||
|
@ -334,11 +318,7 @@ impl PrivateCrossSigningIdentity {
|
|||
None
|
||||
};
|
||||
|
||||
let pickle = PickledSignings {
|
||||
master_key,
|
||||
user_signing_key,
|
||||
self_signing_key,
|
||||
};
|
||||
let pickle = PickledSignings { master_key, user_signing_key, self_signing_key };
|
||||
|
||||
let pickle = serde_json::to_string(&pickle)?;
|
||||
|
||||
|
@ -390,54 +370,35 @@ impl PrivateCrossSigningIdentity {
|
|||
/// Get the upload request that is needed to share the public keys of this
|
||||
/// identity.
|
||||
pub(crate) async fn as_upload_request(&self) -> UploadSigningKeysRequest {
|
||||
let master_key = self
|
||||
.master_key
|
||||
.lock()
|
||||
.await
|
||||
.as_ref()
|
||||
.cloned()
|
||||
.map(|k| k.public_key.into());
|
||||
let master_key =
|
||||
self.master_key.lock().await.as_ref().cloned().map(|k| k.public_key.into());
|
||||
|
||||
let user_signing_key = self
|
||||
.user_signing_key
|
||||
.lock()
|
||||
.await
|
||||
.as_ref()
|
||||
.cloned()
|
||||
.map(|k| k.public_key.into());
|
||||
let user_signing_key =
|
||||
self.user_signing_key.lock().await.as_ref().cloned().map(|k| k.public_key.into());
|
||||
|
||||
let self_signing_key = self
|
||||
.self_signing_key
|
||||
.lock()
|
||||
.await
|
||||
.as_ref()
|
||||
.cloned()
|
||||
.map(|k| k.public_key.into());
|
||||
let self_signing_key =
|
||||
self.self_signing_key.lock().await.as_ref().cloned().map(|k| k.public_key.into());
|
||||
|
||||
UploadSigningKeysRequest {
|
||||
master_key,
|
||||
self_signing_key,
|
||||
user_signing_key,
|
||||
}
|
||||
UploadSigningKeysRequest { master_key, self_signing_key, user_signing_key }
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use crate::{
|
||||
identities::{ReadOnlyDevice, UserIdentity},
|
||||
olm::ReadOnlyAccount,
|
||||
};
|
||||
use std::{collections::BTreeMap, sync::Arc};
|
||||
|
||||
use super::{PrivateCrossSigningIdentity, Signing};
|
||||
|
||||
use matrix_sdk_common::{
|
||||
api::r0::keys::CrossSigningKey,
|
||||
identifiers::{user_id, UserId},
|
||||
};
|
||||
use matrix_sdk_test::async_test;
|
||||
|
||||
use super::{PrivateCrossSigningIdentity, Signing};
|
||||
use crate::{
|
||||
identities::{ReadOnlyDevice, UserIdentity},
|
||||
olm::ReadOnlyAccount,
|
||||
};
|
||||
|
||||
fn user_id() -> UserId {
|
||||
user_id!("@example:localhost")
|
||||
}
|
||||
|
@ -481,28 +442,12 @@ mod test {
|
|||
|
||||
assert!(master_key
|
||||
.public_key
|
||||
.verify_subkey(
|
||||
&identity
|
||||
.self_signing_key
|
||||
.lock()
|
||||
.await
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.public_key,
|
||||
)
|
||||
.verify_subkey(&identity.self_signing_key.lock().await.as_ref().unwrap().public_key,)
|
||||
.is_ok());
|
||||
|
||||
assert!(master_key
|
||||
.public_key
|
||||
.verify_subkey(
|
||||
&identity
|
||||
.user_signing_key
|
||||
.lock()
|
||||
.await
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.public_key,
|
||||
)
|
||||
.verify_subkey(&identity.user_signing_key.lock().await.as_ref().unwrap().public_key,)
|
||||
.is_ok());
|
||||
}
|
||||
|
||||
|
@ -512,15 +457,11 @@ mod test {
|
|||
|
||||
let pickled = identity.pickle(pickle_key()).await.unwrap();
|
||||
|
||||
let unpickled = PrivateCrossSigningIdentity::from_pickle(pickled, pickle_key())
|
||||
.await
|
||||
.unwrap();
|
||||
let unpickled =
|
||||
PrivateCrossSigningIdentity::from_pickle(pickled, pickle_key()).await.unwrap();
|
||||
|
||||
assert_eq!(identity.user_id, unpickled.user_id);
|
||||
assert_eq!(
|
||||
&*identity.master_key.lock().await,
|
||||
&*unpickled.master_key.lock().await
|
||||
);
|
||||
assert_eq!(&*identity.master_key.lock().await, &*unpickled.master_key.lock().await);
|
||||
assert_eq!(
|
||||
&*identity.user_signing_key.lock().await,
|
||||
&*unpickled.user_signing_key.lock().await
|
||||
|
@ -591,9 +532,6 @@ mod test {
|
|||
|
||||
bob_public.master_key = master.into();
|
||||
|
||||
user_signing
|
||||
.public_key
|
||||
.verify_master_key(bob_public.master_key())
|
||||
.unwrap();
|
||||
user_signing.public_key.verify_master_key(bob_public.master_key()).unwrap();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -12,32 +12,27 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::{collections::BTreeMap, convert::TryInto, sync::Arc};
|
||||
|
||||
use aes_gcm::{
|
||||
aead::{generic_array::GenericArray, Aead, NewAead},
|
||||
Aes256Gcm,
|
||||
};
|
||||
use getrandom::getrandom;
|
||||
use matrix_sdk_common::{
|
||||
encryption::DeviceKeys,
|
||||
identifiers::{DeviceKeyAlgorithm, DeviceKeyId},
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Error as JsonError, Value};
|
||||
use std::{collections::BTreeMap, convert::TryInto, sync::Arc};
|
||||
use thiserror::Error;
|
||||
use zeroize::Zeroizing;
|
||||
|
||||
use olm_rs::pk::OlmPkSigning;
|
||||
|
||||
#[cfg(test)]
|
||||
use olm_rs::{errors::OlmUtilityError, utility::OlmUtility};
|
||||
|
||||
use matrix_sdk_common::{
|
||||
api::r0::keys::{CrossSigningKey, KeyUsage},
|
||||
identifiers::UserId,
|
||||
encryption::DeviceKeys,
|
||||
identifiers::{DeviceKeyAlgorithm, DeviceKeyId, UserId},
|
||||
locks::Mutex,
|
||||
CanonicalJsonValue,
|
||||
};
|
||||
use olm_rs::pk::OlmPkSigning;
|
||||
#[cfg(test)]
|
||||
use olm_rs::{errors::OlmUtilityError, utility::OlmUtility};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Error as JsonError, Value};
|
||||
use thiserror::Error;
|
||||
use zeroize::Zeroizing;
|
||||
|
||||
use crate::{
|
||||
error::SignatureError,
|
||||
|
@ -73,9 +68,7 @@ pub struct Signing {
|
|||
|
||||
impl std::fmt::Debug for Signing {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("Signing")
|
||||
.field("public_key", &self.public_key.as_str())
|
||||
.finish()
|
||||
f.debug_struct("Signing").field("public_key", &self.public_key.as_str()).finish()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -156,10 +149,7 @@ impl MasterSigning {
|
|||
) -> Result<Self, SigningError> {
|
||||
let inner = Signing::from_pickle(pickle.pickle, pickle_key)?;
|
||||
|
||||
Ok(Self {
|
||||
inner,
|
||||
public_key: pickle.public_key.into(),
|
||||
})
|
||||
Ok(Self { inner, public_key: pickle.public_key.into() })
|
||||
}
|
||||
|
||||
pub async fn sign_subkey<'a>(&self, subkey: &mut CrossSigningKey) {
|
||||
|
@ -200,10 +190,7 @@ impl UserSigning {
|
|||
user: &UserIdentity,
|
||||
) -> Result<BTreeMap<UserId, BTreeMap<String, Value>>, SignatureError> {
|
||||
let user_master: &CrossSigningKey = user.master_key().as_ref();
|
||||
let signature = self
|
||||
.inner
|
||||
.sign_json(serde_json::to_value(user_master)?)
|
||||
.await?;
|
||||
let signature = self.inner.sign_json(serde_json::to_value(user_master)?).await?;
|
||||
|
||||
let mut signatures = BTreeMap::new();
|
||||
|
||||
|
@ -228,10 +215,7 @@ impl UserSigning {
|
|||
) -> Result<Self, SigningError> {
|
||||
let inner = Signing::from_pickle(pickle.pickle, pickle_key)?;
|
||||
|
||||
Ok(Self {
|
||||
inner,
|
||||
public_key: pickle.public_key.into(),
|
||||
})
|
||||
Ok(Self { inner, public_key: pickle.public_key.into() })
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -279,10 +263,7 @@ impl SelfSigning {
|
|||
) -> Result<Self, SigningError> {
|
||||
let inner = Signing::from_pickle(pickle.pickle, pickle_key)?;
|
||||
|
||||
Ok(Self {
|
||||
inner,
|
||||
public_key: pickle.public_key.into(),
|
||||
})
|
||||
Ok(Self { inner, public_key: pickle.public_key.into() })
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -353,17 +334,12 @@ impl Signing {
|
|||
getrandom(&mut nonce).expect("Can't generate nonce to pickle the signing object");
|
||||
let nonce = GenericArray::from_slice(nonce.as_slice());
|
||||
|
||||
let ciphertext = cipher
|
||||
.encrypt(nonce, self.seed.as_slice())
|
||||
.expect("Can't encrypt signing pickle");
|
||||
let ciphertext =
|
||||
cipher.encrypt(nonce, self.seed.as_slice()).expect("Can't encrypt signing pickle");
|
||||
|
||||
let ciphertext = encode(ciphertext);
|
||||
|
||||
let pickle = InnerPickle {
|
||||
version: 1,
|
||||
nonce: encode(nonce.as_slice()),
|
||||
ciphertext,
|
||||
};
|
||||
let pickle = InnerPickle { version: 1, nonce: encode(nonce.as_slice()), ciphertext };
|
||||
|
||||
PickledSigning(serde_json::to_string(&pickle).expect("Can't encode pickled signing"))
|
||||
}
|
||||
|
@ -376,11 +352,8 @@ impl Signing {
|
|||
let mut keys = BTreeMap::new();
|
||||
|
||||
keys.insert(
|
||||
DeviceKeyId::from_parts(
|
||||
DeviceKeyAlgorithm::Ed25519,
|
||||
self.public_key().as_str().into(),
|
||||
)
|
||||
.to_string(),
|
||||
DeviceKeyId::from_parts(DeviceKeyAlgorithm::Ed25519, self.public_key().as_str().into())
|
||||
.to_string(),
|
||||
self.public_key().to_string(),
|
||||
);
|
||||
|
||||
|
|
|
@ -12,14 +12,14 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use olm_rs::utility::OlmUtility;
|
||||
use serde_json::Value;
|
||||
use std::convert::TryInto;
|
||||
|
||||
use matrix_sdk_common::{
|
||||
identifiers::{DeviceKeyAlgorithm, DeviceKeyId, UserId},
|
||||
CanonicalJsonValue,
|
||||
};
|
||||
use olm_rs::utility::OlmUtility;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::error::SignatureError;
|
||||
|
||||
|
@ -29,9 +29,7 @@ pub(crate) struct Utility {
|
|||
|
||||
impl Utility {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
inner: OlmUtility::new(),
|
||||
}
|
||||
Self { inner: OlmUtility::new() }
|
||||
}
|
||||
|
||||
/// Verify a signed JSON object.
|
||||
|
@ -49,7 +47,7 @@ impl Utility {
|
|||
/// * `key_id` - The id of the key that signed the JSON object.
|
||||
///
|
||||
/// * `signing_key` - The public ed25519 key which was used to sign the JSON
|
||||
/// object.
|
||||
/// object.
|
||||
///
|
||||
/// * `json` - The JSON object that should be verified.
|
||||
pub(crate) fn verify_json(
|
||||
|
@ -67,29 +65,20 @@ impl Utility {
|
|||
let unsigned = json_object.remove("unsigned");
|
||||
let signatures = json_object.remove("signatures");
|
||||
|
||||
let canonical_json: CanonicalJsonValue = json
|
||||
.clone()
|
||||
.try_into()
|
||||
.map_err(|_| SignatureError::NotAnObject)?;
|
||||
let canonical_json: CanonicalJsonValue =
|
||||
json.clone().try_into().map_err(|_| SignatureError::NotAnObject)?;
|
||||
|
||||
let canonical_json: String = canonical_json.to_string();
|
||||
|
||||
let signatures = signatures.ok_or(SignatureError::NoSignatureFound)?;
|
||||
let signature_object = signatures
|
||||
.as_object()
|
||||
.ok_or(SignatureError::NoSignatureFound)?;
|
||||
let signature = signature_object
|
||||
.get(user_id.as_str())
|
||||
.ok_or(SignatureError::NoSignatureFound)?;
|
||||
let signature = signature
|
||||
.get(key_id.to_string())
|
||||
.ok_or(SignatureError::NoSignatureFound)?;
|
||||
let signature_object = signatures.as_object().ok_or(SignatureError::NoSignatureFound)?;
|
||||
let signature =
|
||||
signature_object.get(user_id.as_str()).ok_or(SignatureError::NoSignatureFound)?;
|
||||
let signature =
|
||||
signature.get(key_id.to_string()).ok_or(SignatureError::NoSignatureFound)?;
|
||||
let signature = signature.as_str().ok_or(SignatureError::NoSignatureFound)?;
|
||||
|
||||
let ret = match self
|
||||
.inner
|
||||
.ed25519_verify(signing_key, &canonical_json, signature)
|
||||
{
|
||||
let ret = match self.inner.ed25519_verify(signing_key, &canonical_json, signature) {
|
||||
Ok(_) => Ok(()),
|
||||
Err(_) => Err(SignatureError::VerificationError),
|
||||
};
|
||||
|
@ -108,10 +97,11 @@ impl Utility {
|
|||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::Utility;
|
||||
use matrix_sdk_common::identifiers::{user_id, DeviceKeyAlgorithm, DeviceKeyId};
|
||||
use serde_json::json;
|
||||
|
||||
use super::Utility;
|
||||
|
||||
#[test]
|
||||
fn signature_test() {
|
||||
let mut device_keys = json!({
|
||||
|
|
|
@ -35,18 +35,19 @@ use matrix_sdk_common::{
|
|||
identifiers::{DeviceIdBox, RoomId, UserId},
|
||||
uuid::Uuid,
|
||||
};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::value::RawValue as RawJsonValue;
|
||||
|
||||
/// Customized version of `ruma_client_api::r0::to_device::send_event_to_device::Request`, using a
|
||||
/// Customized version of
|
||||
/// `ruma_client_api::r0::to_device::send_event_to_device::Request`, using a
|
||||
/// UUID for the transaction ID.
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
pub struct ToDeviceRequest {
|
||||
/// Type of event being sent to each device.
|
||||
pub event_type: EventType,
|
||||
|
||||
/// A request identifier unique to the access token used to send the request.
|
||||
/// A request identifier unique to the access token used to send the
|
||||
/// request.
|
||||
pub txn_id: Uuid,
|
||||
|
||||
/// A map of users to devices to a content for a message event to be
|
||||
|
@ -80,15 +81,18 @@ impl ToDeviceRequest {
|
|||
pub struct UploadSigningKeysRequest {
|
||||
/// The user's master key.
|
||||
pub master_key: Option<CrossSigningKey>,
|
||||
/// The user's self-signing key. Must be signed with the accompanied master, or by the
|
||||
/// user's most recently uploaded master key if no master key is included in the request.
|
||||
/// The user's self-signing key. Must be signed with the accompanied master,
|
||||
/// or by the user's most recently uploaded master key if no master key
|
||||
/// is included in the request.
|
||||
pub self_signing_key: Option<CrossSigningKey>,
|
||||
/// The user's user-signing key. Must be signed with the accompanied master, or by the
|
||||
/// user's most recently uploaded master key if no master key is included in the request.
|
||||
/// The user's user-signing key. Must be signed with the accompanied master,
|
||||
/// or by the user's most recently uploaded master key if no master key
|
||||
/// is included in the request.
|
||||
pub user_signing_key: Option<CrossSigningKey>,
|
||||
}
|
||||
|
||||
/// Customized version of `ruma_client_api::r0::keys::get_keys::Request`, without any references.
|
||||
/// Customized version of `ruma_client_api::r0::keys::get_keys::Request`,
|
||||
/// without any references.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct KeysQueryRequest {
|
||||
/// The time (in milliseconds) to wait when downloading keys from remote
|
||||
|
@ -109,11 +113,7 @@ pub struct KeysQueryRequest {
|
|||
|
||||
impl KeysQueryRequest {
|
||||
pub(crate) fn new(device_keys: BTreeMap<UserId, Vec<DeviceIdBox>>) -> Self {
|
||||
Self {
|
||||
timeout: None,
|
||||
device_keys,
|
||||
token: None,
|
||||
}
|
||||
Self { timeout: None, device_keys, token: None }
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -177,19 +177,13 @@ impl From<SignatureUploadRequest> for OutgoingRequests {
|
|||
|
||||
impl From<OutgoingVerificationRequest> for OutgoingRequest {
|
||||
fn from(r: OutgoingVerificationRequest) -> Self {
|
||||
Self {
|
||||
request_id: r.request_id(),
|
||||
request: Arc::new(r.into()),
|
||||
}
|
||||
Self { request_id: r.request_id(), request: Arc::new(r.into()) }
|
||||
}
|
||||
}
|
||||
|
||||
impl From<SignatureUploadRequest> for OutgoingRequest {
|
||||
fn from(r: SignatureUploadRequest) -> Self {
|
||||
Self {
|
||||
request_id: Uuid::new_v4(),
|
||||
request: Arc::new(r.into()),
|
||||
}
|
||||
Self { request_id: Uuid::new_v4(), request: Arc::new(r.into()) }
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -17,9 +17,8 @@ use std::{
|
|||
sync::Arc,
|
||||
};
|
||||
|
||||
use futures::future::join_all;
|
||||
|
||||
use dashmap::DashMap;
|
||||
use futures::future::join_all;
|
||||
use matrix_sdk_common::{
|
||||
api::r0::to_device::DeviceIdOrAllDevices,
|
||||
events::{
|
||||
|
@ -105,10 +104,7 @@ impl GroupSessionCache {
|
|||
room_id: &RoomId,
|
||||
session_id: &str,
|
||||
) -> StoreResult<Option<OutboundGroupSession>> {
|
||||
Ok(self
|
||||
.get_or_load(room_id)
|
||||
.await?
|
||||
.filter(|o| session_id == o.session_id()))
|
||||
Ok(self.get_or_load(room_id).await?.filter(|o| session_id == o.session_id()))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -127,11 +123,7 @@ impl GroupSessionManager {
|
|||
const MAX_TO_DEVICE_MESSAGES: usize = 250;
|
||||
|
||||
pub(crate) fn new(account: Account, store: Store) -> Self {
|
||||
Self {
|
||||
account,
|
||||
store: store.clone(),
|
||||
sessions: GroupSessionCache::new(store),
|
||||
}
|
||||
Self { account, store: store.clone(), sessions: GroupSessionCache::new(store) }
|
||||
}
|
||||
|
||||
pub async fn invalidate_group_session(&self, room_id: &RoomId) -> StoreResult<bool> {
|
||||
|
@ -231,9 +223,7 @@ impl GroupSessionManager {
|
|||
Ok((s, None))
|
||||
}
|
||||
} else {
|
||||
self.create_outbound_group_session(room_id, settings)
|
||||
.await
|
||||
.map(|(o, i)| (o, i.into()))
|
||||
self.create_outbound_group_session(room_id, settings).await.map(|(o, i)| (o, i.into()))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -253,13 +243,10 @@ impl GroupSessionManager {
|
|||
|
||||
let used_session = match encrypted {
|
||||
Ok((session, encrypted)) => {
|
||||
message
|
||||
.entry(device.user_id().clone())
|
||||
.or_insert_with(BTreeMap::new)
|
||||
.insert(
|
||||
DeviceIdOrAllDevices::DeviceId(device.device_id().into()),
|
||||
serde_json::value::to_raw_value(&encrypted)?,
|
||||
);
|
||||
message.entry(device.user_id().clone()).or_insert_with(BTreeMap::new).insert(
|
||||
DeviceIdOrAllDevices::DeviceId(device.device_id().into()),
|
||||
serde_json::value::to_raw_value(&encrypted)?,
|
||||
);
|
||||
Some(session)
|
||||
}
|
||||
// TODO we'll want to create m.room_key.withheld here.
|
||||
|
@ -271,10 +258,8 @@ impl GroupSessionManager {
|
|||
Ok((used_session, message))
|
||||
};
|
||||
|
||||
let tasks: Vec<_> = devices
|
||||
.iter()
|
||||
.map(|d| spawn(encrypt(d.clone(), content.clone())))
|
||||
.collect();
|
||||
let tasks: Vec<_> =
|
||||
devices.iter().map(|d| spawn(encrypt(d.clone(), content.clone()))).collect();
|
||||
|
||||
let results = join_all(tasks).await;
|
||||
|
||||
|
@ -286,20 +271,14 @@ impl GroupSessionManager {
|
|||
}
|
||||
|
||||
for (user, device_messages) in message.into_iter() {
|
||||
messages
|
||||
.entry(user)
|
||||
.or_insert_with(BTreeMap::new)
|
||||
.extend(device_messages);
|
||||
messages.entry(user).or_insert_with(BTreeMap::new).extend(device_messages);
|
||||
}
|
||||
}
|
||||
|
||||
let id = Uuid::new_v4();
|
||||
|
||||
let request = ToDeviceRequest {
|
||||
event_type: EventType::RoomEncrypted,
|
||||
txn_id: id,
|
||||
messages,
|
||||
};
|
||||
let request =
|
||||
ToDeviceRequest { event_type: EventType::RoomEncrypted, txn_id: id, messages };
|
||||
|
||||
trace!(
|
||||
recipient_count = request.message_count(),
|
||||
|
@ -331,20 +310,14 @@ impl GroupSessionManager {
|
|||
"Calculating group session recipients"
|
||||
);
|
||||
|
||||
let users_shared_with: HashSet<UserId> = outbound
|
||||
.shared_with_set
|
||||
.iter()
|
||||
.map(|k| k.key().clone())
|
||||
.collect();
|
||||
let users_shared_with: HashSet<UserId> =
|
||||
outbound.shared_with_set.iter().map(|k| k.key().clone()).collect();
|
||||
|
||||
let users_shared_with: HashSet<&UserId> = users_shared_with.iter().collect();
|
||||
|
||||
// A user left if a user is missing from the set of users that should
|
||||
// get the session but is in the set of users that received the session.
|
||||
let user_left = !users_shared_with
|
||||
.difference(&users)
|
||||
.collect::<HashSet<_>>()
|
||||
.is_empty();
|
||||
let user_left = !users_shared_with.difference(&users).collect::<HashSet<_>>().is_empty();
|
||||
|
||||
let visibility_changed = outbound.settings().history_visibility != history_visibility;
|
||||
|
||||
|
@ -359,10 +332,8 @@ impl GroupSessionManager {
|
|||
|
||||
for user_id in users {
|
||||
let user_devices = self.store.get_user_devices(&user_id).await?;
|
||||
let non_blacklisted_devices: Vec<Device> = user_devices
|
||||
.devices()
|
||||
.filter(|d| !d.is_blacklisted())
|
||||
.collect();
|
||||
let non_blacklisted_devices: Vec<Device> =
|
||||
user_devices.devices().filter(|d| !d.is_blacklisted()).collect();
|
||||
|
||||
// If we haven't already concluded that the session should be
|
||||
// rotated for other reasons, we also need to check whether any
|
||||
|
@ -370,10 +341,8 @@ impl GroupSessionManager {
|
|||
// meantime. If so, we should also rotate the session.
|
||||
if !should_rotate {
|
||||
// Device IDs that should receive this session
|
||||
let non_blacklisted_device_ids: HashSet<&DeviceId> = non_blacklisted_devices
|
||||
.iter()
|
||||
.map(|d| d.device_id())
|
||||
.collect();
|
||||
let non_blacklisted_device_ids: HashSet<&DeviceId> =
|
||||
non_blacklisted_devices.iter().map(|d| d.device_id()).collect();
|
||||
|
||||
if let Some(shared) = outbound.shared_with_set.get(user_id) {
|
||||
#[allow(clippy::map_clone)]
|
||||
|
@ -389,9 +358,8 @@ impl GroupSessionManager {
|
|||
//
|
||||
// represents newly deleted or blacklisted devices. If this
|
||||
// set is non-empty, we must rotate.
|
||||
let newly_deleted_or_blacklisted = shared
|
||||
.difference(&non_blacklisted_device_ids)
|
||||
.collect::<HashSet<_>>();
|
||||
let newly_deleted_or_blacklisted =
|
||||
shared.difference(&non_blacklisted_device_ids).collect::<HashSet<_>>();
|
||||
|
||||
if !newly_deleted_or_blacklisted.is_empty() {
|
||||
should_rotate = true;
|
||||
|
@ -399,10 +367,7 @@ impl GroupSessionManager {
|
|||
};
|
||||
}
|
||||
|
||||
devices
|
||||
.entry(user_id.clone())
|
||||
.or_insert_with(Vec::new)
|
||||
.extend(non_blacklisted_devices);
|
||||
devices.entry(user_id.clone()).or_insert_with(Vec::new).extend(non_blacklisted_devices);
|
||||
}
|
||||
|
||||
debug!(
|
||||
|
@ -462,25 +427,22 @@ impl GroupSessionManager {
|
|||
let history_visibility = encryption_settings.history_visibility.clone();
|
||||
let mut changes = Changes::default();
|
||||
|
||||
let (outbound, inbound) = self
|
||||
.get_or_create_outbound_session(room_id, encryption_settings.clone())
|
||||
.await?;
|
||||
let (outbound, inbound) =
|
||||
self.get_or_create_outbound_session(room_id, encryption_settings.clone()).await?;
|
||||
|
||||
if let Some(inbound) = inbound {
|
||||
changes.outbound_group_sessions.push(outbound.clone());
|
||||
changes.inbound_group_sessions.push(inbound);
|
||||
}
|
||||
|
||||
let (should_rotate, devices) = self
|
||||
.collect_session_recipients(users, history_visibility, &outbound)
|
||||
.await?;
|
||||
let (should_rotate, devices) =
|
||||
self.collect_session_recipients(users, history_visibility, &outbound).await?;
|
||||
|
||||
let outbound = if should_rotate {
|
||||
let old_session_id = outbound.session_id();
|
||||
|
||||
let (outbound, inbound) = self
|
||||
.create_outbound_group_session(room_id, encryption_settings)
|
||||
.await?;
|
||||
let (outbound, inbound) =
|
||||
self.create_outbound_group_session(room_id, encryption_settings).await?;
|
||||
changes.outbound_group_sessions.push(outbound.clone());
|
||||
changes.inbound_group_sessions.push(inbound);
|
||||
|
||||
|
@ -515,9 +477,7 @@ impl GroupSessionManager {
|
|||
|
||||
if !devices.is_empty() {
|
||||
let users = devices.iter().fold(BTreeMap::new(), |mut acc, d| {
|
||||
acc.entry(d.user_id())
|
||||
.or_insert_with(BTreeSet::new)
|
||||
.insert(d.device_id());
|
||||
acc.entry(d.user_id()).or_insert_with(BTreeSet::new).insert(d.device_id());
|
||||
acc
|
||||
});
|
||||
|
||||
|
@ -626,14 +586,8 @@ mod test {
|
|||
|
||||
let machine = OlmMachine::new(&alice_id(), &alice_device_id());
|
||||
|
||||
machine
|
||||
.mark_request_as_sent(&uuid, &keys_query)
|
||||
.await
|
||||
.unwrap();
|
||||
machine
|
||||
.mark_request_as_sent(&uuid, &keys_claim)
|
||||
.await
|
||||
.unwrap();
|
||||
machine.mark_request_as_sent(&uuid, &keys_query).await.unwrap();
|
||||
machine.mark_request_as_sent(&uuid, &keys_claim).await.unwrap();
|
||||
|
||||
machine
|
||||
}
|
||||
|
@ -647,11 +601,7 @@ mod test {
|
|||
let users: Vec<_> = keys_claim.one_time_keys.keys().collect();
|
||||
|
||||
let requests = machine
|
||||
.share_group_session(
|
||||
&room_id,
|
||||
users.clone().into_iter(),
|
||||
EncryptionSettings::default(),
|
||||
)
|
||||
.share_group_session(&room_id, users.clone().into_iter(), EncryptionSettings::default())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
|
|
@ -77,11 +77,7 @@ impl SessionManager {
|
|||
}
|
||||
|
||||
pub async fn mark_device_as_wedged(&self, sender: &UserId, curve_key: &str) -> StoreResult<()> {
|
||||
if let Some(device) = self
|
||||
.store
|
||||
.get_device_from_curve_key(sender, curve_key)
|
||||
.await?
|
||||
{
|
||||
if let Some(device) = self.store.get_device_from_curve_key(sender, curve_key).await? {
|
||||
let sessions = device.get_sessions().await?;
|
||||
|
||||
if let Some(sessions) = sessions {
|
||||
|
@ -120,25 +116,16 @@ impl SessionManager {
|
|||
///
|
||||
/// If the device was wedged this will queue up a dummy to-device message.
|
||||
async fn check_if_unwedged(&self, user_id: &UserId, device_id: &DeviceId) -> OlmResult<()> {
|
||||
if self
|
||||
.wedged_devices
|
||||
.get(user_id)
|
||||
.map(|d| d.remove(device_id))
|
||||
.flatten()
|
||||
.is_some()
|
||||
{
|
||||
if self.wedged_devices.get(user_id).map(|d| d.remove(device_id)).flatten().is_some() {
|
||||
if let Some(device) = self.store.get_device(user_id, device_id).await? {
|
||||
let (_, content) = device.encrypt(EventType::Dummy, json!({})).await?;
|
||||
let id = Uuid::new_v4();
|
||||
let mut messages = BTreeMap::new();
|
||||
|
||||
messages
|
||||
.entry(device.user_id().to_owned())
|
||||
.or_insert_with(BTreeMap::new)
|
||||
.insert(
|
||||
DeviceIdOrAllDevices::DeviceId(device.device_id().into()),
|
||||
to_raw_value(&content)?,
|
||||
);
|
||||
messages.entry(device.user_id().to_owned()).or_insert_with(BTreeMap::new).insert(
|
||||
DeviceIdOrAllDevices::DeviceId(device.device_id().into()),
|
||||
to_raw_value(&content)?,
|
||||
);
|
||||
|
||||
let request = OutgoingRequest {
|
||||
request_id: id,
|
||||
|
@ -307,13 +294,13 @@ impl SessionManager {
|
|||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use dashmap::DashMap;
|
||||
use matrix_sdk_common::locks::Mutex;
|
||||
use std::{collections::BTreeMap, sync::Arc};
|
||||
|
||||
use dashmap::DashMap;
|
||||
use matrix_sdk_common::{
|
||||
api::r0::keys::claim_keys::Response as KeyClaimResponse,
|
||||
identifiers::{user_id, DeviceIdBox, UserId},
|
||||
locks::Mutex,
|
||||
};
|
||||
use matrix_sdk_test::async_test;
|
||||
|
||||
|
@ -347,9 +334,7 @@ mod test {
|
|||
let account = ReadOnlyAccount::new(&user_id, &device_id);
|
||||
let store: Arc<Box<dyn CryptoStore>> = Arc::new(Box::new(MemoryStore::new()));
|
||||
store.save_account(account.clone()).await.unwrap();
|
||||
let identity = Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty(
|
||||
user_id.clone(),
|
||||
)));
|
||||
let identity = Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty(user_id.clone())));
|
||||
let verification =
|
||||
VerificationMachine::new(account.clone(), identity.clone(), store.clone());
|
||||
|
||||
|
@ -358,10 +343,7 @@ mod test {
|
|||
|
||||
let store = Store::new(user_id.clone(), identity, store, verification);
|
||||
|
||||
let account = Account {
|
||||
inner: account,
|
||||
store: store.clone(),
|
||||
};
|
||||
let account = Account { inner: account, store: store.clone() };
|
||||
|
||||
let session_cache = GroupSessionCache::new(store.clone());
|
||||
|
||||
|
@ -405,10 +387,7 @@ mod test {
|
|||
|
||||
let response = KeyClaimResponse::new(one_time_keys);
|
||||
|
||||
manager
|
||||
.receive_keys_claim_response(&response)
|
||||
.await
|
||||
.unwrap();
|
||||
manager.receive_keys_claim_response(&response).await.unwrap();
|
||||
|
||||
assert!(manager
|
||||
.get_missing_sessions(&mut [bob.user_id().clone()].iter())
|
||||
|
@ -434,11 +413,7 @@ mod test {
|
|||
let bob_device = ReadOnlyDevice::from_account(&bob).await;
|
||||
session.creation_time = Arc::new(Instant::now() - Duration::from_secs(3601));
|
||||
|
||||
manager
|
||||
.store
|
||||
.save_devices(&[bob_device.clone()])
|
||||
.await
|
||||
.unwrap();
|
||||
manager.store.save_devices(&[bob_device.clone()]).await.unwrap();
|
||||
manager.store.save_sessions(&[session]).await.unwrap();
|
||||
|
||||
assert!(manager
|
||||
|
@ -451,10 +426,7 @@ mod test {
|
|||
|
||||
assert!(!manager.users_for_key_claim.contains_key(bob.user_id()));
|
||||
assert!(!manager.is_device_wedged(&bob_device));
|
||||
manager
|
||||
.mark_device_as_wedged(bob_device.user_id(), &curve_key)
|
||||
.await
|
||||
.unwrap();
|
||||
manager.mark_device_as_wedged(bob_device.user_id(), &curve_key).await.unwrap();
|
||||
assert!(manager.is_device_wedged(&bob_device));
|
||||
assert!(manager.users_for_key_claim.contains_key(bob.user_id()));
|
||||
|
||||
|
@ -480,10 +452,7 @@ mod test {
|
|||
|
||||
assert!(manager.outgoing_to_device_requests.is_empty());
|
||||
|
||||
manager
|
||||
.receive_keys_claim_response(&response)
|
||||
.await
|
||||
.unwrap();
|
||||
manager.receive_keys_claim_response(&response).await.unwrap();
|
||||
|
||||
assert!(!manager.is_device_wedged(&bob_device));
|
||||
assert!(manager
|
||||
|
|
|
@ -39,9 +39,7 @@ pub struct SessionStore {
|
|||
impl SessionStore {
|
||||
/// Create a new empty Session store.
|
||||
pub fn new() -> Self {
|
||||
SessionStore {
|
||||
entries: Arc::new(DashMap::new()),
|
||||
}
|
||||
SessionStore { entries: Arc::new(DashMap::new()) }
|
||||
}
|
||||
|
||||
/// Add a session to the store.
|
||||
|
@ -72,8 +70,7 @@ impl SessionStore {
|
|||
|
||||
/// Add a list of sessions belonging to the sender key.
|
||||
pub fn set_for_sender(&self, sender_key: &str, sessions: Vec<Session>) {
|
||||
self.entries
|
||||
.insert(sender_key.to_owned(), Arc::new(Mutex::new(sessions)));
|
||||
self.entries.insert(sender_key.to_owned(), Arc::new(Mutex::new(sessions)));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -87,9 +84,7 @@ pub struct GroupSessionStore {
|
|||
impl GroupSessionStore {
|
||||
/// Create a new empty store.
|
||||
pub fn new() -> Self {
|
||||
GroupSessionStore {
|
||||
entries: Arc::new(DashMap::new()),
|
||||
}
|
||||
GroupSessionStore { entries: Arc::new(DashMap::new()) }
|
||||
}
|
||||
|
||||
/// Add an inbound group session to the store.
|
||||
|
@ -148,9 +143,7 @@ pub struct DeviceStore {
|
|||
impl DeviceStore {
|
||||
/// Create a new empty device store.
|
||||
pub fn new() -> Self {
|
||||
DeviceStore {
|
||||
entries: Arc::new(DashMap::new()),
|
||||
}
|
||||
DeviceStore { entries: Arc::new(DashMap::new()) }
|
||||
}
|
||||
|
||||
/// Add a device to the store.
|
||||
|
@ -167,19 +160,15 @@ impl DeviceStore {
|
|||
|
||||
/// Get the device with the given device_id and belonging to the given user.
|
||||
pub fn get(&self, user_id: &UserId, device_id: &DeviceId) -> Option<ReadOnlyDevice> {
|
||||
self.entries
|
||||
.get(user_id)
|
||||
.and_then(|m| m.get(device_id).map(|d| d.value().clone()))
|
||||
self.entries.get(user_id).and_then(|m| m.get(device_id).map(|d| d.value().clone()))
|
||||
}
|
||||
|
||||
/// Remove the device with the given device_id and belonging to the given user.
|
||||
/// Remove the device with the given device_id and belonging to the given
|
||||
/// user.
|
||||
///
|
||||
/// Returns the device if it was removed, None if it wasn't in the store.
|
||||
pub fn remove(&self, user_id: &UserId, device_id: &DeviceId) -> Option<ReadOnlyDevice> {
|
||||
self.entries
|
||||
.get(user_id)
|
||||
.and_then(|m| m.remove(device_id))
|
||||
.map(|(_, d)| d)
|
||||
self.entries.get(user_id).and_then(|m| m.remove(device_id)).map(|(_, d)| d)
|
||||
}
|
||||
|
||||
/// Get a read-only view over all devices of the given user.
|
||||
|
@ -195,12 +184,13 @@ impl DeviceStore {
|
|||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use matrix_sdk_common::identifiers::room_id;
|
||||
|
||||
use crate::{
|
||||
identities::device::test::get_device,
|
||||
olm::{test::get_account_and_session, InboundGroupSession},
|
||||
store::caches::{DeviceStore, GroupSessionStore, SessionStore},
|
||||
};
|
||||
use matrix_sdk_common::identifiers::room_id;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_session_store() {
|
||||
|
@ -239,10 +229,8 @@ mod test {
|
|||
let (account, _) = get_account_and_session().await;
|
||||
let room_id = room_id!("!test:localhost");
|
||||
|
||||
let (outbound, _) = account
|
||||
.create_group_session_pair_with_defaults(&room_id)
|
||||
.await
|
||||
.unwrap();
|
||||
let (outbound, _) =
|
||||
account.create_group_session_pair_with_defaults(&room_id).await.unwrap();
|
||||
|
||||
assert_eq!(0, outbound.message_index().await);
|
||||
assert!(!outbound.shared());
|
||||
|
@ -261,9 +249,7 @@ mod test {
|
|||
let store = GroupSessionStore::new();
|
||||
store.add(inbound.clone());
|
||||
|
||||
let loaded_session = store
|
||||
.get(&room_id, "test_key", outbound.session_id())
|
||||
.unwrap();
|
||||
let loaded_session = store.get(&room_id, "test_key", outbound.session_id()).unwrap();
|
||||
assert_eq!(inbound, loaded_session);
|
||||
}
|
||||
|
||||
|
|
|
@ -37,10 +37,7 @@ use crate::{
|
|||
};
|
||||
|
||||
fn encode_key_info(info: &RequestedKeyInfo) -> String {
|
||||
format!(
|
||||
"{}{}{}{}",
|
||||
info.room_id, info.sender_key, info.algorithm, info.session_id
|
||||
)
|
||||
format!("{}{}{}{}", info.room_id, info.sender_key, info.algorithm, info.session_id)
|
||||
}
|
||||
|
||||
/// An in-memory only store that will forget all the E2EE key once it's dropped.
|
||||
|
@ -121,22 +118,14 @@ impl CryptoStore for MemoryStore {
|
|||
|
||||
async fn save_changes(&self, mut changes: Changes) -> Result<()> {
|
||||
self.save_sessions(changes.sessions).await;
|
||||
self.save_inbound_group_sessions(changes.inbound_group_sessions)
|
||||
.await;
|
||||
self.save_inbound_group_sessions(changes.inbound_group_sessions).await;
|
||||
|
||||
self.save_devices(changes.devices.new).await;
|
||||
self.save_devices(changes.devices.changed).await;
|
||||
self.delete_devices(changes.devices.deleted).await;
|
||||
|
||||
for identity in changes
|
||||
.identities
|
||||
.new
|
||||
.drain(..)
|
||||
.chain(changes.identities.changed)
|
||||
{
|
||||
let _ = self
|
||||
.identities
|
||||
.insert(identity.user_id().to_owned(), identity.clone());
|
||||
for identity in changes.identities.new.drain(..).chain(changes.identities.changed) {
|
||||
let _ = self.identities.insert(identity.user_id().to_owned(), identity.clone());
|
||||
}
|
||||
|
||||
for hash in changes.message_hashes {
|
||||
|
@ -167,9 +156,7 @@ impl CryptoStore for MemoryStore {
|
|||
sender_key: &str,
|
||||
session_id: &str,
|
||||
) -> Result<Option<InboundGroupSession>> {
|
||||
Ok(self
|
||||
.inbound_group_sessions
|
||||
.get(room_id, sender_key, session_id))
|
||||
Ok(self.inbound_group_sessions.get(room_id, sender_key, session_id))
|
||||
}
|
||||
|
||||
async fn get_inbound_group_sessions(&self) -> Result<Vec<InboundGroupSession>> {
|
||||
|
@ -250,10 +237,7 @@ impl CryptoStore for MemoryStore {
|
|||
&self,
|
||||
request_id: Uuid,
|
||||
) -> Result<Option<OutgoingKeyRequest>> {
|
||||
Ok(self
|
||||
.outgoing_key_requests
|
||||
.get(&request_id)
|
||||
.map(|r| r.clone()))
|
||||
Ok(self.outgoing_key_requests.get(&request_id).map(|r| r.clone()))
|
||||
}
|
||||
|
||||
async fn get_key_request_by_info(
|
||||
|
@ -278,12 +262,10 @@ impl CryptoStore for MemoryStore {
|
|||
}
|
||||
|
||||
async fn delete_outgoing_key_request(&self, request_id: Uuid) -> Result<()> {
|
||||
self.outgoing_key_requests
|
||||
.remove(&request_id)
|
||||
.and_then(|(_, i)| {
|
||||
let key_info_string = encode_key_info(&i.info);
|
||||
self.key_requests_by_info.remove(&key_info_string)
|
||||
});
|
||||
self.outgoing_key_requests.remove(&request_id).and_then(|(_, i)| {
|
||||
let key_info_string = encode_key_info(&i.info);
|
||||
self.key_requests_by_info.remove(&key_info_string)
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -291,12 +273,13 @@ impl CryptoStore for MemoryStore {
|
|||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use matrix_sdk_common::identifiers::room_id;
|
||||
|
||||
use crate::{
|
||||
identities::device::test::get_device,
|
||||
olm::{test::get_account_and_session, InboundGroupSession, OlmMessageHash},
|
||||
store::{memorystore::MemoryStore, Changes, CryptoStore},
|
||||
};
|
||||
use matrix_sdk_common::identifiers::room_id;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_session_store() {
|
||||
|
@ -308,11 +291,7 @@ mod test {
|
|||
|
||||
store.save_sessions(vec![session.clone()]).await;
|
||||
|
||||
let sessions = store
|
||||
.get_sessions(&session.sender_key)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
let sessions = store.get_sessions(&session.sender_key).await.unwrap().unwrap();
|
||||
let sessions = sessions.lock().await;
|
||||
|
||||
let loaded_session = &sessions[0];
|
||||
|
@ -325,10 +304,8 @@ mod test {
|
|||
let (account, _) = get_account_and_session().await;
|
||||
let room_id = room_id!("!test:localhost");
|
||||
|
||||
let (outbound, _) = account
|
||||
.create_group_session_pair_with_defaults(&room_id)
|
||||
.await
|
||||
.unwrap();
|
||||
let (outbound, _) =
|
||||
account.create_group_session_pair_with_defaults(&room_id).await.unwrap();
|
||||
let inbound = InboundGroupSession::new(
|
||||
"test_key",
|
||||
"test_key",
|
||||
|
@ -339,9 +316,7 @@ mod test {
|
|||
.unwrap();
|
||||
|
||||
let store = MemoryStore::new();
|
||||
let _ = store
|
||||
.save_inbound_group_sessions(vec![inbound.clone()])
|
||||
.await;
|
||||
let _ = store.save_inbound_group_sessions(vec![inbound.clone()]).await;
|
||||
|
||||
let loaded_session = store
|
||||
.get_inbound_group_session(&room_id, "test_key", outbound.session_id())
|
||||
|
@ -358,11 +333,8 @@ mod test {
|
|||
|
||||
store.save_devices(vec![device.clone()]).await;
|
||||
|
||||
let loaded_device = store
|
||||
.get_device(device.user_id(), device.device_id())
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
let loaded_device =
|
||||
store.get_device(device.user_id(), device.device_id()).await.unwrap().unwrap();
|
||||
|
||||
assert_eq!(device, loaded_device);
|
||||
|
||||
|
@ -376,11 +348,7 @@ mod test {
|
|||
assert_eq!(&device, loaded_device);
|
||||
|
||||
store.delete_devices(vec![device.clone()]).await;
|
||||
assert!(store
|
||||
.get_device(device.user_id(), device.device_id())
|
||||
.await
|
||||
.unwrap()
|
||||
.is_none());
|
||||
assert!(store.get_device(device.user_id(), device.device_id()).await.unwrap().is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
@ -388,14 +356,8 @@ mod test {
|
|||
let device = get_device();
|
||||
let store = MemoryStore::new();
|
||||
|
||||
assert!(store
|
||||
.update_tracked_user(device.user_id(), false)
|
||||
.await
|
||||
.unwrap());
|
||||
assert!(!store
|
||||
.update_tracked_user(device.user_id(), false)
|
||||
.await
|
||||
.unwrap());
|
||||
assert!(store.update_tracked_user(device.user_id(), false).await.unwrap());
|
||||
assert!(!store.update_tracked_user(device.user_id(), false).await.unwrap());
|
||||
|
||||
assert!(store.is_user_tracked(device.user_id()));
|
||||
}
|
||||
|
@ -404,10 +366,8 @@ mod test {
|
|||
async fn test_message_hash() {
|
||||
let store = MemoryStore::new();
|
||||
|
||||
let hash = OlmMessageHash {
|
||||
sender_key: "test_sender".to_owned(),
|
||||
hash: "test_hash".to_owned(),
|
||||
};
|
||||
let hash =
|
||||
OlmMessageHash { sender_key: "test_sender".to_owned(), hash: "test_hash".to_owned() };
|
||||
|
||||
let mut changes = Changes::default();
|
||||
changes.message_hashes.push(hash.clone());
|
||||
|
|
|
@ -43,11 +43,6 @@ mod pickle_key;
|
|||
#[cfg(feature = "sled_cryptostore")]
|
||||
pub(crate) mod sled;
|
||||
|
||||
#[cfg(feature = "sled_cryptostore")]
|
||||
pub use self::sled::SledStore;
|
||||
pub use memorystore::MemoryStore;
|
||||
pub use pickle_key::{EncryptedPickleKey, PickleKey};
|
||||
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
fmt::Debug,
|
||||
|
@ -56,10 +51,6 @@ use std::{
|
|||
sync::Arc,
|
||||
};
|
||||
|
||||
use olm_rs::errors::{OlmAccountError, OlmGroupSessionError, OlmSessionError};
|
||||
use serde_json::Error as SerdeError;
|
||||
use thiserror::Error;
|
||||
|
||||
use matrix_sdk_common::{
|
||||
async_trait,
|
||||
events::room_key_request::RequestedKeyInfo,
|
||||
|
@ -71,7 +62,14 @@ use matrix_sdk_common::{
|
|||
uuid::Uuid,
|
||||
AsyncTraitDeps,
|
||||
};
|
||||
pub use memorystore::MemoryStore;
|
||||
use olm_rs::errors::{OlmAccountError, OlmGroupSessionError, OlmSessionError};
|
||||
pub use pickle_key::{EncryptedPickleKey, PickleKey};
|
||||
use serde_json::Error as SerdeError;
|
||||
use thiserror::Error;
|
||||
|
||||
#[cfg(feature = "sled_cryptostore")]
|
||||
pub use self::sled::SledStore;
|
||||
use crate::{
|
||||
error::SessionUnpicklingError,
|
||||
identities::{Device, ReadOnlyDevice, UserDevices, UserIdentities},
|
||||
|
@ -145,12 +143,7 @@ impl Store {
|
|||
store: Arc<Box<dyn CryptoStore>>,
|
||||
verification_machine: VerificationMachine,
|
||||
) -> Self {
|
||||
Self {
|
||||
user_id,
|
||||
identity,
|
||||
inner: store,
|
||||
verification_machine,
|
||||
}
|
||||
Self { user_id, identity, inner: store, verification_machine }
|
||||
}
|
||||
|
||||
pub async fn get_readonly_device(
|
||||
|
@ -162,10 +155,7 @@ impl Store {
|
|||
}
|
||||
|
||||
pub async fn save_sessions(&self, sessions: &[Session]) -> Result<()> {
|
||||
let changes = Changes {
|
||||
sessions: sessions.to_vec(),
|
||||
..Default::default()
|
||||
};
|
||||
let changes = Changes { sessions: sessions.to_vec(), ..Default::default() };
|
||||
|
||||
self.save_changes(changes).await
|
||||
}
|
||||
|
@ -173,10 +163,7 @@ impl Store {
|
|||
#[cfg(test)]
|
||||
pub async fn save_devices(&self, devices: &[ReadOnlyDevice]) -> Result<()> {
|
||||
let changes = Changes {
|
||||
devices: DeviceChanges {
|
||||
changed: devices.to_vec(),
|
||||
..Default::default()
|
||||
},
|
||||
devices: DeviceChanges { changed: devices.to_vec(), ..Default::default() },
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
|
@ -188,10 +175,7 @@ impl Store {
|
|||
&self,
|
||||
sessions: &[InboundGroupSession],
|
||||
) -> Result<()> {
|
||||
let changes = Changes {
|
||||
inbound_group_sessions: sessions.to_vec(),
|
||||
..Default::default()
|
||||
};
|
||||
let changes = Changes { inbound_group_sessions: sessions.to_vec(), ..Default::default() };
|
||||
|
||||
self.save_changes(changes).await
|
||||
}
|
||||
|
@ -210,8 +194,7 @@ impl Store {
|
|||
) -> Result<Option<Device>> {
|
||||
self.get_user_devices(user_id).await.map(|d| {
|
||||
d.devices().find(|d| {
|
||||
d.get_key(DeviceKeyAlgorithm::Curve25519)
|
||||
.map_or(false, |k| k == curve_key)
|
||||
d.get_key(DeviceKeyAlgorithm::Curve25519).map_or(false, |k| k == curve_key)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
@ -219,12 +202,8 @@ impl Store {
|
|||
pub async fn get_user_devices(&self, user_id: &UserId) -> Result<UserDevices> {
|
||||
let devices = self.inner.get_user_devices(user_id).await?;
|
||||
|
||||
let own_identity = self
|
||||
.inner
|
||||
.get_user_identity(&self.user_id)
|
||||
.await?
|
||||
.map(|i| i.own().cloned())
|
||||
.flatten();
|
||||
let own_identity =
|
||||
self.inner.get_user_identity(&self.user_id).await?.map(|i| i.own().cloned()).flatten();
|
||||
let device_owner_identity = self.inner.get_user_identity(user_id).await.ok().flatten();
|
||||
|
||||
Ok(UserDevices {
|
||||
|
@ -241,24 +220,17 @@ impl Store {
|
|||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
) -> Result<Option<Device>> {
|
||||
let own_identity = self
|
||||
.get_user_identity(&self.user_id)
|
||||
.await?
|
||||
.map(|i| i.own().cloned())
|
||||
.flatten();
|
||||
let own_identity =
|
||||
self.get_user_identity(&self.user_id).await?.map(|i| i.own().cloned()).flatten();
|
||||
let device_owner_identity = self.get_user_identity(user_id).await?;
|
||||
|
||||
Ok(self
|
||||
.inner
|
||||
.get_device(user_id, device_id)
|
||||
.await?
|
||||
.map(|d| Device {
|
||||
inner: d,
|
||||
private_identity: self.identity.clone(),
|
||||
verification_machine: self.verification_machine.clone(),
|
||||
own_identity,
|
||||
device_owner_identity,
|
||||
}))
|
||||
Ok(self.inner.get_device(user_id, device_id).await?.map(|d| Device {
|
||||
inner: d,
|
||||
private_identity: self.identity.clone(),
|
||||
verification_machine: self.verification_machine.clone(),
|
||||
own_identity,
|
||||
device_owner_identity,
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -366,7 +338,8 @@ pub trait CryptoStore: AsyncTraitDeps {
|
|||
/// Get all the inbound group sessions we have stored.
|
||||
async fn get_inbound_group_sessions(&self) -> Result<Vec<InboundGroupSession>>;
|
||||
|
||||
/// Get the outobund group sessions we have stored that is used for the given room.
|
||||
/// Get the outobund group sessions we have stored that is used for the
|
||||
/// given room.
|
||||
async fn get_outbound_group_sessions(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
|
|
|
@ -22,11 +22,10 @@ use getrandom::getrandom;
|
|||
use hmac::Hmac;
|
||||
use olm_rs::PicklingMode;
|
||||
use pbkdf2::pbkdf2;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sha2::Sha256;
|
||||
use zeroize::{Zeroize, Zeroizing};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
const KEY_SIZE: usize = 32;
|
||||
const NONCE_SIZE: usize = 12;
|
||||
const KDF_SALT_SIZE: usize = 32;
|
||||
|
@ -114,9 +113,7 @@ impl PickleKey {
|
|||
|
||||
/// Get a `PicklingMode` version of this pickle key.
|
||||
pub fn pickle_mode(&self) -> PicklingMode {
|
||||
PicklingMode::Encrypted {
|
||||
key: self.aes256_key.clone(),
|
||||
}
|
||||
PicklingMode::Encrypted { key: self.aes256_key.clone() }
|
||||
}
|
||||
|
||||
/// Get the raw AES256 key.
|
||||
|
@ -142,10 +139,7 @@ impl PickleKey {
|
|||
getrandom(&mut nonce).expect("Can't generate new random nonce for the pickle key");
|
||||
|
||||
let ciphertext = cipher
|
||||
.encrypt(
|
||||
&GenericArray::from_slice(nonce.as_ref()),
|
||||
self.aes256_key.as_slice(),
|
||||
)
|
||||
.encrypt(&GenericArray::from_slice(nonce.as_ref()), self.aes256_key.as_slice())
|
||||
.expect("Can't encrypt pickle key");
|
||||
|
||||
EncryptedPickleKey {
|
||||
|
@ -181,9 +175,7 @@ impl PickleKey {
|
|||
}
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
aes256_key: decrypted,
|
||||
})
|
||||
Ok(Self { aes256_key: decrypted })
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -20,13 +20,6 @@ use std::{
|
|||
};
|
||||
|
||||
use dashmap::DashSet;
|
||||
use olm_rs::{account::IdentityKeys, PicklingMode};
|
||||
pub use sled::Error;
|
||||
use sled::{
|
||||
transaction::{ConflictableTransactionError, TransactionError},
|
||||
Config, Db, Transactional, Tree,
|
||||
};
|
||||
|
||||
use matrix_sdk_common::{
|
||||
async_trait,
|
||||
events::room_key_request::RequestedKeyInfo,
|
||||
|
@ -34,6 +27,12 @@ use matrix_sdk_common::{
|
|||
locks::Mutex,
|
||||
uuid,
|
||||
};
|
||||
use olm_rs::{account::IdentityKeys, PicklingMode};
|
||||
pub use sled::Error;
|
||||
use sled::{
|
||||
transaction::{ConflictableTransactionError, TransactionError},
|
||||
Config, Db, Transactional, Tree,
|
||||
};
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::{
|
||||
|
@ -97,13 +96,7 @@ impl EncodeKey for &str {
|
|||
|
||||
impl EncodeKey for (&str, &str) {
|
||||
fn encode(&self) -> Vec<u8> {
|
||||
[
|
||||
self.0.as_bytes(),
|
||||
&[Self::SEPARATOR],
|
||||
self.1.as_bytes(),
|
||||
&[Self::SEPARATOR],
|
||||
]
|
||||
.concat()
|
||||
[self.0.as_bytes(), &[Self::SEPARATOR], self.1.as_bytes(), &[Self::SEPARATOR]].concat()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -164,9 +157,7 @@ impl std::fmt::Debug for SledStore {
|
|||
if let Some(path) = &self.path {
|
||||
f.debug_struct("SledStore").field("path", &path).finish()
|
||||
} else {
|
||||
f.debug_struct("SledStore")
|
||||
.field("path", &"memory store")
|
||||
.finish()
|
||||
f.debug_struct("SledStore").field("path", &"memory store").finish()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -253,9 +244,8 @@ impl SledStore {
|
|||
}
|
||||
|
||||
fn get_or_create_pickle_key(passphrase: &str, database: &Db) -> Result<PickleKey> {
|
||||
let key = if let Some(key) = database
|
||||
.get("pickle_key".encode())?
|
||||
.map(|v| serde_json::from_slice(&v))
|
||||
let key = if let Some(key) =
|
||||
database.get("pickle_key".encode())?.map(|v| serde_json::from_slice(&v))
|
||||
{
|
||||
PickleKey::from_encrypted(passphrase, key?)
|
||||
.map_err(|_| CryptoStoreError::UnpicklingError)?
|
||||
|
@ -297,9 +287,7 @@ impl SledStore {
|
|||
&self,
|
||||
room_id: &RoomId,
|
||||
) -> Result<Option<OutboundGroupSession>> {
|
||||
let account_info = self
|
||||
.get_account_info()
|
||||
.ok_or(CryptoStoreError::AccountUnset)?;
|
||||
let account_info = self.get_account_info().ok_or(CryptoStoreError::AccountUnset)?;
|
||||
|
||||
self.outbound_group_sessions
|
||||
.get(room_id.encode())?
|
||||
|
@ -501,17 +489,11 @@ impl SledStore {
|
|||
&self,
|
||||
id: &[u8],
|
||||
) -> Result<Option<OutgoingKeyRequest>> {
|
||||
let request = self
|
||||
.outgoing_key_requests
|
||||
.get(id)?
|
||||
.map(|r| serde_json::from_slice(&r))
|
||||
.transpose()?;
|
||||
let request =
|
||||
self.outgoing_key_requests.get(id)?.map(|r| serde_json::from_slice(&r)).transpose()?;
|
||||
|
||||
let request = if request.is_none() {
|
||||
self.unsent_key_requests
|
||||
.get(id)?
|
||||
.map(|r| serde_json::from_slice(&r))
|
||||
.transpose()?
|
||||
self.unsent_key_requests.get(id)?.map(|r| serde_json::from_slice(&r)).transpose()?
|
||||
} else {
|
||||
request
|
||||
};
|
||||
|
@ -553,10 +535,7 @@ impl CryptoStore for SledStore {
|
|||
|
||||
*self.account_info.write().unwrap() = Some(account_info);
|
||||
|
||||
let changes = Changes {
|
||||
account: Some(account),
|
||||
..Default::default()
|
||||
};
|
||||
let changes = Changes { account: Some(account), ..Default::default() };
|
||||
|
||||
self.save_changes(changes).await
|
||||
}
|
||||
|
@ -579,9 +558,7 @@ impl CryptoStore for SledStore {
|
|||
}
|
||||
|
||||
async fn get_sessions(&self, sender_key: &str) -> Result<Option<Arc<Mutex<Vec<Session>>>>> {
|
||||
let account_info = self
|
||||
.get_account_info()
|
||||
.ok_or(CryptoStoreError::AccountUnset)?;
|
||||
let account_info = self.get_account_info().ok_or(CryptoStoreError::AccountUnset)?;
|
||||
|
||||
if self.session_cache.get(sender_key).is_none() {
|
||||
let sessions: Result<Vec<Session>> = self
|
||||
|
@ -613,16 +590,10 @@ impl CryptoStore for SledStore {
|
|||
session_id: &str,
|
||||
) -> Result<Option<InboundGroupSession>> {
|
||||
let key = (room_id.as_str(), sender_key, session_id).encode();
|
||||
let pickle = self
|
||||
.inbound_group_sessions
|
||||
.get(&key)?
|
||||
.map(|p| serde_json::from_slice(&p));
|
||||
let pickle = self.inbound_group_sessions.get(&key)?.map(|p| serde_json::from_slice(&p));
|
||||
|
||||
if let Some(pickle) = pickle {
|
||||
Ok(Some(InboundGroupSession::from_pickle(
|
||||
pickle?,
|
||||
self.get_pickle_mode(),
|
||||
)?))
|
||||
Ok(Some(InboundGroupSession::from_pickle(pickle?, self.get_pickle_mode())?))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
|
@ -658,10 +629,7 @@ impl CryptoStore for SledStore {
|
|||
|
||||
fn users_for_key_query(&self) -> HashSet<UserId> {
|
||||
#[allow(clippy::map_clone)]
|
||||
self.users_for_key_query_cache
|
||||
.iter()
|
||||
.map(|u| u.clone())
|
||||
.collect()
|
||||
self.users_for_key_query_cache.iter().map(|u| u.clone()).collect()
|
||||
}
|
||||
|
||||
async fn update_tracked_user(&self, user: &UserId, dirty: bool) -> Result<bool> {
|
||||
|
@ -715,9 +683,7 @@ impl CryptoStore for SledStore {
|
|||
}
|
||||
|
||||
async fn is_message_known(&self, message_hash: &crate::olm::OlmMessageHash) -> Result<bool> {
|
||||
Ok(self
|
||||
.olm_hashes
|
||||
.contains_key(serde_json::to_vec(message_hash)?)?)
|
||||
Ok(self.olm_hashes.contains_key(serde_json::to_vec(message_hash)?)?)
|
||||
}
|
||||
|
||||
async fn get_outgoing_key_request(
|
||||
|
@ -753,36 +719,33 @@ impl CryptoStore for SledStore {
|
|||
}
|
||||
|
||||
async fn delete_outgoing_key_request(&self, request_id: Uuid) -> Result<()> {
|
||||
let ret: Result<(), TransactionError<serde_json::Error>> = (
|
||||
&self.outgoing_key_requests,
|
||||
&self.unsent_key_requests,
|
||||
&self.key_requests_by_info,
|
||||
)
|
||||
.transaction(
|
||||
|(outgoing_key_requests, unsent_key_requests, key_requests_by_info)| {
|
||||
let sent_request: Option<OutgoingKeyRequest> = outgoing_key_requests
|
||||
.remove(request_id.encode())?
|
||||
.map(|r| serde_json::from_slice(&r))
|
||||
.transpose()
|
||||
.map_err(ConflictableTransactionError::Abort)?;
|
||||
let ret: Result<(), TransactionError<serde_json::Error>> =
|
||||
(&self.outgoing_key_requests, &self.unsent_key_requests, &self.key_requests_by_info)
|
||||
.transaction(
|
||||
|(outgoing_key_requests, unsent_key_requests, key_requests_by_info)| {
|
||||
let sent_request: Option<OutgoingKeyRequest> = outgoing_key_requests
|
||||
.remove(request_id.encode())?
|
||||
.map(|r| serde_json::from_slice(&r))
|
||||
.transpose()
|
||||
.map_err(ConflictableTransactionError::Abort)?;
|
||||
|
||||
let unsent_request: Option<OutgoingKeyRequest> = unsent_key_requests
|
||||
.remove(request_id.encode())?
|
||||
.map(|r| serde_json::from_slice(&r))
|
||||
.transpose()
|
||||
.map_err(ConflictableTransactionError::Abort)?;
|
||||
let unsent_request: Option<OutgoingKeyRequest> = unsent_key_requests
|
||||
.remove(request_id.encode())?
|
||||
.map(|r| serde_json::from_slice(&r))
|
||||
.transpose()
|
||||
.map_err(ConflictableTransactionError::Abort)?;
|
||||
|
||||
if let Some(request) = sent_request {
|
||||
key_requests_by_info.remove((&request.info).encode())?;
|
||||
}
|
||||
if let Some(request) = sent_request {
|
||||
key_requests_by_info.remove((&request.info).encode())?;
|
||||
}
|
||||
|
||||
if let Some(request) = unsent_request {
|
||||
key_requests_by_info.remove((&request.info).encode())?;
|
||||
}
|
||||
if let Some(request) = unsent_request {
|
||||
key_requests_by_info.remove((&request.info).encode())?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
},
|
||||
);
|
||||
Ok(())
|
||||
},
|
||||
);
|
||||
|
||||
ret?;
|
||||
self.inner.flush_async().await?;
|
||||
|
@ -793,6 +756,19 @@ impl CryptoStore for SledStore {
|
|||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
use matrix_sdk_common::{
|
||||
api::r0::keys::SignedKey,
|
||||
events::room_key_request::RequestedKeyInfo,
|
||||
identifiers::{room_id, user_id, DeviceId, EventEncryptionAlgorithm, UserId},
|
||||
uuid::Uuid,
|
||||
};
|
||||
use matrix_sdk_test::async_test;
|
||||
use olm_rs::outbound_group_session::OlmOutboundGroupSession;
|
||||
use tempfile::tempdir;
|
||||
|
||||
use super::{CryptoStore, OutgoingKeyRequest, SledStore};
|
||||
use crate::{
|
||||
identities::{
|
||||
device::test::get_device,
|
||||
|
@ -804,18 +780,6 @@ mod test {
|
|||
},
|
||||
store::{Changes, DeviceChanges, IdentityChanges},
|
||||
};
|
||||
use matrix_sdk_common::{
|
||||
api::r0::keys::SignedKey,
|
||||
events::room_key_request::RequestedKeyInfo,
|
||||
identifiers::{room_id, user_id, DeviceId, EventEncryptionAlgorithm, UserId},
|
||||
uuid::Uuid,
|
||||
};
|
||||
use matrix_sdk_test::async_test;
|
||||
use olm_rs::outbound_group_session::OlmOutboundGroupSession;
|
||||
use std::collections::BTreeMap;
|
||||
use tempfile::tempdir;
|
||||
|
||||
use super::{CryptoStore, OutgoingKeyRequest, SledStore};
|
||||
|
||||
fn alice_id() -> UserId {
|
||||
user_id!("@alice:example.org")
|
||||
|
@ -846,10 +810,7 @@ mod test {
|
|||
async fn get_loaded_store() -> (ReadOnlyAccount, SledStore, tempfile::TempDir) {
|
||||
let (store, dir) = get_store(None).await;
|
||||
let account = get_account();
|
||||
store
|
||||
.save_account(account.clone())
|
||||
.await
|
||||
.expect("Can't save account");
|
||||
store.save_account(account.clone()).await.expect("Can't save account");
|
||||
|
||||
(account, store, dir)
|
||||
}
|
||||
|
@ -863,21 +824,12 @@ mod test {
|
|||
let bob = ReadOnlyAccount::new(&bob_id(), &bob_device_id());
|
||||
|
||||
bob.generate_one_time_keys_helper(1).await;
|
||||
let one_time_key = bob
|
||||
.one_time_keys()
|
||||
.await
|
||||
.curve25519()
|
||||
.iter()
|
||||
.next()
|
||||
.unwrap()
|
||||
.1
|
||||
.to_owned();
|
||||
let one_time_key =
|
||||
bob.one_time_keys().await.curve25519().iter().next().unwrap().1.to_owned();
|
||||
let one_time_key = SignedKey::new(one_time_key, BTreeMap::new());
|
||||
let sender_key = bob.identity_keys().curve25519().to_owned();
|
||||
let session = alice
|
||||
.create_outbound_session_helper(&sender_key, &one_time_key)
|
||||
.await
|
||||
.unwrap();
|
||||
let session =
|
||||
alice.create_outbound_session_helper(&sender_key, &one_time_key).await.unwrap();
|
||||
|
||||
(alice, session)
|
||||
}
|
||||
|
@ -895,10 +847,7 @@ mod test {
|
|||
assert!(store.load_account().await.unwrap().is_none());
|
||||
let account = get_account();
|
||||
|
||||
store
|
||||
.save_account(account)
|
||||
.await
|
||||
.expect("Can't save account");
|
||||
store.save_account(account).await.expect("Can't save account");
|
||||
}
|
||||
|
||||
#[async_test]
|
||||
|
@ -906,10 +855,7 @@ mod test {
|
|||
let (store, _dir) = get_store(None).await;
|
||||
let account = get_account();
|
||||
|
||||
store
|
||||
.save_account(account.clone())
|
||||
.await
|
||||
.expect("Can't save account");
|
||||
store.save_account(account.clone()).await.expect("Can't save account");
|
||||
|
||||
let loaded_account = store.load_account().await.expect("Can't load account");
|
||||
let loaded_account = loaded_account.unwrap();
|
||||
|
@ -922,10 +868,7 @@ mod test {
|
|||
let (store, _dir) = get_store(Some("secret_passphrase")).await;
|
||||
let account = get_account();
|
||||
|
||||
store
|
||||
.save_account(account.clone())
|
||||
.await
|
||||
.expect("Can't save account");
|
||||
store.save_account(account.clone()).await.expect("Can't save account");
|
||||
|
||||
let loaded_account = store.load_account().await.expect("Can't load account");
|
||||
let loaded_account = loaded_account.unwrap();
|
||||
|
@ -938,50 +881,32 @@ mod test {
|
|||
let (store, _dir) = get_store(None).await;
|
||||
let account = get_account();
|
||||
|
||||
store
|
||||
.save_account(account.clone())
|
||||
.await
|
||||
.expect("Can't save account");
|
||||
store.save_account(account.clone()).await.expect("Can't save account");
|
||||
|
||||
account.mark_as_shared();
|
||||
account.update_uploaded_key_count(50);
|
||||
|
||||
store
|
||||
.save_account(account.clone())
|
||||
.await
|
||||
.expect("Can't save account");
|
||||
store.save_account(account.clone()).await.expect("Can't save account");
|
||||
|
||||
let loaded_account = store.load_account().await.expect("Can't load account");
|
||||
let loaded_account = loaded_account.unwrap();
|
||||
|
||||
assert_eq!(account, loaded_account);
|
||||
assert_eq!(
|
||||
account.uploaded_key_count(),
|
||||
loaded_account.uploaded_key_count()
|
||||
);
|
||||
assert_eq!(account.uploaded_key_count(), loaded_account.uploaded_key_count());
|
||||
}
|
||||
|
||||
#[async_test]
|
||||
async fn load_sessions() {
|
||||
let (store, _dir) = get_store(None).await;
|
||||
let (account, session) = get_account_and_session().await;
|
||||
store
|
||||
.save_account(account.clone())
|
||||
.await
|
||||
.expect("Can't save account");
|
||||
store.save_account(account.clone()).await.expect("Can't save account");
|
||||
|
||||
let changes = Changes {
|
||||
sessions: vec![session.clone()],
|
||||
..Default::default()
|
||||
};
|
||||
let changes = Changes { sessions: vec![session.clone()], ..Default::default() };
|
||||
|
||||
store.save_changes(changes).await.unwrap();
|
||||
|
||||
let sessions = store
|
||||
.get_sessions(&session.sender_key)
|
||||
.await
|
||||
.expect("Can't load sessions")
|
||||
.unwrap();
|
||||
let sessions =
|
||||
store.get_sessions(&session.sender_key).await.expect("Can't load sessions").unwrap();
|
||||
let loaded_session = sessions.lock().await.get(0).cloned().unwrap();
|
||||
|
||||
assert_eq!(&session, &loaded_session);
|
||||
|
@ -994,15 +919,9 @@ mod test {
|
|||
let sender_key = session.sender_key.to_owned();
|
||||
let session_id = session.session_id().to_owned();
|
||||
|
||||
store
|
||||
.save_account(account.clone())
|
||||
.await
|
||||
.expect("Can't save account");
|
||||
store.save_account(account.clone()).await.expect("Can't save account");
|
||||
|
||||
let changes = Changes {
|
||||
sessions: vec![session.clone()],
|
||||
..Default::default()
|
||||
};
|
||||
let changes = Changes { sessions: vec![session.clone()], ..Default::default() };
|
||||
store.save_changes(changes).await.unwrap();
|
||||
|
||||
let sessions = store.get_sessions(&sender_key).await.unwrap().unwrap();
|
||||
|
@ -1040,15 +959,9 @@ mod test {
|
|||
)
|
||||
.expect("Can't create session");
|
||||
|
||||
let changes = Changes {
|
||||
inbound_group_sessions: vec![session],
|
||||
..Default::default()
|
||||
};
|
||||
let changes = Changes { inbound_group_sessions: vec![session], ..Default::default() };
|
||||
|
||||
store
|
||||
.save_changes(changes)
|
||||
.await
|
||||
.expect("Can't save group session");
|
||||
store.save_changes(changes).await.expect("Can't save group session");
|
||||
}
|
||||
|
||||
#[async_test]
|
||||
|
@ -1072,15 +985,10 @@ mod test {
|
|||
|
||||
let session = InboundGroupSession::from_export(export).unwrap();
|
||||
|
||||
let changes = Changes {
|
||||
inbound_group_sessions: vec![session.clone()],
|
||||
..Default::default()
|
||||
};
|
||||
let changes =
|
||||
Changes { inbound_group_sessions: vec![session.clone()], ..Default::default() };
|
||||
|
||||
store
|
||||
.save_changes(changes)
|
||||
.await
|
||||
.expect("Can't save group session");
|
||||
store.save_changes(changes).await.expect("Can't save group session");
|
||||
|
||||
drop(store);
|
||||
|
||||
|
@ -1103,21 +1011,12 @@ mod test {
|
|||
let (_account, store, dir) = get_loaded_store().await;
|
||||
let device = get_device();
|
||||
|
||||
assert!(store
|
||||
.update_tracked_user(device.user_id(), false)
|
||||
.await
|
||||
.unwrap());
|
||||
assert!(!store
|
||||
.update_tracked_user(device.user_id(), false)
|
||||
.await
|
||||
.unwrap());
|
||||
assert!(store.update_tracked_user(device.user_id(), false).await.unwrap());
|
||||
assert!(!store.update_tracked_user(device.user_id(), false).await.unwrap());
|
||||
|
||||
assert!(store.is_user_tracked(device.user_id()));
|
||||
assert!(!store.users_for_key_query().contains(device.user_id()));
|
||||
assert!(!store
|
||||
.update_tracked_user(device.user_id(), true)
|
||||
.await
|
||||
.unwrap());
|
||||
assert!(!store.update_tracked_user(device.user_id(), true).await.unwrap());
|
||||
assert!(store.users_for_key_query().contains(device.user_id()));
|
||||
drop(store);
|
||||
|
||||
|
@ -1128,10 +1027,7 @@ mod test {
|
|||
assert!(store.is_user_tracked(device.user_id()));
|
||||
assert!(store.users_for_key_query().contains(device.user_id()));
|
||||
|
||||
store
|
||||
.update_tracked_user(device.user_id(), false)
|
||||
.await
|
||||
.unwrap();
|
||||
store.update_tracked_user(device.user_id(), false).await.unwrap();
|
||||
assert!(!store.users_for_key_query().contains(device.user_id()));
|
||||
drop(store);
|
||||
|
||||
|
@ -1148,10 +1044,7 @@ mod test {
|
|||
let device = get_device();
|
||||
|
||||
let changes = Changes {
|
||||
devices: DeviceChanges {
|
||||
changed: vec![device.clone()],
|
||||
..Default::default()
|
||||
},
|
||||
devices: DeviceChanges { changed: vec![device.clone()], ..Default::default() },
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
|
@ -1163,11 +1056,8 @@ mod test {
|
|||
|
||||
store.load_account().await.unwrap();
|
||||
|
||||
let loaded_device = store
|
||||
.get_device(device.user_id(), device.device_id())
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
let loaded_device =
|
||||
store.get_device(device.user_id(), device.device_id()).await.unwrap().unwrap();
|
||||
|
||||
assert_eq!(device, loaded_device);
|
||||
|
||||
|
@ -1188,20 +1078,14 @@ mod test {
|
|||
let device = get_device();
|
||||
|
||||
let changes = Changes {
|
||||
devices: DeviceChanges {
|
||||
changed: vec![device.clone()],
|
||||
..Default::default()
|
||||
},
|
||||
devices: DeviceChanges { changed: vec![device.clone()], ..Default::default() },
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
store.save_changes(changes).await.unwrap();
|
||||
|
||||
let changes = Changes {
|
||||
devices: DeviceChanges {
|
||||
deleted: vec![device.clone()],
|
||||
..Default::default()
|
||||
},
|
||||
devices: DeviceChanges { deleted: vec![device.clone()], ..Default::default() },
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
|
@ -1212,10 +1096,7 @@ mod test {
|
|||
|
||||
store.load_account().await.unwrap();
|
||||
|
||||
let loaded_device = store
|
||||
.get_device(device.user_id(), device.device_id())
|
||||
.await
|
||||
.unwrap();
|
||||
let loaded_device = store.get_device(device.user_id(), device.device_id()).await.unwrap();
|
||||
|
||||
assert!(loaded_device.is_none());
|
||||
}
|
||||
|
@ -1232,10 +1113,7 @@ mod test {
|
|||
|
||||
let account = ReadOnlyAccount::new(&user_id, &device_id);
|
||||
|
||||
store
|
||||
.save_account(account.clone())
|
||||
.await
|
||||
.expect("Can't save account");
|
||||
store.save_account(account.clone()).await.expect("Can't save account");
|
||||
|
||||
let own_identity = get_own_identity();
|
||||
|
||||
|
@ -1247,10 +1125,7 @@ mod test {
|
|||
..Default::default()
|
||||
};
|
||||
|
||||
store
|
||||
.save_changes(changes)
|
||||
.await
|
||||
.expect("Can't save identity");
|
||||
store.save_changes(changes).await.expect("Can't save identity");
|
||||
|
||||
drop(store);
|
||||
|
||||
|
@ -1258,17 +1133,10 @@ mod test {
|
|||
|
||||
store.load_account().await.unwrap();
|
||||
|
||||
let loaded_user = store
|
||||
.get_user_identity(own_identity.user_id())
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
let loaded_user = store.get_user_identity(own_identity.user_id()).await.unwrap().unwrap();
|
||||
|
||||
assert_eq!(loaded_user.master_key(), own_identity.master_key());
|
||||
assert_eq!(
|
||||
loaded_user.self_signing_key(),
|
||||
own_identity.self_signing_key()
|
||||
);
|
||||
assert_eq!(loaded_user.self_signing_key(), own_identity.self_signing_key());
|
||||
assert_eq!(loaded_user, own_identity.clone().into());
|
||||
|
||||
let other_identity = get_other_identity();
|
||||
|
@ -1283,17 +1151,10 @@ mod test {
|
|||
|
||||
store.save_changes(changes).await.unwrap();
|
||||
|
||||
let loaded_user = store
|
||||
.get_user_identity(other_identity.user_id())
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
let loaded_user = store.get_user_identity(other_identity.user_id()).await.unwrap().unwrap();
|
||||
|
||||
assert_eq!(loaded_user.master_key(), other_identity.master_key());
|
||||
assert_eq!(
|
||||
loaded_user.self_signing_key(),
|
||||
other_identity.self_signing_key()
|
||||
);
|
||||
assert_eq!(loaded_user.self_signing_key(), other_identity.self_signing_key());
|
||||
assert_eq!(loaded_user, other_identity.into());
|
||||
|
||||
own_identity.mark_as_verified();
|
||||
|
@ -1317,10 +1178,7 @@ mod test {
|
|||
assert!(store.load_identity().await.unwrap().is_none());
|
||||
let identity = PrivateCrossSigningIdentity::new(alice_id()).await;
|
||||
|
||||
let changes = Changes {
|
||||
private_identity: Some(identity.clone()),
|
||||
..Default::default()
|
||||
};
|
||||
let changes = Changes { private_identity: Some(identity.clone()), ..Default::default() };
|
||||
|
||||
store.save_changes(changes).await.unwrap();
|
||||
let loaded_identity = store.load_identity().await.unwrap().unwrap();
|
||||
|
@ -1331,10 +1189,8 @@ mod test {
|
|||
async fn olm_hash_saving() {
|
||||
let (_, store, _dir) = get_loaded_store().await;
|
||||
|
||||
let hash = OlmMessageHash {
|
||||
sender_key: "test_sender".to_owned(),
|
||||
hash: "test_hash".to_owned(),
|
||||
};
|
||||
let hash =
|
||||
OlmMessageHash { sender_key: "test_sender".to_owned(), hash: "test_hash".to_owned() };
|
||||
|
||||
let mut changes = Changes::default();
|
||||
changes.message_hashes.push(hash.clone());
|
||||
|
|
|
@ -15,9 +15,6 @@
|
|||
use std::{convert::TryFrom, sync::Arc};
|
||||
|
||||
use dashmap::DashMap;
|
||||
|
||||
use tracing::{info, trace, warn};
|
||||
|
||||
use matrix_sdk_common::{
|
||||
events::{
|
||||
room::message::MessageType, AnyMessageEvent, AnySyncMessageEvent, AnySyncRoomEvent,
|
||||
|
@ -27,12 +24,12 @@ use matrix_sdk_common::{
|
|||
locks::Mutex,
|
||||
uuid::Uuid,
|
||||
};
|
||||
use tracing::{info, trace, warn};
|
||||
|
||||
use super::{
|
||||
requests::VerificationRequest,
|
||||
sas::{content_to_request, OutgoingContent, Sas, VerificationResult},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
olm::PrivateCrossSigningIdentity,
|
||||
requests::OutgoingRequest,
|
||||
|
@ -85,18 +82,14 @@ impl VerificationMachine {
|
|||
);
|
||||
|
||||
let request = match content.into() {
|
||||
OutgoingContent::Room(r, c) => RoomMessageRequest {
|
||||
room_id: r,
|
||||
txn_id: Uuid::new_v4(),
|
||||
content: c,
|
||||
OutgoingContent::Room(r, c) => {
|
||||
RoomMessageRequest { room_id: r, txn_id: Uuid::new_v4(), content: c }.into()
|
||||
}
|
||||
.into(),
|
||||
OutgoingContent::ToDevice(c) => {
|
||||
let request =
|
||||
content_to_request(device.user_id(), device.device_id().to_owned(), c);
|
||||
|
||||
self.verifications
|
||||
.insert(sas.flow_id().as_str().to_owned(), sas.clone());
|
||||
self.verifications.insert(sas.flow_id().as_str().to_owned(), sas.clone());
|
||||
|
||||
request.into()
|
||||
}
|
||||
|
@ -136,10 +129,7 @@ impl VerificationMachine {
|
|||
let request = content_to_request(recipient, recipient_device.to_owned(), c);
|
||||
let request_id = request.txn_id;
|
||||
|
||||
let request = OutgoingRequest {
|
||||
request_id,
|
||||
request: Arc::new(request.into()),
|
||||
};
|
||||
let request = OutgoingRequest { request_id, request: Arc::new(request.into()) };
|
||||
|
||||
self.outgoing_messages.insert(request_id, request);
|
||||
}
|
||||
|
@ -149,12 +139,7 @@ impl VerificationMachine {
|
|||
|
||||
let request = OutgoingRequest {
|
||||
request: Arc::new(
|
||||
RoomMessageRequest {
|
||||
room_id: r,
|
||||
txn_id: request_id,
|
||||
content: c,
|
||||
}
|
||||
.into(),
|
||||
RoomMessageRequest { room_id: r, txn_id: request_id, content: c }.into(),
|
||||
),
|
||||
request_id,
|
||||
};
|
||||
|
@ -181,24 +166,17 @@ impl VerificationMachine {
|
|||
}
|
||||
|
||||
pub fn outgoing_messages(&self) -> Vec<OutgoingRequest> {
|
||||
self.outgoing_messages
|
||||
.iter()
|
||||
.map(|r| (*r).clone())
|
||||
.collect()
|
||||
self.outgoing_messages.iter().map(|r| (*r).clone()).collect()
|
||||
}
|
||||
|
||||
pub fn garbage_collect(&self) {
|
||||
self.verifications
|
||||
.retain(|_, s| !(s.is_done() || s.is_canceled()));
|
||||
self.verifications.retain(|_, s| !(s.is_done() || s.is_canceled()));
|
||||
|
||||
for sas in self.verifications.iter() {
|
||||
if let Some(r) = sas.cancel_if_timed_out() {
|
||||
self.outgoing_messages.insert(
|
||||
r.request_id(),
|
||||
OutgoingRequest {
|
||||
request_id: r.request_id(),
|
||||
request: Arc::new(r.into()),
|
||||
},
|
||||
OutgoingRequest { request_id: r.request_id(), request: Arc::new(r.into()) },
|
||||
);
|
||||
}
|
||||
}
|
||||
|
@ -239,8 +217,7 @@ impl VerificationMachine {
|
|||
r,
|
||||
);
|
||||
|
||||
self.requests
|
||||
.insert(request.flow_id().as_str().to_owned(), request);
|
||||
self.requests.insert(request.flow_id().as_str().to_owned(), request);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -261,10 +238,8 @@ impl VerificationMachine {
|
|||
if let Some((_, request)) =
|
||||
self.requests.remove(e.content.relation.event_id.as_str())
|
||||
{
|
||||
if let Some(d) = self
|
||||
.store
|
||||
.get_device(&e.sender, &e.content.from_device)
|
||||
.await?
|
||||
if let Some(d) =
|
||||
self.store.get_device(&e.sender, &e.content.from_device).await?
|
||||
{
|
||||
match request.into_started_sas(
|
||||
e,
|
||||
|
@ -370,8 +345,7 @@ impl VerificationMachine {
|
|||
&e.content,
|
||||
);
|
||||
|
||||
self.requests
|
||||
.insert(request.flow_id().as_str().to_string(), request);
|
||||
self.requests.insert(request.flow_id().as_str().to_string(), request);
|
||||
}
|
||||
AnyToDeviceEvent::KeyVerificationReady(e) => {
|
||||
if let Some(request) = self.requests.get(&e.content.transaction_id) {
|
||||
|
@ -388,11 +362,7 @@ impl VerificationMachine {
|
|||
e.content.from_device
|
||||
);
|
||||
|
||||
if let Some(d) = self
|
||||
.store
|
||||
.get_device(&e.sender, &e.content.from_device)
|
||||
.await?
|
||||
{
|
||||
if let Some(d) = self.store.get_device(&e.sender, &e.content.from_device).await? {
|
||||
let private_identity = self.private_identity.lock().await.clone();
|
||||
match Sas::from_start_event(
|
||||
self.account.clone(),
|
||||
|
@ -403,8 +373,7 @@ impl VerificationMachine {
|
|||
self.store.get_user_identity(&e.sender).await?,
|
||||
) {
|
||||
Ok(s) => {
|
||||
self.verifications
|
||||
.insert(e.content.transaction_id.clone(), s);
|
||||
self.verifications.insert(e.content.transaction_id.clone(), s);
|
||||
}
|
||||
Err(c) => {
|
||||
warn!(
|
||||
|
@ -455,10 +424,7 @@ impl VerificationMachine {
|
|||
|
||||
self.outgoing_messages.insert(
|
||||
request_id,
|
||||
OutgoingRequest {
|
||||
request_id,
|
||||
request: Arc::new(r.into()),
|
||||
},
|
||||
OutgoingRequest { request_id, request: Arc::new(r.into()) },
|
||||
);
|
||||
}
|
||||
}
|
||||
|
@ -535,10 +501,7 @@ mod test {
|
|||
);
|
||||
|
||||
machine
|
||||
.receive_event(&wrap_any_to_device_content(
|
||||
bob_sas.user_id(),
|
||||
start_content.into(),
|
||||
))
|
||||
.receive_event(&wrap_any_to_device_content(bob_sas.user_id(), start_content.into()))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
|
|
@ -17,11 +17,10 @@ mod requests;
|
|||
mod sas;
|
||||
|
||||
pub use machine::VerificationMachine;
|
||||
use matrix_sdk_common::identifiers::{EventId, RoomId};
|
||||
pub use requests::VerificationRequest;
|
||||
pub use sas::{AcceptSettings, Sas, VerificationResult};
|
||||
|
||||
use matrix_sdk_common::identifiers::{EventId, RoomId};
|
||||
|
||||
#[derive(Clone, Debug, Hash, PartialEq, PartialOrd)]
|
||||
pub enum FlowId {
|
||||
ToDevice(String),
|
||||
|
@ -59,18 +58,17 @@ impl From<(RoomId, EventId)> for FlowId {
|
|||
|
||||
#[cfg(test)]
|
||||
pub(crate) mod test {
|
||||
use crate::{
|
||||
requests::{OutgoingRequest, OutgoingRequests},
|
||||
OutgoingVerificationRequest,
|
||||
};
|
||||
use serde_json::Value;
|
||||
|
||||
use matrix_sdk_common::{
|
||||
events::{AnyToDeviceEvent, AnyToDeviceEventContent, EventType, ToDeviceEvent},
|
||||
identifiers::UserId,
|
||||
};
|
||||
use serde_json::Value;
|
||||
|
||||
use super::sas::OutgoingContent;
|
||||
use crate::{
|
||||
requests::{OutgoingRequest, OutgoingRequests},
|
||||
OutgoingVerificationRequest,
|
||||
};
|
||||
|
||||
pub(crate) fn request_to_event(
|
||||
sender: &UserId,
|
||||
|
@ -94,11 +92,7 @@ pub(crate) mod test {
|
|||
sender: &UserId,
|
||||
content: OutgoingContent,
|
||||
) -> AnyToDeviceEvent {
|
||||
let content = if let OutgoingContent::ToDevice(c) = content {
|
||||
c
|
||||
} else {
|
||||
unreachable!()
|
||||
};
|
||||
let content = if let OutgoingContent::ToDevice(c) = content { c } else { unreachable!() };
|
||||
|
||||
match content {
|
||||
AnyToDeviceEventContent::KeyVerificationKey(c) => {
|
||||
|
@ -133,22 +127,11 @@ pub(crate) mod test {
|
|||
pub(crate) fn get_content_from_request(
|
||||
request: &OutgoingVerificationRequest,
|
||||
) -> OutgoingContent {
|
||||
let request = if let OutgoingVerificationRequest::ToDevice(r) = request {
|
||||
r
|
||||
} else {
|
||||
unreachable!()
|
||||
};
|
||||
let request =
|
||||
if let OutgoingVerificationRequest::ToDevice(r) = request { r } else { unreachable!() };
|
||||
|
||||
let json: Value = serde_json::from_str(
|
||||
request
|
||||
.messages
|
||||
.values()
|
||||
.next()
|
||||
.unwrap()
|
||||
.values()
|
||||
.next()
|
||||
.unwrap()
|
||||
.get(),
|
||||
request.messages.values().next().unwrap().values().next().unwrap().get(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
|
|
|
@ -35,6 +35,10 @@ use matrix_sdk_common::{
|
|||
uuid::Uuid,
|
||||
};
|
||||
|
||||
use super::{
|
||||
sas::{content_to_request, OutgoingContent, StartContent},
|
||||
FlowId,
|
||||
};
|
||||
use crate::{
|
||||
olm::{PrivateCrossSigningIdentity, ReadOnlyAccount},
|
||||
store::CryptoStore,
|
||||
|
@ -42,11 +46,6 @@ use crate::{
|
|||
UserIdentities,
|
||||
};
|
||||
|
||||
use super::{
|
||||
sas::{content_to_request, OutgoingContent, StartContent},
|
||||
FlowId,
|
||||
};
|
||||
|
||||
const SUPPORTED_METHODS: &[VerificationMethod] = &[VerificationMethod::MSasV1];
|
||||
|
||||
pub enum RequestContent<'a> {
|
||||
|
@ -256,15 +255,13 @@ impl VerificationRequest {
|
|||
content: RequestContent,
|
||||
) -> Self {
|
||||
Self {
|
||||
inner: Arc::new(Mutex::new(InnerRequest::Requested(
|
||||
RequestState::from_request_event(
|
||||
account.user_id(),
|
||||
account.device_id(),
|
||||
sender,
|
||||
&flow_id,
|
||||
content,
|
||||
),
|
||||
))),
|
||||
inner: Arc::new(Mutex::new(InnerRequest::Requested(RequestState::from_request_event(
|
||||
account.user_id(),
|
||||
account.device_id(),
|
||||
sender,
|
||||
&flow_id,
|
||||
content,
|
||||
)))),
|
||||
account,
|
||||
other_user_id: sender.clone().into(),
|
||||
private_cross_signing_identity,
|
||||
|
@ -278,15 +275,12 @@ impl VerificationRequest {
|
|||
let mut inner = self.inner.lock().unwrap();
|
||||
|
||||
inner.accept().map(|c| match c {
|
||||
OutgoingContent::ToDevice(content) => self
|
||||
.content_to_request(inner.other_device_id(), content)
|
||||
.into(),
|
||||
OutgoingContent::Room(room_id, content) => RoomMessageRequest {
|
||||
room_id,
|
||||
txn_id: Uuid::new_v4(),
|
||||
content,
|
||||
OutgoingContent::ToDevice(content) => {
|
||||
self.content_to_request(inner.other_device_id(), content).into()
|
||||
}
|
||||
OutgoingContent::Room(room_id, content) => {
|
||||
RoomMessageRequest { room_id, txn_id: Uuid::new_v4(), content }.into()
|
||||
}
|
||||
.into(),
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -445,10 +439,7 @@ impl RequestState<Created> {
|
|||
own_user_id: own_user_id.to_owned(),
|
||||
own_device_id: own_device_id.to_owned(),
|
||||
other_user_id: other_user_id.to_owned(),
|
||||
state: Created {
|
||||
methods: SUPPORTED_METHODS.to_vec(),
|
||||
flow_id: flow_id.to_owned(),
|
||||
},
|
||||
state: Created { methods: SUPPORTED_METHODS.to_vec(), flow_id: flow_id.to_owned() },
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -589,14 +580,9 @@ impl RequestState<Ready> {
|
|||
other_identity: Option<UserIdentities>,
|
||||
) -> (Sas, StartContent) {
|
||||
match self.state.flow_id {
|
||||
FlowId::ToDevice(t) => Sas::start(
|
||||
account,
|
||||
private_identity,
|
||||
other_device,
|
||||
store,
|
||||
other_identity,
|
||||
Some(t),
|
||||
),
|
||||
FlowId::ToDevice(t) => {
|
||||
Sas::start(account, private_identity, other_device, store, other_identity, Some(t))
|
||||
}
|
||||
FlowId::InRoom(r, e) => Sas::start_in_room(
|
||||
e,
|
||||
r,
|
||||
|
@ -630,6 +616,7 @@ mod test {
|
|||
};
|
||||
use matrix_sdk_test::async_test;
|
||||
|
||||
use super::VerificationRequest;
|
||||
use crate::{
|
||||
olm::{PrivateCrossSigningIdentity, ReadOnlyAccount},
|
||||
store::{CryptoStore, MemoryStore},
|
||||
|
@ -640,8 +627,6 @@ mod test {
|
|||
ReadOnlyDevice,
|
||||
};
|
||||
|
||||
use super::VerificationRequest;
|
||||
|
||||
fn alice_id() -> UserId {
|
||||
UserId::try_from("@alice:example.org").unwrap()
|
||||
}
|
||||
|
@ -760,9 +745,7 @@ mod test {
|
|||
panic!("Invalid start event content type");
|
||||
};
|
||||
|
||||
let alice_sas = alice_request
|
||||
.into_started_sas(&event, bob_device, None)
|
||||
.unwrap();
|
||||
let alice_sas = alice_request.into_started_sas(&event, bob_device, None).unwrap();
|
||||
|
||||
assert!(!bob_sas.is_canceled());
|
||||
assert!(!alice_sas.is_canceled());
|
||||
|
|
|
@ -59,10 +59,7 @@ impl StartContent {
|
|||
StartContent::Room(_, c) => serde_json::to_value(c),
|
||||
};
|
||||
|
||||
content
|
||||
.expect("Can't serialize content")
|
||||
.try_into()
|
||||
.expect("Can't canonicalize content")
|
||||
content.expect("Can't serialize content").try_into().expect("Can't canonicalize content")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -287,14 +284,7 @@ impl From<OutgoingVerificationRequest> for OutgoingContent {
|
|||
match request {
|
||||
OutgoingVerificationRequest::ToDevice(r) => {
|
||||
let json: Value = serde_json::from_str(
|
||||
r.messages
|
||||
.values()
|
||||
.next()
|
||||
.unwrap()
|
||||
.values()
|
||||
.next()
|
||||
.unwrap()
|
||||
.get(),
|
||||
r.messages.values().next().unwrap().values().next().unwrap().get(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
|
|
|
@ -14,11 +14,6 @@
|
|||
|
||||
use std::{collections::BTreeMap, convert::TryInto};
|
||||
|
||||
use sha2::{Digest, Sha256};
|
||||
use tracing::{trace, warn};
|
||||
|
||||
use olm_rs::sas::OlmSas;
|
||||
|
||||
use matrix_sdk_common::{
|
||||
api::r0::to_device::DeviceIdOrAllDevices,
|
||||
events::{
|
||||
|
@ -32,17 +27,19 @@ use matrix_sdk_common::{
|
|||
identifiers::{DeviceKeyAlgorithm, DeviceKeyId, UserId},
|
||||
uuid::Uuid,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
identities::{ReadOnlyDevice, UserIdentities},
|
||||
utilities::encode,
|
||||
ReadOnlyAccount, ToDeviceRequest,
|
||||
};
|
||||
use olm_rs::sas::OlmSas;
|
||||
use sha2::{Digest, Sha256};
|
||||
use tracing::{trace, warn};
|
||||
|
||||
use super::{
|
||||
event_enums::{MacContent, StartContent},
|
||||
FlowId,
|
||||
};
|
||||
use crate::{
|
||||
identities::{ReadOnlyDevice, UserIdentities},
|
||||
utilities::encode,
|
||||
ReadOnlyAccount, ToDeviceRequest,
|
||||
};
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct SasIds {
|
||||
|
@ -65,12 +62,7 @@ pub fn calculate_commitment(public_key: &str, content: impl Into<StartContent>)
|
|||
let content = content.into().canonical_json();
|
||||
let content_string = content.to_string();
|
||||
|
||||
encode(
|
||||
Sha256::new()
|
||||
.chain(&public_key)
|
||||
.chain(&content_string)
|
||||
.finalize(),
|
||||
)
|
||||
encode(Sha256::new().chain(&public_key).chain(&content_string).finalize())
|
||||
}
|
||||
|
||||
/// Get a tuple of an emoji and a description of the emoji using a number.
|
||||
|
@ -234,11 +226,7 @@ pub fn receive_mac_event(
|
|||
.calculate_mac(key, &format!("{}{}", info, key_id))
|
||||
.expect("Can't calculate SAS MAC")
|
||||
{
|
||||
trace!(
|
||||
"Successfully verified the device key {} from {}",
|
||||
key_id,
|
||||
sender
|
||||
);
|
||||
trace!("Successfully verified the device key {} from {}", key_id, sender);
|
||||
|
||||
verified_devices.push(ids.other_device.clone());
|
||||
} else {
|
||||
|
@ -253,11 +241,7 @@ pub fn receive_mac_event(
|
|||
.calculate_mac(key, &format!("{}{}", info, key_id))
|
||||
.expect("Can't calculate SAS MAC")
|
||||
{
|
||||
trace!(
|
||||
"Successfully verified the master key {} from {}",
|
||||
key_id,
|
||||
sender
|
||||
);
|
||||
trace!("Successfully verified the master key {} from {}", key_id, sender);
|
||||
verified_identities.push(identity.clone())
|
||||
} else {
|
||||
return Err(CancelCode::KeyMismatch);
|
||||
|
@ -319,8 +303,7 @@ pub fn get_mac_content(sas: &OlmSas, ids: &SasIds, flow_id: &FlowId) -> MacConte
|
|||
|
||||
mac.insert(
|
||||
key_id.to_string(),
|
||||
sas.calculate_mac(key, &format!("{}{}", info, key_id))
|
||||
.expect("Can't calculate SAS MAC"),
|
||||
sas.calculate_mac(key, &format!("{}{}", info, key_id)).expect("Can't calculate SAS MAC"),
|
||||
);
|
||||
|
||||
// TODO Add the cross signing master key here if we trust/have it.
|
||||
|
@ -332,23 +315,13 @@ pub fn get_mac_content(sas: &OlmSas, ids: &SasIds, flow_id: &FlowId) -> MacConte
|
|||
.expect("Can't calculate SAS MAC");
|
||||
|
||||
match flow_id {
|
||||
FlowId::ToDevice(s) => MacToDeviceEventContent {
|
||||
transaction_id: s.to_string(),
|
||||
keys,
|
||||
mac,
|
||||
FlowId::ToDevice(s) => {
|
||||
MacToDeviceEventContent { transaction_id: s.to_string(), keys, mac }.into()
|
||||
}
|
||||
FlowId::InRoom(r, e) => {
|
||||
(r.clone(), MacEventContent { mac, keys, relation: Relation { event_id: e.clone() } })
|
||||
.into()
|
||||
}
|
||||
.into(),
|
||||
FlowId::InRoom(r, e) => (
|
||||
r.clone(),
|
||||
MacEventContent {
|
||||
mac,
|
||||
keys,
|
||||
relation: Relation {
|
||||
event_id: e.clone(),
|
||||
},
|
||||
},
|
||||
)
|
||||
.into(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -369,24 +342,12 @@ fn extra_info_sas(
|
|||
flow_id: &str,
|
||||
we_started: bool,
|
||||
) -> String {
|
||||
let our_info = format!(
|
||||
"{}|{}|{}",
|
||||
ids.account.user_id(),
|
||||
ids.account.device_id(),
|
||||
own_pubkey
|
||||
);
|
||||
let their_info = format!(
|
||||
"{}|{}|{}",
|
||||
ids.other_device.user_id(),
|
||||
ids.other_device.device_id(),
|
||||
their_pubkey
|
||||
);
|
||||
let our_info = format!("{}|{}|{}", ids.account.user_id(), ids.account.device_id(), own_pubkey);
|
||||
let their_info =
|
||||
format!("{}|{}|{}", ids.other_device.user_id(), ids.other_device.device_id(), their_pubkey);
|
||||
|
||||
let (first_info, second_info) = if we_started {
|
||||
(our_info, their_info)
|
||||
} else {
|
||||
(their_info, our_info)
|
||||
};
|
||||
let (first_info, second_info) =
|
||||
if we_started { (our_info, their_info) } else { (their_info, our_info) };
|
||||
|
||||
let info = format!(
|
||||
"MATRIX_KEY_VERIFICATION_SAS|{first_info}|{second_info}|{flow_id}",
|
||||
|
@ -585,11 +546,7 @@ pub fn content_to_request(
|
|||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
ToDeviceRequest {
|
||||
txn_id: Uuid::new_v4(),
|
||||
event_type,
|
||||
messages,
|
||||
}
|
||||
ToDeviceRequest { txn_id: Uuid::new_v4(), event_type, messages }
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
@ -627,18 +584,14 @@ mod test {
|
|||
#[test]
|
||||
fn emoji_generation() {
|
||||
let bytes = vec![0, 0, 0, 0, 0, 0];
|
||||
let index: Vec<(&'static str, &'static str)> = vec![0, 0, 0, 0, 0, 0, 0]
|
||||
.into_iter()
|
||||
.map(emoji_from_index)
|
||||
.collect();
|
||||
let index: Vec<(&'static str, &'static str)> =
|
||||
vec![0, 0, 0, 0, 0, 0, 0].into_iter().map(emoji_from_index).collect();
|
||||
assert_eq!(bytes_to_emoji(bytes), index.as_ref());
|
||||
|
||||
let bytes = vec![0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF];
|
||||
|
||||
let index: Vec<(&'static str, &'static str)> = vec![63, 63, 63, 63, 63, 63, 63]
|
||||
.into_iter()
|
||||
.map(emoji_from_index)
|
||||
.collect();
|
||||
let index: Vec<(&'static str, &'static str)> =
|
||||
vec![63, 63, 63, 63, 63, 63, 63].into_iter().map(emoji_from_index).collect();
|
||||
assert_eq!(bytes_to_emoji(bytes), index.as_ref());
|
||||
}
|
||||
|
||||
|
|
|
@ -12,11 +12,10 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::sync::Arc;
|
||||
#[cfg(test)]
|
||||
use std::time::Instant;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use matrix_sdk_common::{
|
||||
events::{
|
||||
key::verification::{cancel::CancelCode, ShortAuthenticationString},
|
||||
|
@ -25,11 +24,6 @@ use matrix_sdk_common::{
|
|||
identifiers::{EventId, RoomId},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
identities::{ReadOnlyDevice, UserIdentities},
|
||||
ReadOnlyAccount,
|
||||
};
|
||||
|
||||
use super::{
|
||||
event_enums::{AcceptContent, CancelContent, MacContent, OutgoingContent},
|
||||
sas_state::{
|
||||
|
@ -38,6 +32,10 @@ use super::{
|
|||
},
|
||||
FlowId, StartContent,
|
||||
};
|
||||
use crate::{
|
||||
identities::{ReadOnlyDevice, UserIdentities},
|
||||
ReadOnlyAccount,
|
||||
};
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum InnerSas {
|
||||
|
@ -315,14 +313,15 @@ impl InnerSas {
|
|||
_ => (self, None),
|
||||
},
|
||||
AnyToDeviceEvent::KeyVerificationMac(e) => match self {
|
||||
InnerSas::KeyRecieved(s) => match s.into_mac_received(&e.sender, e.content.clone())
|
||||
{
|
||||
Ok(s) => (InnerSas::MacReceived(s), None),
|
||||
Err(s) => {
|
||||
let content = s.as_content();
|
||||
(InnerSas::Canceled(s), Some(content.into()))
|
||||
InnerSas::KeyRecieved(s) => {
|
||||
match s.into_mac_received(&e.sender, e.content.clone()) {
|
||||
Ok(s) => (InnerSas::MacReceived(s), None),
|
||||
Err(s) => {
|
||||
let content = s.as_content();
|
||||
(InnerSas::Canceled(s), Some(content.into()))
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
InnerSas::Confirmed(s) => match s.into_done(&e.sender, e.content.clone()) {
|
||||
Ok(s) => (InnerSas::Done(s), None),
|
||||
Err(s) => {
|
||||
|
|
|
@ -17,13 +17,14 @@ mod helpers;
|
|||
mod inner_sas;
|
||||
mod sas_state;
|
||||
|
||||
use std::sync::{Arc, Mutex};
|
||||
#[cfg(test)]
|
||||
use std::time::Instant;
|
||||
|
||||
use event_enums::AcceptContent;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use tracing::{error, info, trace, warn};
|
||||
|
||||
pub use event_enums::{CancelContent, OutgoingContent, StartContent};
|
||||
pub use helpers::content_to_request;
|
||||
use inner_sas::InnerSas;
|
||||
use matrix_sdk_common::{
|
||||
api::r0::keys::upload_signatures::Request as SignatureUploadRequest,
|
||||
events::{
|
||||
|
@ -37,7 +38,9 @@ use matrix_sdk_common::{
|
|||
identifiers::{DeviceId, EventId, RoomId, UserId},
|
||||
uuid::Uuid,
|
||||
};
|
||||
use tracing::{error, info, trace, warn};
|
||||
|
||||
use super::FlowId;
|
||||
use crate::{
|
||||
error::SignatureError,
|
||||
identities::{LocalTrust, ReadOnlyDevice, UserIdentities},
|
||||
|
@ -47,12 +50,6 @@ use crate::{
|
|||
ReadOnlyAccount, ToDeviceRequest,
|
||||
};
|
||||
|
||||
use super::FlowId;
|
||||
|
||||
pub use event_enums::{CancelContent, OutgoingContent, StartContent};
|
||||
pub use helpers::content_to_request;
|
||||
use inner_sas::InnerSas;
|
||||
|
||||
#[derive(Debug)]
|
||||
/// A result of a verification flow.
|
||||
#[allow(clippy::large_enum_variant)]
|
||||
|
@ -275,22 +272,18 @@ impl Sas {
|
|||
&self,
|
||||
settings: AcceptSettings,
|
||||
) -> Option<OutgoingVerificationRequest> {
|
||||
self.inner
|
||||
.lock()
|
||||
.unwrap()
|
||||
.accept()
|
||||
.map(|c| match settings.apply(c) {
|
||||
AcceptContent::ToDevice(c) => {
|
||||
let content = AnyToDeviceEventContent::KeyVerificationAccept(c);
|
||||
self.content_to_request(content).into()
|
||||
}
|
||||
AcceptContent::Room(room_id, content) => RoomMessageRequest {
|
||||
room_id,
|
||||
txn_id: Uuid::new_v4(),
|
||||
content: AnyMessageEventContent::KeyVerificationAccept(content),
|
||||
}
|
||||
.into(),
|
||||
})
|
||||
self.inner.lock().unwrap().accept().map(|c| match settings.apply(c) {
|
||||
AcceptContent::ToDevice(c) => {
|
||||
let content = AnyToDeviceEventContent::KeyVerificationAccept(c);
|
||||
self.content_to_request(content).into()
|
||||
}
|
||||
AcceptContent::Room(room_id, content) => RoomMessageRequest {
|
||||
room_id,
|
||||
txn_id: Uuid::new_v4(),
|
||||
content: AnyMessageEventContent::KeyVerificationAccept(content),
|
||||
}
|
||||
.into(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Confirm the Sas verification.
|
||||
|
@ -303,10 +296,7 @@ impl Sas {
|
|||
pub async fn confirm(
|
||||
&self,
|
||||
) -> Result<
|
||||
(
|
||||
Option<OutgoingVerificationRequest>,
|
||||
Option<SignatureUploadRequest>,
|
||||
),
|
||||
(Option<OutgoingVerificationRequest>, Option<SignatureUploadRequest>),
|
||||
CryptoStoreError,
|
||||
> {
|
||||
let (content, done) = {
|
||||
|
@ -319,9 +309,9 @@ impl Sas {
|
|||
};
|
||||
|
||||
let mac_request = content.map(|c| match c {
|
||||
event_enums::MacContent::ToDevice(c) => self
|
||||
.content_to_request(AnyToDeviceEventContent::KeyVerificationMac(c))
|
||||
.into(),
|
||||
event_enums::MacContent::ToDevice(c) => {
|
||||
self.content_to_request(AnyToDeviceEventContent::KeyVerificationMac(c)).into()
|
||||
}
|
||||
event_enums::MacContent::Room(r, c) => RoomMessageRequest {
|
||||
room_id: r,
|
||||
txn_id: Uuid::new_v4(),
|
||||
|
@ -374,10 +364,7 @@ impl Sas {
|
|||
};
|
||||
|
||||
let mut changes = Changes {
|
||||
devices: DeviceChanges {
|
||||
changed: vec![device],
|
||||
..Default::default()
|
||||
},
|
||||
devices: DeviceChanges { changed: vec![device], ..Default::default() },
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
|
@ -437,10 +424,7 @@ impl Sas {
|
|||
.map(VerificationResult::SignatureUpload)
|
||||
.unwrap_or(VerificationResult::Ok))
|
||||
} else {
|
||||
Ok(self
|
||||
.cancel()
|
||||
.map(VerificationResult::Cancel)
|
||||
.unwrap_or(VerificationResult::Ok))
|
||||
Ok(self.cancel().map(VerificationResult::Cancel).unwrap_or(VerificationResult::Ok))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -463,14 +447,8 @@ impl Sas {
|
|||
.as_ref()
|
||||
.map_or(false, |i| i.master_key() == identity.master_key())
|
||||
{
|
||||
if self
|
||||
.verified_identities()
|
||||
.map_or(false, |i| i.contains(&identity))
|
||||
{
|
||||
trace!(
|
||||
"Marking user identity of {} as verified.",
|
||||
identity.user_id(),
|
||||
);
|
||||
if self.verified_identities().map_or(false, |i| i.contains(&identity)) {
|
||||
trace!("Marking user identity of {} as verified.", identity.user_id(),);
|
||||
|
||||
if let UserIdentities::Own(i) = &identity {
|
||||
i.mark_as_verified();
|
||||
|
@ -509,17 +487,11 @@ impl Sas {
|
|||
pub(crate) async fn mark_device_as_verified(
|
||||
&self,
|
||||
) -> Result<Option<ReadOnlyDevice>, CryptoStoreError> {
|
||||
let device = self
|
||||
.store
|
||||
.get_device(self.other_user_id(), self.other_device_id())
|
||||
.await?;
|
||||
let device = self.store.get_device(self.other_user_id(), self.other_device_id()).await?;
|
||||
|
||||
if let Some(device) = device {
|
||||
if device.keys() == self.other_device.keys() {
|
||||
if self
|
||||
.verified_devices()
|
||||
.map_or(false, |v| v.contains(&device))
|
||||
{
|
||||
if self.verified_devices().map_or(false, |v| v.contains(&device)) {
|
||||
trace!(
|
||||
"Marking device {} {} as verified.",
|
||||
device.user_id(),
|
||||
|
@ -580,9 +552,9 @@ impl Sas {
|
|||
content: AnyMessageEventContent::KeyVerificationCancel(content),
|
||||
}
|
||||
.into(),
|
||||
CancelContent::ToDevice(c) => self
|
||||
.content_to_request(AnyToDeviceEventContent::KeyVerificationCancel(c))
|
||||
.into(),
|
||||
CancelContent::ToDevice(c) => {
|
||||
self.content_to_request(AnyToDeviceEventContent::KeyVerificationCancel(c)).into()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -684,11 +656,7 @@ impl Sas {
|
|||
}
|
||||
|
||||
pub(crate) fn content_to_request(&self, content: AnyToDeviceEventContent) -> ToDeviceRequest {
|
||||
content_to_request(
|
||||
self.other_user_id(),
|
||||
self.other_device_id().to_owned(),
|
||||
content,
|
||||
)
|
||||
content_to_request(self.other_user_id(), self.other_device_id().to_owned(), content)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -717,9 +685,7 @@ impl AcceptSettings {
|
|||
///
|
||||
/// * `methods` - The methods this client allows at most
|
||||
pub fn with_allowed_methods(methods: Vec<ShortAuthenticationString>) -> Self {
|
||||
Self {
|
||||
allowed_methods: methods,
|
||||
}
|
||||
Self { allowed_methods: methods }
|
||||
}
|
||||
|
||||
fn apply(self, mut content: AcceptContent) -> AcceptContent {
|
||||
|
@ -728,15 +694,8 @@ impl AcceptSettings {
|
|||
method: AcceptMethod::MSasV1(c),
|
||||
..
|
||||
})
|
||||
| AcceptContent::Room(
|
||||
_,
|
||||
AcceptEventContent {
|
||||
method: AcceptMethod::MSasV1(c),
|
||||
..
|
||||
},
|
||||
) => {
|
||||
c.short_authentication_string
|
||||
.retain(|sas| self.allowed_methods.contains(sas));
|
||||
| AcceptContent::Room(_, AcceptEventContent { method: AcceptMethod::MSasV1(c), .. }) => {
|
||||
c.short_authentication_string.retain(|sas| self.allowed_methods.contains(sas));
|
||||
content
|
||||
}
|
||||
_ => content,
|
||||
|
@ -750,6 +709,7 @@ mod test {
|
|||
|
||||
use matrix_sdk_common::identifiers::{DeviceId, UserId};
|
||||
|
||||
use super::Sas;
|
||||
use crate::{
|
||||
olm::PrivateCrossSigningIdentity,
|
||||
store::{CryptoStore, MemoryStore},
|
||||
|
@ -757,8 +717,6 @@ mod test {
|
|||
ReadOnlyAccount, ReadOnlyDevice,
|
||||
};
|
||||
|
||||
use super::Sas;
|
||||
|
||||
fn alice_id() -> UserId {
|
||||
UserId::try_from("@alice:example.org").unwrap()
|
||||
}
|
||||
|
@ -841,13 +799,7 @@ mod test {
|
|||
);
|
||||
alice.receive_event(&event);
|
||||
|
||||
assert!(alice
|
||||
.verified_devices()
|
||||
.unwrap()
|
||||
.contains(&alice.other_device()));
|
||||
assert!(bob
|
||||
.verified_devices()
|
||||
.unwrap()
|
||||
.contains(&bob.other_device()));
|
||||
assert!(alice.verified_devices().unwrap().contains(&alice.other_device()));
|
||||
assert!(bob.verified_devices().unwrap().contains(&bob.other_device()));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,8 +19,6 @@ use std::{
|
|||
time::{Duration, Instant},
|
||||
};
|
||||
|
||||
use olm_rs::sas::OlmSas;
|
||||
|
||||
use matrix_sdk_common::{
|
||||
events::key::verification::{
|
||||
accept::{
|
||||
|
@ -40,6 +38,7 @@ use matrix_sdk_common::{
|
|||
identifiers::{DeviceId, EventId, RoomId, UserId},
|
||||
uuid::Uuid,
|
||||
};
|
||||
use olm_rs::sas::OlmSas;
|
||||
use tracing::info;
|
||||
|
||||
use super::{
|
||||
|
@ -51,7 +50,6 @@ use super::{
|
|||
receive_mac_event, SasIds,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
identities::{ReadOnlyDevice, UserIdentities},
|
||||
verification::FlowId,
|
||||
|
@ -62,10 +60,8 @@ const KEY_AGREEMENT_PROTOCOLS: &[KeyAgreementProtocol] =
|
|||
&[KeyAgreementProtocol::Curve25519HkdfSha256];
|
||||
const HASHES: &[HashAlgorithm] = &[HashAlgorithm::Sha256];
|
||||
const MACS: &[MessageAuthenticationCode] = &[MessageAuthenticationCode::HkdfHmacSha256];
|
||||
const STRINGS: &[ShortAuthenticationString] = &[
|
||||
ShortAuthenticationString::Decimal,
|
||||
ShortAuthenticationString::Emoji,
|
||||
];
|
||||
const STRINGS: &[ShortAuthenticationString] =
|
||||
&[ShortAuthenticationString::Decimal, ShortAuthenticationString::Emoji];
|
||||
|
||||
// The max time a SAS flow can take from start to done.
|
||||
const MAX_AGE: Duration = Duration::from_secs(60 * 5);
|
||||
|
@ -91,9 +87,7 @@ impl TryFrom<AcceptV1Content> for AcceptedProtocols {
|
|||
if !KEY_AGREEMENT_PROTOCOLS.contains(&content.key_agreement_protocol)
|
||||
|| !HASHES.contains(&content.hash)
|
||||
|| !MACS.contains(&content.message_authentication_code)
|
||||
|| (!content
|
||||
.short_authentication_string
|
||||
.contains(&ShortAuthenticationString::Emoji)
|
||||
|| (!content.short_authentication_string.contains(&ShortAuthenticationString::Emoji)
|
||||
&& !content
|
||||
.short_authentication_string
|
||||
.contains(&ShortAuthenticationString::Decimal))
|
||||
|
@ -402,11 +396,7 @@ impl SasState<Created> {
|
|||
) -> SasState<Created> {
|
||||
SasState {
|
||||
inner: Arc::new(Mutex::new(OlmSas::new())),
|
||||
ids: SasIds {
|
||||
account,
|
||||
other_device,
|
||||
other_identity,
|
||||
},
|
||||
ids: SasIds { account, other_device, other_identity },
|
||||
verification_flow_id: flow_id.into(),
|
||||
|
||||
creation_time: Arc::new(Instant::now()),
|
||||
|
@ -441,9 +431,7 @@ impl SasState<Created> {
|
|||
MSasV1Content::new(self.state.protocol_definitions.clone())
|
||||
.expect("Invalid initial protocol definitions."),
|
||||
),
|
||||
relation: Relation {
|
||||
event_id: e.clone(),
|
||||
},
|
||||
relation: Relation { event_id: e.clone() },
|
||||
},
|
||||
),
|
||||
}
|
||||
|
@ -490,8 +478,8 @@ impl SasState<Created> {
|
|||
}
|
||||
|
||||
impl SasState<Started> {
|
||||
/// Create a new SAS verification flow from an in-room m.key.verification.start
|
||||
/// event.
|
||||
/// Create a new SAS verification flow from an in-room
|
||||
/// m.key.verification.start event.
|
||||
///
|
||||
/// This will put us in the `started` state.
|
||||
///
|
||||
|
@ -549,11 +537,7 @@ impl SasState<Started> {
|
|||
Ok(SasState {
|
||||
inner: Arc::new(Mutex::new(sas)),
|
||||
|
||||
ids: SasIds {
|
||||
account,
|
||||
other_device,
|
||||
other_identity,
|
||||
},
|
||||
ids: SasIds { account, other_device, other_identity },
|
||||
|
||||
creation_time: Arc::new(Instant::now()),
|
||||
last_event_time: Arc::new(Instant::now()),
|
||||
|
@ -605,19 +589,12 @@ impl SasState<Started> {
|
|||
);
|
||||
|
||||
match self.verification_flow_id.as_ref() {
|
||||
FlowId::ToDevice(s) => AcceptToDeviceEventContent {
|
||||
transaction_id: s.to_string(),
|
||||
method,
|
||||
FlowId::ToDevice(s) => {
|
||||
AcceptToDeviceEventContent { transaction_id: s.to_string(), method }.into()
|
||||
}
|
||||
.into(),
|
||||
FlowId::InRoom(r, e) => (
|
||||
r.clone(),
|
||||
AcceptEventContent {
|
||||
method,
|
||||
relation: Relation {
|
||||
event_id: e.clone(),
|
||||
},
|
||||
},
|
||||
AcceptEventContent { method, relation: Relation { event_id: e.clone() } },
|
||||
)
|
||||
.into(),
|
||||
}
|
||||
|
@ -683,10 +660,8 @@ impl SasState<Accepted> {
|
|||
self.check_event(&sender, content.flow_id().as_str())
|
||||
.map_err(|c| self.clone().cancel(c))?;
|
||||
|
||||
let commitment = calculate_commitment(
|
||||
content.public_key(),
|
||||
self.state.start_content.as_ref().clone(),
|
||||
);
|
||||
let commitment =
|
||||
calculate_commitment(content.public_key(), self.state.start_content.as_ref().clone());
|
||||
|
||||
if self.state.commitment != commitment {
|
||||
Err(self.cancel(CancelCode::InvalidMessage))
|
||||
|
@ -728,9 +703,7 @@ impl SasState<Accepted> {
|
|||
r.clone(),
|
||||
KeyEventContent {
|
||||
key: self.inner.lock().unwrap().public_key(),
|
||||
relation: Relation {
|
||||
event_id: e.clone(),
|
||||
},
|
||||
relation: Relation { event_id: e.clone() },
|
||||
},
|
||||
)
|
||||
.into(),
|
||||
|
@ -754,9 +727,7 @@ impl SasState<KeyReceived> {
|
|||
r.clone(),
|
||||
KeyEventContent {
|
||||
key: self.inner.lock().unwrap().public_key(),
|
||||
relation: Relation {
|
||||
event_id: e.clone(),
|
||||
},
|
||||
relation: Relation { event_id: e.clone() },
|
||||
},
|
||||
)
|
||||
.into(),
|
||||
|
@ -779,8 +750,8 @@ impl SasState<KeyReceived> {
|
|||
|
||||
/// Get the index of the emoji of the short authentication string.
|
||||
///
|
||||
/// Returns seven u8 numbers in the range from 0 to 63 inclusive, those numbers
|
||||
/// can be converted to a unique emoji defined by the spec.
|
||||
/// Returns seven u8 numbers in the range from 0 to 63 inclusive, those
|
||||
/// numbers can be converted to a unique emoji defined by the spec.
|
||||
pub fn get_emoji_index(&self) -> [u8; 7] {
|
||||
get_emoji_index(
|
||||
&self.inner.lock().unwrap(),
|
||||
|
@ -952,11 +923,7 @@ impl SasState<Confirmed> {
|
|||
///
|
||||
/// The content needs to be automatically sent to the other side.
|
||||
pub fn as_content(&self) -> MacContent {
|
||||
get_mac_content(
|
||||
&self.inner.lock().unwrap(),
|
||||
&self.ids,
|
||||
&self.verification_flow_id,
|
||||
)
|
||||
get_mac_content(&self.inner.lock().unwrap(), &self.ids, &self.verification_flow_id)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1015,8 +982,8 @@ impl SasState<MacReceived> {
|
|||
|
||||
/// Get the index of the emoji of the short authentication string.
|
||||
///
|
||||
/// Returns seven u8 numbers in the range from 0 to 63 inclusive, those numbers
|
||||
/// can be converted to a unique emoji defined by the spec.
|
||||
/// Returns seven u8 numbers in the range from 0 to 63 inclusive, those
|
||||
/// numbers can be converted to a unique emoji defined by the spec.
|
||||
pub fn get_emoji_index(&self) -> [u8; 7] {
|
||||
get_emoji_index(
|
||||
&self.inner.lock().unwrap(),
|
||||
|
@ -1048,11 +1015,7 @@ impl SasState<WaitingForDone> {
|
|||
/// The content needs to be automatically sent to the other side if it
|
||||
/// wasn't already sent.
|
||||
pub fn as_content(&self) -> MacContent {
|
||||
get_mac_content(
|
||||
&self.inner.lock().unwrap(),
|
||||
&self.ids,
|
||||
&self.verification_flow_id,
|
||||
)
|
||||
get_mac_content(&self.inner.lock().unwrap(), &self.ids, &self.verification_flow_id)
|
||||
}
|
||||
|
||||
pub fn done_content(&self) -> DoneContent {
|
||||
|
@ -1060,15 +1023,9 @@ impl SasState<WaitingForDone> {
|
|||
FlowId::ToDevice(_) => {
|
||||
unreachable!("The done content isn't supported yet for to-device verifications")
|
||||
}
|
||||
FlowId::InRoom(r, e) => (
|
||||
r.clone(),
|
||||
DoneEventContent {
|
||||
relation: Relation {
|
||||
event_id: e.clone(),
|
||||
},
|
||||
},
|
||||
)
|
||||
.into(),
|
||||
FlowId::InRoom(r, e) => {
|
||||
(r.clone(), DoneEventContent { relation: Relation { event_id: e.clone() } }).into()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1110,11 +1067,7 @@ impl SasState<Done> {
|
|||
/// The content needs to be automatically sent to the other side if it
|
||||
/// wasn't already sent.
|
||||
pub fn as_content(&self) -> MacContent {
|
||||
get_mac_content(
|
||||
&self.inner.lock().unwrap(),
|
||||
&self.ids,
|
||||
&self.verification_flow_id,
|
||||
)
|
||||
get_mac_content(&self.inner.lock().unwrap(), &self.ids, &self.verification_flow_id)
|
||||
}
|
||||
|
||||
pub fn done_content(&self) -> DoneContent {
|
||||
|
@ -1122,15 +1075,9 @@ impl SasState<Done> {
|
|||
FlowId::ToDevice(_) => {
|
||||
unreachable!("The done content isn't supported yet for to-device verifications")
|
||||
}
|
||||
FlowId::InRoom(r, e) => (
|
||||
r.clone(),
|
||||
DoneEventContent {
|
||||
relation: Relation {
|
||||
event_id: e.clone(),
|
||||
},
|
||||
},
|
||||
)
|
||||
.into(),
|
||||
FlowId::InRoom(r, e) => {
|
||||
(r.clone(), DoneEventContent { relation: Relation { event_id: e.clone() } }).into()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1166,10 +1113,7 @@ impl Canceled {
|
|||
_ => unimplemented!(),
|
||||
};
|
||||
|
||||
Canceled {
|
||||
cancel_code: code,
|
||||
reason,
|
||||
}
|
||||
Canceled { cancel_code: code, reason }
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1188,9 +1132,7 @@ impl SasState<Canceled> {
|
|||
CancelEventContent {
|
||||
reason: self.state.reason.to_string(),
|
||||
code: self.state.cancel_code.clone(),
|
||||
relation: Relation {
|
||||
event_id: e.clone(),
|
||||
},
|
||||
relation: Relation { event_id: e.clone() },
|
||||
},
|
||||
)
|
||||
.into(),
|
||||
|
@ -1202,10 +1144,6 @@ impl SasState<Canceled> {
|
|||
mod test {
|
||||
use std::convert::TryFrom;
|
||||
|
||||
use crate::{
|
||||
verification::sas::{event_enums::AcceptContent, StartContent},
|
||||
ReadOnlyAccount, ReadOnlyDevice,
|
||||
};
|
||||
use matrix_sdk_common::{
|
||||
events::key::verification::{
|
||||
accept::{AcceptMethod, CustomContent},
|
||||
|
@ -1215,6 +1153,10 @@ mod test {
|
|||
};
|
||||
|
||||
use super::{Accepted, Created, SasState, Started};
|
||||
use crate::{
|
||||
verification::sas::{event_enums::AcceptContent, StartContent},
|
||||
ReadOnlyAccount, ReadOnlyDevice,
|
||||
};
|
||||
|
||||
fn alice_id() -> UserId {
|
||||
UserId::try_from("@alice:example.org").unwrap()
|
||||
|
@ -1353,9 +1295,7 @@ mod test {
|
|||
|
||||
let content = bob.as_content();
|
||||
let sender = UserId::try_from("@malory:example.org").unwrap();
|
||||
alice
|
||||
.into_accepted(&sender, content)
|
||||
.expect_err("Didn't cancel on a invalid sender");
|
||||
alice.into_accepted(&sender, content).expect_err("Didn't cancel on a invalid sender");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
use std::{collections::HashMap, panic};
|
||||
|
||||
use http::Response;
|
||||
|
||||
use matrix_sdk_common::{
|
||||
api::r0::sync::sync_events::Response as SyncResponse,
|
||||
events::{
|
||||
|
@ -11,9 +10,8 @@ use matrix_sdk_common::{
|
|||
identifiers::{room_id, RoomId},
|
||||
IncomingResponse,
|
||||
};
|
||||
use serde_json::Value as JsonValue;
|
||||
|
||||
pub use matrix_sdk_test_macros::async_test;
|
||||
use serde_json::Value as JsonValue;
|
||||
|
||||
pub mod test_json;
|
||||
|
||||
|
@ -44,16 +42,17 @@ pub enum EventsJson {
|
|||
Typing,
|
||||
}
|
||||
|
||||
/// The `EventBuilder` struct can be used to easily generate valid sync responses for testing.
|
||||
/// These can be then fed into either `Client` or `Room`.
|
||||
/// The `EventBuilder` struct can be used to easily generate valid sync
|
||||
/// responses for testing. These can be then fed into either `Client` or `Room`.
|
||||
///
|
||||
/// It supports generated a number of canned events, such as a member entering a room, his power
|
||||
/// level and display name changing and similar. It also supports insertion of custom events in the
|
||||
/// form of `EventsJson` values.
|
||||
/// It supports generated a number of canned events, such as a member entering a
|
||||
/// room, his power level and display name changing and similar. It also
|
||||
/// supports insertion of custom events in the form of `EventsJson` values.
|
||||
///
|
||||
/// **Important** You *must* use the *same* builder when sending multiple sync responses to
|
||||
/// a single client. Otherwise, the subsequent responses will be *ignored* by the client because
|
||||
/// the `next_batch` sync token will not be rotated properly.
|
||||
/// **Important** You *must* use the *same* builder when sending multiple sync
|
||||
/// responses to a single client. Otherwise, the subsequent responses will be
|
||||
/// *ignored* by the client because the `next_batch` sync token will not be
|
||||
/// rotated properly.
|
||||
///
|
||||
/// # Example usage
|
||||
///
|
||||
|
@ -94,7 +93,8 @@ pub struct EventBuilder {
|
|||
ephemeral: Vec<AnySyncEphemeralRoomEvent>,
|
||||
/// The account data events that determine the state of a `Room`.
|
||||
account_data: Vec<AnyGlobalAccountDataEvent>,
|
||||
/// Internal counter to enable the `prev_batch` and `next_batch` of each sync response to vary.
|
||||
/// Internal counter to enable the `prev_batch` and `next_batch` of each
|
||||
/// sync response to vary.
|
||||
batch_counter: i64,
|
||||
}
|
||||
|
||||
|
@ -154,10 +154,7 @@ impl EventBuilder {
|
|||
}
|
||||
|
||||
fn add_joined_event(&mut self, room_id: &RoomId, event: AnySyncRoomEvent) {
|
||||
self.joined_room_events
|
||||
.entry(room_id.clone())
|
||||
.or_insert_with(Vec::new)
|
||||
.push(event);
|
||||
self.joined_room_events.entry(room_id.clone()).or_insert_with(Vec::new).push(event);
|
||||
}
|
||||
|
||||
pub fn add_custom_invited_event(
|
||||
|
@ -166,10 +163,7 @@ impl EventBuilder {
|
|||
event: serde_json::Value,
|
||||
) -> &mut Self {
|
||||
let event = serde_json::from_value::<AnySyncStateEvent>(event).unwrap();
|
||||
self.invited_room_events
|
||||
.entry(room_id.clone())
|
||||
.or_insert_with(Vec::new)
|
||||
.push(event);
|
||||
self.invited_room_events.entry(room_id.clone()).or_insert_with(Vec::new).push(event);
|
||||
self
|
||||
}
|
||||
|
||||
|
@ -179,10 +173,7 @@ impl EventBuilder {
|
|||
event: serde_json::Value,
|
||||
) -> &mut Self {
|
||||
let event = serde_json::from_value::<AnySyncRoomEvent>(event).unwrap();
|
||||
self.left_room_events
|
||||
.entry(room_id.clone())
|
||||
.or_insert_with(Vec::new)
|
||||
.push(event);
|
||||
self.left_room_events.entry(room_id.clone()).or_insert_with(Vec::new).push(event);
|
||||
self
|
||||
}
|
||||
|
||||
|
@ -227,7 +218,8 @@ impl EventBuilder {
|
|||
pub fn build_json_sync_response(&mut self) -> JsonValue {
|
||||
let main_room_id = room_id!("!SVkFJHzfwvuaIEawgC:localhost");
|
||||
|
||||
// First time building a sync response, so initialize the `prev_batch` to a default one.
|
||||
// First time building a sync response, so initialize the `prev_batch` to a
|
||||
// default one.
|
||||
let prev_batch = self.generate_sync_token();
|
||||
self.batch_counter += 1;
|
||||
let next_batch = self.generate_sync_token();
|
||||
|
@ -352,9 +344,7 @@ impl EventBuilder {
|
|||
pub fn build_sync_response(&mut self) -> SyncResponse {
|
||||
let body = self.build_json_sync_response();
|
||||
|
||||
let response = Response::builder()
|
||||
.body(serde_json::to_vec(&body).unwrap())
|
||||
.unwrap();
|
||||
let response = Response::builder().body(serde_json::to_vec(&body).unwrap()).unwrap();
|
||||
|
||||
SyncResponse::try_from_http_response(response).unwrap()
|
||||
}
|
||||
|
@ -395,15 +385,10 @@ pub fn sync_response(kind: SyncResponseFile) -> SyncResponse {
|
|||
SyncResponseFile::Voip => &test_json::VOIP_SYNC,
|
||||
};
|
||||
|
||||
let response = Response::builder()
|
||||
.body(data.to_string().as_bytes().to_vec())
|
||||
.unwrap();
|
||||
let response = Response::builder().body(data.to_string().as_bytes().to_vec()).unwrap();
|
||||
SyncResponse::try_from_http_response(response).unwrap()
|
||||
}
|
||||
|
||||
pub fn response_from_file(json: &serde_json::Value) -> Response<Vec<u8>> {
|
||||
Response::builder()
|
||||
.status(200)
|
||||
.body(json.to_string().as_bytes().to_vec())
|
||||
.unwrap()
|
||||
Response::builder().status(200).body(json.to_string().as_bytes().to_vec()).unwrap()
|
||||
}
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
//! Test data for the matrix-sdk crates.
|
||||
//!
|
||||
//! Exporting each const allows all the test data to have a single source of truth.
|
||||
//! When running `cargo publish` no external folders are allowed so all the
|
||||
//! test data needs to be contained within this crate.
|
||||
//! Exporting each const allows all the test data to have a single source of
|
||||
//! truth. When running `cargo publish` no external folders are allowed so all
|
||||
//! the test data needs to be contained within this crate.
|
||||
|
||||
use lazy_static::lazy_static;
|
||||
use serde_json::{json, Value as JsonValue};
|
||||
|
@ -17,12 +17,11 @@ pub use events::{
|
|||
PUBLIC_ROOMS, REACTION, REDACTED, REDACTED_INVALID, REDACTED_STATE, REDACTION,
|
||||
REGISTRATION_RESPONSE_ERR, ROOM_ID, ROOM_MESSAGES, TYPING,
|
||||
};
|
||||
pub use members::MEMBERS;
|
||||
pub use sync::{
|
||||
DEFAULT_SYNC_SUMMARY, INVITE_SYNC, LEAVE_SYNC, LEAVE_SYNC_EVENT, MORE_SYNC, SYNC, VOIP_SYNC,
|
||||
};
|
||||
|
||||
pub use members::MEMBERS;
|
||||
|
||||
lazy_static! {
|
||||
pub static ref DEVICES: JsonValue = json!({
|
||||
"devices": [
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
use proc_macro::TokenStream;
|
||||
|
||||
/// Attribute to use `wasm_bindgen_test` for wasm32 targets and `tokio::test` for everything else
|
||||
/// Attribute to use `wasm_bindgen_test` for wasm32 targets and `tokio::test`
|
||||
/// for everything else
|
||||
#[proc_macro_attribute]
|
||||
pub fn async_test(_attr: TokenStream, item: TokenStream) -> TokenStream {
|
||||
let attrs = r#"
|
||||
|
|
Loading…
Reference in a new issue