Merge branch 'master' into sas-longer-flow

master
Damir Jelić 2021-05-13 11:26:40 +02:00
commit 3f57a2a9f2
79 changed files with 2020 additions and 4180 deletions

View File

@ -20,7 +20,7 @@ jobs:
- name: Install rust - name: Install rust
uses: actions-rs/toolchain@v1 uses: actions-rs/toolchain@v1
with: with:
toolchain: stable toolchain: nightly
components: rustfmt components: rustfmt
profile: minimal profile: minimal
override: true override: true

6
.rustfmt.toml Normal file
View File

@ -0,0 +1,6 @@
max_width = 100
comment_width = 80
wrap_comments = true
imports_granularity = "Crate"
use_small_heuristics = "Max"
group_imports = "StdExternalCrate"

View File

@ -1,5 +1,4 @@
use std::{env, process::exit}; use std::{env, process::exit};
use tokio::time::{sleep, Duration};
use matrix_sdk::{ use matrix_sdk::{
self, async_trait, self, async_trait,
@ -7,6 +6,7 @@ use matrix_sdk::{
room::Room, room::Room,
Client, ClientConfig, EventHandler, SyncSettings, Client, ClientConfig, EventHandler, SyncSettings,
}; };
use tokio::time::{sleep, Duration};
use url::Url; use url::Url;
struct AutoJoinBot { struct AutoJoinBot {
@ -72,15 +72,11 @@ async fn login_and_sync(
let homeserver_url = Url::parse(&homeserver_url).expect("Couldn't parse the homeserver URL"); let homeserver_url = Url::parse(&homeserver_url).expect("Couldn't parse the homeserver URL");
let client = Client::new_with_config(homeserver_url, client_config).unwrap(); let client = Client::new_with_config(homeserver_url, client_config).unwrap();
client client.login(username, password, None, Some("autojoin bot")).await?;
.login(username, password, None, Some("autojoin bot"))
.await?;
println!("logged in as {}", username); println!("logged in as {}", username);
client client.set_event_handler(Box::new(AutoJoinBot::new(client.clone()))).await;
.set_event_handler(Box::new(AutoJoinBot::new(client.clone())))
.await;
client.sync(SyncSettings::default()).await; client.sync(SyncSettings::default()).await;

View File

@ -69,24 +69,23 @@ async fn login_and_sync(
// create a new Client with the given homeserver url and config // create a new Client with the given homeserver url and config
let client = Client::new_with_config(homeserver_url, client_config).unwrap(); let client = Client::new_with_config(homeserver_url, client_config).unwrap();
client client.login(&username, &password, None, Some("command bot")).await?;
.login(&username, &password, None, Some("command bot"))
.await?;
println!("logged in as {}", username); println!("logged in as {}", username);
// An initial sync to set up state and so our bot doesn't respond to old messages. // An initial sync to set up state and so our bot doesn't respond to old
// If the `StateStore` finds saved state in the location given the initial sync will // messages. If the `StateStore` finds saved state in the location given the
// be skipped in favor of loading state from the store // initial sync will be skipped in favor of loading state from the store
client.sync_once(SyncSettings::default()).await.unwrap(); client.sync_once(SyncSettings::default()).await.unwrap();
// add our CommandBot to be notified of incoming messages, we do this after the initial // add our CommandBot to be notified of incoming messages, we do this after the
// sync to avoid responding to messages before the bot was running. // initial sync to avoid responding to messages before the bot was running.
client.set_event_handler(Box::new(CommandBot::new())).await; client.set_event_handler(Box::new(CommandBot::new())).await;
// since we called `sync_once` before we entered our sync loop we must pass // since we called `sync_once` before we entered our sync loop we must pass
// that sync token to `sync` // that sync token to `sync`
let settings = SyncSettings::default().token(client.sync_token().await.unwrap()); let settings = SyncSettings::default().token(client.sync_token().await.unwrap());
// this keeps state from the server streaming in to CommandBot via the EventHandler trait // this keeps state from the server streaming in to CommandBot via the
// EventHandler trait
client.sync(settings).await; client.sync(settings).await;
Ok(()) Ok(())

View File

@ -5,12 +5,11 @@ use std::{
sync::atomic::{AtomicBool, Ordering}, sync::atomic::{AtomicBool, Ordering},
}; };
use serde_json::json;
use url::Url;
use matrix_sdk::{ use matrix_sdk::{
self, api::r0::uiaa::AuthData, identifiers::UserId, Client, LoopCtrl, SyncSettings, self, api::r0::uiaa::AuthData, identifiers::UserId, Client, LoopCtrl, SyncSettings,
}; };
use serde_json::json;
use url::Url;
fn auth_data<'a>(user: &UserId, password: &str, session: Option<&'a str>) -> AuthData<'a> { fn auth_data<'a>(user: &UserId, password: &str, session: Option<&'a str>) -> AuthData<'a> {
let mut auth_parameters = BTreeMap::new(); let mut auth_parameters = BTreeMap::new();
@ -22,11 +21,7 @@ fn auth_data<'a>(user: &UserId, password: &str, session: Option<&'a str>) -> Aut
auth_parameters.insert("identifier".to_owned(), identifier); auth_parameters.insert("identifier".to_owned(), identifier);
auth_parameters.insert("password".to_owned(), password.to_owned().into()); auth_parameters.insert("password".to_owned(), password.to_owned().into());
AuthData::DirectRequest { AuthData::DirectRequest { kind: "m.login.password", auth_parameters, session }
kind: "m.login.password",
auth_parameters,
session,
}
} }
async fn bootstrap(client: Client, user_id: UserId, password: String) { async fn bootstrap(client: Client, user_id: UserId, password: String) {
@ -34,9 +29,7 @@ async fn bootstrap(client: Client, user_id: UserId, password: String) {
let mut input = String::new(); let mut input = String::new();
io::stdin() io::stdin().read_line(&mut input).expect("error: unable to read user input");
.read_line(&mut input)
.expect("error: unable to read user input");
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
if let Err(e) = client.bootstrap_cross_signing(None).await { if let Err(e) = client.bootstrap_cross_signing(None).await {
@ -63,9 +56,7 @@ async fn login(
let homeserver_url = Url::parse(&homeserver_url).expect("Couldn't parse the homeserver URL"); let homeserver_url = Url::parse(&homeserver_url).expect("Couldn't parse the homeserver URL");
let client = Client::new(homeserver_url).unwrap(); let client = Client::new(homeserver_url).unwrap();
let response = client let response = client.login(username, password, None, Some("rust-sdk")).await?;
.login(username, password, None, Some("rust-sdk"))
.await?;
let user_id = &response.user_id; let user_id = &response.user_id;
let client_ref = &client; let client_ref = &client;

View File

@ -6,7 +6,6 @@ use std::{
Arc, Arc,
}, },
}; };
use url::Url;
use matrix_sdk::{ use matrix_sdk::{
self, self,
@ -14,14 +13,13 @@ use matrix_sdk::{
identifiers::UserId, identifiers::UserId,
Client, LoopCtrl, Sas, SyncSettings, Client, LoopCtrl, Sas, SyncSettings,
}; };
use url::Url;
async fn wait_for_confirmation(client: Client, sas: Sas) { async fn wait_for_confirmation(client: Client, sas: Sas) {
println!("Does the emoji match: {:?}", sas.emoji()); println!("Does the emoji match: {:?}", sas.emoji());
let mut input = String::new(); let mut input = String::new();
io::stdin() io::stdin().read_line(&mut input).expect("error: unable to read user input");
.read_line(&mut input)
.expect("error: unable to read user input");
match input.trim().to_lowercase().as_ref() { match input.trim().to_lowercase().as_ref() {
"yes" | "true" | "ok" => { "yes" | "true" | "ok" => {
@ -68,9 +66,7 @@ async fn login(
let homeserver_url = Url::parse(&homeserver_url).expect("Couldn't parse the homeserver URL"); let homeserver_url = Url::parse(&homeserver_url).expect("Couldn't parse the homeserver URL");
let client = Client::new(homeserver_url).unwrap(); let client = Client::new(homeserver_url).unwrap();
client client.login(username, password, None, Some("rust-sdk")).await?;
.login(username, password, None, Some("rust-sdk"))
.await?;
let client_ref = &client; let client_ref = &client;
let initial_sync = Arc::new(AtomicBool::from(true)); let initial_sync = Arc::new(AtomicBool::from(true));
@ -81,12 +77,7 @@ async fn login(
let client = &client_ref; let client = &client_ref;
let initial = &initial_ref; let initial = &initial_ref;
for event in response for event in response.to_device.events.iter().filter_map(|e| e.deserialize().ok()) {
.to_device
.events
.iter()
.filter_map(|e| e.deserialize().ok())
{
match event { match event {
AnyToDeviceEvent::KeyVerificationStart(e) => { AnyToDeviceEvent::KeyVerificationStart(e) => {
let sas = client let sas = client
@ -129,11 +120,8 @@ async fn login(
if !initial.load(Ordering::SeqCst) { if !initial.load(Ordering::SeqCst) {
for (_room_id, room_info) in response.rooms.join { for (_room_id, room_info) in response.rooms.join {
for event in room_info for event in
.timeline room_info.timeline.events.iter().filter_map(|e| e.event.deserialize().ok())
.events
.iter()
.filter_map(|e| e.event.deserialize().ok())
{ {
if let AnySyncRoomEvent::Message(event) = event { if let AnySyncRoomEvent::Message(event) = event {
match event { match event {

View File

@ -1,13 +1,12 @@
use std::{convert::TryFrom, env, process::exit}; use std::{convert::TryFrom, env, process::exit};
use url::Url;
use matrix_sdk::{ use matrix_sdk::{
self, self,
api::r0::profile, api::r0::profile,
identifiers::{MxcUri, UserId}, identifiers::{MxcUri, UserId},
Client, Result as MatrixResult, Client, Result as MatrixResult,
}; };
use url::Url;
#[derive(Debug)] #[derive(Debug)]
struct UserProfile { struct UserProfile {
@ -29,10 +28,7 @@ async fn get_profile(client: Client, mxid: &UserId) -> MatrixResult<UserProfile>
// Use the response and construct a UserProfile struct. // Use the response and construct a UserProfile struct.
// See https://docs.rs/ruma-client-api/0.9.0/ruma_client_api/r0/profile/get_profile/struct.Response.html // See https://docs.rs/ruma-client-api/0.9.0/ruma_client_api/r0/profile/get_profile/struct.Response.html
// for details on the Response for this Request // for details on the Response for this Request
let user_profile = UserProfile { let user_profile = UserProfile { avatar_url: resp.avatar_url, displayname: resp.displayname };
avatar_url: resp.avatar_url,
displayname: resp.displayname,
};
Ok(user_profile) Ok(user_profile)
} }
@ -44,9 +40,7 @@ async fn login(
let homeserver_url = Url::parse(&homeserver_url).expect("Couldn't parse the homeserver URL"); let homeserver_url = Url::parse(&homeserver_url).expect("Couldn't parse the homeserver URL");
let client = Client::new(homeserver_url).unwrap(); let client = Client::new(homeserver_url).unwrap();
client client.login(username, password, None, Some("rust-sdk")).await?;
.login(username, password, None, Some("rust-sdk"))
.await?;
Ok(client) Ok(client)
} }

View File

@ -6,7 +6,6 @@ use std::{
process::exit, process::exit,
sync::Arc, sync::Arc,
}; };
use tokio::sync::Mutex;
use matrix_sdk::{ use matrix_sdk::{
self, async_trait, self, async_trait,
@ -17,6 +16,7 @@ use matrix_sdk::{
room::Room, room::Room,
Client, EventHandler, SyncSettings, Client, EventHandler, SyncSettings,
}; };
use tokio::sync::Mutex;
use url::Url; use url::Url;
struct ImageBot { struct ImageBot {
@ -52,9 +52,7 @@ impl EventHandler for ImageBot {
println!("sending image"); println!("sending image");
let mut image = self.image.lock().await; let mut image = self.image.lock().await;
room.send_attachment("cat", &mime::IMAGE_JPEG, &mut *image, None) room.send_attachment("cat", &mime::IMAGE_JPEG, &mut *image, None).await.unwrap();
.await
.unwrap();
image.seek(SeekFrom::Start(0)).unwrap(); image.seek(SeekFrom::Start(0)).unwrap();
@ -73,14 +71,10 @@ async fn login_and_sync(
let homeserver_url = Url::parse(&homeserver_url).expect("Couldn't parse the homeserver URL"); let homeserver_url = Url::parse(&homeserver_url).expect("Couldn't parse the homeserver URL");
let client = Client::new(homeserver_url).unwrap(); let client = Client::new(homeserver_url).unwrap();
client client.login(&username, &password, None, Some("command bot")).await?;
.login(&username, &password, None, Some("command bot"))
.await?;
client.sync_once(SyncSettings::default()).await.unwrap(); client.sync_once(SyncSettings::default()).await.unwrap();
client client.set_event_handler(Box::new(ImageBot::new(image))).await;
.set_event_handler(Box::new(ImageBot::new(image)))
.await;
let settings = SyncSettings::default().token(client.sync_token().await.unwrap()); let settings = SyncSettings::default().token(client.sync_token().await.unwrap());
client.sync(settings).await; client.sync(settings).await;
@ -91,26 +85,19 @@ async fn login_and_sync(
#[tokio::main] #[tokio::main]
async fn main() -> Result<(), matrix_sdk::Error> { async fn main() -> Result<(), matrix_sdk::Error> {
tracing_subscriber::fmt::init(); tracing_subscriber::fmt::init();
let (homeserver_url, username, password, image_path) = match ( let (homeserver_url, username, password, image_path) =
env::args().nth(1), match (env::args().nth(1), env::args().nth(2), env::args().nth(3), env::args().nth(4)) {
env::args().nth(2), (Some(a), Some(b), Some(c), Some(d)) => (a, b, c, d),
env::args().nth(3), _ => {
env::args().nth(4), eprintln!(
) { "Usage: {} <homeserver_url> <username> <password> <image>",
(Some(a), Some(b), Some(c), Some(d)) => (a, b, c, d), env::args().next().unwrap()
_ => { );
eprintln!( exit(1)
"Usage: {} <homeserver_url> <username> <password> <image>", }
env::args().next().unwrap() };
);
exit(1)
}
};
println!( println!("helloooo {} {} {} {:#?}", homeserver_url, username, password, image_path);
"helloooo {} {} {} {:#?}",
homeserver_url, username, password, image_path
);
let path = PathBuf::from(image_path); let path = PathBuf::from(image_path);
let image = File::open(path).expect("Can't open image file."); let image = File::open(path).expect("Can't open image file.");

View File

@ -1,5 +1,4 @@
use std::{env, process::exit}; use std::{env, process::exit};
use url::Url;
use matrix_sdk::{ use matrix_sdk::{
self, async_trait, self, async_trait,
@ -10,6 +9,7 @@ use matrix_sdk::{
room::Room, room::Room,
Client, EventHandler, SyncSettings, Client, EventHandler, SyncSettings,
}; };
use url::Url;
struct EventCallback; struct EventCallback;
@ -28,9 +28,7 @@ impl EventHandler for EventCallback {
} = event } = event
{ {
let member = room.get_member(&sender).await.unwrap().unwrap(); let member = room.get_member(&sender).await.unwrap().unwrap();
let name = member let name = member.display_name().unwrap_or_else(|| member.user_id().as_str());
.display_name()
.unwrap_or_else(|| member.user_id().as_str());
println!("{}: {}", name, msg_body); println!("{}: {}", name, msg_body);
} }
} }
@ -47,9 +45,7 @@ async fn login(
client.set_event_handler(Box::new(EventCallback)).await; client.set_event_handler(Box::new(EventCallback)).await;
client client.login(username, password, None, Some("rust-sdk")).await?;
.login(username, password, None, Some("rust-sdk"))
.await?;
client.sync(SyncSettings::new()).await; client.sync(SyncSettings::new()).await;
Ok(()) Ok(())

File diff suppressed because it is too large Load Diff

View File

@ -40,7 +40,8 @@ impl Deref for Device {
impl Device { impl Device {
/// Start a interactive verification with this `Device` /// Start a interactive verification with this `Device`
/// ///
/// Returns a `Sas` object that represents the interactive verification flow. /// Returns a `Sas` object that represents the interactive verification
/// flow.
/// ///
/// # Example /// # Example
/// ///
@ -65,10 +66,7 @@ impl Device {
let (sas, request) = self.inner.start_verification().await?; let (sas, request) = self.inner.start_verification().await?;
self.client.send_to_device(&request).await?; self.client.send_to_device(&request).await?;
Ok(Sas { Ok(Sas { inner: sas, client: self.client.clone() })
inner: sas,
client: self.client.clone(),
})
} }
/// Is the device trusted. /// Is the device trusted.
@ -102,10 +100,7 @@ pub struct UserDevices {
impl UserDevices { impl UserDevices {
/// Get the specific device with the given device id. /// Get the specific device with the given device id.
pub fn get(&self, device_id: &DeviceId) -> Option<Device> { pub fn get(&self, device_id: &DeviceId) -> Option<Device> {
self.inner.get(device_id).map(|d| Device { self.inner.get(device_id).map(|d| Device { inner: d, client: self.client.clone() })
inner: d,
client: self.client.clone(),
})
} }
/// Iterator over all the device ids of the user devices. /// Iterator over all the device ids of the user devices.
@ -117,9 +112,6 @@ impl UserDevices {
pub fn devices(&self) -> impl Iterator<Item = Device> + '_ { pub fn devices(&self) -> impl Iterator<Item = Device> + '_ {
let client = self.client.clone(); let client = self.client.clone();
self.inner.devices().map(move |d| Device { self.inner.devices().map(move |d| Device { inner: d, client: client.clone() })
inner: d,
client: client.clone(),
})
} }
} }

View File

@ -14,7 +14,11 @@
//! Error conditions. //! Error conditions.
use std::io::Error as IoError;
use http::StatusCode; use http::StatusCode;
#[cfg(feature = "encryption")]
use matrix_sdk_base::crypto::store::CryptoStoreError;
use matrix_sdk_base::{Error as MatrixError, StoreError}; use matrix_sdk_base::{Error as MatrixError, StoreError};
use matrix_sdk_common::{ use matrix_sdk_common::{
api::{ api::{
@ -26,12 +30,8 @@ use matrix_sdk_common::{
}; };
use reqwest::Error as ReqwestError; use reqwest::Error as ReqwestError;
use serde_json::Error as JsonError; use serde_json::Error as JsonError;
use std::io::Error as IoError;
use thiserror::Error; use thiserror::Error;
#[cfg(feature = "encryption")]
use matrix_sdk_base::crypto::store::CryptoStoreError;
/// Result type of the rust-sdk. /// Result type of the rust-sdk.
pub type Result<T> = std::result::Result<T, Error>; pub type Result<T> = std::result::Result<T, Error>;
@ -43,11 +43,13 @@ pub enum HttpError {
#[error(transparent)] #[error(transparent)]
Reqwest(#[from] ReqwestError), Reqwest(#[from] ReqwestError),
/// Queried endpoint requires authentication but was called on an anonymous client. /// Queried endpoint requires authentication but was called on an anonymous
/// client.
#[error("the queried endpoint requires authentication but was called before logging in")] #[error("the queried endpoint requires authentication but was called before logging in")]
AuthenticationRequired, AuthenticationRequired,
/// Client tried to force authentication but did not provide an access token. /// Client tried to force authentication but did not provide an access
/// token.
#[error("tried to force authentication but no access token was provided")] #[error("tried to force authentication but no access token was provided")]
ForcedAuthenticationWithoutAccessToken, ForcedAuthenticationWithoutAccessToken,
@ -69,9 +71,10 @@ pub enum HttpError {
/// An error occurred while authenticating. /// An error occurred while authenticating.
/// ///
/// When registering or authenticating the Matrix server can send a `UiaaResponse` /// When registering or authenticating the Matrix server can send a
/// as the error type, this is a User-Interactive Authentication API response. This /// `UiaaResponse` as the error type, this is a User-Interactive
/// represents an error with information about how to authenticate the user. /// Authentication API response. This represents an error with
/// information about how to authenticate the user.
#[error(transparent)] #[error(transparent)]
UiaaError(#[from] FromHttpResponseError<UiaaError>), UiaaError(#[from] FromHttpResponseError<UiaaError>),
@ -96,7 +99,8 @@ pub enum Error {
#[error(transparent)] #[error(transparent)]
Http(#[from] HttpError), Http(#[from] HttpError),
/// Queried endpoint requires authentication but was called on an anonymous client. /// Queried endpoint requires authentication but was called on an anonymous
/// client.
#[error("the queried endpoint requires authentication but was called before logging in")] #[error("the queried endpoint requires authentication but was called before logging in")]
AuthenticationRequired, AuthenticationRequired,

View File

@ -16,6 +16,7 @@ use std::ops::Deref;
use matrix_sdk_common::{ use matrix_sdk_common::{
api::r0::push::get_notifications::Notification, api::r0::push::get_notifications::Notification,
async_trait,
events::{ events::{
fully_read::FullyReadEventContent, AnySyncRoomEvent, GlobalAccountDataEvent, fully_read::FullyReadEventContent, AnySyncRoomEvent, GlobalAccountDataEvent,
RoomAccountDataEvent, RoomAccountDataEvent,
@ -56,7 +57,6 @@ use crate::{
room::Room, room::Room,
Client, Client,
}; };
use matrix_sdk_common::async_trait;
pub(crate) struct Handler { pub(crate) struct Handler {
pub(crate) inner: Box<dyn EventHandler>, pub(crate) inner: Box<dyn EventHandler>,
@ -77,50 +77,29 @@ impl Handler {
} }
pub(crate) async fn handle_sync(&self, response: &SyncResponse) { pub(crate) async fn handle_sync(&self, response: &SyncResponse) {
for event in response for event in response.account_data.events.iter().filter_map(|e| e.deserialize().ok()) {
.account_data
.events
.iter()
.filter_map(|e| e.deserialize().ok())
{
self.handle_account_data_event(&event).await; self.handle_account_data_event(&event).await;
} }
for (room_id, room_info) in &response.rooms.join { for (room_id, room_info) in &response.rooms.join {
if let Some(room) = self.get_room(room_id) { if let Some(room) = self.get_room(room_id) {
for event in room_info for event in room_info.ephemeral.events.iter().filter_map(|e| e.deserialize().ok())
.ephemeral
.events
.iter()
.filter_map(|e| e.deserialize().ok())
{ {
self.handle_ephemeral_event(room.clone(), &event).await; self.handle_ephemeral_event(room.clone(), &event).await;
} }
for event in room_info for event in
.account_data room_info.account_data.events.iter().filter_map(|e| e.deserialize().ok())
.events
.iter()
.filter_map(|e| e.deserialize().ok())
{ {
self.handle_room_account_data_event(room.clone(), &event) self.handle_room_account_data_event(room.clone(), &event).await;
.await;
} }
for event in room_info for event in room_info.state.events.iter().filter_map(|e| e.deserialize().ok()) {
.state
.events
.iter()
.filter_map(|e| e.deserialize().ok())
{
self.handle_state_event(room.clone(), &event).await; self.handle_state_event(room.clone(), &event).await;
} }
for event in room_info for event in
.timeline room_info.timeline.events.iter().filter_map(|e| e.event.deserialize().ok())
.events
.iter()
.filter_map(|e| e.event.deserialize().ok())
{ {
self.handle_timeline_event(room.clone(), &event).await; self.handle_timeline_event(room.clone(), &event).await;
} }
@ -129,30 +108,18 @@ impl Handler {
for (room_id, room_info) in &response.rooms.leave { for (room_id, room_info) in &response.rooms.leave {
if let Some(room) = self.get_room(room_id) { if let Some(room) = self.get_room(room_id) {
for event in room_info for event in
.account_data room_info.account_data.events.iter().filter_map(|e| e.deserialize().ok())
.events
.iter()
.filter_map(|e| e.deserialize().ok())
{ {
self.handle_room_account_data_event(room.clone(), &event) self.handle_room_account_data_event(room.clone(), &event).await;
.await;
} }
for event in room_info for event in room_info.state.events.iter().filter_map(|e| e.deserialize().ok()) {
.state
.events
.iter()
.filter_map(|e| e.deserialize().ok())
{
self.handle_state_event(room.clone(), &event).await; self.handle_state_event(room.clone(), &event).await;
} }
for event in room_info for event in
.timeline room_info.timeline.events.iter().filter_map(|e| e.event.deserialize().ok())
.events
.iter()
.filter_map(|e| e.event.deserialize().ok())
{ {
self.handle_timeline_event(room.clone(), &event).await; self.handle_timeline_event(room.clone(), &event).await;
} }
@ -161,31 +128,22 @@ impl Handler {
for (room_id, room_info) in &response.rooms.invite { for (room_id, room_info) in &response.rooms.invite {
if let Some(room) = self.get_room(room_id) { if let Some(room) = self.get_room(room_id) {
for event in room_info for event in
.invite_state room_info.invite_state.events.iter().filter_map(|e| e.deserialize().ok())
.events
.iter()
.filter_map(|e| e.deserialize().ok())
{ {
self.handle_stripped_state_event(room.clone(), &event).await; self.handle_stripped_state_event(room.clone(), &event).await;
} }
} }
} }
for event in response for event in response.presence.events.iter().filter_map(|e| e.deserialize().ok()) {
.presence
.events
.iter()
.filter_map(|e| e.deserialize().ok())
{
self.on_presence_event(&event).await; self.on_presence_event(&event).await;
} }
for (room_id, notifications) in &response.notifications { for (room_id, notifications) in &response.notifications {
if let Some(room) = self.get_room(&room_id) { if let Some(room) = self.get_room(&room_id) {
for notification in notifications { for notification in notifications {
self.on_room_notification(room.clone(), notification.clone()) self.on_room_notification(room.clone(), notification.clone()).await;
.await;
} }
} }
} }
@ -249,8 +207,7 @@ impl Handler {
self.on_room_tombstone(room, &tomb).await self.on_room_tombstone(room, &tomb).await
} }
AnySyncStateEvent::Custom(custom) => { AnySyncStateEvent::Custom(custom) => {
self.on_custom_event(room, &CustomEvent::State(custom)) self.on_custom_event(room, &CustomEvent::State(custom)).await
.await
} }
_ => {} _ => {}
} }
@ -268,8 +225,7 @@ impl Handler {
} }
AnyStrippedStateEvent::RoomName(name) => self.on_stripped_state_name(room, &name).await, AnyStrippedStateEvent::RoomName(name) => self.on_stripped_state_name(room, &name).await,
AnyStrippedStateEvent::RoomCanonicalAlias(canonical) => { AnyStrippedStateEvent::RoomCanonicalAlias(canonical) => {
self.on_stripped_state_canonical_alias(room, &canonical) self.on_stripped_state_canonical_alias(room, &canonical).await
.await
} }
AnyStrippedStateEvent::RoomAliases(aliases) => { AnyStrippedStateEvent::RoomAliases(aliases) => {
self.on_stripped_state_aliases(room, &aliases).await self.on_stripped_state_aliases(room, &aliases).await
@ -341,8 +297,9 @@ pub enum CustomEvent<'c> {
StrippedState(&'c StrippedStateEvent<CustomEventContent>), StrippedState(&'c StrippedStateEvent<CustomEventContent>),
} }
/// This trait allows any type implementing `EventHandler` to specify event callbacks for each event. /// This trait allows any type implementing `EventHandler` to specify event
/// The `Client` calls each method when the corresponding event is received. /// callbacks for each event. The `Client` calls each method when the
/// corresponding event is received.
/// ///
/// # Examples /// # Examples
/// ``` /// ```
@ -427,8 +384,8 @@ pub trait EventHandler: Send + Sync {
/// Fires when `Client` receives a `RoomEvent::Tombstone` event. /// Fires when `Client` receives a `RoomEvent::Tombstone` event.
async fn on_room_tombstone(&self, _: Room, _: &SyncStateEvent<TombstoneEventContent>) {} async fn on_room_tombstone(&self, _: Room, _: &SyncStateEvent<TombstoneEventContent>) {}
/// Fires when `Client` receives room events that trigger notifications according to /// Fires when `Client` receives room events that trigger notifications
/// the push rules of the user. /// according to the push rules of the user.
async fn on_room_notification(&self, _: Room, _: Notification) {} async fn on_room_notification(&self, _: Room, _: Notification) {}
// `RoomEvent`s from `IncomingState` // `RoomEvent`s from `IncomingState`
@ -453,7 +410,8 @@ pub trait EventHandler: Send + Sync {
async fn on_state_join_rules(&self, _: Room, _: &SyncStateEvent<JoinRulesEventContent>) {} async fn on_state_join_rules(&self, _: Room, _: &SyncStateEvent<JoinRulesEventContent>) {}
// `AnyStrippedStateEvent`s // `AnyStrippedStateEvent`s
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomMember` event. /// Fires when `Client` receives a
/// `AnyStrippedStateEvent::StrippedRoomMember` event.
async fn on_stripped_state_member( async fn on_stripped_state_member(
&self, &self,
_: Room, _: Room,
@ -461,32 +419,38 @@ pub trait EventHandler: Send + Sync {
_: Option<MemberEventContent>, _: Option<MemberEventContent>,
) { ) {
} }
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomName` event. /// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomName`
/// event.
async fn on_stripped_state_name(&self, _: Room, _: &StrippedStateEvent<NameEventContent>) {} async fn on_stripped_state_name(&self, _: Room, _: &StrippedStateEvent<NameEventContent>) {}
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomCanonicalAlias` event. /// Fires when `Client` receives a
/// `AnyStrippedStateEvent::StrippedRoomCanonicalAlias` event.
async fn on_stripped_state_canonical_alias( async fn on_stripped_state_canonical_alias(
&self, &self,
_: Room, _: Room,
_: &StrippedStateEvent<CanonicalAliasEventContent>, _: &StrippedStateEvent<CanonicalAliasEventContent>,
) { ) {
} }
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomAliases` event. /// Fires when `Client` receives a
/// `AnyStrippedStateEvent::StrippedRoomAliases` event.
async fn on_stripped_state_aliases( async fn on_stripped_state_aliases(
&self, &self,
_: Room, _: Room,
_: &StrippedStateEvent<AliasesEventContent>, _: &StrippedStateEvent<AliasesEventContent>,
) { ) {
} }
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomAvatar` event. /// Fires when `Client` receives a
/// `AnyStrippedStateEvent::StrippedRoomAvatar` event.
async fn on_stripped_state_avatar(&self, _: Room, _: &StrippedStateEvent<AvatarEventContent>) {} async fn on_stripped_state_avatar(&self, _: Room, _: &StrippedStateEvent<AvatarEventContent>) {}
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomPowerLevels` event. /// Fires when `Client` receives a
/// `AnyStrippedStateEvent::StrippedRoomPowerLevels` event.
async fn on_stripped_state_power_levels( async fn on_stripped_state_power_levels(
&self, &self,
_: Room, _: Room,
_: &StrippedStateEvent<PowerLevelsEventContent>, _: &StrippedStateEvent<PowerLevelsEventContent>,
) { ) {
} }
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomJoinRules` event. /// Fires when `Client` receives a
/// `AnyStrippedStateEvent::StrippedRoomJoinRules` event.
async fn on_stripped_state_join_rules( async fn on_stripped_state_join_rules(
&self, &self,
_: Room, _: Room,
@ -523,31 +487,33 @@ pub trait EventHandler: Send + Sync {
/// Fires when `Client` receives a `NonRoomEvent::RoomAliases` event. /// Fires when `Client` receives a `NonRoomEvent::RoomAliases` event.
async fn on_presence_event(&self, _: &PresenceEvent) {} async fn on_presence_event(&self, _: &PresenceEvent) {}
/// Fires when `Client` receives a `Event::Custom` event or if deserialization fails /// Fires when `Client` receives a `Event::Custom` event or if
/// because the event was unknown to ruma. /// deserialization fails because the event was unknown to ruma.
/// ///
/// The only guarantee this method can give about the event is that it is valid JSON. /// The only guarantee this method can give about the event is that it is
/// valid JSON.
async fn on_unrecognized_event(&self, _: Room, _: &RawJsonValue) {} async fn on_unrecognized_event(&self, _: Room, _: &RawJsonValue) {}
/// Fires when `Client` receives a `Event::Custom` event or if deserialization fails /// Fires when `Client` receives a `Event::Custom` event or if
/// because the event was unknown to ruma. /// deserialization fails because the event was unknown to ruma.
/// ///
/// The only guarantee this method can give about the event is that it is in the /// The only guarantee this method can give about the event is that it is in
/// shape of a valid matrix event. /// the shape of a valid matrix event.
async fn on_custom_event(&self, _: Room, _: &CustomEvent<'_>) {} async fn on_custom_event(&self, _: Room, _: &CustomEvent<'_>) {}
} }
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use super::*; use std::{sync::Arc, time::Duration};
use matrix_sdk_common::{async_trait, locks::Mutex}; use matrix_sdk_common::{async_trait, locks::Mutex};
use matrix_sdk_test::{async_test, test_json}; use matrix_sdk_test::{async_test, test_json};
use mockito::{mock, Matcher}; use mockito::{mock, Matcher};
use std::{sync::Arc, time::Duration};
#[cfg(target_arch = "wasm32")] #[cfg(target_arch = "wasm32")]
pub use wasm_bindgen_test::*; pub use wasm_bindgen_test::*;
use super::*;
#[derive(Clone)] #[derive(Clone)]
pub struct EvHandlerTest(Arc<Mutex<Vec<String>>>); pub struct EvHandlerTest(Arc<Mutex<Vec<String>>>);
@ -640,56 +606,50 @@ mod test {
} }
// `AnyStrippedStateEvent`s // `AnyStrippedStateEvent`s
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomMember` event. /// Fires when `Client` receives a
/// `AnyStrippedStateEvent::StrippedRoomMember` event.
async fn on_stripped_state_member( async fn on_stripped_state_member(
&self, &self,
_: Room, _: Room,
_: &StrippedStateEvent<MemberEventContent>, _: &StrippedStateEvent<MemberEventContent>,
_: Option<MemberEventContent>, _: Option<MemberEventContent>,
) { ) {
self.0 self.0.lock().await.push("stripped state member".to_string())
.lock()
.await
.push("stripped state member".to_string())
} }
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomName` event. /// Fires when `Client` receives a
/// `AnyStrippedStateEvent::StrippedRoomName` event.
async fn on_stripped_state_name(&self, _: Room, _: &StrippedStateEvent<NameEventContent>) { async fn on_stripped_state_name(&self, _: Room, _: &StrippedStateEvent<NameEventContent>) {
self.0.lock().await.push("stripped state name".to_string()) self.0.lock().await.push("stripped state name".to_string())
} }
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomCanonicalAlias` event. /// Fires when `Client` receives a
/// `AnyStrippedStateEvent::StrippedRoomCanonicalAlias` event.
async fn on_stripped_state_canonical_alias( async fn on_stripped_state_canonical_alias(
&self, &self,
_: Room, _: Room,
_: &StrippedStateEvent<CanonicalAliasEventContent>, _: &StrippedStateEvent<CanonicalAliasEventContent>,
) { ) {
self.0 self.0.lock().await.push("stripped state canonical".to_string())
.lock()
.await
.push("stripped state canonical".to_string())
} }
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomAliases` event. /// Fires when `Client` receives a
/// `AnyStrippedStateEvent::StrippedRoomAliases` event.
async fn on_stripped_state_aliases( async fn on_stripped_state_aliases(
&self, &self,
_: Room, _: Room,
_: &StrippedStateEvent<AliasesEventContent>, _: &StrippedStateEvent<AliasesEventContent>,
) { ) {
self.0 self.0.lock().await.push("stripped state aliases".to_string())
.lock()
.await
.push("stripped state aliases".to_string())
} }
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomAvatar` event. /// Fires when `Client` receives a
/// `AnyStrippedStateEvent::StrippedRoomAvatar` event.
async fn on_stripped_state_avatar( async fn on_stripped_state_avatar(
&self, &self,
_: Room, _: Room,
_: &StrippedStateEvent<AvatarEventContent>, _: &StrippedStateEvent<AvatarEventContent>,
) { ) {
self.0 self.0.lock().await.push("stripped state avatar".to_string())
.lock()
.await
.push("stripped state avatar".to_string())
} }
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomPowerLevels` event. /// Fires when `Client` receives a
/// `AnyStrippedStateEvent::StrippedRoomPowerLevels` event.
async fn on_stripped_state_power_levels( async fn on_stripped_state_power_levels(
&self, &self,
_: Room, _: Room,
@ -697,7 +657,8 @@ mod test {
) { ) {
self.0.lock().await.push("stripped state power".to_string()) self.0.lock().await.push("stripped state power".to_string())
} }
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomJoinRules` event. /// Fires when `Client` receives a
/// `AnyStrippedStateEvent::StrippedRoomJoinRules` event.
async fn on_stripped_state_join_rules( async fn on_stripped_state_join_rules(
&self, &self,
_: Room, _: Room,
@ -768,14 +729,11 @@ mod test {
} }
async fn mock_sync(client: &Client, response: String) { async fn mock_sync(client: &Client, response: String) {
let _m = mock( let _m = mock("GET", Matcher::Regex(r"^/_matrix/client/r0/sync\?.*$".to_string()))
"GET", .with_status(200)
Matcher::Regex(r"^/_matrix/client/r0/sync\?.*$".to_string()), .match_header("authorization", "Bearer 1234")
) .with_body(response)
.with_status(200) .create();
.match_header("authorization", "Bearer 1234")
.with_body(response)
.create();
let sync_settings = SyncSettings::new().timeout(Duration::from_millis(3000)); let sync_settings = SyncSettings::new().timeout(Duration::from_millis(3000));
let _response = client.sync_once(sync_settings).await.unwrap(); let _response = client.sync_once(sync_settings).await.unwrap();
@ -823,14 +781,7 @@ mod test {
mock_sync(&client, test_json::INVITE_SYNC.to_string()).await; mock_sync(&client, test_json::INVITE_SYNC.to_string()).await;
let v = test_vec.lock().await; let v = test_vec.lock().await;
assert_eq!( assert_eq!(v.as_slice(), ["stripped state name", "stripped state member", "presence event"],)
v.as_slice(),
[
"stripped state name",
"stripped state member",
"presence event"
],
)
} }
#[async_test] #[async_test]
@ -897,15 +848,7 @@ mod test {
mock_sync(&client, test_json::VOIP_SYNC.to_string()).await; mock_sync(&client, test_json::VOIP_SYNC.to_string()).await;
let v = test_vec.lock().await; let v = test_vec.lock().await;
assert_eq!( assert_eq!(v.as_slice(), ["call invite", "call answer", "call candidates", "call hangup",],)
v.as_slice(),
[
"call invite",
"call answer",
"call candidates",
"call hangup",
],
)
} }
#[async_test] #[async_test]

View File

@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#[cfg(all(not(target_arch = "wasm32")))]
use std::sync::atomic::{AtomicU64, Ordering};
use std::{convert::TryFrom, fmt::Debug, sync::Arc}; use std::{convert::TryFrom, fmt::Debug, sync::Arc};
#[cfg(all(not(target_arch = "wasm32")))] #[cfg(all(not(target_arch = "wasm32")))]
@ -19,16 +21,13 @@ use backoff::{future::retry, Error as RetryError, ExponentialBackoff};
#[cfg(all(not(target_arch = "wasm32")))] #[cfg(all(not(target_arch = "wasm32")))]
use http::StatusCode; use http::StatusCode;
use http::{HeaderValue, Response as HttpResponse}; use http::{HeaderValue, Response as HttpResponse};
use reqwest::{Client, Response};
#[cfg(all(not(target_arch = "wasm32")))]
use std::sync::atomic::{AtomicU64, Ordering};
use tracing::trace;
use url::Url;
use matrix_sdk_common::{ use matrix_sdk_common::{
api::r0::media::create_content, async_trait, locks::RwLock, AsyncTraitDeps, AuthScheme, api::r0::media::create_content, async_trait, locks::RwLock, AsyncTraitDeps, AuthScheme,
FromHttpResponseError, IncomingResponse, SendAccessToken, FromHttpResponseError, IncomingResponse, SendAccessToken,
}; };
use reqwest::{Client, Response};
use tracing::trace;
use url::Url;
use crate::{ use crate::{
error::HttpError, Bytes, BytesMut, ClientConfig, OutgoingRequest, RequestConfig, Session, error::HttpError, Bytes, BytesMut, ClientConfig, OutgoingRequest, RequestConfig, Session,
@ -39,13 +38,16 @@ use crate::{
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] #[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)] #[cfg_attr(not(target_arch = "wasm32"), async_trait)]
pub trait HttpSend: AsyncTraitDeps { pub trait HttpSend: AsyncTraitDeps {
/// The method abstracting sending request types and receiving response types. /// The method abstracting sending request types and receiving response
/// types.
/// ///
/// This is called by the client every time it wants to send anything to a homeserver. /// This is called by the client every time it wants to send anything to a
/// homeserver.
/// ///
/// # Arguments /// # Arguments
/// ///
/// * `request` - The http request that has been converted from a ruma `Request`. /// * `request` - The http request that has been converted from a ruma
/// `Request`.
/// ///
/// * `request_config` - The config used for this request. /// * `request_config` - The config used for this request.
/// ///
@ -122,8 +124,7 @@ impl HttpClient {
let request = if !self.request_config.assert_identity { let request = if !self.request_config.assert_identity {
self.try_into_http_request(request, session, config).await? self.try_into_http_request(request, session, config).await?
} else { } else {
self.try_into_http_request_with_identy_assertion(request, session, config) self.try_into_http_request_with_identy_assertion(request, session, config).await?
.await?
}; };
self.inner.send_request(request, config).await self.inner.send_request(request, config).await
@ -202,9 +203,7 @@ impl HttpClient {
request: create_content::Request<'_>, request: create_content::Request<'_>,
config: Option<RequestConfig>, config: Option<RequestConfig>,
) -> Result<create_content::Response, HttpError> { ) -> Result<create_content::Response, HttpError> {
let response = self let response = self.send_request(request, self.session.clone(), config).await?;
.send_request(request, self.session.clone(), config)
.await?;
Ok(create_content::Response::try_from_http_response(response)?) Ok(create_content::Response::try_from_http_response(response)?)
} }
@ -217,9 +216,7 @@ impl HttpClient {
Request: OutgoingRequest + Debug, Request: OutgoingRequest + Debug,
HttpError: From<FromHttpResponseError<Request::EndpointError>>, HttpError: From<FromHttpResponseError<Request::EndpointError>>,
{ {
let response = self let response = self.send_request(request, self.session.clone(), config).await?;
.send_request(request, self.session.clone(), config)
.await?;
trace!("Got response: {:?}", response); trace!("Got response: {:?}", response);
@ -256,9 +253,7 @@ pub(crate) fn client_with_config(config: &ClientConfig) -> Result<Client, HttpEr
headers.insert(reqwest::header::USER_AGENT, user_agent); headers.insert(reqwest::header::USER_AGENT, user_agent);
http_client http_client.default_headers(headers).timeout(config.request_config.timeout)
.default_headers(headers)
.timeout(config.request_config.timeout)
}; };
#[cfg(target_arch = "wasm32")] #[cfg(target_arch = "wasm32")]
@ -274,9 +269,7 @@ async fn response_to_http_response(
let status = response.status(); let status = response.status();
let mut http_builder = HttpResponse::builder().status(status); let mut http_builder = HttpResponse::builder().status(status);
let headers = http_builder let headers = http_builder.headers_mut().expect("Can't get the response builder headers");
.headers_mut()
.expect("Can't get the response builder headers");
for (k, v) in response.headers_mut().drain() { for (k, v) in response.headers_mut().drain() {
if let Some(key) = k { if let Some(key) = k {
@ -286,9 +279,7 @@ async fn response_to_http_response(
let body = response.bytes().await?; let body = response.bytes().await?;
Ok(http_builder Ok(http_builder.body(body).expect("Can't construct a response using the given body"))
.body(body)
.expect("Can't construct a response using the given body"))
} }
#[cfg(any(target_arch = "wasm32"))] #[cfg(any(target_arch = "wasm32"))]
@ -329,18 +320,12 @@ async fn send_request(
}; };
// Turn errors into permanent errors when the retry limit is reached // Turn errors into permanent errors when the retry limit is reached
let error_type = if stop { let error_type = if stop { RetryError::Permanent } else { RetryError::Transient };
RetryError::Permanent
} else {
RetryError::Transient
};
let request = request.try_clone().ok_or(HttpError::UnableToCloneRequest)?; let request = request.try_clone().ok_or(HttpError::UnableToCloneRequest)?;
let response = client let response =
.execute(request) client.execute(request).await.map_err(|e| error_type(HttpError::Reqwest(e)))?;
.await
.map_err(|e| error_type(HttpError::Reqwest(e)))?;
let status_code = response.status(); let status_code = response.status();
// TODO TOO_MANY_REQUESTS will have a retry timeout which we should // TODO TOO_MANY_REQUESTS will have a retry timeout which we should

View File

@ -17,19 +17,21 @@
//! //!
//! # Enabling logging //! # Enabling logging
//! //!
//! Users of the matrix-sdk crate can enable log output by depending on the `tracing-subscriber` //! Users of the matrix-sdk crate can enable log output by depending on the
//! crate and including the following line in their application (e.g. at the start of `main`): //! `tracing-subscriber` crate and including the following line in their
//! application (e.g. at the start of `main`):
//! //!
//! ```rust //! ```rust
//! tracing_subscriber::fmt::init(); //! tracing_subscriber::fmt::init();
//! ``` //! ```
//! //!
//! The log output is controlled via the `RUST_LOG` environment variable by setting it to one of //! The log output is controlled via the `RUST_LOG` environment variable by
//! the `error`, `warn`, `info`, `debug` or `trace` levels. The output is printed to stdout. //! setting it to one of the `error`, `warn`, `info`, `debug` or `trace` levels.
//! The output is printed to stdout.
//! //!
//! The `RUST_LOG` variable also supports a more advanced syntax for filtering log output more //! The `RUST_LOG` variable also supports a more advanced syntax for filtering
//! precisely, for instance with crate-level granularity. For more information on this, check out //! log output more precisely, for instance with crate-level granularity. For
//! the [tracing_subscriber //! more information on this, check out the [tracing_subscriber
//! documentation](https://tracing.rs/tracing_subscriber/filter/struct.envfilter). //! documentation](https://tracing.rs/tracing_subscriber/filter/struct.envfilter).
//! //!
//! # Crate Feature Flags //! # Crate Feature Flags
@ -44,10 +46,13 @@
//! * `markdown`: Support for sending markdown formatted messages. //! * `markdown`: Support for sending markdown formatted messages.
//! * `socks`: Enables SOCKS support in reqwest, the default HTTP client. //! * `socks`: Enables SOCKS support in reqwest, the default HTTP client.
//! * `sso_login`: Enables SSO login with a local http server. //! * `sso_login`: Enables SSO login with a local http server.
//! * `require_auth_for_profile_requests`: Whether to send the access token in the authentication //! * `require_auth_for_profile_requests`: Whether to send the access token in
//! header when calling endpoints that retrieve profile data. This matches the synapse //! the authentication
//! configuration `require_auth_for_profile_requests`. Enabled by default. //! header when calling endpoints that retrieve profile data. This matches the
//! * `appservice`: Enables low-level appservice functionality. For an high-level API there's the //! synapse configuration `require_auth_for_profile_requests`. Enabled by
//! default.
//! * `appservice`: Enables low-level appservice functionality. For an
//! high-level API there's the
//! `matrix-sdk-appservice` crate //! `matrix-sdk-appservice` crate
#![deny( #![deny(
@ -71,6 +76,7 @@ compile_error!("only one of 'native-tls' or 'rustls-tls' features can be enabled
#[cfg(all(feature = "sso_login", target_arch = "wasm32"))] #[cfg(all(feature = "sso_login", target_arch = "wasm32"))]
compile_error!("'sso_login' cannot be enabled on 'wasm32' arch"); compile_error!("'sso_login' cannot be enabled on 'wasm32' arch");
pub use bytes::{Bytes, BytesMut};
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
#[cfg_attr(feature = "docs", doc(cfg(encryption)))] #[cfg_attr(feature = "docs", doc(cfg(encryption)))]
pub use matrix_sdk_base::crypto::{EncryptionInfo, LocalTrust}; pub use matrix_sdk_base::crypto::{EncryptionInfo, LocalTrust};
@ -78,8 +84,6 @@ pub use matrix_sdk_base::{
Error as BaseError, Room as BaseRoom, RoomInfo, RoomMember as BaseRoomMember, RoomType, Error as BaseError, Room as BaseRoom, RoomInfo, RoomMember as BaseRoomMember, RoomType,
Session, StateChanges, StoreError, Session, StateChanges, StoreError,
}; };
pub use bytes::{Bytes, BytesMut};
pub use matrix_sdk_common::*; pub use matrix_sdk_common::*;
pub use reqwest; pub use reqwest;

View File

@ -1,3 +1,5 @@
use std::{ops::Deref, sync::Arc};
use matrix_sdk_base::{deserialized_responses::MembersResponse, identifiers::UserId}; use matrix_sdk_base::{deserialized_responses::MembersResponse, identifiers::UserId};
use matrix_sdk_common::{ use matrix_sdk_common::{
api::r0::{ api::r0::{
@ -8,11 +10,10 @@ use matrix_sdk_common::{
locks::Mutex, locks::Mutex,
}; };
use std::{ops::Deref, sync::Arc};
use crate::{BaseRoom, Client, Result, RoomMember}; use crate::{BaseRoom, Client, Result, RoomMember};
/// A struct containing methodes that are common for Joined, Invited and Left Rooms /// A struct containing methodes that are common for Joined, Invited and Left
/// Rooms
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Common { pub struct Common {
inner: BaseRoom, inner: BaseRoom,
@ -36,10 +37,7 @@ impl Common {
/// * `room` - The underlaying room. /// * `room` - The underlaying room.
pub fn new(client: Client, room: BaseRoom) -> Self { pub fn new(client: Client, room: BaseRoom) -> Self {
// TODO: Make this private // TODO: Make this private
Self { Self { inner: room, client }
inner: room,
client,
}
} }
/// Leave this room. /// Leave this room.
@ -111,9 +109,9 @@ impl Common {
} }
} }
/// Sends a request to `/_matrix/client/r0/rooms/{room_id}/messages` and returns /// Sends a request to `/_matrix/client/r0/rooms/{room_id}/messages` and
/// a `get_message_events::Response` that contains a chunk of room and state events /// returns a `get_message_events::Response` that contains a chunk of
/// (`AnyRoomEvent` and `AnyStateEvent`). /// room and state events (`AnyRoomEvent` and `AnyStateEvent`).
/// ///
/// # Arguments /// # Arguments
/// ///
@ -152,35 +150,25 @@ impl Common {
pub(crate) async fn request_members(&self) -> Result<Option<MembersResponse>> { pub(crate) async fn request_members(&self) -> Result<Option<MembersResponse>> {
#[allow(clippy::map_clone)] #[allow(clippy::map_clone)]
if let Some(mutex) = self if let Some(mutex) =
.client self.client.members_request_locks.get(self.inner.room_id()).map(|m| m.clone())
.members_request_locks
.get(self.inner.room_id())
.map(|m| m.clone())
{ {
mutex.lock().await; mutex.lock().await;
Ok(None) Ok(None)
} else { } else {
let mutex = Arc::new(Mutex::new(())); let mutex = Arc::new(Mutex::new(()));
self.client self.client.members_request_locks.insert(self.inner.room_id().clone(), mutex.clone());
.members_request_locks
.insert(self.inner.room_id().clone(), mutex.clone());
let _guard = mutex.lock().await; let _guard = mutex.lock().await;
let request = get_member_events::Request::new(self.inner.room_id()); let request = get_member_events::Request::new(self.inner.room_id());
let response = self.client.send(request, None).await?; let response = self.client.send(request, None).await?;
let response = self let response =
.client self.client.base_client.receive_members(self.inner.room_id(), &response).await?;
.base_client
.receive_members(self.inner.room_id(), &response)
.await?;
self.client self.client.members_request_locks.remove(self.inner.room_id());
.members_request_locks
.remove(self.inner.room_id());
Ok(Some(response)) Ok(Some(response))
} }
@ -248,9 +236,9 @@ impl Common {
/// Get all the joined members of this room. /// Get all the joined members of this room.
/// ///
/// *Note*: This method will not fetch the members from the homeserver if the /// *Note*: This method will not fetch the members from the homeserver if
/// member list isn't synchronized due to member lazy loading. Thus, members /// the member list isn't synchronized due to member lazy loading. Thus,
/// could be missing from the list. /// members could be missing from the list.
/// ///
/// Use [joined_members()](#method.joined_members) if you want to ensure to /// Use [joined_members()](#method.joined_members) if you want to ensure to
/// always get the full member list. /// always get the full member list.
@ -284,9 +272,9 @@ impl Common {
/// Get a specific member of this room. /// Get a specific member of this room.
/// ///
/// *Note*: This method will not fetch the members from the homeserver if the /// *Note*: This method will not fetch the members from the homeserver if
/// member list isn't synchronized due to member lazy loading. Thus, members /// the member list isn't synchronized due to member lazy loading. Thus,
/// could be missing. /// members could be missing.
/// ///
/// Use [get_member()](#method.get_member) if you want to ensure to always /// Use [get_member()](#method.get_member) if you want to ensure to always
/// have the full member list to chose from. /// have the full member list to chose from.
@ -295,7 +283,6 @@ impl Common {
/// ///
/// * `user_id` - The ID of the user that should be fetched out of the /// * `user_id` - The ID of the user that should be fetched out of the
/// store. /// store.
///
pub async fn get_member_no_sync(&self, user_id: &UserId) -> Result<Option<RoomMember>> { pub async fn get_member_no_sync(&self, user_id: &UserId) -> Result<Option<RoomMember>> {
Ok(self Ok(self
.inner .inner
@ -304,7 +291,8 @@ impl Common {
.map(|member| RoomMember::new(self.client.clone(), member))) .map(|member| RoomMember::new(self.client.clone(), member)))
} }
/// Get all members for this room, includes invited, joined and left members. /// Get all members for this room, includes invited, joined and left
/// members.
/// ///
/// *Note*: This method will fetch the members from the homeserver if the /// *Note*: This method will fetch the members from the homeserver if the
/// member list isn't synchronized due to member lazy loading. Because of /// member list isn't synchronized due to member lazy loading. Because of
@ -317,11 +305,12 @@ impl Common {
self.members_no_sync().await self.members_no_sync().await
} }
/// Get all members for this room, includes invited, joined and left members. /// Get all members for this room, includes invited, joined and left
/// members.
/// ///
/// *Note*: This method will not fetch the members from the homeserver if the /// *Note*: This method will not fetch the members from the homeserver if
/// member list isn't synchronized due to member lazy loading. Thus, members /// the member list isn't synchronized due to member lazy loading. Thus,
/// could be missing. /// members could be missing.
/// ///
/// Use [members()](#method.members) if you want to ensure to always get /// Use [members()](#method.members) if you want to ensure to always get
/// the full member list. /// the full member list.

View File

@ -1,17 +1,20 @@
use crate::{room::Common, BaseRoom, Client, Result, RoomType};
use std::ops::Deref; use std::ops::Deref;
use crate::{room::Common, BaseRoom, Client, Result, RoomType};
/// A room in the invited state. /// A room in the invited state.
/// ///
/// This struct contains all methodes specific to a `Room` with type `RoomType::Invited`. /// This struct contains all methodes specific to a `Room` with type
/// Operations may fail once the underlaying `Room` changes `RoomType`. /// `RoomType::Invited`. Operations may fail once the underlaying `Room` changes
/// `RoomType`.
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Invited { pub struct Invited {
pub(crate) inner: Common, pub(crate) inner: Common,
} }
impl Invited { impl Invited {
/// Create a new `room::Invited` if the underlaying `Room` has type `RoomType::Invited`. /// Create a new `room::Invited` if the underlaying `Room` has type
/// `RoomType::Invited`.
/// ///
/// # Arguments /// # Arguments
/// * `client` - The client used to make requests. /// * `client` - The client used to make requests.
@ -20,9 +23,7 @@ impl Invited {
pub fn new(client: Client, room: BaseRoom) -> Option<Self> { pub fn new(client: Client, room: BaseRoom) -> Option<Self> {
// TODO: Make this private // TODO: Make this private
if room.room_type() == RoomType::Invited { if room.room_type() == RoomType::Invited {
Some(Self { Some(Self { inner: Common::new(client, room) })
inner: Common::new(client, room),
})
} else { } else {
None None
} }

View File

@ -1,9 +1,11 @@
use crate::{room::Common, BaseRoom, Client, Result, RoomType}; #[cfg(feature = "encryption")]
use std::sync::Arc;
use std::{io::Read, ops::Deref}; use std::{io::Read, ops::Deref};
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
use std::sync::Arc; use matrix_sdk_base::crypto::AttachmentEncryptor;
#[cfg(feature = "encryption")]
use matrix_sdk_common::locks::Mutex;
use matrix_sdk_common::{ use matrix_sdk_common::{
api::r0::{ api::r0::{
membership::{ membership::{
@ -34,25 +36,20 @@ use matrix_sdk_common::{
receipt::ReceiptType, receipt::ReceiptType,
uuid::Uuid, uuid::Uuid,
}; };
use mime::{self, Mime}; use mime::{self, Mime};
#[cfg(feature = "encryption")]
use matrix_sdk_common::locks::Mutex;
#[cfg(feature = "encryption")]
use matrix_sdk_base::crypto::AttachmentEncryptor;
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
use tracing::instrument; use tracing::instrument;
use crate::{room::Common, BaseRoom, Client, Result, RoomType};
const TYPING_NOTICE_TIMEOUT: Duration = Duration::from_secs(4); const TYPING_NOTICE_TIMEOUT: Duration = Duration::from_secs(4);
const TYPING_NOTICE_RESEND_TIMEOUT: Duration = Duration::from_secs(3); const TYPING_NOTICE_RESEND_TIMEOUT: Duration = Duration::from_secs(3);
/// A room in the joined state. /// A room in the joined state.
/// ///
/// The `JoinedRoom` contains all methodes specific to a `Room` with type `RoomType::Joined`. /// The `JoinedRoom` contains all methodes specific to a `Room` with type
/// Operations may fail once the underlaying `Room` changes `RoomType`. /// `RoomType::Joined`. Operations may fail once the underlaying `Room` changes
/// `RoomType`.
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Joined { pub struct Joined {
pub(crate) inner: Common, pub(crate) inner: Common,
@ -67,7 +64,8 @@ impl Deref for Joined {
} }
impl Joined { impl Joined {
/// Create a new `room::Joined` if the underlaying `BaseRoom` has type `RoomType::Joined`. /// Create a new `room::Joined` if the underlaying `BaseRoom` has type
/// `RoomType::Joined`.
/// ///
/// # Arguments /// # Arguments
/// * `client` - The client used to make requests. /// * `client` - The client used to make requests.
@ -76,9 +74,7 @@ impl Joined {
pub fn new(client: Client, room: BaseRoom) -> Option<Self> { pub fn new(client: Client, room: BaseRoom) -> Option<Self> {
// TODO: Make this private // TODO: Make this private
if room.room_type() == RoomType::Joined { if room.room_type() == RoomType::Joined {
Some(Self { Some(Self { inner: Common::new(client, room) })
inner: Common::new(client, room),
})
} else { } else {
None None
} }
@ -97,9 +93,7 @@ impl Joined {
/// ///
/// * `reason` - The reason for banning this user. /// * `reason` - The reason for banning this user.
pub async fn ban_user(&self, user_id: &UserId, reason: Option<&str>) -> Result<()> { pub async fn ban_user(&self, user_id: &UserId, reason: Option<&str>) -> Result<()> {
let request = assign!(ban_user::Request::new(self.inner.room_id(), user_id), { let request = assign!(ban_user::Request::new(self.inner.room_id(), user_id), { reason });
reason
});
self.client.send(request, None).await?; self.client.send(request, None).await?;
Ok(()) Ok(())
} }
@ -108,13 +102,12 @@ impl Joined {
/// ///
/// # Arguments /// # Arguments
/// ///
/// * `user_id` - The `UserId` of the user that should be kicked out of the room. /// * `user_id` - The `UserId` of the user that should be kicked out of the
/// room.
/// ///
/// * `reason` - Optional reason why the room member is being kicked out. /// * `reason` - Optional reason why the room member is being kicked out.
pub async fn kick_user(&self, user_id: &UserId, reason: Option<&str>) -> Result<()> { pub async fn kick_user(&self, user_id: &UserId, reason: Option<&str>) -> Result<()> {
let request = assign!(kick_user::Request::new(self.inner.room_id(), user_id), { let request = assign!(kick_user::Request::new(self.inner.room_id(), user_id), { reason });
reason
});
self.client.send(request, None).await?; self.client.send(request, None).await?;
Ok(()) Ok(())
} }
@ -148,9 +141,10 @@ impl Joined {
/// Activate typing notice for this room. /// Activate typing notice for this room.
/// ///
/// The typing notice remains active for 4s. It can be deactivate at any point by setting /// The typing notice remains active for 4s. It can be deactivate at any
/// typing to `false`. If this method is called while the typing notice is active nothing will happen. /// point by setting typing to `false`. If this method is called while
/// This method can be called on every key stroke, since it will do nothing while typing is /// the typing notice is active nothing will happen. This method can be
/// called on every key stroke, since it will do nothing while typing is
/// active. /// active.
/// ///
/// # Arguments /// # Arguments
@ -183,21 +177,23 @@ impl Joined {
/// # }); /// # });
/// ``` /// ```
pub async fn typing_notice(&self, typing: bool) -> Result<()> { pub async fn typing_notice(&self, typing: bool) -> Result<()> {
// Only send a request to the homeserver if the old timeout has elapsed or the typing // Only send a request to the homeserver if the old timeout has elapsed
// notice changed state within the TYPING_NOTICE_TIMEOUT // or the typing notice changed state within the
// TYPING_NOTICE_TIMEOUT
let send = let send =
if let Some(typing_time) = self.client.typing_notice_times.get(self.inner.room_id()) { if let Some(typing_time) = self.client.typing_notice_times.get(self.inner.room_id()) {
if typing_time.elapsed() > TYPING_NOTICE_RESEND_TIMEOUT { if typing_time.elapsed() > TYPING_NOTICE_RESEND_TIMEOUT {
// We always reactivate the typing notice if typing is true or we may need to // We always reactivate the typing notice if typing is true or
// deactivate it if it's currently active if typing is false // we may need to deactivate it if it's
// currently active if typing is false
typing || typing_time.elapsed() <= TYPING_NOTICE_TIMEOUT typing || typing_time.elapsed() <= TYPING_NOTICE_TIMEOUT
} else { } else {
// Only send a request when we need to deactivate typing // Only send a request when we need to deactivate typing
!typing !typing
} }
} else { } else {
// Typing notice is currently deactivated, therefore, send a request only when it's // Typing notice is currently deactivated, therefore, send a request
// about to be activated // only when it's about to be activated
typing typing
}; };
@ -220,11 +216,13 @@ impl Joined {
Ok(()) Ok(())
} }
/// Send a request to notify this room that the user has read specific event. /// Send a request to notify this room that the user has read specific
/// event.
/// ///
/// # Arguments /// # Arguments
/// ///
/// * `event_id` - The `EventId` specifies the event to set the read receipt on. /// * `event_id` - The `EventId` specifies the event to set the read receipt
/// on.
pub async fn read_receipt(&self, event_id: &EventId) -> Result<()> { pub async fn read_receipt(&self, event_id: &EventId) -> Result<()> {
let request = let request =
create_receipt::Request::new(self.inner.room_id(), ReceiptType::Read, event_id); create_receipt::Request::new(self.inner.room_id(), ReceiptType::Read, event_id);
@ -233,22 +231,23 @@ impl Joined {
Ok(()) Ok(())
} }
/// Send a request to notify this room that the user has read up to specific event. /// Send a request to notify this room that the user has read up to specific
/// event.
/// ///
/// # Arguments /// # Arguments
/// ///
/// * fully_read - The `EventId` of the event the user has read to. /// * fully_read - The `EventId` of the event the user has read to.
/// ///
/// * read_receipt - An `EventId` to specify the event to set the read receipt on. /// * read_receipt - An `EventId` to specify the event to set the read
/// receipt on.
pub async fn read_marker( pub async fn read_marker(
&self, &self,
fully_read: &EventId, fully_read: &EventId,
read_receipt: Option<&EventId>, read_receipt: Option<&EventId>,
) -> Result<()> { ) -> Result<()> {
let request = assign!( let request = assign!(set_read_marker::Request::new(self.inner.room_id(), fully_read), {
set_read_marker::Request::new(self.inner.room_id(), fully_read), read_receipt
{ read_receipt } });
);
self.client.send(request, None).await?; self.client.send(request, None).await?;
Ok(()) Ok(())
@ -266,11 +265,8 @@ impl Joined {
// TODO expose this publicly so people can pre-share a group session if // TODO expose this publicly so people can pre-share a group session if
// e.g. a user starts to type a message for a room. // e.g. a user starts to type a message for a room.
#[allow(clippy::map_clone)] #[allow(clippy::map_clone)]
if let Some(mutex) = self if let Some(mutex) =
.client self.client.group_session_locks.get(self.inner.room_id()).map(|m| m.clone())
.group_session_locks
.get(self.inner.room_id())
.map(|m| m.clone())
{ {
// If a group session share request is already going on, // If a group session share request is already going on,
// await the release of the lock. // await the release of the lock.
@ -279,23 +275,14 @@ impl Joined {
// Otherwise create a new lock and share the group // Otherwise create a new lock and share the group
// session. // session.
let mutex = Arc::new(Mutex::new(())); let mutex = Arc::new(Mutex::new(()));
self.client self.client.group_session_locks.insert(self.inner.room_id().clone(), mutex.clone());
.group_session_locks
.insert(self.inner.room_id().clone(), mutex.clone());
let _guard = mutex.lock().await; let _guard = mutex.lock().await;
{ {
let joined = self let joined = self.client.store().get_joined_user_ids(self.inner.room_id()).await?;
.client let invited =
.store() self.client.store().get_invited_user_ids(self.inner.room_id()).await?;
.get_joined_user_ids(self.inner.room_id())
.await?;
let invited = self
.client
.store()
.get_invited_user_ids(self.inner.room_id())
.await?;
let members = joined.iter().chain(&invited); let members = joined.iter().chain(&invited);
self.client.claim_one_time_keys(members).await?; self.client.claim_one_time_keys(members).await?;
}; };
@ -308,10 +295,7 @@ impl Joined {
// session as using it would end up in undecryptable // session as using it would end up in undecryptable
// messages. // messages.
if let Err(r) = response { if let Err(r) = response {
self.client self.client.base_client.invalidate_group_session(self.inner.room_id()).await?;
.base_client
.invalidate_group_session(self.inner.room_id())
.await?;
return Err(r); return Err(r);
} }
} }
@ -328,19 +312,13 @@ impl Joined {
#[cfg_attr(feature = "docs", doc(cfg(encryption)))] #[cfg_attr(feature = "docs", doc(cfg(encryption)))]
#[instrument] #[instrument]
async fn share_group_session(&self) -> Result<()> { async fn share_group_session(&self) -> Result<()> {
let mut requests = self let mut requests =
.client self.client.base_client.share_group_session(self.inner.room_id()).await?;
.base_client
.share_group_session(self.inner.room_id())
.await?;
for request in requests.drain(..) { for request in requests.drain(..) {
let response = self.client.send_to_device(&request).await?; let response = self.client.send_to_device(&request).await?;
self.client self.client.base_client.mark_request_as_sent(&request.txn_id, &response).await?;
.base_client
.mark_request_as_sent(&request.txn_id, &response)
.await?;
} }
Ok(()) Ok(())
@ -407,10 +385,7 @@ impl Joined {
self.preshare_group_session().await?; self.preshare_group_session().await?;
AnyMessageEventContent::RoomEncrypted( AnyMessageEventContent::RoomEncrypted(
self.client self.client.base_client.encrypt(self.inner.room_id(), content).await?,
.base_client
.encrypt(self.inner.room_id(), content)
.await?,
) )
} else { } else {
content.into() content.into()
@ -430,8 +405,9 @@ impl Joined {
/// If the room is encrypted and the encryption feature is enabled the /// If the room is encrypted and the encryption feature is enabled the
/// upload will be encrypted. /// upload will be encrypted.
/// ///
/// This is a convenience method that calls the [`Client::upload()`](#Client::method.upload) /// This is a convenience method that calls the
/// and afterwards the [`send()`](#method.send). /// [`Client::upload()`](#Client::method.upload) and afterwards the
/// [`send()`](#method.send).
/// ///
/// # Arguments /// # Arguments
/// * `body` - A textual representation of the media that is going to be /// * `body` - A textual representation of the media that is going to be
@ -538,11 +514,8 @@ impl Joined {
}), }),
}; };
self.send( self.send(AnyMessageEventContent::RoomMessage(MessageEventContent::new(content)), txn_id)
AnyMessageEventContent::RoomMessage(MessageEventContent::new(content)), .await
txn_id,
)
.await
} }
/// Send a room state event to the homeserver. /// Send a room state event to the homeserver.
@ -639,10 +612,10 @@ impl Joined {
txn_id: Option<Uuid>, txn_id: Option<Uuid>,
) -> Result<redact_event::Response> { ) -> Result<redact_event::Response> {
let txn_id = txn_id.unwrap_or_else(Uuid::new_v4).to_string(); let txn_id = txn_id.unwrap_or_else(Uuid::new_v4).to_string();
let request = assign!( let request =
redact_event::Request::new(self.inner.room_id(), event_id, &txn_id), assign!(redact_event::Request::new(self.inner.room_id(), event_id, &txn_id), {
{ reason } reason
); });
self.client.send(request, None).await self.client.send(request, None).await
} }

View File

@ -1,19 +1,22 @@
use crate::{room::Common, BaseRoom, Client, Result, RoomType};
use std::ops::Deref; use std::ops::Deref;
use matrix_sdk_common::api::r0::membership::forget_room; use matrix_sdk_common::api::r0::membership::forget_room;
use crate::{room::Common, BaseRoom, Client, Result, RoomType};
/// A room in the left state. /// A room in the left state.
/// ///
/// This struct contains all methodes specific to a `Room` with type `RoomType::Left`. /// This struct contains all methodes specific to a `Room` with type
/// Operations may fail once the underlaying `Room` changes `RoomType`. /// `RoomType::Left`. Operations may fail once the underlaying `Room` changes
/// `RoomType`.
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Left { pub struct Left {
pub(crate) inner: Common, pub(crate) inner: Common,
} }
impl Left { impl Left {
/// Create a new `room::Left` if the underlaying `Room` has type `RoomType::Left`. /// Create a new `room::Left` if the underlaying `Room` has type
/// `RoomType::Left`.
/// ///
/// # Arguments /// # Arguments
/// * `client` - The client used to make requests. /// * `client` - The client used to make requests.
@ -22,9 +25,7 @@ impl Left {
pub fn new(client: Client, room: BaseRoom) -> Option<Self> { pub fn new(client: Client, room: BaseRoom) -> Option<Self> {
// TODO: Make this private // TODO: Make this private
if room.room_type() == RoomType::Left { if room.room_type() == RoomType::Left {
Some(Self { Some(Self { inner: Common::new(client, room) })
inner: Common::new(client, room),
})
} else { } else {
None None
} }

View File

@ -1,7 +1,7 @@
use matrix_sdk_common::api::r0::media::{get_content, get_content_thumbnail};
use std::ops::Deref; use std::ops::Deref;
use matrix_sdk_common::api::r0::media::{get_content, get_content_thumbnail};
use crate::{BaseRoomMember, Client, Result}; use crate::{BaseRoomMember, Client, Result};
/// The high-level `RoomMember` representation /// The high-level `RoomMember` representation
@ -21,10 +21,7 @@ impl Deref for RoomMember {
impl RoomMember { impl RoomMember {
pub(crate) fn new(client: Client, member: BaseRoomMember) -> Self { pub(crate) fn new(client: Client, member: BaseRoomMember) -> Self {
Self { Self { inner: member, client }
inner: member,
client,
}
} }
/// Gets the avatar of this member, if set. /// Gets the avatar of this member, if set.

View File

@ -26,11 +26,7 @@ impl AppserviceEventHandler {
#[async_trait] #[async_trait]
impl EventHandler for AppserviceEventHandler { impl EventHandler for AppserviceEventHandler {
async fn on_room_member(&self, room: Room, event: &SyncStateEvent<MemberEventContent>) { async fn on_room_member(&self, room: Room, event: &SyncStateEvent<MemberEventContent>) {
if !self if !self.appservice.user_id_is_in_namespace(&event.state_key).unwrap() {
.appservice
.user_id_is_in_namespace(&event.state_key)
.unwrap()
{
dbg!("not an appservice user"); dbg!("not an appservice user");
return; return;
} }
@ -38,11 +34,7 @@ impl EventHandler for AppserviceEventHandler {
if let MembershipState::Invite = event.content.membership { if let MembershipState::Invite = event.content.membership {
let user_id = UserId::try_from(event.state_key.clone()).unwrap(); let user_id = UserId::try_from(event.state_key.clone()).unwrap();
let client = self let client = self.appservice.client_with_localpart(user_id.localpart()).await.unwrap();
.appservice
.client_with_localpart(user_id.localpart())
.await
.unwrap();
client.join_room_by_id(room.room_id()).await.unwrap(); client.join_room_by_id(room.room_id()).await.unwrap();
} }
@ -51,10 +43,7 @@ impl EventHandler for AppserviceEventHandler {
#[actix_web::main] #[actix_web::main]
pub async fn main() -> std::io::Result<()> { pub async fn main() -> std::io::Result<()> {
env::set_var( env::set_var("RUST_LOG", "actix_web=debug,actix_server=info,matrix_sdk=debug");
"RUST_LOG",
"actix_web=debug,actix_server=info,matrix_sdk=debug",
);
tracing_subscriber::fmt::init(); tracing_subscriber::fmt::init();
let homeserver_url = "http://localhost:8008"; let homeserver_url = "http://localhost:8008";
@ -62,16 +51,11 @@ pub async fn main() -> std::io::Result<()> {
let registration = let registration =
AppserviceRegistration::try_from_yaml_file("./tests/registration.yaml").unwrap(); AppserviceRegistration::try_from_yaml_file("./tests/registration.yaml").unwrap();
let appservice = Appservice::new(homeserver_url, server_name, registration) let appservice = Appservice::new(homeserver_url, server_name, registration).await.unwrap();
.await
.unwrap();
let event_handler = AppserviceEventHandler::new(appservice.clone()); let event_handler = AppserviceEventHandler::new(appservice.clone());
appservice appservice.client().set_event_handler(Box::new(event_handler)).await;
.client()
.set_event_handler(Box::new(event_handler))
.await;
HttpServer::new(move || App::new().service(appservice.actix_service())) HttpServer::new(move || App::new().service(appservice.actix_service()))
.bind(("0.0.0.0", 8090))? .bind(("0.0.0.0", 8090))?

View File

@ -17,6 +17,7 @@ use std::{
pin::Pin, pin::Pin,
}; };
pub use actix_web::Scope;
use actix_web::{ use actix_web::{
dev::Payload, dev::Payload,
error::PayloadError, error::PayloadError,
@ -30,8 +31,6 @@ use futures::Future;
use futures_util::{TryFutureExt, TryStreamExt}; use futures_util::{TryFutureExt, TryStreamExt};
use matrix_sdk::api_appservice as api; use matrix_sdk::api_appservice as api;
pub use actix_web::Scope;
use crate::{error::Error, Appservice}; use crate::{error::Error, Appservice};
pub async fn run_server( pub async fn run_server(
@ -53,10 +52,7 @@ pub fn get_scope() -> Scope {
} }
fn gen_scope(scope: &str) -> Scope { fn gen_scope(scope: &str) -> Scope {
web::scope(scope) web::scope(scope).service(push_transactions).service(query_user_id).service(query_room_alias)
.service(push_transactions)
.service(query_user_id)
.service(query_room_alias)
} }
#[tracing::instrument] #[tracing::instrument]
@ -69,11 +65,7 @@ async fn push_transactions(
return Ok(HttpResponse::Unauthorized().finish()); return Ok(HttpResponse::Unauthorized().finish());
} }
appservice appservice.client().receive_transaction(request.incoming).await.unwrap();
.client()
.receive_transaction(request.incoming)
.await
.unwrap();
Ok(HttpResponse::Ok().json("{}")) Ok(HttpResponse::Ok().json("{}"))
} }
@ -136,13 +128,9 @@ impl<T: matrix_sdk::IncomingRequest> FromRequest for IncomingRequest<T> {
uri uri
}; };
let mut builder = http::request::Builder::new() let mut builder = http::request::Builder::new().method(request.method()).uri(uri);
.method(request.method())
.uri(uri);
let headers = builder let headers = builder.headers_mut().ok_or(Error::UnknownHttpRequestBuilder)?;
.headers_mut()
.ok_or(Error::UnknownHttpRequestBuilder)?;
for (key, value) in request.headers().iter() { for (key, value) in request.headers().iter() {
headers.append(key, value.to_owned()); headers.append(key, value.to_owned());
} }
@ -158,10 +146,7 @@ impl<T: matrix_sdk::IncomingRequest> FromRequest for IncomingRequest<T> {
let access_token = match request.uri().query() { let access_token = match request.uri().query() {
Some(query) => { Some(query) => {
let query: Vec<(String, String)> = matrix_sdk::urlencoded::from_str(query)?; let query: Vec<(String, String)> = matrix_sdk::urlencoded::from_str(query)?;
query query.into_iter().find(|(key, _)| key == "access_token").map(|(_, value)| value)
.into_iter()
.find(|(key, _)| key == "access_token")
.map(|(_, value)| value)
} }
None => None, None => None,
}; };

View File

@ -14,11 +14,14 @@
//! Matrix [Application Service] library //! Matrix [Application Service] library
//! //!
//! The appservice crate aims to provide a batteries-included experience. That means that we //! The appservice crate aims to provide a batteries-included experience. That
//! * ship with functionality to configure your webserver crate or simply run the webserver for you //! means that we
//! * ship with functionality to configure your webserver crate or simply run
//! the webserver for you
//! * receive and validate requests from the homeserver correctly //! * receive and validate requests from the homeserver correctly
//! * allow calling the homeserver with proper virtual user identity assertion //! * allow calling the homeserver with proper virtual user identity assertion
//! * have the goal to have a consistent room state available by leveraging the stores that the matrix-sdk provides //! * have the goal to have a consistent room state available by leveraging the
//! stores that the matrix-sdk provides
//! //!
//! # Quickstart //! # Quickstart
//! //!
@ -62,6 +65,8 @@ use std::{
}; };
use http::Uri; use http::Uri;
#[doc(inline)]
pub use matrix_sdk::api_appservice as api;
use matrix_sdk::{ use matrix_sdk::{
api::{ api::{
error::ErrorKind, error::ErrorKind,
@ -81,9 +86,6 @@ use regex::Regex;
use tracing::error; use tracing::error;
use tracing::warn; use tracing::warn;
#[doc(inline)]
pub use matrix_sdk::api_appservice as api;
#[cfg(feature = "actix")] #[cfg(feature = "actix")]
mod actix; mod actix;
mod error; mod error;
@ -104,9 +106,7 @@ impl AppserviceRegistration {
/// ///
/// See the fields of [`Registration`] for the required format /// See the fields of [`Registration`] for the required format
pub fn try_from_yaml_str(value: impl AsRef<str>) -> Result<Self> { pub fn try_from_yaml_str(value: impl AsRef<str>) -> Result<Self> {
Ok(Self { Ok(Self { inner: serde_yaml::from_str(value.as_ref())? })
inner: serde_yaml::from_str(value.as_ref())?,
})
} }
/// Try to load registration from yaml file /// Try to load registration from yaml file
@ -115,9 +115,7 @@ impl AppserviceRegistration {
pub fn try_from_yaml_file(path: impl Into<PathBuf>) -> Result<Self> { pub fn try_from_yaml_file(path: impl Into<PathBuf>) -> Result<Self> {
let file = File::open(path.into())?; let file = File::open(path.into())?;
Ok(Self { Ok(Self { inner: serde_yaml::from_reader(file)? })
inner: serde_yaml::from_reader(file)?,
})
} }
} }
@ -177,8 +175,10 @@ impl Appservice {
/// # Arguments /// # Arguments
/// ///
/// * `homeserver_url` - The homeserver that the client should connect to. /// * `homeserver_url` - The homeserver that the client should connect to.
/// * `server_name` - The server name to use when constructing user ids from the localpart. /// * `server_name` - The server name to use when constructing user ids from
/// * `registration` - The [Appservice Registration] to use when interacting with the homserver. /// the localpart.
/// * `registration` - The [Appservice Registration] to use when interacting
/// with the homserver.
/// ///
/// [Appservice Registration]: https://matrix.org/docs/spec/application_service/r0.1.2#registration /// [Appservice Registration]: https://matrix.org/docs/spec/application_service/r0.1.2#registration
pub async fn new( pub async fn new(
@ -209,8 +209,9 @@ impl Appservice {
/// Get `Client` for the given `localpart` /// Get `Client` for the given `localpart`
/// ///
/// If the `localpart` is covered by the `namespaces` in the [registration] all requests to the /// If the `localpart` is covered by the `namespaces` in the [registration]
/// homeserver will [assert the identity] to the according virtual user. /// all requests to the homeserver will [assert the identity] to the
/// according virtual user.
/// ///
/// [registration]: https://matrix.org/docs/spec/application_service/r0.1.2#registration /// [registration]: https://matrix.org/docs/spec/application_service/r0.1.2#registration
/// [assert the identity]: /// [assert the identity]:
@ -291,7 +292,8 @@ impl Appservice {
/// Get the host and port from the registration URL /// Get the host and port from the registration URL
/// ///
/// If no port is found it falls back to scheme defaults: 80 for http and 443 for https /// If no port is found it falls back to scheme defaults: 80 for http and
/// 443 for https
pub fn get_host_and_port_from_registration(&self) -> Result<(Host, Port)> { pub fn get_host_and_port_from_registration(&self) -> Result<(Host, Port)> {
let uri = Uri::try_from(&self.registration.url)?; let uri = Uri::try_from(&self.registration.url)?;
@ -315,9 +317,11 @@ impl Appservice {
actix::get_scope().data(self.clone()) actix::get_scope().data(self.clone())
} }
/// Convenience method that runs an http server depending on the selected server feature /// Convenience method that runs an http server depending on the selected
/// server feature
/// ///
/// This is a blocking call that tries to listen on the provided host and port /// This is a blocking call that tries to listen on the provided host and
/// port
pub async fn run(&self, host: impl AsRef<str>, port: impl Into<u16>) -> Result<()> { pub async fn run(&self, host: impl AsRef<str>, port: impl Into<u16>) -> Result<()> {
#[cfg(feature = "actix")] #[cfg(feature = "actix")]
{ {

View File

@ -1,14 +1,12 @@
#[cfg(feature = "actix")] #[cfg(feature = "actix")]
mod actix { mod actix {
use actix_web::{test, App};
use matrix_sdk_appservice::*;
use std::env; use std::env;
use actix_web::{test, App};
use matrix_sdk_appservice::*;
async fn appservice() -> Appservice { async fn appservice() -> Appservice {
env::set_var( env::set_var("RUST_LOG", "mockito=debug,matrix_sdk=debug,ruma=debug,actix_web=debug");
"RUST_LOG",
"mockito=debug,matrix_sdk=debug,ruma=debug,actix_web=debug",
);
let _ = tracing_subscriber::fmt::try_init(); let _ = tracing_subscriber::fmt::try_init();
Appservice::new( Appservice::new(
@ -109,7 +107,8 @@ mod actix {
let resp = test::call_service(&app, req).await; let resp = test::call_service(&app, req).await;
// TODO: this should actually return a 401 but is 500 because something in the extractor fails // TODO: this should actually return a 401 but is 500 because something in the
// extractor fails
assert_eq!(resp.status(), 500); assert_eq!(resp.status(), 500);
} }
} }

View File

@ -76,10 +76,7 @@ async fn test_event_handler() -> Result<()> {
} }
} }
appservice appservice.client().set_event_handler(Box::new(Example::new())).await;
.client()
.set_event_handler(Box::new(Example::new()))
.await;
let event = serde_json::from_value::<AnyStateEvent>(member_json()).unwrap(); let event = serde_json::from_value::<AnyStateEvent>(member_json()).unwrap();
let event: Raw<AnyRoomEvent> = AnyRoomEvent::State(event).into(); let event: Raw<AnyRoomEvent> = AnyRoomEvent::State(event).into();

View File

@ -1,12 +1,15 @@
use std::{convert::TryFrom, fmt::Debug, sync::Arc}; use std::{convert::TryFrom, fmt::Debug, sync::Arc};
use futures::executor::block_on;
use serde::Serialize;
#[cfg(not(target_arch = "wasm32"))] #[cfg(not(target_arch = "wasm32"))]
use atty::Stream; use atty::Stream;
#[cfg(not(target_arch = "wasm32"))] #[cfg(not(target_arch = "wasm32"))]
use clap::{App as Argparse, AppSettings as ArgParseSettings, Arg, ArgMatches, SubCommand}; use clap::{App as Argparse, AppSettings as ArgParseSettings, Arg, ArgMatches, SubCommand};
use futures::executor::block_on;
use matrix_sdk_base::{
events::EventType,
identifiers::{RoomId, UserId},
RoomInfo, Store,
};
#[cfg(not(target_arch = "wasm32"))] #[cfg(not(target_arch = "wasm32"))]
use rustyline::{ use rustyline::{
completion::{Completer, Pair}, completion::{Completer, Pair},
@ -18,7 +21,7 @@ use rustyline::{
}; };
#[cfg(not(target_arch = "wasm32"))] #[cfg(not(target_arch = "wasm32"))]
use rustyline_derive::Helper; use rustyline_derive::Helper;
use serde::Serialize;
#[cfg(not(target_arch = "wasm32"))] #[cfg(not(target_arch = "wasm32"))]
use syntect::{ use syntect::{
dumps::from_binary, dumps::from_binary,
@ -28,12 +31,6 @@ use syntect::{
util::{as_24_bit_terminal_escaped, LinesWithEndings}, util::{as_24_bit_terminal_escaped, LinesWithEndings},
}; };
use matrix_sdk_base::{
events::EventType,
identifiers::{RoomId, UserId},
RoomInfo, Store,
};
#[derive(Clone)] #[derive(Clone)]
#[cfg(not(target_arch = "wasm32"))] #[cfg(not(target_arch = "wasm32"))]
struct Inspector { struct Inspector {
@ -79,17 +76,8 @@ impl InspectorHelper {
fn complete_event_types(&self, arg: Option<&&str>) -> Vec<Pair> { fn complete_event_types(&self, arg: Option<&&str>) -> Vec<Pair> {
Self::EVENT_TYPES Self::EVENT_TYPES
.iter() .iter()
.map(|t| Pair { .map(|t| Pair { display: t.to_string(), replacement: format!("{} ", t) })
display: t.to_string(), .filter(|r| if let Some(arg) = arg { r.replacement.starts_with(arg) } else { true })
replacement: format!("{} ", t),
})
.filter(|r| {
if let Some(arg) = arg {
r.replacement.starts_with(arg)
} else {
true
}
})
.collect() .collect()
} }
@ -102,13 +90,7 @@ impl InspectorHelper {
display: r.room_id.to_string(), display: r.room_id.to_string(),
replacement: format!("{} ", r.room_id.to_string()), replacement: format!("{} ", r.room_id.to_string()),
}) })
.filter(|r| { .filter(|r| if let Some(arg) = arg { r.replacement.starts_with(arg) } else { true })
if let Some(arg) = arg {
r.replacement.starts_with(arg)
} else {
true
}
})
.collect() .collect()
} }
} }
@ -127,15 +109,9 @@ impl Completer for InspectorHelper {
let commands = vec![ let commands = vec![
("get-state", "get a state event in the given room"), ("get-state", "get a state event in the given room"),
( ("get-profiles", "get all the stored profiles in the given room"),
"get-profiles",
"get all the stored profiles in the given room",
),
("list-rooms", "list all rooms"), ("list-rooms", "list all rooms"),
( ("get-members", "get all the membership events in the given room"),
"get-members",
"get all the membership events in the given room",
),
] ]
.iter() .iter()
.map(|(r, d)| Pair { .map(|(r, d)| Pair {
@ -154,19 +130,13 @@ impl Completer for InspectorHelper {
} else { } else {
Ok(( Ok((
0, 0,
commands commands.into_iter().filter(|c| c.replacement.starts_with(args[0])).collect(),
.into_iter()
.filter(|c| c.replacement.starts_with(args[0]))
.collect(),
)) ))
} }
} else if args.len() == 2 { } else if args.len() == 2 {
if args[0] == "get-state" { if args[0] == "get-state" {
if line.ends_with(' ') { if line.ends_with(' ') {
Ok(( Ok((args[0].len() + args[1].len() + 2, self.complete_event_types(args.get(2))))
args[0].len() + args[1].len() + 2,
self.complete_event_types(args.get(2)),
))
} else { } else {
Ok((args[0].len() + 1, self.complete_rooms(args.get(1)))) Ok((args[0].len() + 1, self.complete_rooms(args.get(1))))
} }
@ -177,10 +147,7 @@ impl Completer for InspectorHelper {
} }
} else if args.len() == 3 { } else if args.len() == 3 {
if args[0] == "get-state" { if args[0] == "get-state" {
Ok(( Ok((args[0].len() + args[1].len() + 2, self.complete_event_types(args.get(2))))
args[0].len() + args[1].len() + 2,
self.complete_event_types(args.get(2)),
))
} else { } else {
Ok((pos, vec![])) Ok((pos, vec![]))
} }
@ -216,12 +183,7 @@ impl Printer {
let syntax_set: SyntaxSet = from_binary(include_bytes!("./syntaxes.bin")); let syntax_set: SyntaxSet = from_binary(include_bytes!("./syntaxes.bin"));
let themes: ThemeSet = from_binary(include_bytes!("./themes.bin")); let themes: ThemeSet = from_binary(include_bytes!("./themes.bin"));
Self { Self { ps: syntax_set.into(), ts: themes.into(), json, color }
ps: syntax_set.into(),
ts: themes.into(),
json,
color,
}
} }
fn pretty_print_struct<T: Debug + Serialize>(&self, data: &T) { fn pretty_print_struct<T: Debug + Serialize>(&self, data: &T) {
@ -232,13 +194,9 @@ impl Printer {
}; };
let syntax = if self.json { let syntax = if self.json {
self.ps self.ps.find_syntax_by_extension("rs").expect("Can't find rust syntax extension")
.find_syntax_by_extension("rs")
.expect("Can't find rust syntax extension")
} else { } else {
self.ps self.ps.find_syntax_by_extension("json").expect("Can't find json syntax extension")
.find_syntax_by_extension("json")
.expect("Can't find json syntax extension")
}; };
if self.color { if self.color {
@ -305,11 +263,7 @@ impl Inspector {
} }
async fn get_display_name_owners(&self, room_id: RoomId, display_name: String) { async fn get_display_name_owners(&self, room_id: RoomId, display_name: String) {
let users = self let users = self.store.get_users_with_display_name(&room_id, &display_name).await.unwrap();
.store
.get_users_with_display_name(&room_id, &display_name)
.await
.unwrap();
self.printer.pretty_print_struct(&users); self.printer.pretty_print_struct(&users);
} }
@ -326,22 +280,14 @@ impl Inspector {
let joined: Vec<UserId> = self.store.get_joined_user_ids(&room_id).await.unwrap(); let joined: Vec<UserId> = self.store.get_joined_user_ids(&room_id).await.unwrap();
for member in joined { for member in joined {
let event = self let event = self.store.get_member_event(&room_id, &member).await.unwrap();
.store
.get_member_event(&room_id, &member)
.await
.unwrap();
self.printer.pretty_print_struct(&event); self.printer.pretty_print_struct(&event);
} }
} }
async fn get_state(&self, room_id: RoomId, event_type: EventType) { async fn get_state(&self, room_id: RoomId, event_type: EventType) {
self.printer.pretty_print_struct( self.printer.pretty_print_struct(
&self &self.store.get_state_event(&room_id, event_type, "").await.unwrap(),
.store
.get_state_event(&room_id, event_type, "")
.await
.unwrap(),
); );
} }
@ -350,35 +296,25 @@ impl Inspector {
SubCommand::with_name("list-rooms"), SubCommand::with_name("list-rooms"),
SubCommand::with_name("get-members").arg( SubCommand::with_name("get-members").arg(
Arg::with_name("room-id").required(true).validator(|r| { Arg::with_name("room-id").required(true).validator(|r| {
RoomId::try_from(r) RoomId::try_from(r).map(|_| ()).map_err(|_| "Invalid room id given".to_owned())
.map(|_| ())
.map_err(|_| "Invalid room id given".to_owned())
}), }),
), ),
SubCommand::with_name("get-profiles").arg( SubCommand::with_name("get-profiles").arg(
Arg::with_name("room-id").required(true).validator(|r| { Arg::with_name("room-id").required(true).validator(|r| {
RoomId::try_from(r) RoomId::try_from(r).map(|_| ()).map_err(|_| "Invalid room id given".to_owned())
.map(|_| ())
.map_err(|_| "Invalid room id given".to_owned())
}), }),
), ),
SubCommand::with_name("get-display-names") SubCommand::with_name("get-display-names")
.arg(Arg::with_name("room-id").required(true).validator(|r| { .arg(Arg::with_name("room-id").required(true).validator(|r| {
RoomId::try_from(r) RoomId::try_from(r).map(|_| ()).map_err(|_| "Invalid room id given".to_owned())
.map(|_| ())
.map_err(|_| "Invalid room id given".to_owned())
})) }))
.arg(Arg::with_name("display-name").required(true)), .arg(Arg::with_name("display-name").required(true)),
SubCommand::with_name("get-state") SubCommand::with_name("get-state")
.arg(Arg::with_name("room-id").required(true).validator(|r| { .arg(Arg::with_name("room-id").required(true).validator(|r| {
RoomId::try_from(r) RoomId::try_from(r).map(|_| ()).map_err(|_| "Invalid room id given".to_owned())
.map(|_| ())
.map_err(|_| "Invalid room id given".to_owned())
})) }))
.arg(Arg::with_name("event-type").required(true).validator(|e| { .arg(Arg::with_name("event-type").required(true).validator(|e| {
EventType::try_from(e) EventType::try_from(e).map(|_| ()).map_err(|_| "Invalid event type".to_string())
.map(|_| ())
.map_err(|_| "Invalid event type".to_string())
})), })),
] ]
} }

View File

@ -87,20 +87,21 @@ pub struct AdditionalUnsignedData {
pub prev_content: Option<Raw<MemberEventContent>>, pub prev_content: Option<Raw<MemberEventContent>>,
} }
/// Transform state event by hoisting `prev_content` field from `unsigned` to the top level. /// Transform state event by hoisting `prev_content` field from `unsigned` to
/// the top level.
/// ///
/// Due to a [bug in synapse][synapse-bug], `prev_content` often ends up in `unsigned` contrary to /// Due to a [bug in synapse][synapse-bug], `prev_content` often ends up in
/// the C2S spec. Some more discussion can be found [here][discussion]. Until this is fixed in /// `unsigned` contrary to the C2S spec. Some more discussion can be found
/// synapse or handled in Ruma, we use this to hoist up `prev_content` to the top level. /// [here][discussion]. Until this is fixed in synapse or handled in Ruma, we
/// use this to hoist up `prev_content` to the top level.
/// ///
/// [synapse-bug]: <https://github.com/matrix-org/matrix-doc/issues/684#issuecomment-641182668> /// [synapse-bug]: <https://github.com/matrix-org/matrix-doc/issues/684#issuecomment-641182668>
/// [discussion]: <https://github.com/matrix-org/matrix-doc/issues/684#issuecomment-641182668> /// [discussion]: <https://github.com/matrix-org/matrix-doc/issues/684#issuecomment-641182668>
pub fn hoist_and_deserialize_state_event( pub fn hoist_and_deserialize_state_event(
event: &Raw<AnySyncStateEvent>, event: &Raw<AnySyncStateEvent>,
) -> StdResult<AnySyncStateEvent, serde_json::Error> { ) -> StdResult<AnySyncStateEvent, serde_json::Error> {
let prev_content = serde_json::from_str::<AdditionalEventData>(event.json().get())? let prev_content =
.unsigned serde_json::from_str::<AdditionalEventData>(event.json().get())?.unsigned.prev_content;
.prev_content;
let mut ev = event.deserialize()?; let mut ev = event.deserialize()?;
@ -116,9 +117,8 @@ pub fn hoist_and_deserialize_state_event(
fn hoist_member_event( fn hoist_member_event(
event: &Raw<StateEvent<MemberEventContent>>, event: &Raw<StateEvent<MemberEventContent>>,
) -> StdResult<StateEvent<MemberEventContent>, serde_json::Error> { ) -> StdResult<StateEvent<MemberEventContent>, serde_json::Error> {
let prev_content = serde_json::from_str::<AdditionalEventData>(event.json().get())? let prev_content =
.unsigned serde_json::from_str::<AdditionalEventData>(event.json().get())?.unsigned.prev_content;
.prev_content;
let mut e = event.deserialize()?; let mut e = event.deserialize()?;
@ -340,7 +340,8 @@ impl BaseClient {
/// ///
/// # Arguments /// # Arguments
/// ///
/// * `response` - A successful login response that contains our access token /// * `response` - A successful login response that contains our access
/// token
/// and device id. /// and device id.
pub async fn receive_login_response( pub async fn receive_login_response(
&self, &self,
@ -440,9 +441,7 @@ impl BaseClient {
AnySyncRoomEvent::State(s) => match s { AnySyncRoomEvent::State(s) => match s {
AnySyncStateEvent::RoomMember(member) => { AnySyncStateEvent::RoomMember(member) => {
if let Ok(member) = MemberEvent::try_from(member.clone()) { if let Ok(member) = MemberEvent::try_from(member.clone()) {
ambiguity_cache ambiguity_cache.handle_event(changes, room_id, &member).await?;
.handle_event(changes, room_id, &member)
.await?;
match member.content.membership { match member.content.membership {
MembershipState::Join | MembershipState::Invite => { MembershipState::Join | MembershipState::Invite => {
@ -500,8 +499,7 @@ impl BaseClient {
} }
if let Some(context) = &mut push_context { if let Some(context) = &mut push_context {
self.update_push_room_context(context, user_id, room_info, changes) self.update_push_room_context(context, user_id, room_info, changes).await;
.await;
} else { } else {
push_context = self.get_push_room_context(room, room_info, changes).await?; push_context = self.get_push_room_context(room, room_info, changes).await?;
} }
@ -521,10 +519,13 @@ impl BaseClient {
), ),
); );
} }
// TODO if there is an Action::SetTweak(Tweak::Highlight) we need to store // TODO if there is an
// its value with the event so a client can show if the event is highlighted // Action::SetTweak(Tweak::Highlight) we need to store
// its value with the event so a client can show if the
// event is highlighted
// in the UI. // in the UI.
// Requires the possibility to associate custom data with events and to // Requires the possibility to associate custom data
// with events and to
// store them. // store them.
} }
} }
@ -762,18 +763,14 @@ impl BaseClient {
let mut changes = StateChanges::new(next_batch.clone()); let mut changes = StateChanges::new(next_batch.clone());
let mut ambiguity_cache = AmbiguityCache::new(self.store.clone()); let mut ambiguity_cache = AmbiguityCache::new(self.store.clone());
self.handle_account_data(&account_data.events, &mut changes) self.handle_account_data(&account_data.events, &mut changes).await;
.await;
let push_rules = self.get_push_rules(&changes).await?; let push_rules = self.get_push_rules(&changes).await?;
let mut new_rooms = Rooms::default(); let mut new_rooms = Rooms::default();
for (room_id, new_info) in rooms.join { for (room_id, new_info) in rooms.join {
let room = self let room = self.store.get_or_create_room(&room_id, RoomType::Joined).await;
.store
.get_or_create_room(&room_id, RoomType::Joined)
.await;
let mut room_info = room.clone_info(); let mut room_info = room.clone_info();
room_info.mark_as_joined(); room_info.mark_as_joined();
@ -844,10 +841,7 @@ impl BaseClient {
} }
for (room_id, new_info) in rooms.leave { for (room_id, new_info) in rooms.leave {
let room = self let room = self.store.get_or_create_room(&room_id, RoomType::Left).await;
.store
.get_or_create_room(&room_id, RoomType::Left)
.await;
let mut room_info = room.clone_info(); let mut room_info = room.clone_info();
room_info.mark_as_left(); room_info.mark_as_left();
@ -876,18 +870,14 @@ impl BaseClient {
.await; .await;
changes.add_room(room_info); changes.add_room(room_info);
new_rooms.leave.insert( new_rooms
room_id, .leave
LeftRoom::new(timeline, new_info.state, new_info.account_data), .insert(room_id, LeftRoom::new(timeline, new_info.state, new_info.account_data));
);
} }
for (room_id, new_info) in rooms.invite { for (room_id, new_info) in rooms.invite {
{ {
let room = self let room = self.store.get_or_create_room(&room_id, RoomType::Invited).await;
.store
.get_or_create_room(&room_id, RoomType::Invited)
.await;
let mut room_info = room.clone_info(); let mut room_info = room.clone_info();
room_info.mark_as_invited(); room_info.mark_as_invited();
changes.add_room(room_info); changes.add_room(room_info);
@ -934,9 +924,7 @@ impl BaseClient {
.into_iter() .into_iter()
.map(|(k, v)| (k, v.into())) .map(|(k, v)| (k, v.into()))
.collect(), .collect(),
ambiguity_changes: AmbiguityChanges { ambiguity_changes: AmbiguityChanges { changes: ambiguity_cache.changes },
changes: ambiguity_cache.changes,
},
notifications: changes.notifications, notifications: changes.notifications,
}; };
@ -968,11 +956,7 @@ impl BaseClient {
let members: Vec<MemberEvent> = response let members: Vec<MemberEvent> = response
.chunk .chunk
.iter() .iter()
.filter_map(|e| { .filter_map(|e| hoist_member_event(e).ok().and_then(|e| MemberEvent::try_from(e).ok()))
hoist_member_event(e)
.ok()
.and_then(|e| MemberEvent::try_from(e).ok())
})
.collect(); .collect();
let mut ambiguity_cache = AmbiguityCache::new(self.store.clone()); let mut ambiguity_cache = AmbiguityCache::new(self.store.clone());
@ -986,12 +970,7 @@ impl BaseClient {
let mut user_ids = BTreeSet::new(); let mut user_ids = BTreeSet::new();
for member in &members { for member in &members {
if self if self.store.get_member_event(&room_id, &member.state_key).await?.is_none() {
.store
.get_member_event(&room_id, &member.state_key)
.await?
.is_none()
{
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
match member.content.membership { match member.content.membership {
MembershipState::Join | MembershipState::Invite => { MembershipState::Join | MembershipState::Invite => {
@ -1000,9 +979,7 @@ impl BaseClient {
_ => (), _ => (),
} }
ambiguity_cache ambiguity_cache.handle_event(&changes, room_id, &member).await?;
.handle_event(&changes, room_id, &member)
.await?;
if member.state_key == member.sender { if member.state_key == member.sender {
changes changes
@ -1036,9 +1013,7 @@ impl BaseClient {
Ok(MembersResponse { Ok(MembersResponse {
chunk: members, chunk: members,
ambiguity_changes: AmbiguityChanges { ambiguity_changes: AmbiguityChanges { changes: ambiguity_cache.changes },
changes: ambiguity_cache.changes,
},
}) })
} }
@ -1050,7 +1025,8 @@ impl BaseClient {
/// ///
/// # Arguments /// # Arguments
/// ///
/// * `filter_name` - The name that should be used to persist the filter id in /// * `filter_name` - The name that should be used to persist the filter id
/// in
/// the store. /// the store.
/// ///
/// * `response` - The successful filter upload response containing the /// * `response` - The successful filter upload response containing the
@ -1062,10 +1038,7 @@ impl BaseClient {
filter_name: &str, filter_name: &str,
response: &api::filter::create_filter::Response, response: &api::filter::create_filter::Response,
) -> Result<()> { ) -> Result<()> {
Ok(self Ok(self.store.save_filter(filter_name, &response.filter_id).await?)
.store
.save_filter(filter_name, &response.filter_id)
.await?)
} }
/// Get the filter id of a previously uploaded filter. /// Get the filter id of a previously uploaded filter.
@ -1224,18 +1197,14 @@ impl BaseClient {
/// # Arguments /// # Arguments
/// ///
/// * `flow_id` - The unique id that identifies a interactive verification /// * `flow_id` - The unique id that identifies a interactive verification
/// flow. For in-room verifications this will be the event id of the /// flow. For in-room verifications this will be the event id of the
/// *m.key.verification.request* event that started the flow, for the /// *m.key.verification.request* event that started the flow, for the
/// to-device verification flows this will be the transaction id of the /// to-device verification flows this will be the transaction id of the
/// *m.key.verification.start* event. /// *m.key.verification.start* event.
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
#[cfg_attr(feature = "docs", doc(cfg(encryption)))] #[cfg_attr(feature = "docs", doc(cfg(encryption)))]
pub async fn get_verification(&self, flow_id: &str) -> Option<Sas> { pub async fn get_verification(&self, flow_id: &str) -> Option<Sas> {
self.olm self.olm.lock().await.as_ref().and_then(|o| o.get_verification(flow_id))
.lock()
.await
.as_ref()
.and_then(|o| o.get_verification(flow_id))
} }
/// Get a specific device of a user. /// Get a specific device of a user.
@ -1284,10 +1253,12 @@ impl BaseClient {
/// Get the user login session. /// Get the user login session.
/// ///
/// If the client is currently logged in, this will return a `matrix_sdk::Session` object which /// If the client is currently logged in, this will return a
/// can later be given to `restore_login`. /// `matrix_sdk::Session` object which can later be given to
/// `restore_login`.
/// ///
/// Returns a session object if the client is logged in. Otherwise returns `None`. /// Returns a session object if the client is logged in. Otherwise returns
/// `None`.
pub async fn get_session(&self) -> Option<Session> { pub async fn get_session(&self) -> Option<Session> {
self.session.read().await.clone() self.session.read().await.clone()
} }
@ -1349,8 +1320,9 @@ impl BaseClient {
/// Get the push rules. /// Get the push rules.
/// ///
/// Gets the push rules from `changes` if they have been updated, otherwise get them from the /// Gets the push rules from `changes` if they have been updated, otherwise
/// store. As a fallback, uses `Ruleset::server_default` if the user is logged in. /// get them from the store. As a fallback, uses
/// `Ruleset::server_default` if the user is logged in.
pub async fn get_push_rules(&self, changes: &StateChanges) -> Result<Ruleset> { pub async fn get_push_rules(&self, changes: &StateChanges) -> Result<Ruleset> {
if let Some(AnyGlobalAccountDataEvent::PushRules(event)) = changes if let Some(AnyGlobalAccountDataEvent::PushRules(event)) = changes
.account_data .account_data
@ -1374,11 +1346,11 @@ impl BaseClient {
/// Get the push context for the given room. /// Get the push context for the given room.
/// ///
/// Tries to get the data from `changes` or the up to date `room_info`. Loads the data from the /// Tries to get the data from `changes` or the up to date `room_info`.
/// store otherwise. /// Loads the data from the store otherwise.
/// ///
/// Returns `None` if some data couldn't be found. This should only happen in brand new rooms, /// Returns `None` if some data couldn't be found. This should only happen
/// while we process its state. /// in brand new rooms, while we process its state.
pub async fn get_push_room_context( pub async fn get_push_room_context(
&self, &self,
room: &Room, room: &Room,
@ -1390,16 +1362,10 @@ impl BaseClient {
let member_count = room_info.active_members_count(); let member_count = room_info.active_members_count();
let user_display_name = if let Some(member) = changes let user_display_name = if let Some(member) =
.members changes.members.get(room_id).and_then(|members| members.get(user_id))
.get(room_id)
.and_then(|members| members.get(user_id))
{ {
member member.content.displayname.clone().unwrap_or_else(|| user_id.localpart().to_owned())
.content
.displayname
.clone()
.unwrap_or_else(|| user_id.localpart().to_owned())
} else if let Some(member) = room.get_member(user_id).await? { } else if let Some(member) = room.get_member(user_id).await? {
member.name().to_owned() member.name().to_owned()
} else { } else {
@ -1449,16 +1415,10 @@ impl BaseClient {
push_rules.member_count = UInt::new(room_info.active_members_count()).unwrap_or(UInt::MAX); push_rules.member_count = UInt::new(room_info.active_members_count()).unwrap_or(UInt::MAX);
if let Some(member) = changes if let Some(member) = changes.members.get(room_id).and_then(|members| members.get(user_id))
.members
.get(room_id)
.and_then(|members| members.get(user_id))
{ {
push_rules.user_display_name = member push_rules.user_display_name =
.content member.content.displayname.clone().unwrap_or_else(|| user_id.localpart().to_owned())
.displayname
.clone()
.unwrap_or_else(|| user_id.localpart().to_owned())
} }
if let Some(AnySyncStateEvent::RoomPowerLevels(event)) = changes if let Some(AnySyncStateEvent::RoomPowerLevels(event)) = changes

View File

@ -15,12 +15,12 @@
//! Error conditions. //! Error conditions.
use serde_json::Error as JsonError;
use std::io::Error as IoError; use std::io::Error as IoError;
use thiserror::Error;
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
use matrix_sdk_crypto::{CryptoStoreError, MegolmError, OlmError}; use matrix_sdk_crypto::{CryptoStoreError, MegolmError, OlmError};
use serde_json::Error as JsonError;
use thiserror::Error;
/// Result type of the rust-sdk. /// Result type of the rust-sdk.
pub type Result<T, E = Error> = std::result::Result<T, E>; pub type Result<T, E = Error> = std::result::Result<T, E>;
@ -28,7 +28,8 @@ pub type Result<T, E = Error> = std::result::Result<T, E>;
/// Internal representation of errors. /// Internal representation of errors.
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub enum Error { pub enum Error {
/// Queried endpoint requires authentication but was called on an anonymous client. /// Queried endpoint requires authentication but was called on an anonymous
/// client.
#[error("the queried endpoint requires authentication but was called before logging in")] #[error("the queried endpoint requires authentication but was called before logging in")]
AuthenticationRequired, AuthenticationRequired,

View File

@ -36,11 +36,12 @@
)] )]
#![cfg_attr(feature = "docs", feature(doc_cfg))] #![cfg_attr(feature = "docs", feature(doc_cfg))]
pub use matrix_sdk_common::*;
pub use crate::{ pub use crate::{
error::{Error, Result}, error::{Error, Result},
session::Session, session::Session,
}; };
pub use matrix_sdk_common::*;
mod client; mod client;
mod error; mod error;
@ -48,11 +49,9 @@ mod rooms;
mod session; mod session;
mod store; mod store;
pub use rooms::{Room, RoomInfo, RoomMember, RoomType};
pub use store::{StateChanges, StateStore, Store, StoreError};
pub use client::{BaseClient, BaseClientConfig}; pub use client::{BaseClient, BaseClientConfig};
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
#[cfg_attr(feature = "docs", doc(cfg(encryption)))] #[cfg_attr(feature = "docs", doc(cfg(encryption)))]
pub use matrix_sdk_crypto as crypto; pub use matrix_sdk_crypto as crypto;
pub use rooms::{Room, RoomInfo, RoomMember, RoomType};
pub use store::{StateChanges, StateStore, Store, StoreError};

View File

@ -1,27 +1,22 @@
mod members; mod members;
mod normal; mod normal;
use matrix_sdk_common::{
events::room::{
create::CreateEventContent, guest_access::GuestAccess,
history_visibility::HistoryVisibility, join_rules::JoinRule,
},
identifiers::{MxcUri, UserId},
};
pub use normal::{Room, RoomInfo, RoomType};
pub use members::RoomMember;
use serde::{Deserialize, Serialize};
use std::cmp::max; use std::cmp::max;
use matrix_sdk_common::{ use matrix_sdk_common::{
events::{ events::{
room::{encryption::EncryptionEventContent, tombstone::TombstoneEventContent}, room::{
create::CreateEventContent, encryption::EncryptionEventContent,
guest_access::GuestAccess, history_visibility::HistoryVisibility, join_rules::JoinRule,
tombstone::TombstoneEventContent,
},
AnyStateEventContent, AnyStateEventContent,
}, },
identifiers::RoomAliasId, identifiers::{MxcUri, RoomAliasId, UserId},
}; };
pub use members::RoomMember;
pub use normal::{Room, RoomInfo, RoomType};
use serde::{Deserialize, Serialize};
/// A base room info struct that is the backbone of normal as well as stripped /// A base room info struct that is the backbone of normal as well as stripped
/// rooms. Holds all the state events that are important to present a room to /// rooms. Holds all the state events that are important to present a room to
@ -71,20 +66,12 @@ impl BaseRoomInfo {
let invited_joined = (invited_member_count + joined_member_count).saturating_sub(1); let invited_joined = (invited_member_count + joined_member_count).saturating_sub(1);
if heroes_count >= invited_joined { if heroes_count >= invited_joined {
let mut names = heroes let mut names = heroes.iter().take(3).map(|mem| mem.name()).collect::<Vec<&str>>();
.iter()
.take(3)
.map(|mem| mem.name())
.collect::<Vec<&str>>();
// stabilize ordering // stabilize ordering
names.sort_unstable(); names.sort_unstable();
names.join(", ") names.join(", ")
} else if heroes_count < invited_joined && invited_joined > 1 { } else if heroes_count < invited_joined && invited_joined > 1 {
let mut names = heroes let mut names = heroes.iter().take(3).map(|mem| mem.name()).collect::<Vec<&str>>();
.iter()
.take(3)
.map(|mem| mem.name())
.collect::<Vec<&str>>();
names.sort_unstable(); names.sort_unstable();
// TODO: What length does the spec want us to use here and in // TODO: What length does the spec want us to use here and in
@ -149,10 +136,8 @@ impl BaseRoomInfo {
true true
} }
AnyStateEventContent::RoomPowerLevels(p) => { AnyStateEventContent::RoomPowerLevels(p) => {
let max_power_level = p let max_power_level =
.users p.users.values().fold(self.max_power_level, |acc, p| max(acc, (*p).into()));
.values()
.fold(self.max_power_level, |acc, p| max(acc, (*p).into()));
self.max_power_level = max_power_level; self.max_power_level = max_power_level;
true true
} }

View File

@ -37,14 +37,14 @@ use matrix_sdk_common::{
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tracing::info; use tracing::info;
use super::{BaseRoomInfo, RoomMember};
use crate::{ use crate::{
deserialized_responses::UnreadNotificationsCount, deserialized_responses::UnreadNotificationsCount,
store::{Result as StoreResult, StateStore}, store::{Result as StoreResult, StateStore},
}; };
use super::{BaseRoomInfo, RoomMember}; /// The underlying room data structure collecting state for joined, left and
/// invtied rooms.
/// The underlying room data structure collecting state for joined, left and invtied rooms.
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Room { pub struct Room {
room_id: Arc<RoomId>, room_id: Arc<RoomId>,
@ -135,7 +135,8 @@ impl Room {
/// Check if the room has it's members fully synced. /// Check if the room has it's members fully synced.
/// ///
/// Members might be missing if lazy member loading was enabled for the sync. /// Members might be missing if lazy member loading was enabled for the
/// sync.
/// ///
/// Returns true if no members are missing, false otherwise. /// Returns true if no members are missing, false otherwise.
pub fn are_members_synced(&self) -> bool { pub fn are_members_synced(&self) -> bool {
@ -200,12 +201,7 @@ impl Room {
/// Get the history visibility policy of this room. /// Get the history visibility policy of this room.
pub fn history_visibility(&self) -> HistoryVisibility { pub fn history_visibility(&self) -> HistoryVisibility {
self.inner self.inner.read().unwrap().base_info.history_visibility.clone()
.read()
.unwrap()
.base_info
.history_visibility
.clone()
} }
/// Is the room considered to be public. /// Is the room considered to be public.
@ -367,9 +363,7 @@ impl Room {
); );
let inner = self.inner.read().unwrap(); let inner = self.inner.read().unwrap();
Ok(inner Ok(inner.base_info.calculate_room_name(joined, invited, members))
.base_info
.calculate_room_name(joined, invited, members))
} }
pub(crate) fn clone_info(&self) -> RoomInfo { pub(crate) fn clone_info(&self) -> RoomInfo {
@ -394,11 +388,8 @@ impl Room {
return Ok(None); return Ok(None);
}; };
let presence = self let presence =
.store self.store.get_presence_event(user_id).await?.and_then(|e| e.deserialize().ok());
.get_presence_event(user_id)
.await?
.and_then(|e| e.deserialize().ok());
let profile = self.store.get_profile(self.room_id(), user_id).await?; let profile = self.store.get_profile(self.room_id(), user_id).await?;
let max_power_level = self.max_power_level(); let max_power_level = self.max_power_level();
let is_room_creator = self let is_room_creator = self
@ -411,28 +402,24 @@ impl Room {
.map(|c| &c.creator == user_id) .map(|c| &c.creator == user_id)
.unwrap_or(false); .unwrap_or(false);
let power = self let power =
.store self.store
.get_state_event(self.room_id(), EventType::RoomPowerLevels, "") .get_state_event(self.room_id(), EventType::RoomPowerLevels, "")
.await? .await?
.and_then(|e| e.deserialize().ok()) .and_then(|e| e.deserialize().ok())
.and_then(|e| { .and_then(|e| {
if let AnySyncStateEvent::RoomPowerLevels(e) = e { if let AnySyncStateEvent::RoomPowerLevels(e) = e {
Some(e) Some(e)
} else { } else {
None None
} }
}); });
let ambiguous = self let ambiguous = self
.store .store
.get_users_with_display_name( .get_users_with_display_name(
self.room_id(), self.room_id(),
member_event member_event.content.displayname.as_deref().unwrap_or_else(|| user_id.localpart()),
.content
.displayname
.as_deref()
.unwrap_or_else(|| user_id.localpart()),
) )
.await? .await?
.len() .len()
@ -558,8 +545,6 @@ impl RoomInfo {
/// ///
/// The return value is saturated at `u64::MAX`. /// The return value is saturated at `u64::MAX`.
pub fn active_members_count(&self) -> u64 { pub fn active_members_count(&self) -> u64 {
self.summary self.summary.joined_member_count.saturating_add(self.summary.invited_member_count)
.joined_member_count
.saturating_add(self.summary.invited_member_count)
} }
} }

View File

@ -15,9 +15,8 @@
//! User sessions. //! User sessions.
use serde::{Deserialize, Serialize};
use matrix_sdk_common::identifiers::{DeviceId, UserId}; use matrix_sdk_common::identifiers::{DeviceId, UserId};
use serde::{Deserialize, Serialize};
/// A user session, containing an access token and information about the /// A user session, containing an access token and information about the
/// associated user account. /// associated user account.

View File

@ -19,12 +19,10 @@ use matrix_sdk_common::{
events::room::member::MembershipState, events::room::member::MembershipState,
identifiers::{EventId, RoomId, UserId}, identifiers::{EventId, RoomId, UserId},
}; };
use tracing::trace; use tracing::trace;
use crate::Store;
use super::{Result, StateChanges}; use super::{Result, StateChanges};
use crate::Store;
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct AmbiguityCache { pub struct AmbiguityCache {
@ -51,11 +49,8 @@ impl AmbiguityMap {
} }
fn add(&mut self, user_id: UserId) -> Option<UserId> { fn add(&mut self, user_id: UserId) -> Option<UserId> {
let ambiguous_user = if self.user_count() == 1 { let ambiguous_user =
self.users.iter().next().cloned() if self.user_count() == 1 { self.users.iter().next().cloned() } else { None };
} else {
None
};
self.users.insert(user_id); self.users.insert(user_id);
@ -73,11 +68,7 @@ impl AmbiguityMap {
impl AmbiguityCache { impl AmbiguityCache {
pub fn new(store: Store) -> Self { pub fn new(store: Store) -> Self {
Self { Self { store, cache: BTreeMap::new(), changes: BTreeMap::new() }
store,
cache: BTreeMap::new(),
changes: BTreeMap::new(),
}
} }
pub async fn handle_event( pub async fn handle_event(
@ -115,12 +106,9 @@ impl AmbiguityCache {
return Ok(()); return Ok(());
} }
let disambiguated_member = old_map let disambiguated_member = old_map.as_mut().and_then(|o| o.remove(&member_event.state_key));
.as_mut() let ambiguated_member =
.and_then(|o| o.remove(&member_event.state_key)); new_map.as_mut().and_then(|n| n.add(member_event.state_key.clone()));
let ambiguated_member = new_map
.as_mut()
.and_then(|n| n.add(member_event.state_key.clone()));
let ambiguous = new_map.as_ref().map(|n| n.is_ambiguous()).unwrap_or(false); let ambiguous = new_map.as_ref().map(|n| n.is_ambiguous()).unwrap_or(false);
self.update(room_id, old_map, new_map); self.update(room_id, old_map, new_map);
@ -131,11 +119,7 @@ impl AmbiguityCache {
member_ambiguous: ambiguous, member_ambiguous: ambiguous,
}; };
trace!( trace!("Handling display name ambiguity for {}: {:#?}", member_event.state_key, change);
"Handling display name ambiguity for {}: {:#?}",
member_event.state_key,
change
);
self.add_change(room_id, member_event.event_id.clone(), change); self.add_change(room_id, member_event.event_id.clone(), change);
@ -148,10 +132,7 @@ impl AmbiguityCache {
old_map: Option<AmbiguityMap>, old_map: Option<AmbiguityMap>,
new_map: Option<AmbiguityMap>, new_map: Option<AmbiguityMap>,
) { ) {
let entry = self let entry = self.cache.entry(room_id.clone()).or_insert_with(BTreeMap::new);
.cache
.entry(room_id.clone())
.or_insert_with(BTreeMap::new);
if let Some(old) = old_map { if let Some(old) = old_map {
entry.insert(old.display_name, old.users); entry.insert(old.display_name, old.users);
@ -163,10 +144,7 @@ impl AmbiguityCache {
} }
fn add_change(&mut self, room_id: &RoomId, event_id: EventId, change: AmbiguityChange) { fn add_change(&mut self, room_id: &RoomId, event_id: EventId, change: AmbiguityChange) {
self.changes self.changes.entry(room_id.clone()).or_insert_with(BTreeMap::new).insert(event_id, change);
.entry(room_id.clone())
.or_insert_with(BTreeMap::new)
.insert(event_id, change);
} }
async fn get( async fn get(
@ -177,16 +155,12 @@ impl AmbiguityCache {
) -> Result<(Option<AmbiguityMap>, Option<AmbiguityMap>)> { ) -> Result<(Option<AmbiguityMap>, Option<AmbiguityMap>)> {
use MembershipState::*; use MembershipState::*;
let old_event = if let Some(m) = changes let old_event = if let Some(m) =
.members changes.members.get(room_id).and_then(|m| m.get(&member_event.state_key))
.get(room_id)
.and_then(|m| m.get(&member_event.state_key))
{ {
Some(m.clone()) Some(m.clone())
} else { } else {
self.store self.store.get_member_event(room_id, &member_event.state_key).await?
.get_member_event(room_id, &member_event.state_key)
.await?
}; };
let old_display_name = if let Some(event) = old_event { let old_display_name = if let Some(event) = old_event {
@ -218,23 +192,15 @@ impl AmbiguityCache {
}; };
let old_map = if let Some(old_name) = old_display_name.as_deref() { let old_map = if let Some(old_name) = old_display_name.as_deref() {
let old_display_name_map = if let Some(u) = self let old_display_name_map = if let Some(u) =
.cache self.cache.entry(room_id.clone()).or_insert_with(BTreeMap::new).get(old_name)
.entry(room_id.clone())
.or_insert_with(BTreeMap::new)
.get(old_name)
{ {
u.clone() u.clone()
} else { } else {
self.store self.store.get_users_with_display_name(&room_id, &old_name).await?
.get_users_with_display_name(&room_id, &old_name)
.await?
}; };
Some(AmbiguityMap { Some(AmbiguityMap { display_name: old_name.to_string(), users: old_display_name_map })
display_name: old_name.to_string(),
users: old_display_name_map,
})
} else { } else {
None None
}; };
@ -246,8 +212,9 @@ impl AmbiguityCache {
.as_deref() .as_deref()
.unwrap_or_else(|| member_event.state_key.localpart()); .unwrap_or_else(|| member_event.state_key.localpart());
// We don't allow other users to set the display name, so if we have // We don't allow other users to set the display name, so if we
// a more trusted version of the display name use that. // have a more trusted version of the display
// name use that.
let new_display_name = if member_event.sender.as_str() == member_event.state_key { let new_display_name = if member_event.sender.as_str() == member_event.state_key {
new new
} else if let Some(old) = old_display_name.as_deref() { } else if let Some(old) = old_display_name.as_deref() {
@ -264,9 +231,7 @@ impl AmbiguityCache {
{ {
u.clone() u.clone()
} else { } else {
self.store self.store.get_users_with_display_name(&room_id, &new_display_name).await?
.get_users_with_display_name(&room_id, &new_display_name)
.await?
}; };
Some(AmbiguityMap { Some(AmbiguityMap {

View File

@ -30,12 +30,10 @@ use matrix_sdk_common::{
instant::Instant, instant::Instant,
Raw, Raw,
}; };
use tracing::info; use tracing::info;
use crate::deserialized_responses::{MemberEvent, StrippedMemberEvent};
use super::{Result, RoomInfo, StateChanges, StateStore}; use super::{Result, RoomInfo, StateChanges, StateStore};
use crate::deserialized_responses::{MemberEvent, StrippedMemberEvent};
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct MemoryStore { pub struct MemoryStore {
@ -82,8 +80,7 @@ impl MemoryStore {
} }
async fn save_filter(&self, filter_name: &str, filter_id: &str) -> Result<()> { async fn save_filter(&self, filter_name: &str, filter_id: &str) -> Result<()> {
self.filters self.filters.insert(filter_name.to_string(), filter_id.to_string());
.insert(filter_name.to_string(), filter_id.to_string());
Ok(()) Ok(())
} }
@ -164,8 +161,7 @@ impl MemoryStore {
} }
for (event_type, event) in &changes.account_data { for (event_type, event) in &changes.account_data {
self.account_data self.account_data.insert(event_type.to_string(), event.clone());
.insert(event_type.to_string(), event.clone());
} }
for (room, events) in &changes.room_account_data { for (room, events) in &changes.room_account_data {
@ -199,8 +195,7 @@ impl MemoryStore {
} }
for (room_id, info) in &changes.invited_room_info { for (room_id, info) in &changes.invited_room_info {
self.stripped_room_info self.stripped_room_info.insert(room_id.clone(), info.clone());
.insert(room_id.clone(), info.clone());
} }
for (room, events) in &changes.stripped_members { for (room, events) in &changes.stripped_members {
@ -243,8 +238,7 @@ impl MemoryStore {
) -> Result<Option<Raw<AnySyncStateEvent>>> { ) -> Result<Option<Raw<AnySyncStateEvent>>> {
#[allow(clippy::map_clone)] #[allow(clippy::map_clone)]
Ok(self.room_state.get(room_id).and_then(|e| { Ok(self.room_state.get(room_id).and_then(|e| {
e.get(event_type.as_ref()) e.get(event_type.as_ref()).and_then(|s| s.get(state_key).map(|e| e.clone()))
.and_then(|s| s.get(state_key).map(|e| e.clone()))
})) }))
} }
@ -254,10 +248,7 @@ impl MemoryStore {
user_id: &UserId, user_id: &UserId,
) -> Result<Option<MemberEventContent>> { ) -> Result<Option<MemberEventContent>> {
#[allow(clippy::map_clone)] #[allow(clippy::map_clone)]
Ok(self Ok(self.profiles.get(room_id).and_then(|p| p.get(user_id).map(|p| p.clone())))
.profiles
.get(room_id)
.and_then(|p| p.get(user_id).map(|p| p.clone())))
} }
async fn get_member_event( async fn get_member_event(
@ -266,10 +257,7 @@ impl MemoryStore {
state_key: &UserId, state_key: &UserId,
) -> Result<Option<MemberEvent>> { ) -> Result<Option<MemberEvent>> {
#[allow(clippy::map_clone)] #[allow(clippy::map_clone)]
Ok(self Ok(self.members.get(room_id).and_then(|m| m.get(state_key).map(|m| m.clone())))
.members
.get(room_id)
.and_then(|m| m.get(state_key).map(|m| m.clone())))
} }
fn get_user_ids(&self, room_id: &RoomId) -> Vec<UserId> { fn get_user_ids(&self, room_id: &RoomId) -> Vec<UserId> {
@ -310,10 +298,7 @@ impl MemoryStore {
&self, &self,
event_type: EventType, event_type: EventType,
) -> Result<Option<Raw<AnyGlobalAccountDataEvent>>> { ) -> Result<Option<Raw<AnyGlobalAccountDataEvent>>> {
Ok(self Ok(self.account_data.get(event_type.as_ref()).map(|e| e.clone()))
.account_data
.get(event_type.as_ref())
.map(|e| e.clone()))
} }
async fn get_room_account_data_event( async fn get_room_account_data_event(

View File

@ -12,15 +12,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#[cfg(feature = "sled_state_store")]
use std::path::Path;
use std::{ use std::{
collections::{BTreeMap, BTreeSet}, collections::{BTreeMap, BTreeSet},
ops::Deref, ops::Deref,
sync::Arc, sync::Arc,
}; };
#[cfg(feature = "sled_state_store")]
use std::path::Path;
use dashmap::DashMap; use dashmap::DashMap;
use matrix_sdk_common::{ use matrix_sdk_common::{
api::r0::push::get_notifications::Notification, api::r0::push::get_notifications::Notification,
@ -201,7 +200,8 @@ pub trait StateStore: AsyncTraitDeps {
/// ///
/// # Arguments /// # Arguments
/// ///
/// * `room_id` - The id of the room for which the room account data event should /// * `room_id` - The id of the room for which the room account data event
/// should
/// be fetched. /// be fetched.
/// ///
/// * `event_type` - The event type of the room account data event. /// * `event_type` - The event type of the room account data event.
@ -298,20 +298,16 @@ impl Store {
/// Get all the rooms this store knows about. /// Get all the rooms this store knows about.
pub fn get_rooms(&self) -> Vec<Room> { pub fn get_rooms(&self) -> Vec<Room> {
self.rooms self.rooms.iter().filter_map(|r| self.get_room(r.key())).collect()
.iter()
.filter_map(|r| self.get_room(r.key()))
.collect()
} }
/// Get the room with the given room id. /// Get the room with the given room id.
pub fn get_room(&self, room_id: &RoomId) -> Option<Room> { pub fn get_room(&self, room_id: &RoomId) -> Option<Room> {
self.get_bare_room(room_id) self.get_bare_room(room_id).and_then(|r| match r.room_type() {
.and_then(|r| match r.room_type() { RoomType::Joined => Some(r),
RoomType::Joined => Some(r), RoomType::Left => Some(r),
RoomType::Left => Some(r), RoomType::Invited => self.get_stripped_room(room_id),
RoomType::Invited => self.get_stripped_room(room_id), })
})
} }
fn get_stripped_room(&self, room_id: &RoomId) -> Option<Room> { fn get_stripped_room(&self, room_id: &RoomId) -> Option<Room> {
@ -321,10 +317,7 @@ impl Store {
pub(crate) async fn get_or_create_stripped_room(&self, room_id: &RoomId) -> Room { pub(crate) async fn get_or_create_stripped_room(&self, room_id: &RoomId) -> Room {
let session = self.session.read().await; let session = self.session.read().await;
let user_id = &session let user_id = &session.as_ref().expect("Creating room while not being logged in").user_id;
.as_ref()
.expect("Creating room while not being logged in")
.user_id;
self.stripped_rooms self.stripped_rooms
.entry(room_id.clone()) .entry(room_id.clone())
@ -334,10 +327,7 @@ impl Store {
pub(crate) async fn get_or_create_room(&self, room_id: &RoomId, room_type: RoomType) -> Room { pub(crate) async fn get_or_create_room(&self, room_id: &RoomId, room_type: RoomType) -> Room {
let session = self.session.read().await; let session = self.session.read().await;
let user_id = &session let user_id = &session.as_ref().expect("Creating room while not being logged in").user_id;
.as_ref()
.expect("Creating room while not being logged in")
.user_id;
self.rooms self.rooms
.entry(room_id.clone()) .entry(room_id.clone())
@ -359,7 +349,8 @@ impl Deref for Store {
pub struct StateChanges { pub struct StateChanges {
/// The sync token that relates to this update. /// The sync token that relates to this update.
pub sync_token: Option<String>, pub sync_token: Option<String>,
/// A user session, containing an access token and information about the associated user account. /// A user session, containing an access token and information about the
/// associated user account.
pub session: Option<Session>, pub session: Option<Session>,
/// A mapping of event type string to `AnyBasicEvent`. /// A mapping of event type string to `AnyBasicEvent`.
pub account_data: BTreeMap<String, Raw<AnyGlobalAccountDataEvent>>, pub account_data: BTreeMap<String, Raw<AnyGlobalAccountDataEvent>>,
@ -371,14 +362,16 @@ pub struct StateChanges {
/// A mapping of `RoomId` to a map of users and their `MemberEventContent`. /// A mapping of `RoomId` to a map of users and their `MemberEventContent`.
pub profiles: BTreeMap<RoomId, BTreeMap<UserId, MemberEventContent>>, pub profiles: BTreeMap<RoomId, BTreeMap<UserId, MemberEventContent>>,
/// A mapping of `RoomId` to a map of event type string to a state key and `AnySyncStateEvent`. /// A mapping of `RoomId` to a map of event type string to a state key and
/// `AnySyncStateEvent`.
pub state: BTreeMap<RoomId, BTreeMap<String, BTreeMap<String, Raw<AnySyncStateEvent>>>>, pub state: BTreeMap<RoomId, BTreeMap<String, BTreeMap<String, Raw<AnySyncStateEvent>>>>,
/// A mapping of `RoomId` to a map of event type string to `AnyBasicEvent`. /// A mapping of `RoomId` to a map of event type string to `AnyBasicEvent`.
pub room_account_data: BTreeMap<RoomId, BTreeMap<String, Raw<AnyRoomAccountDataEvent>>>, pub room_account_data: BTreeMap<RoomId, BTreeMap<String, Raw<AnyRoomAccountDataEvent>>>,
/// A map of `RoomId` to `RoomInfo`. /// A map of `RoomId` to `RoomInfo`.
pub room_infos: BTreeMap<RoomId, RoomInfo>, pub room_infos: BTreeMap<RoomId, RoomInfo>,
/// A mapping of `RoomId` to a map of event type to a map of state key to `AnyStrippedStateEvent`. /// A mapping of `RoomId` to a map of event type to a map of state key to
/// `AnyStrippedStateEvent`.
pub stripped_state: pub stripped_state:
BTreeMap<RoomId, BTreeMap<String, BTreeMap<String, Raw<AnyStrippedStateEvent>>>>, BTreeMap<RoomId, BTreeMap<String, BTreeMap<String, Raw<AnyStrippedStateEvent>>>>,
/// A mapping of `RoomId` to a map of users and their `StrippedMemberEvent`. /// A mapping of `RoomId` to a map of users and their `StrippedMemberEvent`.
@ -396,10 +389,7 @@ pub struct StateChanges {
impl StateChanges { impl StateChanges {
/// Create a new `StateChanges` struct with the given sync_token. /// Create a new `StateChanges` struct with the given sync_token.
pub fn new(sync_token: String) -> Self { pub fn new(sync_token: String) -> Self {
Self { Self { sync_token: Some(sync_token), ..Default::default() }
sync_token: Some(sync_token),
..Default::default()
}
} }
/// Update the `StateChanges` struct with the given `PresenceEvent`. /// Update the `StateChanges` struct with the given `PresenceEvent`.
@ -409,14 +399,12 @@ impl StateChanges {
/// Update the `StateChanges` struct with the given `RoomInfo`. /// Update the `StateChanges` struct with the given `RoomInfo`.
pub fn add_room(&mut self, room: RoomInfo) { pub fn add_room(&mut self, room: RoomInfo) {
self.room_infos self.room_infos.insert(room.room_id.as_ref().to_owned(), room);
.insert(room.room_id.as_ref().to_owned(), room);
} }
/// Update the `StateChanges` struct with the given `RoomInfo`. /// Update the `StateChanges` struct with the given `RoomInfo`.
pub fn add_stripped_room(&mut self, room: RoomInfo) { pub fn add_stripped_room(&mut self, room: RoomInfo) {
self.invited_room_info self.invited_room_info.insert(room.room_id.as_ref().to_owned(), room);
.insert(room.room_id.as_ref().to_owned(), room);
} }
/// Update the `StateChanges` struct with the given `AnyBasicEvent`. /// Update the `StateChanges` struct with the given `AnyBasicEvent`.
@ -425,11 +413,11 @@ impl StateChanges {
event: AnyGlobalAccountDataEvent, event: AnyGlobalAccountDataEvent,
raw_event: Raw<AnyGlobalAccountDataEvent>, raw_event: Raw<AnyGlobalAccountDataEvent>,
) { ) {
self.account_data self.account_data.insert(event.content().event_type().to_owned(), raw_event);
.insert(event.content().event_type().to_owned(), raw_event);
} }
/// Update the `StateChanges` struct with the given room with a new `AnyBasicEvent`. /// Update the `StateChanges` struct with the given room with a new
/// `AnyBasicEvent`.
pub fn add_room_account_data( pub fn add_room_account_data(
&mut self, &mut self,
room_id: &RoomId, room_id: &RoomId,
@ -442,7 +430,8 @@ impl StateChanges {
.insert(event.content().event_type().to_owned(), raw_event); .insert(event.content().event_type().to_owned(), raw_event);
} }
/// Update the `StateChanges` struct with the given room with a new `StrippedMemberEvent`. /// Update the `StateChanges` struct with the given room with a new
/// `StrippedMemberEvent`.
pub fn add_stripped_member(&mut self, room_id: &RoomId, event: StrippedMemberEvent) { pub fn add_stripped_member(&mut self, room_id: &RoomId, event: StrippedMemberEvent) {
let user_id = event.state_key.clone(); let user_id = event.state_key.clone();
@ -452,7 +441,8 @@ impl StateChanges {
.insert(user_id, event); .insert(user_id, event);
} }
/// Update the `StateChanges` struct with the given room with a new `AnySyncStateEvent`. /// Update the `StateChanges` struct with the given room with a new
/// `AnySyncStateEvent`.
pub fn add_state_event( pub fn add_state_event(
&mut self, &mut self,
room_id: &RoomId, room_id: &RoomId,
@ -467,11 +457,9 @@ impl StateChanges {
.insert(event.state_key().to_string(), raw_event); .insert(event.state_key().to_string(), raw_event);
} }
/// Update the `StateChanges` struct with the given room with a new `Notification`. /// Update the `StateChanges` struct with the given room with a new
/// `Notification`.
pub fn add_notification(&mut self, room_id: &RoomId, notification: Notification) { pub fn add_notification(&mut self, room_id: &RoomId, notification: Notification) {
self.notifications self.notifications.entry(room_id.to_owned()).or_insert_with(Vec::new).push(notification);
.entry(room_id.to_owned())
.or_insert_with(Vec::new)
.push(notification);
} }
} }

View File

@ -37,18 +37,15 @@ use matrix_sdk_common::{
Raw, Raw,
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use sled::{ use sled::{
transaction::{ConflictableTransactionError, TransactionError}, transaction::{ConflictableTransactionError, TransactionError},
Config, Db, Transactional, Tree, Config, Db, Transactional, Tree,
}; };
use tracing::info; use tracing::info;
use crate::deserialized_responses::MemberEvent;
use self::store_key::{EncryptedEvent, StoreKey}; use self::store_key::{EncryptedEvent, StoreKey};
use super::{Result, RoomInfo, StateChanges, StateStore, StoreError}; use super::{Result, RoomInfo, StateChanges, StateStore, StoreError};
use crate::deserialized_responses::MemberEvent;
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub enum DatabaseType { pub enum DatabaseType {
@ -111,13 +108,7 @@ impl EncodeKey for &str {
impl EncodeKey for (&str, &str) { impl EncodeKey for (&str, &str) {
fn encode(&self) -> Vec<u8> { fn encode(&self) -> Vec<u8> {
[ [self.0.as_bytes(), &[Self::SEPARATOR], self.1.as_bytes(), &[Self::SEPARATOR]].concat()
self.0.as_bytes(),
&[Self::SEPARATOR],
self.1.as_bytes(),
&[Self::SEPARATOR],
]
.concat()
} }
} }
@ -167,9 +158,7 @@ impl std::fmt::Debug for SledStore {
if let Some(path) = &self.path { if let Some(path) = &self.path {
f.debug_struct("SledStore").field("path", &path).finish() f.debug_struct("SledStore").field("path", &path).finish()
} else { } else {
f.debug_struct("SledStore") f.debug_struct("SledStore").field("path", &"memory store").finish()
.field("path", &"memory store")
.finish()
} }
} }
} }
@ -239,8 +228,7 @@ impl SledStore {
} else { } else {
let key = StoreKey::new().map_err::<StoreError, _>(|e| e.into())?; let key = StoreKey::new().map_err::<StoreError, _>(|e| e.into())?;
let encrypted_key = DatabaseType::Encrypted( let encrypted_key = DatabaseType::Encrypted(
key.export(passphrase) key.export(passphrase).map_err::<StoreError, _>(|e| e.into())?,
.map_err::<StoreError, _>(|e| e.into())?,
); );
db.insert("store_key".encode(), serde_json::to_vec(&encrypted_key)?)?; db.insert("store_key".encode(), serde_json::to_vec(&encrypted_key)?)?;
key key
@ -278,8 +266,7 @@ impl SledStore {
} }
pub async fn save_filter(&self, filter_name: &str, filter_id: &str) -> Result<()> { pub async fn save_filter(&self, filter_name: &str, filter_id: &str) -> Result<()> {
self.session self.session.insert(("filter", filter_name).encode(), filter_id)?;
.insert(("filter", filter_name).encode(), filter_id)?;
Ok(()) Ok(())
} }
@ -479,11 +466,7 @@ impl SledStore {
} }
pub async fn get_presence_event(&self, user_id: &UserId) -> Result<Option<Raw<PresenceEvent>>> { pub async fn get_presence_event(&self, user_id: &UserId) -> Result<Option<Raw<PresenceEvent>>> {
Ok(self Ok(self.presence.get(user_id.encode())?.map(|e| self.deserialize_event(&e)).transpose()?)
.presence
.get(user_id.encode())?
.map(|e| self.deserialize_event(&e))
.transpose()?)
} }
pub async fn get_state_event( pub async fn get_state_event(
@ -534,14 +517,10 @@ impl SledStore {
&self, &self,
room_id: &RoomId, room_id: &RoomId,
) -> impl Stream<Item = Result<UserId>> { ) -> impl Stream<Item = Result<UserId>> {
stream::iter( stream::iter(self.invited_user_ids.scan_prefix(room_id.encode()).map(|u| {
self.invited_user_ids UserId::try_from(String::from_utf8_lossy(&u?.1).to_string())
.scan_prefix(room_id.encode()) .map_err(StoreError::Identifier)
.map(|u| { }))
UserId::try_from(String::from_utf8_lossy(&u?.1).to_string())
.map_err(StoreError::Identifier)
}),
)
} }
pub async fn get_joined_user_ids( pub async fn get_joined_user_ids(
@ -557,9 +536,7 @@ impl SledStore {
pub async fn get_room_infos(&self) -> impl Stream<Item = Result<RoomInfo>> { pub async fn get_room_infos(&self) -> impl Stream<Item = Result<RoomInfo>> {
let db = self.clone(); let db = self.clone();
stream::iter( stream::iter(
self.room_info self.room_info.iter().map(move |r| db.deserialize_event(&r?.1).map_err(|e| e.into())),
.iter()
.map(move |r| db.deserialize_event(&r?.1).map_err(|e| e.into())),
) )
} }
@ -683,8 +660,7 @@ impl StateStore for SledStore {
room_id: &RoomId, room_id: &RoomId,
display_name: &str, display_name: &str,
) -> Result<BTreeSet<UserId>> { ) -> Result<BTreeSet<UserId>> {
self.get_users_with_display_name(room_id, display_name) self.get_users_with_display_name(room_id, display_name).await
.await
} }
async fn get_account_data_event( async fn get_account_data_event(
@ -770,11 +746,7 @@ mod test {
let room_id = room_id!("!test:localhost"); let room_id = room_id!("!test:localhost");
let user_id = user_id(); let user_id = user_id();
assert!(store assert!(store.get_member_event(&room_id, &user_id).await.unwrap().is_none());
.get_member_event(&room_id, &user_id)
.await
.unwrap()
.is_none());
let mut changes = StateChanges::default(); let mut changes = StateChanges::default();
changes changes
.members .members
@ -783,11 +755,7 @@ mod test {
.insert(user_id.clone(), membership_event()); .insert(user_id.clone(), membership_event());
store.save_changes(&changes).await.unwrap(); store.save_changes(&changes).await.unwrap();
assert!(store assert!(store.get_member_event(&room_id, &user_id).await.unwrap().is_some());
.get_member_event(&room_id, &user_id)
.await
.unwrap()
.is_some());
} }
#[async_test] #[async_test]

View File

@ -21,11 +21,10 @@ use chacha20poly1305::{
use hmac::Hmac; use hmac::Hmac;
use pbkdf2::pbkdf2; use pbkdf2::pbkdf2;
use rand::{thread_rng, Error as RngError, Fill}; use rand::{thread_rng, Error as RngError, Fill};
use serde::{Deserialize, Serialize};
use sha2::Sha256; use sha2::Sha256;
use zeroize::{Zeroize, Zeroizing}; use zeroize::{Zeroize, Zeroizing};
use serde::{Deserialize, Serialize};
use crate::StoreError; use crate::StoreError;
const VERSION: u8 = 1; const VERSION: u8 = 1;
@ -76,9 +75,11 @@ pub struct EncryptedEvent {
#[derive(Debug, Serialize, Deserialize, PartialEq)] #[derive(Debug, Serialize, Deserialize, PartialEq)]
pub enum KdfInfo { pub enum KdfInfo {
Pbkdf2ToChaCha20Poly1305 { Pbkdf2ToChaCha20Poly1305 {
/// The number of PBKDF rounds that were used when deriving the store key. /// The number of PBKDF rounds that were used when deriving the store
/// key.
rounds: u32, rounds: u32,
/// The salt that was used when the passphrase was expanded into a store key. /// The salt that was used when the passphrase was expanded into a store
/// key.
kdf_salt: Vec<u8>, kdf_salt: Vec<u8>,
}, },
} }
@ -170,10 +171,7 @@ impl StoreKey {
cipher.encrypt(Nonce::from_slice(nonce.as_ref()), self.inner.as_slice())?; cipher.encrypt(Nonce::from_slice(nonce.as_ref()), self.inner.as_slice())?;
Ok(EncryptedStoreKey { Ok(EncryptedStoreKey {
kdf_info: KdfInfo::Pbkdf2ToChaCha20Poly1305 { kdf_info: KdfInfo::Pbkdf2ToChaCha20Poly1305 { rounds: KDF_ROUNDS, kdf_salt: salt },
rounds: KDF_ROUNDS,
kdf_salt: salt,
},
ciphertext_info: CipherTextInfo::ChaCha20Poly1305 { nonce, ciphertext }, ciphertext_info: CipherTextInfo::ChaCha20Poly1305 { nonce, ciphertext },
}) })
} }
@ -196,11 +194,7 @@ impl StoreKey {
let ciphertext = cipher.encrypt(xnonce, event.as_ref())?; let ciphertext = cipher.encrypt(xnonce, event.as_ref())?;
Ok(EncryptedEvent { Ok(EncryptedEvent { version: VERSION, ciphertext, nonce })
version: VERSION,
ciphertext,
nonce,
})
} }
pub fn decrypt<T: for<'b> Deserialize<'b>>(&self, event: EncryptedEvent) -> Result<T, Error> { pub fn decrypt<T: for<'b> Deserialize<'b>>(&self, event: EncryptedEvent) -> Result<T, Error> {
@ -248,9 +242,10 @@ impl StoreKey {
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use super::StoreKey;
use serde_json::{json, Value}; use serde_json::{json, Value};
use super::StoreKey;
#[test] #[test]
fn generating() { fn generating() {
StoreKey::new().unwrap(); StoreKey::new().unwrap();

View File

@ -1,3 +1,5 @@
use std::{collections::BTreeMap, convert::TryFrom, time::SystemTime};
use ruma::{ use ruma::{
api::client::r0::sync::sync_events::{ api::client::r0::sync::sync_events::{
Ephemeral, InvitedRoom, Presence, RoomAccountData, State, ToDevice, Ephemeral, InvitedRoom, Presence, RoomAccountData, State, ToDevice,
@ -6,7 +8,6 @@ use ruma::{
DeviceIdBox, DeviceIdBox,
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::{collections::BTreeMap, convert::TryFrom, time::SystemTime};
use super::{ use super::{
api::r0::{ api::r0::{
@ -103,16 +104,14 @@ pub struct SyncRoomEvent {
impl From<Raw<AnySyncRoomEvent>> for SyncRoomEvent { impl From<Raw<AnySyncRoomEvent>> for SyncRoomEvent {
fn from(inner: Raw<AnySyncRoomEvent>) -> Self { fn from(inner: Raw<AnySyncRoomEvent>) -> Self {
Self { Self { encryption_info: None, event: inner }
encryption_info: None,
event: inner,
}
} }
} }
#[derive(Clone, Debug, Default, Deserialize, Serialize)] #[derive(Clone, Debug, Default, Deserialize, Serialize)]
pub struct SyncResponse { pub struct SyncResponse {
/// The batch token to supply in the `since` param of the next `/sync` request. /// The batch token to supply in the `since` param of the next `/sync`
/// request.
pub next_batch: String, pub next_batch: String,
/// Updates to rooms. /// Updates to rooms.
pub rooms: Rooms, pub rooms: Rooms,
@ -137,10 +136,7 @@ pub struct SyncResponse {
impl SyncResponse { impl SyncResponse {
pub fn new(next_batch: String) -> Self { pub fn new(next_batch: String) -> Self {
Self { Self { next_batch, ..Default::default() }
next_batch,
..Default::default()
}
} }
} }
@ -161,14 +157,15 @@ pub struct JoinedRoom {
pub unread_notifications: UnreadNotificationsCount, pub unread_notifications: UnreadNotificationsCount,
/// The timeline of messages and state changes in the room. /// The timeline of messages and state changes in the room.
pub timeline: Timeline, pub timeline: Timeline,
/// Updates to the state, between the time indicated by the `since` parameter, and the start /// Updates to the state, between the time indicated by the `since`
/// of the `timeline` (or all state up to the start of the `timeline`, if `since` is not /// parameter, and the start of the `timeline` (or all state up to the
/// given, or `full_state` is true). /// start of the `timeline`, if `since` is not given, or `full_state` is
/// true).
pub state: State, pub state: State,
/// The private data that this user has attached to this room. /// The private data that this user has attached to this room.
pub account_data: RoomAccountData, pub account_data: RoomAccountData,
/// The ephemeral events in the room that aren't recorded in the timeline or state of the /// The ephemeral events in the room that aren't recorded in the timeline or
/// room. e.g. typing. /// state of the room. e.g. typing.
pub ephemeral: Ephemeral, pub ephemeral: Ephemeral,
} }
@ -180,20 +177,15 @@ impl JoinedRoom {
ephemeral: Ephemeral, ephemeral: Ephemeral,
unread_notifications: UnreadNotificationsCount, unread_notifications: UnreadNotificationsCount,
) -> Self { ) -> Self {
Self { Self { unread_notifications, timeline, state, account_data, ephemeral }
unread_notifications,
timeline,
state,
account_data,
ephemeral,
}
} }
} }
/// Counts of unread notifications for a room. /// Counts of unread notifications for a room.
#[derive(Copy, Clone, Debug, Default, Deserialize, Serialize)] #[derive(Copy, Clone, Debug, Default, Deserialize, Serialize)]
pub struct UnreadNotificationsCount { pub struct UnreadNotificationsCount {
/// The number of unread notifications for this room with the highlight flag set. /// The number of unread notifications for this room with the highlight flag
/// set.
pub highlight_count: u64, pub highlight_count: u64,
/// The total number of unread notifications for this room. /// The total number of unread notifications for this room.
pub notification_count: u64, pub notification_count: u64,
@ -203,10 +195,7 @@ impl From<RumaUnreadNotificationsCount> for UnreadNotificationsCount {
fn from(notifications: RumaUnreadNotificationsCount) -> Self { fn from(notifications: RumaUnreadNotificationsCount) -> Self {
Self { Self {
highlight_count: notifications.highlight_count.map(|c| c.into()).unwrap_or(0), highlight_count: notifications.highlight_count.map(|c| c.into()).unwrap_or(0),
notification_count: notifications notification_count: notifications.notification_count.map(|c| c.into()).unwrap_or(0),
.notification_count
.map(|c| c.into())
.unwrap_or(0),
} }
} }
} }
@ -216,9 +205,10 @@ pub struct LeftRoom {
/// The timeline of messages and state changes in the room up to the point /// The timeline of messages and state changes in the room up to the point
/// when the user left. /// when the user left.
pub timeline: Timeline, pub timeline: Timeline,
/// Updates to the state, between the time indicated by the `since` parameter, and the start /// Updates to the state, between the time indicated by the `since`
/// of the `timeline` (or all state up to the start of the `timeline`, if `since` is not /// parameter, and the start of the `timeline` (or all state up to the
/// given, or `full_state` is true). /// start of the `timeline`, if `since` is not given, or `full_state` is
/// true).
pub state: State, pub state: State,
/// The private data that this user has attached to this room. /// The private data that this user has attached to this room.
pub account_data: RoomAccountData, pub account_data: RoomAccountData,
@ -226,18 +216,15 @@ pub struct LeftRoom {
impl LeftRoom { impl LeftRoom {
pub fn new(timeline: Timeline, state: State, account_data: RoomAccountData) -> Self { pub fn new(timeline: Timeline, state: State, account_data: RoomAccountData) -> Self {
Self { Self { timeline, state, account_data }
timeline,
state,
account_data,
}
} }
} }
/// Events in the room. /// Events in the room.
#[derive(Clone, Debug, Default, Deserialize, Serialize)] #[derive(Clone, Debug, Default, Deserialize, Serialize)]
pub struct Timeline { pub struct Timeline {
/// True if the number of events returned was limited by the `limit` on the filter. /// True if the number of events returned was limited by the `limit` on the
/// filter.
pub limited: bool, pub limited: bool,
/// A token that can be supplied to to the `from` parameter of the /// A token that can be supplied to to the `from` parameter of the
@ -250,11 +237,7 @@ pub struct Timeline {
impl Timeline { impl Timeline {
pub fn new(limited: bool, prev_batch: Option<String>) -> Self { pub fn new(limited: bool, prev_batch: Option<String>) -> Self {
Self { Self { limited, prev_batch, ..Default::default() }
limited,
prev_batch,
..Default::default()
}
} }
} }

View File

@ -6,14 +6,12 @@ use std::{
task::{Context, Poll}, task::{Context, Poll},
}; };
#[cfg(not(target_arch = "wasm32"))]
pub use tokio::spawn;
#[cfg(target_arch = "wasm32")]
use wasm_bindgen_futures::spawn_local;
#[cfg(target_arch = "wasm32")] #[cfg(target_arch = "wasm32")]
use futures::{future::RemoteHandle, Future, FutureExt}; use futures::{future::RemoteHandle, Future, FutureExt};
#[cfg(not(target_arch = "wasm32"))]
pub use tokio::spawn;
#[cfg(target_arch = "wasm32")]
use wasm_bindgen_futures::spawn_local;
#[cfg(target_arch = "wasm32")] #[cfg(target_arch = "wasm32")]
pub fn spawn<F, T>(future: F) -> JoinHandle<T> pub fn spawn<F, T>(future: F) -> JoinHandle<T>

View File

@ -17,7 +17,6 @@ pub use ruma::{
serde::{CanonicalJsonValue, Raw}, serde::{CanonicalJsonValue, Raw},
thirdparty, uint, Int, Outgoing, UInt, thirdparty, uint, Int, Outgoing, UInt,
}; };
pub use uuid; pub use uuid;
pub mod deserialized_responses; pub mod deserialized_responses;

View File

@ -4,6 +4,5 @@
#[cfg(target_arch = "wasm32")] #[cfg(target_arch = "wasm32")]
pub use futures_locks::{Mutex, MutexGuard, RwLock, RwLockReadGuard, RwLockWriteGuard}; pub use futures_locks::{Mutex, MutexGuard, RwLock, RwLockReadGuard, RwLockWriteGuard};
#[cfg(not(target_arch = "wasm32"))] #[cfg(not(target_arch = "wasm32"))]
pub use tokio::sync::{Mutex, MutexGuard, RwLock, RwLockReadGuard, RwLockWriteGuard}; pub use tokio::sync::{Mutex, MutexGuard, RwLock, RwLockReadGuard, RwLockWriteGuard};

View File

@ -4,7 +4,6 @@ mod perf;
use std::sync::Arc; use std::sync::Arc;
use criterion::*; use criterion::*;
use matrix_sdk_common::{ use matrix_sdk_common::{
api::r0::{ api::r0::{
keys::{claim_keys, get_keys}, keys::{claim_keys, get_keys},
@ -50,17 +49,12 @@ fn huge_keys_query_resopnse() -> get_keys::Response {
} }
pub fn keys_query(c: &mut Criterion) { pub fn keys_query(c: &mut Criterion) {
let runtime = Builder::new_multi_thread() let runtime = Builder::new_multi_thread().build().expect("Can't create runtime");
.build()
.expect("Can't create runtime");
let machine = OlmMachine::new(&alice_id(), &alice_device_id()); let machine = OlmMachine::new(&alice_id(), &alice_device_id());
let response = keys_query_response(); let response = keys_query_response();
let uuid = Uuid::new_v4(); let uuid = Uuid::new_v4();
let count = response let count = response.device_keys.values().fold(0, |acc, d| acc + d.len())
.device_keys
.values()
.fold(0, |acc, d| acc + d.len())
+ response.master_keys.len() + response.master_keys.len()
+ response.self_signing_keys.len() + response.self_signing_keys.len()
+ response.user_signing_keys.len(); + response.user_signing_keys.len();
@ -70,14 +64,10 @@ pub fn keys_query(c: &mut Criterion) {
let name = format!("{} device and cross signing keys", count); let name = format!("{} device and cross signing keys", count);
group.bench_with_input( group.bench_with_input(BenchmarkId::new("memory store", &name), &response, |b, response| {
BenchmarkId::new("memory store", &name), b.to_async(&runtime)
&response, .iter(|| async { machine.mark_request_as_sent(&uuid, response).await.unwrap() })
|b, response| { });
b.to_async(&runtime)
.iter(|| async { machine.mark_request_as_sent(&uuid, response).await.unwrap() })
},
);
let dir = tempfile::tempdir().unwrap(); let dir = tempfile::tempdir().unwrap();
let machine = runtime let machine = runtime
@ -89,99 +79,74 @@ pub fn keys_query(c: &mut Criterion) {
)) ))
.unwrap(); .unwrap();
group.bench_with_input( group.bench_with_input(BenchmarkId::new("sled store", &name), &response, |b, response| {
BenchmarkId::new("sled store", &name), b.to_async(&runtime)
&response, .iter(|| async { machine.mark_request_as_sent(&uuid, response).await.unwrap() })
|b, response| { });
b.to_async(&runtime)
.iter(|| async { machine.mark_request_as_sent(&uuid, response).await.unwrap() })
},
);
group.finish() group.finish()
} }
pub fn keys_claiming(c: &mut Criterion) { pub fn keys_claiming(c: &mut Criterion) {
let runtime = Arc::new( let runtime = Arc::new(Builder::new_multi_thread().build().expect("Can't create runtime"));
Builder::new_multi_thread()
.build()
.expect("Can't create runtime"),
);
let keys_query_response = keys_query_response(); let keys_query_response = keys_query_response();
let uuid = Uuid::new_v4(); let uuid = Uuid::new_v4();
let response = keys_claim_response(); let response = keys_claim_response();
let count = response let count = response.one_time_keys.values().fold(0, |acc, d| acc + d.len());
.one_time_keys
.values()
.fold(0, |acc, d| acc + d.len());
let mut group = c.benchmark_group("Olm session creation"); let mut group = c.benchmark_group("Olm session creation");
group.throughput(Throughput::Elements(count as u64)); group.throughput(Throughput::Elements(count as u64));
let name = format!("{} one-time keys", count); let name = format!("{} one-time keys", count);
group.bench_with_input( group.bench_with_input(BenchmarkId::new("memory store", &name), &response, |b, response| {
BenchmarkId::new("memory store", &name), b.iter_batched(
&response, || {
|b, response| { let machine = OlmMachine::new(&alice_id(), &alice_device_id());
b.iter_batched( runtime
|| { .block_on(machine.mark_request_as_sent(&uuid, &keys_query_response))
let machine = OlmMachine::new(&alice_id(), &alice_device_id()); .unwrap();
runtime (machine, runtime.clone())
.block_on(machine.mark_request_as_sent(&uuid, &keys_query_response)) },
.unwrap(); move |(machine, runtime)| {
(machine, runtime.clone()) runtime.block_on(machine.mark_request_as_sent(&uuid, response)).unwrap()
}, },
move |(machine, runtime)| { BatchSize::SmallInput,
runtime )
.block_on(machine.mark_request_as_sent(&uuid, response)) });
.unwrap()
},
BatchSize::SmallInput,
)
},
);
group.bench_with_input( group.bench_with_input(BenchmarkId::new("sled store", &name), &response, |b, response| {
BenchmarkId::new("sled store", &name), b.iter_batched(
&response, || {
|b, response| { let dir = tempfile::tempdir().unwrap();
b.iter_batched( let machine = runtime
|| { .block_on(OlmMachine::new_with_default_store(
let dir = tempfile::tempdir().unwrap(); &alice_id(),
let machine = runtime &alice_device_id(),
.block_on(OlmMachine::new_with_default_store( dir.path(),
&alice_id(), None,
&alice_device_id(), ))
dir.path(), .unwrap();
None, runtime
)) .block_on(machine.mark_request_as_sent(&uuid, &keys_query_response))
.unwrap(); .unwrap();
runtime (machine, runtime.clone())
.block_on(machine.mark_request_as_sent(&uuid, &keys_query_response)) },
.unwrap(); move |(machine, runtime)| {
(machine, runtime.clone()) runtime.block_on(machine.mark_request_as_sent(&uuid, response)).unwrap()
}, },
move |(machine, runtime)| { BatchSize::SmallInput,
runtime )
.block_on(machine.mark_request_as_sent(&uuid, response)) });
.unwrap()
},
BatchSize::SmallInput,
)
},
);
group.finish() group.finish()
} }
pub fn room_key_sharing(c: &mut Criterion) { pub fn room_key_sharing(c: &mut Criterion) {
let runtime = Builder::new_multi_thread() let runtime = Builder::new_multi_thread().build().expect("Can't create runtime");
.build()
.expect("Can't create runtime");
let keys_query_response = keys_query_response(); let keys_query_response = keys_query_response();
let uuid = Uuid::new_v4(); let uuid = Uuid::new_v4();
@ -191,18 +156,11 @@ pub fn room_key_sharing(c: &mut Criterion) {
let to_device_response = ToDeviceResponse::new(); let to_device_response = ToDeviceResponse::new();
let users: Vec<UserId> = keys_query_response.device_keys.keys().cloned().collect(); let users: Vec<UserId> = keys_query_response.device_keys.keys().cloned().collect();
let count = response let count = response.one_time_keys.values().fold(0, |acc, d| acc + d.len());
.one_time_keys
.values()
.fold(0, |acc, d| acc + d.len());
let machine = OlmMachine::new(&alice_id(), &alice_device_id()); let machine = OlmMachine::new(&alice_id(), &alice_device_id());
runtime runtime.block_on(machine.mark_request_as_sent(&uuid, &keys_query_response)).unwrap();
.block_on(machine.mark_request_as_sent(&uuid, &keys_query_response)) runtime.block_on(machine.mark_request_as_sent(&uuid, &response)).unwrap();
.unwrap();
runtime
.block_on(machine.mark_request_as_sent(&uuid, &response))
.unwrap();
let mut group = c.benchmark_group("Room key sharing"); let mut group = c.benchmark_group("Room key sharing");
group.throughput(Throughput::Elements(count as u64)); group.throughput(Throughput::Elements(count as u64));
@ -218,10 +176,7 @@ pub fn room_key_sharing(c: &mut Criterion) {
assert!(!requests.is_empty()); assert!(!requests.is_empty());
for request in requests { for request in requests {
machine machine.mark_request_as_sent(&request.txn_id, &to_device_response).await.unwrap();
.mark_request_as_sent(&request.txn_id, &to_device_response)
.await
.unwrap();
} }
machine.invalidate_group_session(&room_id).await.unwrap(); machine.invalidate_group_session(&room_id).await.unwrap();
@ -237,12 +192,8 @@ pub fn room_key_sharing(c: &mut Criterion) {
None, None,
)) ))
.unwrap(); .unwrap();
runtime runtime.block_on(machine.mark_request_as_sent(&uuid, &keys_query_response)).unwrap();
.block_on(machine.mark_request_as_sent(&uuid, &keys_query_response)) runtime.block_on(machine.mark_request_as_sent(&uuid, &response)).unwrap();
.unwrap();
runtime
.block_on(machine.mark_request_as_sent(&uuid, &response))
.unwrap();
group.bench_function(BenchmarkId::new("sled store", &name), |b| { group.bench_function(BenchmarkId::new("sled store", &name), |b| {
b.to_async(&runtime).iter(|| async { b.to_async(&runtime).iter(|| async {
@ -254,10 +205,7 @@ pub fn room_key_sharing(c: &mut Criterion) {
assert!(!requests.is_empty()); assert!(!requests.is_empty());
for request in requests { for request in requests {
machine machine.mark_request_as_sent(&request.txn_id, &to_device_response).await.unwrap();
.mark_request_as_sent(&request.txn_id, &to_device_response)
.await
.unwrap();
} }
machine.invalidate_group_session(&room_id).await.unwrap(); machine.invalidate_group_session(&room_id).await.unwrap();
@ -268,28 +216,21 @@ pub fn room_key_sharing(c: &mut Criterion) {
} }
pub fn devices_missing_sessions_collecting(c: &mut Criterion) { pub fn devices_missing_sessions_collecting(c: &mut Criterion) {
let runtime = Builder::new_multi_thread() let runtime = Builder::new_multi_thread().build().expect("Can't create runtime");
.build()
.expect("Can't create runtime");
let machine = OlmMachine::new(&alice_id(), &alice_device_id()); let machine = OlmMachine::new(&alice_id(), &alice_device_id());
let response = huge_keys_query_resopnse(); let response = huge_keys_query_resopnse();
let uuid = Uuid::new_v4(); let uuid = Uuid::new_v4();
let users: Vec<UserId> = response.device_keys.keys().cloned().collect(); let users: Vec<UserId> = response.device_keys.keys().cloned().collect();
let count = response let count = response.device_keys.values().fold(0, |acc, d| acc + d.len());
.device_keys
.values()
.fold(0, |acc, d| acc + d.len());
let mut group = c.benchmark_group("Devices missing sessions collecting"); let mut group = c.benchmark_group("Devices missing sessions collecting");
group.throughput(Throughput::Elements(count as u64)); group.throughput(Throughput::Elements(count as u64));
let name = format!("{} devices", count); let name = format!("{} devices", count);
runtime runtime.block_on(machine.mark_request_as_sent(&uuid, &response)).unwrap();
.block_on(machine.mark_request_as_sent(&uuid, &response))
.unwrap();
group.bench_function(BenchmarkId::new("memory store", &name), |b| { group.bench_function(BenchmarkId::new("memory store", &name), |b| {
b.to_async(&runtime).iter_with_large_drop(|| async { b.to_async(&runtime).iter_with_large_drop(|| async {
@ -307,9 +248,7 @@ pub fn devices_missing_sessions_collecting(c: &mut Criterion) {
)) ))
.unwrap(); .unwrap();
runtime runtime.block_on(machine.mark_request_as_sent(&uuid, &response)).unwrap();
.block_on(machine.mark_request_as_sent(&uuid, &response))
.unwrap();
group.bench_function(BenchmarkId::new("sled store", &name), |b| { group.bench_function(BenchmarkId::new("sled store", &name), |b| {
b.to_async(&runtime) b.to_async(&runtime)

View File

@ -6,8 +6,9 @@ use std::{fs::File, os::raw::c_int, path::Path};
use criterion::profiler::Profiler; use criterion::profiler::Profiler;
use pprof::ProfilerGuard; use pprof::ProfilerGuard;
/// Small custom profiler that can be used with Criterion to create a flamegraph for benchmarks. /// Small custom profiler that can be used with Criterion to create a flamegraph
/// Also see [the Criterion documentation on this][custom-profiler]. /// for benchmarks. Also see [the Criterion documentation on
/// this][custom-profiler].
/// ///
/// ## Example on how to enable the custom profiler: /// ## Example on how to enable the custom profiler:
/// ///
@ -30,12 +31,12 @@ use pprof::ProfilerGuard;
/// } /// }
/// ``` /// ```
/// ///
/// The neat thing about this is that it will sample _only_ the benchmark, and not other stuff like /// The neat thing about this is that it will sample _only_ the benchmark, and
/// the setup process. /// not other stuff like the setup process.
/// ///
/// Further, it will only kick in if `--profile-time <time>` is passed to the benchmark binary. /// Further, it will only kick in if `--profile-time <time>` is passed to the
/// A flamegraph will be created for each individual benchmark in its report directory under /// benchmark binary. A flamegraph will be created for each individual benchmark
/// `profile/flamegraph.svg`. /// in its report directory under `profile/flamegraph.svg`.
/// ///
/// [custom-profiler]: https://bheisler.github.io/criterion.rs/book/user_guide/profiling.html#implementing-in-process-profiling-hooks /// [custom-profiler]: https://bheisler.github.io/criterion.rs/book/user_guide/profiling.html#implementing-in-process-profiling-hooks
pub struct FlamegraphProfiler<'a> { pub struct FlamegraphProfiler<'a> {
@ -45,10 +46,7 @@ pub struct FlamegraphProfiler<'a> {
impl<'a> FlamegraphProfiler<'a> { impl<'a> FlamegraphProfiler<'a> {
pub fn new(frequency: c_int) -> Self { pub fn new(frequency: c_int) -> Self {
FlamegraphProfiler { FlamegraphProfiler { frequency, active_profiler: None }
frequency,
active_profiler: None,
}
} }
} }

View File

@ -17,21 +17,17 @@ use std::{
io::{Error as IoError, ErrorKind, Read}, io::{Error as IoError, ErrorKind, Read},
}; };
use thiserror::Error;
use zeroize::Zeroizing;
use serde::{Deserialize, Serialize};
use matrix_sdk_common::events::room::JsonWebKey;
use getrandom::getrandom;
use aes_ctr::{ use aes_ctr::{
cipher::{NewStreamCipher, SyncStreamCipher}, cipher::{NewStreamCipher, SyncStreamCipher},
Aes256Ctr, Aes256Ctr,
}; };
use base64::DecodeError; use base64::DecodeError;
use getrandom::getrandom;
use matrix_sdk_common::events::room::JsonWebKey;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256}; use sha2::{Digest, Sha256};
use thiserror::Error;
use zeroize::Zeroizing;
use crate::utilities::{decode, decode_url_safe, encode, encode_url_safe}; use crate::utilities::{decode, decode_url_safe, encode, encode_url_safe};
@ -59,10 +55,7 @@ impl<'a, R: Read> Read for AttachmentDecryptor<'a, R> {
if hash.as_slice() == self.expected_hash.as_slice() { if hash.as_slice() == self.expected_hash.as_slice() {
Ok(0) Ok(0)
} else { } else {
Err(IoError::new( Err(IoError::new(ErrorKind::Other, "Hash missmatch while decrypting"))
ErrorKind::Other,
"Hash missmatch while decrypting",
))
} }
} else { } else {
self.sha.update(&buf[0..read_bytes]); self.sha.update(&buf[0..read_bytes]);
@ -130,23 +123,14 @@ impl<'a, R: Read + 'a> AttachmentDecryptor<'a, R> {
return Err(DecryptorError::UnknownVersion); return Err(DecryptorError::UnknownVersion);
} }
let hash = decode( let hash = decode(info.hashes.get("sha256").ok_or(DecryptorError::MissingHash)?)?;
info.hashes
.get("sha256")
.ok_or(DecryptorError::MissingHash)?,
)?;
let key = Zeroizing::from(decode_url_safe(info.web_key.k)?); let key = Zeroizing::from(decode_url_safe(info.web_key.k)?);
let iv = decode(info.iv)?; let iv = decode(info.iv)?;
let sha = Sha256::default(); let sha = Sha256::default();
let aes = Aes256Ctr::new_var(&key, &iv).map_err(|_| DecryptorError::KeyNonceLength)?; let aes = Aes256Ctr::new_var(&key, &iv).map_err(|_| DecryptorError::KeyNonceLength)?;
Ok(AttachmentDecryptor { Ok(AttachmentDecryptor { inner_reader: input, expected_hash: hash, sha, aes })
inner_reader: input,
expected_hash: hash,
sha,
aes,
})
} }
} }
@ -168,9 +152,7 @@ impl<'a, R: Read + 'a> Read for AttachmentEncryptor<'a, R> {
if read_bytes == 0 { if read_bytes == 0 {
let hash = self.sha.finalize_reset(); let hash = self.sha.finalize_reset();
self.hashes self.hashes.entry("sha256".to_owned()).or_insert_with(|| encode(hash));
.entry("sha256".to_owned())
.or_insert_with(|| encode(hash));
Ok(0) Ok(0)
} else { } else {
self.aes.apply_keystream(&mut buf[0..read_bytes]); self.aes.apply_keystream(&mut buf[0..read_bytes]);
@ -244,9 +226,7 @@ impl<'a, R: Read + 'a> AttachmentEncryptor<'a, R> {
/// Consume the encryptor and get the encryption key. /// Consume the encryptor and get the encryption key.
pub fn finish(mut self) -> EncryptionInfo { pub fn finish(mut self) -> EncryptionInfo {
let hash = self.sha.finalize(); let hash = self.sha.finalize();
self.hashes self.hashes.entry("sha256".to_owned()).or_insert_with(|| encode(hash));
.entry("sha256".to_owned())
.or_insert_with(|| encode(hash));
EncryptionInfo { EncryptionInfo {
version: VERSION.to_string(), version: VERSION.to_string(),
@ -274,10 +254,12 @@ pub struct EncryptionInfo {
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use super::{AttachmentDecryptor, AttachmentEncryptor, EncryptionInfo};
use serde_json::json;
use std::io::{Cursor, Read}; use std::io::{Cursor, Read};
use serde_json::json;
use super::{AttachmentDecryptor, AttachmentEncryptor, EncryptionInfo};
const EXAMPLE_DATA: &[u8] = &[ const EXAMPLE_DATA: &[u8] = &[
179, 154, 118, 127, 186, 127, 110, 33, 203, 33, 33, 134, 67, 100, 173, 46, 235, 27, 215, 179, 154, 118, 127, 186, 127, 110, 33, 203, 33, 33, 134, 67, 100, 173, 46, 235, 27, 215,
172, 36, 26, 75, 47, 33, 160, 172, 36, 26, 75, 47, 33, 160,

View File

@ -12,20 +12,19 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use serde_json::Error as SerdeError;
use std::io::{Cursor, Read, Seek, SeekFrom}; use std::io::{Cursor, Read, Seek, SeekFrom};
use thiserror::Error;
use byteorder::{BigEndian, ReadBytesExt};
use getrandom::getrandom;
use aes_ctr::{ use aes_ctr::{
cipher::{NewStreamCipher, SyncStreamCipher}, cipher::{NewStreamCipher, SyncStreamCipher},
Aes256Ctr, Aes256Ctr,
}; };
use byteorder::{BigEndian, ReadBytesExt};
use getrandom::getrandom;
use hmac::{Hmac, Mac, NewMac}; use hmac::{Hmac, Mac, NewMac};
use pbkdf2::pbkdf2; use pbkdf2::pbkdf2;
use serde_json::Error as SerdeError;
use sha2::{Sha256, Sha512}; use sha2::{Sha256, Sha512};
use thiserror::Error;
use crate::{ use crate::{
olm::ExportedRoomKey, olm::ExportedRoomKey,
@ -99,14 +98,10 @@ pub fn decrypt_key_export(
return Err(KeyExportError::InvalidHeaders); return Err(KeyExportError::InvalidHeaders);
} }
let payload: String = x let payload: String =
.lines() x.lines().filter(|l| !(l.starts_with(HEADER) || l.starts_with(FOOTER))).collect();
.filter(|l| !(l.starts_with(HEADER) || l.starts_with(FOOTER)))
.collect();
Ok(serde_json::from_str(&decrypt_helper( Ok(serde_json::from_str(&decrypt_helper(&payload, passphrase)?)?)
&payload, passphrase,
)?)?)
} }
/// Encrypt the list of exported room keys using the given passphrase. /// Encrypt the list of exported room keys using the given passphrase.
@ -231,12 +226,12 @@ fn decrypt_helper(ciphertext: &str, passphrase: &str) -> Result<String, KeyExpor
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use indoc::indoc;
use proptest::prelude::*;
use std::io::Cursor; use std::io::Cursor;
use indoc::indoc;
use matrix_sdk_common::identifiers::room_id; use matrix_sdk_common::identifiers::room_id;
use matrix_sdk_test::async_test; use matrix_sdk_test::async_test;
use proptest::prelude::*;
use super::{decode, decrypt_helper, decrypt_key_export, encrypt_helper, encrypt_key_export}; use super::{decode, decrypt_helper, decrypt_key_export, encrypt_helper, encrypt_key_export};
use crate::machine::test::get_prepared_machine; use crate::machine::test::get_prepared_machine;
@ -261,10 +256,7 @@ mod test {
"}; "};
fn export_wihtout_headers() -> String { fn export_wihtout_headers() -> String {
TEST_EXPORT TEST_EXPORT.lines().filter(|l| !l.starts_with("-----")).collect()
.lines()
.filter(|l| !l.starts_with("-----"))
.collect()
} }
#[test] #[test]
@ -301,14 +293,8 @@ mod test {
let (machine, _) = get_prepared_machine().await; let (machine, _) = get_prepared_machine().await;
let room_id = room_id!("!test:localhost"); let room_id = room_id!("!test:localhost");
machine machine.create_outbound_group_session_with_defaults(&room_id).await.unwrap();
.create_outbound_group_session_with_defaults(&room_id) let export = machine.export_keys(|s| s.room_id() == &room_id).await.unwrap();
.await
.unwrap();
let export = machine
.export_keys(|s| s.room_id() == &room_id)
.await
.unwrap();
assert!(!export.is_empty()); assert!(!export.is_empty());
@ -316,10 +302,7 @@ mod test {
let decrypted = decrypt_key_export(Cursor::new(encrypted), "1234").unwrap(); let decrypted = decrypt_key_export(Cursor::new(encrypted), "1234").unwrap();
assert_eq!(export, decrypted); assert_eq!(export, decrypted);
assert_eq!( assert_eq!(machine.import_keys(decrypted, |_, _| {}).await.unwrap(), (0, 1));
machine.import_keys(decrypted, |_, _| {}).await.unwrap(),
(0, 1)
);
} }
#[test] #[test]

View File

@ -39,24 +39,17 @@ use serde::{Deserialize, Deserializer, Serialize, Serializer};
use serde_json::{json, Value}; use serde_json::{json, Value};
use tracing::warn; use tracing::warn;
use crate::{ use super::{atomic_bool_deserializer, atomic_bool_serializer};
olm::{InboundGroupSession, PrivateCrossSigningIdentity, Session},
store::{Changes, DeviceChanges},
OutgoingVerificationRequest,
};
#[cfg(test)]
use crate::{OlmMachine, ReadOnlyAccount};
use crate::{ use crate::{
error::{EventError, OlmError, OlmResult, SignatureError}, error::{EventError, OlmError, OlmResult, SignatureError},
identities::{OwnUserIdentity, UserIdentities}, identities::{OwnUserIdentity, UserIdentities},
olm::Utility, olm::{InboundGroupSession, PrivateCrossSigningIdentity, Session, Utility},
store::{CryptoStore, Result as StoreResult}, store::{Changes, CryptoStore, DeviceChanges, Result as StoreResult},
verification::VerificationMachine, verification::VerificationMachine,
Sas, ToDeviceRequest, OutgoingVerificationRequest, Sas, ToDeviceRequest,
}; };
#[cfg(test)]
use super::{atomic_bool_deserializer, atomic_bool_serializer}; use crate::{OlmMachine, ReadOnlyAccount};
/// A read-only version of a `Device`. /// A read-only version of a `Device`.
#[derive(Clone, Serialize, Deserialize)] #[derive(Clone, Serialize, Deserialize)]
@ -120,9 +113,7 @@ pub struct Device {
impl std::fmt::Debug for Device { impl std::fmt::Debug for Device {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Device") f.debug_struct("Device").field("device", &self.inner).finish()
.field("device", &self.inner)
.finish()
} }
} }
@ -139,10 +130,7 @@ impl Device {
/// ///
/// Returns a `Sas` object and to-device request that needs to be sent out. /// Returns a `Sas` object and to-device request that needs to be sent out.
pub async fn start_verification(&self) -> StoreResult<(Sas, ToDeviceRequest)> { pub async fn start_verification(&self) -> StoreResult<(Sas, ToDeviceRequest)> {
let (sas, request) = self let (sas, request) = self.verification_machine.start_sas(self.inner.clone()).await?;
.verification_machine
.start_sas(self.inner.clone())
.await?;
if let OutgoingVerificationRequest::ToDevice(r) = request { if let OutgoingVerificationRequest::ToDevice(r) = request {
Ok((sas, r)) Ok((sas, r))
@ -162,8 +150,7 @@ impl Device {
/// Get the trust state of the device. /// Get the trust state of the device.
pub fn trust_state(&self) -> bool { pub fn trust_state(&self) -> bool {
self.inner self.inner.trust_state(&self.own_identity, &self.device_owner_identity)
.trust_state(&self.own_identity, &self.device_owner_identity)
} }
/// Set the local trust state of the device to the given state. /// Set the local trust state of the device to the given state.
@ -178,10 +165,7 @@ impl Device {
self.inner.set_trust_state(trust_state); self.inner.set_trust_state(trust_state);
let changes = Changes { let changes = Changes {
devices: DeviceChanges { devices: DeviceChanges { changed: vec![self.inner.clone()], ..Default::default() },
changed: vec![self.inner.clone()],
..Default::default()
},
..Default::default() ..Default::default()
}; };
@ -200,9 +184,7 @@ impl Device {
event_type: EventType, event_type: EventType,
content: Value, content: Value,
) -> OlmResult<(Session, EncryptedEventContent)> { ) -> OlmResult<(Session, EncryptedEventContent)> {
self.inner self.inner.encrypt(&**self.verification_machine.store, event_type, content).await
.encrypt(&**self.verification_machine.store, event_type, content)
.await
} }
/// Encrypt the given inbound group session as a forwarded room key for this /// Encrypt the given inbound group session as a forwarded room key for this
@ -261,9 +243,7 @@ impl UserDevices {
/// Returns true if there is at least one devices of this user that is /// Returns true if there is at least one devices of this user that is
/// considered to be verified, false otherwise. /// considered to be verified, false otherwise.
pub fn is_any_verified(&self) -> bool { pub fn is_any_verified(&self) -> bool {
self.inner self.inner.values().any(|d| d.trust_state(&self.own_identity, &self.device_owner_identity))
.values()
.any(|d| d.trust_state(&self.own_identity, &self.device_owner_identity))
} }
/// Iterator over all the device ids of the user devices. /// Iterator over all the device ids of the user devices.
@ -348,8 +328,7 @@ impl ReadOnlyDevice {
/// Get the key of the given key algorithm belonging to this device. /// Get the key of the given key algorithm belonging to this device.
pub fn get_key(&self, algorithm: DeviceKeyAlgorithm) -> Option<&String> { pub fn get_key(&self, algorithm: DeviceKeyAlgorithm) -> Option<&String> {
self.keys self.keys.get(&DeviceKeyId::from_parts(algorithm, &self.device_id))
.get(&DeviceKeyId::from_parts(algorithm, &self.device_id))
} }
/// Get a map containing all the device keys. /// Get a map containing all the device keys.
@ -496,9 +475,8 @@ impl ReadOnlyDevice {
} }
fn is_signed_by_device(&self, json: &mut Value) -> Result<(), SignatureError> { fn is_signed_by_device(&self, json: &mut Value) -> Result<(), SignatureError> {
let signing_key = self let signing_key =
.get_key(DeviceKeyAlgorithm::Ed25519) self.get_key(DeviceKeyAlgorithm::Ed25519).ok_or(SignatureError::MissingSigningKey)?;
.ok_or(SignatureError::MissingSigningKey)?;
let utility = Utility::new(); let utility = Utility::new();
@ -590,14 +568,15 @@ impl PartialEq for ReadOnlyDevice {
#[cfg(test)] #[cfg(test)]
pub(crate) mod test { pub(crate) mod test {
use serde_json::json;
use std::convert::TryFrom; use std::convert::TryFrom;
use crate::identities::{LocalTrust, ReadOnlyDevice};
use matrix_sdk_common::{ use matrix_sdk_common::{
encryption::DeviceKeys, encryption::DeviceKeys,
identifiers::{user_id, DeviceKeyAlgorithm}, identifiers::{user_id, DeviceKeyAlgorithm},
}; };
use serde_json::json;
use crate::identities::{LocalTrust, ReadOnlyDevice};
fn device_keys() -> DeviceKeys { fn device_keys() -> DeviceKeys {
let device_keys = json!({ let device_keys = json!({
@ -640,10 +619,7 @@ pub(crate) mod test {
assert_eq!(device_id, device.device_id()); assert_eq!(device_id, device.device_id());
assert_eq!(device.algorithms.len(), 2); assert_eq!(device.algorithms.len(), 2);
assert_eq!(LocalTrust::Unset, device.local_trust_state()); assert_eq!(LocalTrust::Unset, device.local_trust_state());
assert_eq!( assert_eq!("Alice's mobile phone", device.display_name().as_ref().unwrap());
"Alice's mobile phone",
device.display_name().as_ref().unwrap()
);
assert_eq!( assert_eq!(
device.get_key(DeviceKeyAlgorithm::Curve25519).unwrap(), device.get_key(DeviceKeyAlgorithm::Curve25519).unwrap(),
"xfgbLIC5WAl1OIkpOzoxpCe8FsRDT6nch7NQsOb15nc" "xfgbLIC5WAl1OIkpOzoxpCe8FsRDT6nch7NQsOb15nc"
@ -658,10 +634,7 @@ pub(crate) mod test {
fn update_a_device() { fn update_a_device() {
let mut device = get_device(); let mut device = get_device();
assert_eq!( assert_eq!("Alice's mobile phone", device.display_name().as_ref().unwrap());
"Alice's mobile phone",
device.display_name().as_ref().unwrap()
);
let display_name = "Alice's work computer".to_owned(); let display_name = "Alice's work computer".to_owned();

View File

@ -12,20 +12,20 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use futures::future::join_all;
use std::{ use std::{
collections::{BTreeMap, HashSet}, collections::{BTreeMap, HashSet},
convert::TryFrom, convert::TryFrom,
sync::Arc, sync::Arc,
}; };
use tracing::{trace, warn};
use futures::future::join_all;
use matrix_sdk_common::{ use matrix_sdk_common::{
api::r0::keys::get_keys::Response as KeysQueryResponse, api::r0::keys::get_keys::Response as KeysQueryResponse,
encryption::DeviceKeys, encryption::DeviceKeys,
executor::spawn, executor::spawn,
identifiers::{DeviceIdBox, UserId}, identifiers::{DeviceIdBox, UserId},
}; };
use tracing::{trace, warn};
use crate::{ use crate::{
error::OlmResult, error::OlmResult,
@ -54,11 +54,7 @@ impl IdentityManager {
const MAX_KEY_QUERY_USERS: usize = 250; const MAX_KEY_QUERY_USERS: usize = 250;
pub fn new(user_id: Arc<UserId>, device_id: Arc<DeviceIdBox>, store: Store) -> Self { pub fn new(user_id: Arc<UserId>, device_id: Arc<DeviceIdBox>, store: Store) -> Self {
IdentityManager { IdentityManager { user_id, device_id, store }
user_id,
device_id,
store,
}
} }
fn user_id(&self) -> &UserId { fn user_id(&self) -> &UserId {
@ -78,9 +74,8 @@ impl IdentityManager {
&self, &self,
response: &KeysQueryResponse, response: &KeysQueryResponse,
) -> OlmResult<(DeviceChanges, IdentityChanges)> { ) -> OlmResult<(DeviceChanges, IdentityChanges)> {
let changed_devices = self let changed_devices =
.handle_devices_from_key_query(response.device_keys.clone()) self.handle_devices_from_key_query(response.device_keys.clone()).await?;
.await?;
let changed_identities = self.handle_cross_singing_keys(response).await?; let changed_identities = self.handle_cross_singing_keys(response).await?;
let changes = Changes { let changes = Changes {
@ -104,9 +99,8 @@ impl IdentityManager {
store: Store, store: Store,
device_keys: DeviceKeys, device_keys: DeviceKeys,
) -> StoreResult<DeviceChange> { ) -> StoreResult<DeviceChange> {
let old_device = store let old_device =
.get_readonly_device(&device_keys.user_id, &device_keys.device_id) store.get_readonly_device(&device_keys.user_id, &device_keys.device_id).await?;
.await?;
if let Some(mut device) = old_device { if let Some(mut device) = old_device {
if let Err(e) = device.update_device(&device_keys) { if let Err(e) = device.update_device(&device_keys) {
@ -148,25 +142,20 @@ impl IdentityManager {
let current_devices: HashSet<DeviceIdBox> = device_map.keys().cloned().collect(); let current_devices: HashSet<DeviceIdBox> = device_map.keys().cloned().collect();
let tasks = device_map let tasks = device_map.into_iter().filter_map(|(device_id, device_keys)| {
.into_iter() // We don't need our own device in the device store.
.filter_map(|(device_id, device_keys)| { if user_id == *own_user_id && device_id == *own_device_id {
// We don't need our own device in the device store. None
if user_id == *own_user_id && device_id == *own_device_id { } else if user_id != device_keys.user_id || device_id != device_keys.device_id {
None warn!(
} else if user_id != device_keys.user_id || device_id != device_keys.device_id { "Mismatch in device keys payload of device {}|{} from user {}|{}",
warn!( device_id, device_keys.device_id, user_id, device_keys.user_id
"Mismatch in device keys payload of device {}|{} from user {}|{}", );
device_id, device_keys.device_id, user_id, device_keys.user_id None
); } else {
None Some(spawn(Self::update_or_create_device(store.clone(), device_keys)))
} else { }
Some(spawn(Self::update_or_create_device( });
store.clone(),
device_keys,
)))
}
});
let results = join_all(tasks).await; let results = join_all(tasks).await;
@ -211,17 +200,15 @@ impl IdentityManager {
) -> StoreResult<DeviceChanges> { ) -> StoreResult<DeviceChanges> {
let mut changes = DeviceChanges::default(); let mut changes = DeviceChanges::default();
let tasks = device_keys_map let tasks = device_keys_map.into_iter().map(|(user_id, device_keys_map)| {
.into_iter() spawn(Self::update_user_devices(
.map(|(user_id, device_keys_map)| { self.store.clone(),
spawn(Self::update_user_devices( self.user_id.clone(),
self.store.clone(), self.device_id.clone(),
self.user_id.clone(), user_id,
self.device_id.clone(), device_keys_map,
user_id, ))
device_keys_map, });
))
});
let results = join_all(tasks).await; let results = join_all(tasks).await;
@ -254,10 +241,7 @@ impl IdentityManager {
let self_signing = if let Some(s) = response.self_signing_keys.get(user_id) { let self_signing = if let Some(s) = response.self_signing_keys.get(user_id) {
SelfSigningPubkey::from(s) SelfSigningPubkey::from(s)
} else { } else {
warn!( warn!("User identity for user {} didn't contain a self signing pubkey", user_id);
"User identity for user {} didn't contain a self signing pubkey",
user_id
);
continue; continue;
}; };
@ -276,13 +260,11 @@ impl IdentityManager {
continue; continue;
}; };
identity identity.update(master_key, self_signing, user_signing).map(|_| (i, false))
.update(master_key, self_signing, user_signing) }
.map(|_| (i, false)) UserIdentities::Other(ref mut identity) => {
identity.update(master_key, self_signing).map(|_| (i, false))
} }
UserIdentities::Other(ref mut identity) => identity
.update(master_key, self_signing)
.map(|_| (i, false)),
} }
} else if user_id == self.user_id() { } else if user_id == self.user_id() {
if let Some(s) = response.user_signing_keys.get(user_id) { if let Some(s) = response.user_signing_keys.get(user_id) {
@ -310,10 +292,7 @@ impl IdentityManager {
continue; continue;
} }
} else if master_key.user_id() != user_id || self_signing.user_id() != user_id { } else if master_key.user_id() != user_id || self_signing.user_id() != user_id {
warn!( warn!("User id mismatch in one of the cross signing keys for user {}", user_id);
"User id mismatch in one of the cross signing keys for user {}",
user_id
);
continue; continue;
} else { } else {
UserIdentity::new(master_key, self_signing) UserIdentity::new(master_key, self_signing)
@ -322,11 +301,7 @@ impl IdentityManager {
match result { match result {
Ok((i, new)) => { Ok((i, new)) => {
trace!( trace!("Updated or created new user identity for {}: {:?}", user_id, i);
"Updated or created new user identity for {}: {:?}",
user_id,
i
);
if new { if new {
changes.new.push(i); changes.new.push(i);
} else { } else {
@ -334,10 +309,7 @@ impl IdentityManager {
} }
} }
Err(e) => { Err(e) => {
warn!( warn!("Couldn't update or create new user identity for {}: {:?}", user_id, e);
"Couldn't update or create new user identity for {}: {:?}",
user_id, e
);
continue; continue;
} }
} }
@ -424,9 +396,7 @@ pub(crate) mod test {
locks::Mutex, locks::Mutex,
IncomingResponse, IncomingResponse,
}; };
use matrix_sdk_test::async_test; use matrix_sdk_test::async_test;
use serde_json::json; use serde_json::json;
use crate::{ use crate::{
@ -637,10 +607,7 @@ pub(crate) mod test {
let devices = manager.store.get_user_devices(&other_user).await.unwrap(); let devices = manager.store.get_user_devices(&other_user).await.unwrap();
assert_eq!(devices.devices().count(), 0); assert_eq!(devices.devices().count(), 0);
manager manager.receive_keys_query_response(&other_key_query()).await.unwrap();
.receive_keys_query_response(&other_key_query())
.await
.unwrap();
let devices = manager.store.get_user_devices(&other_user).await.unwrap(); let devices = manager.store.get_user_devices(&other_user).await.unwrap();
assert_eq!(devices.devices().count(), 1); assert_eq!(devices.devices().count(), 1);
@ -651,12 +618,7 @@ pub(crate) mod test {
.await .await
.unwrap() .unwrap()
.unwrap(); .unwrap();
let identity = manager let identity = manager.store.get_user_identity(&other_user).await.unwrap().unwrap();
.store
.get_user_identity(&other_user)
.await
.unwrap()
.unwrap();
let identity = identity.other().unwrap(); let identity = identity.other().unwrap();
assert!(identity.is_device_signed(&device).is_ok()) assert!(identity.is_device_signed(&device).is_ok())
@ -669,10 +631,7 @@ pub(crate) mod test {
let devices = manager.store.get_user_devices(&other_user).await.unwrap(); let devices = manager.store.get_user_devices(&other_user).await.unwrap();
assert_eq!(devices.devices().count(), 0); assert_eq!(devices.devices().count(), 0);
manager manager.receive_keys_query_response(&other_key_query()).await.unwrap();
.receive_keys_query_response(&other_key_query())
.await
.unwrap();
let devices = manager.store.get_user_devices(&other_user).await.unwrap(); let devices = manager.store.get_user_devices(&other_user).await.unwrap();
assert_eq!(devices.devices().count(), 1); assert_eq!(devices.devices().count(), 1);
@ -683,12 +642,7 @@ pub(crate) mod test {
.await .await
.unwrap() .unwrap()
.unwrap(); .unwrap();
let identity = manager let identity = manager.store.get_user_identity(&other_user).await.unwrap().unwrap();
.store
.get_user_identity(&other_user)
.await
.unwrap()
.unwrap();
let identity = identity.other().unwrap(); let identity = identity.other().unwrap();
assert!(identity.is_device_signed(&device).is_ok()) assert!(identity.is_device_signed(&device).is_ok())

View File

@ -29,10 +29,10 @@
//! //!
//! ## User //! ## User
//! //!
//! Cross-signing capable devices will upload 3 additional (master, self-signing, //! Cross-signing capable devices will upload 3 additional (master,
//! user-signing) public keys which represent the user identity owning all the //! self-signing, user-signing) public keys which represent the user identity
//! devices. This is represented in two ways, as a `UserIdentity` for other //! owning all the devices. This is represented in two ways, as a `UserIdentity`
//! users and as `OwnUserIdentity` for our own user. //! for other users and as `OwnUserIdentity` for our own user.
//! //!
//! This is done because the server will only give us access to 2 of the 3 //! This is done because the server will only give us access to 2 of the 3
//! additional public keys for other users, while it will give us access to all //! additional public keys for other users, while it will give us access to all
@ -44,19 +44,19 @@ pub(crate) mod device;
mod manager; mod manager;
pub(crate) mod user; pub(crate) mod user;
pub use device::{Device, LocalTrust, ReadOnlyDevice, UserDevices};
pub(crate) use manager::IdentityManager;
pub use user::{
MasterPubkey, OwnUserIdentity, SelfSigningPubkey, UserIdentities, UserIdentity,
UserSigningPubkey,
};
use serde::{Deserialize, Deserializer, Serializer};
use std::sync::{ use std::sync::{
atomic::{AtomicBool, Ordering}, atomic::{AtomicBool, Ordering},
Arc, Arc,
}; };
pub use device::{Device, LocalTrust, ReadOnlyDevice, UserDevices};
pub(crate) use manager::IdentityManager;
use serde::{Deserialize, Deserializer, Serializer};
pub use user::{
MasterPubkey, OwnUserIdentity, SelfSigningPubkey, UserIdentities, UserIdentity,
UserSigningPubkey,
};
// These methods are only here because Serialize and Deserialize don't seem to // These methods are only here because Serialize and Deserialize don't seem to
// be implemented for WASM. // be implemented for WASM.
fn atomic_bool_serializer<S>(x: &AtomicBool, s: S) -> Result<S::Ok, S::Error> fn atomic_bool_serializer<S>(x: &AtomicBool, s: S) -> Result<S::Ok, S::Error>

View File

@ -21,20 +21,18 @@ use std::{
}, },
}; };
use serde::{Deserialize, Serialize};
use serde_json::to_value;
use matrix_sdk_common::{ use matrix_sdk_common::{
api::r0::keys::{CrossSigningKey, KeyUsage}, api::r0::keys::{CrossSigningKey, KeyUsage},
identifiers::{DeviceKeyId, UserId}, identifiers::{DeviceKeyId, UserId},
}; };
use serde::{Deserialize, Serialize};
use serde_json::to_value;
use super::{atomic_bool_deserializer, atomic_bool_serializer};
#[cfg(test)] #[cfg(test)]
use crate::olm::PrivateCrossSigningIdentity; use crate::olm::PrivateCrossSigningIdentity;
use crate::{error::SignatureError, olm::Utility, ReadOnlyDevice}; use crate::{error::SignatureError, olm::Utility, ReadOnlyDevice};
use super::{atomic_bool_deserializer, atomic_bool_serializer};
/// Wrapper for a cross signing key marking it as the master key. /// Wrapper for a cross signing key marking it as the master key.
/// ///
/// Master keys are used to sign other cross signing keys, the self signing and /// Master keys are used to sign other cross signing keys, the self signing and
@ -227,12 +225,7 @@ impl MasterPubkey {
&self, &self,
subkey: impl Into<CrossSigningSubKeys<'a>>, subkey: impl Into<CrossSigningSubKeys<'a>>,
) -> Result<(), SignatureError> { ) -> Result<(), SignatureError> {
let (key_id, key) = self let (key_id, key) = self.0.keys.iter().next().ok_or(SignatureError::MissingSigningKey)?;
.0
.keys
.iter()
.next()
.ok_or(SignatureError::MissingSigningKey)?;
let key_id = DeviceKeyId::try_from(key_id.as_str())?; let key_id = DeviceKeyId::try_from(key_id.as_str())?;
@ -289,12 +282,7 @@ impl UserSigningPubkey {
&self, &self,
master_key: &MasterPubkey, master_key: &MasterPubkey,
) -> Result<(), SignatureError> { ) -> Result<(), SignatureError> {
let (key_id, key) = self let (key_id, key) = self.0.keys.iter().next().ok_or(SignatureError::MissingSigningKey)?;
.0
.keys
.iter()
.next()
.ok_or(SignatureError::MissingSigningKey)?;
// TODO check that the usage is OK. // TODO check that the usage is OK.
@ -337,12 +325,7 @@ impl SelfSigningPubkey {
/// Returns an empty result if the signature check succeeded, otherwise a /// Returns an empty result if the signature check succeeded, otherwise a
/// SignatureError indicating why the check failed. /// SignatureError indicating why the check failed.
pub(crate) fn verify_device(&self, device: &ReadOnlyDevice) -> Result<(), SignatureError> { pub(crate) fn verify_device(&self, device: &ReadOnlyDevice) -> Result<(), SignatureError> {
let (key_id, key) = self let (key_id, key) = self.0.keys.iter().next().ok_or(SignatureError::MissingSigningKey)?;
.0
.keys
.iter()
.next()
.ok_or(SignatureError::MissingSigningKey)?;
// TODO check that the usage is OK. // TODO check that the usage is OK.
@ -474,37 +457,16 @@ impl UserIdentity {
) -> Result<Self, SignatureError> { ) -> Result<Self, SignatureError> {
master_key.verify_subkey(&self_signing_key)?; master_key.verify_subkey(&self_signing_key)?;
Ok(Self { Ok(Self { user_id: Arc::new(master_key.0.user_id.clone()), master_key, self_signing_key })
user_id: Arc::new(master_key.0.user_id.clone()),
master_key,
self_signing_key,
})
} }
#[cfg(test)] #[cfg(test)]
pub async fn from_private(identity: &PrivateCrossSigningIdentity) -> Self { pub async fn from_private(identity: &PrivateCrossSigningIdentity) -> Self {
let master_key = identity let master_key = identity.master_key.lock().await.as_ref().unwrap().public_key.clone();
.master_key let self_signing_key =
.lock() identity.self_signing_key.lock().await.as_ref().unwrap().public_key.clone();
.await
.as_ref()
.unwrap()
.public_key
.clone();
let self_signing_key = identity
.self_signing_key
.lock()
.await
.as_ref()
.unwrap()
.public_key
.clone();
Self { Self { user_id: Arc::new(identity.user_id().clone()), master_key, self_signing_key }
user_id: Arc::new(identity.user_id().clone()),
master_key,
self_signing_key,
}
} }
/// Get the user id of this identity. /// Get the user id of this identity.
@ -646,8 +608,7 @@ impl OwnUserIdentity {
/// Returns an empty result if the signature check succeeded, otherwise a /// Returns an empty result if the signature check succeeded, otherwise a
/// SignatureError indicating why the check failed. /// SignatureError indicating why the check failed.
pub fn is_identity_signed(&self, identity: &UserIdentity) -> Result<(), SignatureError> { pub fn is_identity_signed(&self, identity: &UserIdentity) -> Result<(), SignatureError> {
self.user_signing_key self.user_signing_key.verify_master_key(&identity.master_key)
.verify_master_key(&identity.master_key)
} }
/// Check if the given device has been signed by this identity. /// Check if the given device has been signed by this identity.
@ -719,6 +680,12 @@ impl OwnUserIdentity {
pub(crate) mod test { pub(crate) mod test {
use std::{convert::TryFrom, sync::Arc}; use std::{convert::TryFrom, sync::Arc};
use matrix_sdk_common::{
api::r0::keys::get_keys::Response as KeyQueryResponse, identifiers::user_id, locks::Mutex,
};
use matrix_sdk_test::async_test;
use super::{OwnUserIdentity, UserIdentities, UserIdentity};
use crate::{ use crate::{
identities::{ identities::{
manager::test::{other_key_query, own_key_query}, manager::test::{other_key_query, own_key_query},
@ -729,13 +696,6 @@ pub(crate) mod test {
verification::VerificationMachine, verification::VerificationMachine,
}; };
use matrix_sdk_common::{
api::r0::keys::get_keys::Response as KeyQueryResponse, identifiers::user_id, locks::Mutex,
};
use matrix_sdk_test::async_test;
use super::{OwnUserIdentity, UserIdentities, UserIdentity};
fn device(response: &KeyQueryResponse) -> (ReadOnlyDevice, ReadOnlyDevice) { fn device(response: &KeyQueryResponse) -> (ReadOnlyDevice, ReadOnlyDevice) {
let mut devices = response.device_keys.values().next().unwrap().values(); let mut devices = response.device_keys.values().next().unwrap().values();
let first = ReadOnlyDevice::try_from(devices.next().unwrap()).unwrap(); let first = ReadOnlyDevice::try_from(devices.next().unwrap()).unwrap();
@ -793,9 +753,8 @@ pub(crate) mod test {
assert!(identity.is_device_signed(&first).is_err()); assert!(identity.is_device_signed(&first).is_err());
assert!(identity.is_device_signed(&second).is_ok()); assert!(identity.is_device_signed(&second).is_ok());
let private_identity = Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty( let private_identity =
second.user_id().clone(), Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty(second.user_id().clone())));
)));
let verification_machine = VerificationMachine::new( let verification_machine = VerificationMachine::new(
ReadOnlyAccount::new(second.user_id(), second.device_id()), ReadOnlyAccount::new(second.user_id(), second.device_id()),
private_identity.clone(), private_identity.clone(),

View File

@ -20,13 +20,9 @@
// If we don't trust the device store an object that remembers the request and // If we don't trust the device store an object that remembers the request and
// let the users introspect that object. // let the users introspect that object.
use dashmap::{mapref::entry::Entry, DashMap, DashSet};
use serde::{Deserialize, Serialize};
use serde_json::value::to_raw_value;
use std::{collections::BTreeMap, sync::Arc}; use std::{collections::BTreeMap, sync::Arc};
use thiserror::Error;
use tracing::{error, info, trace, warn};
use dashmap::{mapref::entry::Entry, DashMap, DashSet};
use matrix_sdk_common::{ use matrix_sdk_common::{
api::r0::to_device::DeviceIdOrAllDevices, api::r0::to_device::DeviceIdOrAllDevices,
events::{ events::{
@ -37,6 +33,10 @@ use matrix_sdk_common::{
identifiers::{DeviceId, DeviceIdBox, EventEncryptionAlgorithm, RoomId, UserId}, identifiers::{DeviceId, DeviceIdBox, EventEncryptionAlgorithm, RoomId, UserId},
uuid::Uuid, uuid::Uuid,
}; };
use serde::{Deserialize, Serialize};
use serde_json::value::to_raw_value;
use thiserror::Error;
use tracing::{error, info, trace, warn};
use crate::{ use crate::{
error::{OlmError, OlmResult}, error::{OlmError, OlmResult},
@ -105,10 +105,8 @@ impl WaitQueue {
&self, &self,
user_id: &UserId, user_id: &UserId,
device_id: &DeviceId, device_id: &DeviceId,
) -> Vec<( ) -> Vec<((UserId, DeviceIdBox, String), ToDeviceEvent<RoomKeyRequestToDeviceEventContent>)>
(UserId, DeviceIdBox, String), {
ToDeviceEvent<RoomKeyRequestToDeviceEventContent>,
)> {
self.requests_ids_waiting self.requests_ids_waiting
.remove(&(user_id.to_owned(), device_id.into())) .remove(&(user_id.to_owned(), device_id.into()))
.map(|(_, request_ids)| { .map(|(_, request_ids)| {
@ -204,12 +202,7 @@ fn wrap_key_request_content(
Ok(OutgoingRequest { Ok(OutgoingRequest {
request_id: id, request_id: id,
request: Arc::new( request: Arc::new(
ToDeviceRequest { ToDeviceRequest { event_type: EventType::RoomKeyRequest, txn_id: id, messages }.into(),
event_type: EventType::RoomKeyRequest,
txn_id: id,
messages,
}
.into(),
), ),
}) })
} }
@ -241,10 +234,7 @@ impl KeyRequestMachine {
.await? .await?
.into_iter() .into_iter()
.filter(|i| !i.sent_out) .filter(|i| !i.sent_out)
.map(|info| { .map(|info| info.to_request(self.device_id()).map_err(CryptoStoreError::from))
info.to_request(self.device_id())
.map_err(CryptoStoreError::from)
})
.collect() .collect()
} }
@ -262,11 +252,8 @@ impl KeyRequestMachine {
&self, &self,
) -> Result<Vec<OutgoingRequest>, CryptoStoreError> { ) -> Result<Vec<OutgoingRequest>, CryptoStoreError> {
let mut key_requests = self.load_outgoing_requests().await?; let mut key_requests = self.load_outgoing_requests().await?;
let key_forwards: Vec<OutgoingRequest> = self let key_forwards: Vec<OutgoingRequest> =
.outgoing_to_device_requests self.outgoing_to_device_requests.iter().map(|i| i.value().clone()).collect();
.iter()
.map(|i| i.value().clone())
.collect();
key_requests.extend(key_forwards); key_requests.extend(key_forwards);
Ok(key_requests) Ok(key_requests)
@ -281,8 +268,7 @@ impl KeyRequestMachine {
let device_id = event.content.requesting_device_id.clone(); let device_id = event.content.requesting_device_id.clone();
let request_id = event.content.request_id.clone(); let request_id = event.content.request_id.clone();
self.incoming_key_requests self.incoming_key_requests.insert((sender, device_id, request_id), event.clone());
.insert((sender, device_id, request_id), event.clone());
} }
/// Handle all the incoming key requests that are queued up and empty our /// Handle all the incoming key requests that are queued up and empty our
@ -401,10 +387,8 @@ impl KeyRequestMachine {
return Ok(None); return Ok(None);
}; };
let device = self let device =
.store self.store.get_device(&event.sender, &event.content.requesting_device_id).await?;
.get_device(&event.sender, &event.content.requesting_device_id)
.await?;
if let Some(device) = device { if let Some(device) = device {
match self.should_share_key(&device, &session).await { match self.should_share_key(&device, &session).await {
@ -461,30 +445,22 @@ impl KeyRequestMachine {
device: &Device, device: &Device,
message_index: Option<u32>, message_index: Option<u32>,
) -> OlmResult<Session> { ) -> OlmResult<Session> {
let (used_session, content) = device let (used_session, content) =
.encrypt_session(session.clone(), message_index) device.encrypt_session(session.clone(), message_index).await?;
.await?;
let id = Uuid::new_v4(); let id = Uuid::new_v4();
let mut messages = BTreeMap::new(); let mut messages = BTreeMap::new();
messages messages.entry(device.user_id().to_owned()).or_insert_with(BTreeMap::new).insert(
.entry(device.user_id().to_owned()) DeviceIdOrAllDevices::DeviceId(device.device_id().into()),
.or_insert_with(BTreeMap::new) to_raw_value(&content)?,
.insert( );
DeviceIdOrAllDevices::DeviceId(device.device_id().into()),
to_raw_value(&content)?,
);
let request = OutgoingRequest { let request = OutgoingRequest {
request_id: id, request_id: id,
request: Arc::new( request: Arc::new(
ToDeviceRequest { ToDeviceRequest { event_type: EventType::RoomEncrypted, txn_id: id, messages }
event_type: EventType::RoomEncrypted, .into(),
txn_id: id,
messages,
}
.into(),
), ),
}; };
@ -542,8 +518,8 @@ impl KeyRequestMachine {
} else { } else {
Err(KeyshareDecision::OutboundSessionNotShared) Err(KeyshareDecision::OutboundSessionNotShared)
} }
// Else just check if it's one of our own devices that requested the key and // Else just check if it's one of our own devices that requested the key
// check if the device is trusted. // and check if the device is trusted.
} else if device.user_id() == self.user_id() { } else if device.user_id() == self.user_id() {
own_device_check() own_device_check()
// Otherwise, there's not enough info to decide if we can safely share // Otherwise, there's not enough info to decide if we can safely share
@ -711,9 +687,7 @@ impl KeyRequestMachine {
/// Delete the given outgoing key info. /// Delete the given outgoing key info.
async fn delete_key_info(&self, info: &OutgoingKeyRequest) -> Result<(), CryptoStoreError> { async fn delete_key_info(&self, info: &OutgoingKeyRequest) -> Result<(), CryptoStoreError> {
self.store self.store.delete_outgoing_key_request(info.request_id).await
.delete_outgoing_key_request(info.request_id)
.await
} }
/// Mark the outgoing request as sent. /// Mark the outgoing request as sent.
@ -736,20 +710,15 @@ impl KeyRequestMachine {
/// This will queue up a request cancelation. /// This will queue up a request cancelation.
async fn mark_as_done(&self, key_info: OutgoingKeyRequest) -> Result<(), CryptoStoreError> { async fn mark_as_done(&self, key_info: OutgoingKeyRequest) -> Result<(), CryptoStoreError> {
// TODO perhaps only remove the key info if the first known index is 0. // TODO perhaps only remove the key info if the first known index is 0.
trace!( trace!("Successfully received a forwarded room key for {:#?}", key_info);
"Successfully received a forwarded room key for {:#?}",
key_info
);
self.outgoing_to_device_requests self.outgoing_to_device_requests.remove(&key_info.request_id);
.remove(&key_info.request_id);
// TODO return the key info instead of deleting it so the sync handler // TODO return the key info instead of deleting it so the sync handler
// can delete it in one transaction. // can delete it in one transaction.
self.delete_key_info(&key_info).await?; self.delete_key_info(&key_info).await?;
let request = key_info.to_cancelation(self.device_id())?; let request = key_info.to_cancelation(self.device_id())?;
self.outgoing_to_device_requests self.outgoing_to_device_requests.insert(request.request_id, request);
.insert(request.request_id, request);
Ok(()) Ok(())
} }
@ -801,10 +770,7 @@ impl KeyRequestMachine {
); );
} }
Ok(( Ok((Some(AnyToDeviceEvent::ForwardedRoomKey(event.clone())), session))
Some(AnyToDeviceEvent::ForwardedRoomKey(event.clone())),
session,
))
} else { } else {
info!( info!(
"Received a forwarded room key from {}, but no key info was found.", "Received a forwarded room key from {}, but no key info was found.",
@ -817,6 +783,8 @@ impl KeyRequestMachine {
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use std::{convert::TryInto, sync::Arc};
use dashmap::DashMap; use dashmap::DashMap;
use matrix_sdk_common::{ use matrix_sdk_common::{
api::r0::to_device::DeviceIdOrAllDevices, api::r0::to_device::DeviceIdOrAllDevices,
@ -829,8 +797,8 @@ mod test {
locks::Mutex, locks::Mutex,
}; };
use matrix_sdk_test::async_test; use matrix_sdk_test::async_test;
use std::{convert::TryInto, sync::Arc};
use super::{KeyRequestMachine, KeyshareDecision};
use crate::{ use crate::{
identities::{LocalTrust, ReadOnlyDevice}, identities::{LocalTrust, ReadOnlyDevice},
olm::{Account, PrivateCrossSigningIdentity, ReadOnlyAccount}, olm::{Account, PrivateCrossSigningIdentity, ReadOnlyAccount},
@ -839,8 +807,6 @@ mod test {
verification::VerificationMachine, verification::VerificationMachine,
}; };
use super::{KeyRequestMachine, KeyshareDecision};
fn alice_id() -> UserId { fn alice_id() -> UserId {
user_id!("@alice:example.org") user_id!("@alice:example.org")
} }
@ -919,11 +885,7 @@ mod test {
async fn create_machine() { async fn create_machine() {
let machine = get_machine().await; let machine = get_machine().await;
assert!(machine assert!(machine.outgoing_to_device_requests().await.unwrap().is_empty());
.outgoing_to_device_requests()
.await
.unwrap()
.is_empty());
} }
#[async_test] #[async_test]
@ -931,16 +893,10 @@ mod test {
let machine = get_machine().await; let machine = get_machine().await;
let account = account(); let account = account();
let (_, session) = account let (_, session) =
.create_group_session_pair_with_defaults(&room_id()) account.create_group_session_pair_with_defaults(&room_id()).await.unwrap();
.await
.unwrap();
assert!(machine assert!(machine.outgoing_to_device_requests().await.unwrap().is_empty());
.outgoing_to_device_requests()
.await
.unwrap()
.is_empty());
let (cancel, request) = machine let (cancel, request) = machine
.request_key(session.room_id(), &session.sender_key, session.session_id()) .request_key(session.room_id(), &session.sender_key, session.session_id())
.await .await
@ -948,10 +904,7 @@ mod test {
assert!(cancel.is_none()); assert!(cancel.is_none());
machine machine.mark_outgoing_request_as_sent(request.request_id).await.unwrap();
.mark_outgoing_request_as_sent(request.request_id)
.await
.unwrap();
let (cancel, _) = machine let (cancel, _) = machine
.request_key(session.room_id(), &session.sender_key, session.session_id()) .request_key(session.room_id(), &session.sender_key, session.session_id())
@ -972,16 +925,10 @@ mod test {
alice_device.set_trust_state(LocalTrust::Verified); alice_device.set_trust_state(LocalTrust::Verified);
machine.store.save_devices(&[alice_device]).await.unwrap(); machine.store.save_devices(&[alice_device]).await.unwrap();
let (_, session) = account let (_, session) =
.create_group_session_pair_with_defaults(&room_id()) account.create_group_session_pair_with_defaults(&room_id()).await.unwrap();
.await
.unwrap();
assert!(machine assert!(machine.outgoing_to_device_requests().await.unwrap().is_empty());
.outgoing_to_device_requests()
.await
.unwrap()
.is_empty());
machine machine
.create_outgoing_key_request( .create_outgoing_key_request(
session.room_id(), session.room_id(),
@ -990,15 +937,8 @@ mod test {
) )
.await .await
.unwrap(); .unwrap();
assert!(!machine assert!(!machine.outgoing_to_device_requests().await.unwrap().is_empty());
.outgoing_to_device_requests() assert_eq!(machine.outgoing_to_device_requests().await.unwrap().len(), 1);
.await
.unwrap()
.is_empty());
assert_eq!(
machine.outgoing_to_device_requests().await.unwrap().len(),
1
);
machine machine
.create_outgoing_key_request( .create_outgoing_key_request(
@ -1014,15 +954,8 @@ mod test {
let request = requests.get(0).unwrap(); let request = requests.get(0).unwrap();
machine machine.mark_outgoing_request_as_sent(request.request_id).await.unwrap();
.mark_outgoing_request_as_sent(request.request_id) assert!(machine.outgoing_to_device_requests().await.unwrap().is_empty());
.await
.unwrap();
assert!(machine
.outgoing_to_device_requests()
.await
.unwrap()
.is_empty());
} }
#[async_test] #[async_test]
@ -1037,10 +970,8 @@ mod test {
alice_device.set_trust_state(LocalTrust::Verified); alice_device.set_trust_state(LocalTrust::Verified);
machine.store.save_devices(&[alice_device]).await.unwrap(); machine.store.save_devices(&[alice_device]).await.unwrap();
let (_, session) = account let (_, session) =
.create_group_session_pair_with_defaults(&room_id()) account.create_group_session_pair_with_defaults(&room_id()).await.unwrap();
.await
.unwrap();
machine machine
.create_outgoing_key_request( .create_outgoing_key_request(
session.room_id(), session.room_id(),
@ -1060,10 +991,7 @@ mod test {
let content: ForwardedRoomKeyToDeviceEventContent = export.try_into().unwrap(); let content: ForwardedRoomKeyToDeviceEventContent = export.try_into().unwrap();
let mut event = ToDeviceEvent { let mut event = ToDeviceEvent { sender: alice_id(), content };
sender: alice_id(),
content,
};
assert!( assert!(
machine machine
@ -1078,19 +1006,13 @@ mod test {
.is_none() .is_none()
); );
let (_, first_session) = machine let (_, first_session) =
.receive_forwarded_room_key(&session.sender_key, &mut event) machine.receive_forwarded_room_key(&session.sender_key, &mut event).await.unwrap();
.await
.unwrap();
let first_session = first_session.unwrap(); let first_session = first_session.unwrap();
assert_eq!(first_session.first_known_index(), 10); assert_eq!(first_session.first_known_index(), 10);
machine machine.store.save_inbound_group_sessions(&[first_session.clone()]).await.unwrap();
.store
.save_inbound_group_sessions(&[first_session.clone()])
.await
.unwrap();
// Get the cancel request. // Get the cancel request.
let request = machine.outgoing_to_device_requests.iter().next().unwrap(); let request = machine.outgoing_to_device_requests.iter().next().unwrap();
@ -1110,24 +1032,16 @@ mod test {
let requests = machine.outgoing_to_device_requests().await.unwrap(); let requests = machine.outgoing_to_device_requests().await.unwrap();
let request = &requests[0]; let request = &requests[0];
machine machine.mark_outgoing_request_as_sent(request.request_id).await.unwrap();
.mark_outgoing_request_as_sent(request.request_id)
.await
.unwrap();
let export = session.export_at_index(15).await; let export = session.export_at_index(15).await;
let content: ForwardedRoomKeyToDeviceEventContent = export.try_into().unwrap(); let content: ForwardedRoomKeyToDeviceEventContent = export.try_into().unwrap();
let mut event = ToDeviceEvent { let mut event = ToDeviceEvent { sender: alice_id(), content };
sender: alice_id(),
content,
};
let (_, second_session) = machine let (_, second_session) =
.receive_forwarded_room_key(&session.sender_key, &mut event) machine.receive_forwarded_room_key(&session.sender_key, &mut event).await.unwrap();
.await
.unwrap();
assert!(second_session.is_none()); assert!(second_session.is_none());
@ -1135,15 +1049,10 @@ mod test {
let content: ForwardedRoomKeyToDeviceEventContent = export.try_into().unwrap(); let content: ForwardedRoomKeyToDeviceEventContent = export.try_into().unwrap();
let mut event = ToDeviceEvent { let mut event = ToDeviceEvent { sender: alice_id(), content };
sender: alice_id(),
content,
};
let (_, second_session) = machine let (_, second_session) =
.receive_forwarded_room_key(&session.sender_key, &mut event) machine.receive_forwarded_room_key(&session.sender_key, &mut event).await.unwrap();
.await
.unwrap();
assert_eq!(second_session.unwrap().first_known_index(), 0); assert_eq!(second_session.unwrap().first_known_index(), 0);
} }
@ -1153,17 +1062,11 @@ mod test {
let machine = get_machine().await; let machine = get_machine().await;
let account = account(); let account = account();
let own_device = machine let own_device =
.store machine.store.get_device(&alice_id(), &alice_device_id()).await.unwrap().unwrap();
.get_device(&alice_id(), &alice_device_id())
.await
.unwrap()
.unwrap();
let (outbound, inbound) = account let (outbound, inbound) =
.create_group_session_pair_with_defaults(&room_id()) account.create_group_session_pair_with_defaults(&room_id()).await.unwrap();
.await
.unwrap();
// We don't share keys with untrusted devices. // We don't share keys with untrusted devices.
assert_eq!( assert_eq!(
@ -1175,20 +1078,13 @@ mod test {
); );
own_device.set_trust_state(LocalTrust::Verified); own_device.set_trust_state(LocalTrust::Verified);
// Now we do want to share the keys. // Now we do want to share the keys.
assert!(machine assert!(machine.should_share_key(&own_device, &inbound).await.is_ok());
.should_share_key(&own_device, &inbound)
.await
.is_ok());
let bob_device = ReadOnlyDevice::from_account(&bob_account()).await; let bob_device = ReadOnlyDevice::from_account(&bob_account()).await;
machine.store.save_devices(&[bob_device]).await.unwrap(); machine.store.save_devices(&[bob_device]).await.unwrap();
let bob_device = machine let bob_device =
.store machine.store.get_device(&bob_id(), &bob_device_id()).await.unwrap().unwrap();
.get_device(&bob_id(), &bob_device_id())
.await
.unwrap()
.unwrap();
// We don't share sessions with other user's devices if no outbound // We don't share sessions with other user's devices if no outbound
// session was provided. // session was provided.
@ -1231,17 +1127,12 @@ mod test {
// We now share the session, since it was shared before. // We now share the session, since it was shared before.
outbound.mark_shared_with(bob_device.user_id(), bob_device.device_id()); outbound.mark_shared_with(bob_device.user_id(), bob_device.device_id());
assert!(machine assert!(machine.should_share_key(&bob_device, &inbound).await.is_ok());
.should_share_key(&bob_device, &inbound)
.await
.is_ok());
// But we don't share some other session that doesn't match our outbound // But we don't share some other session that doesn't match our outbound
// session // session
let (_, other_inbound) = account let (_, other_inbound) =
.create_group_session_pair_with_defaults(&room_id()) account.create_group_session_pair_with_defaults(&room_id()).await.unwrap();
.await
.unwrap();
assert_eq!( assert_eq!(
machine machine
@ -1255,10 +1146,7 @@ mod test {
#[async_test] #[async_test]
async fn key_share_cycle() { async fn key_share_cycle() {
let alice_machine = get_machine().await; let alice_machine = get_machine().await;
let alice_account = Account { let alice_account = Account { inner: account(), store: alice_machine.store.clone() };
inner: account(),
store: alice_machine.store.clone(),
};
let bob_machine = bob_machine(); let bob_machine = bob_machine();
let bob_account = bob_account(); let bob_account = bob_account();
@ -1268,11 +1156,7 @@ mod test {
// We need a trusted device, otherwise we won't request keys // We need a trusted device, otherwise we won't request keys
alice_device.set_trust_state(LocalTrust::Verified); alice_device.set_trust_state(LocalTrust::Verified);
alice_machine alice_machine.store.save_devices(&[alice_device]).await.unwrap();
.store
.save_devices(&[alice_device])
.await
.unwrap();
// Create Olm sessions for our two accounts. // Create Olm sessions for our two accounts.
let (alice_session, bob_session) = alice_account.create_session_for(&bob_account).await; let (alice_session, bob_session) = alice_account.create_session_for(&bob_account).await;
@ -1282,37 +1166,15 @@ mod test {
// Populate our stores with Olm sessions and a Megolm session. // Populate our stores with Olm sessions and a Megolm session.
alice_machine alice_machine.store.save_sessions(&[alice_session]).await.unwrap();
.store alice_machine.store.save_devices(&[bob_device]).await.unwrap();
.save_sessions(&[alice_session]) bob_machine.store.save_sessions(&[bob_session]).await.unwrap();
.await bob_machine.store.save_devices(&[alice_device]).await.unwrap();
.unwrap();
alice_machine
.store
.save_devices(&[bob_device])
.await
.unwrap();
bob_machine
.store
.save_sessions(&[bob_session])
.await
.unwrap();
bob_machine
.store
.save_devices(&[alice_device])
.await
.unwrap();
let (group_session, inbound_group_session) = bob_account let (group_session, inbound_group_session) =
.create_group_session_pair_with_defaults(&room_id()) bob_account.create_group_session_pair_with_defaults(&room_id()).await.unwrap();
.await
.unwrap();
bob_machine bob_machine.store.save_inbound_group_sessions(&[inbound_group_session]).await.unwrap();
.store
.save_inbound_group_sessions(&[inbound_group_session])
.await
.unwrap();
// Alice wants to request the outbound group session from bob. // Alice wants to request the outbound group session from bob.
alice_machine alice_machine
@ -1326,9 +1188,7 @@ mod test {
group_session.mark_shared_with(&alice_id(), &alice_device_id()); group_session.mark_shared_with(&alice_id(), &alice_device_id());
// Put the outbound session into bobs store. // Put the outbound session into bobs store.
bob_machine bob_machine.outbound_group_sessions.insert(group_session.clone());
.outbound_group_sessions
.insert(group_session.clone());
// Get the request and convert it into a event. // Get the request and convert it into a event.
let requests = alice_machine.outgoing_to_device_requests().await.unwrap(); let requests = alice_machine.outgoing_to_device_requests().await.unwrap();
@ -1346,15 +1206,9 @@ mod test {
let content: RoomKeyRequestToDeviceEventContent = let content: RoomKeyRequestToDeviceEventContent =
serde_json::from_str(content.get()).unwrap(); serde_json::from_str(content.get()).unwrap();
alice_machine alice_machine.mark_outgoing_request_as_sent(id).await.unwrap();
.mark_outgoing_request_as_sent(id)
.await
.unwrap();
let event = ToDeviceEvent { let event = ToDeviceEvent { sender: alice_id(), content };
sender: alice_id(),
content,
};
// Bob doesn't have any outgoing requests. // Bob doesn't have any outgoing requests.
assert!(bob_machine.outgoing_to_device_requests.is_empty()); assert!(bob_machine.outgoing_to_device_requests.is_empty());
@ -1383,10 +1237,7 @@ mod test {
bob_machine.mark_outgoing_request_as_sent(id).await.unwrap(); bob_machine.mark_outgoing_request_as_sent(id).await.unwrap();
let event = ToDeviceEvent { let event = ToDeviceEvent { sender: bob_id(), content };
sender: bob_id(),
content,
};
// Check that alice doesn't have the session. // Check that alice doesn't have the session.
assert!(alice_machine assert!(alice_machine
@ -1407,11 +1258,7 @@ mod test {
.receive_forwarded_room_key(&decrypted.sender_key, &mut e) .receive_forwarded_room_key(&decrypted.sender_key, &mut e)
.await .await
.unwrap(); .unwrap();
alice_machine alice_machine.store.save_inbound_group_sessions(&[session.unwrap()]).await.unwrap();
.store
.save_inbound_group_sessions(&[session.unwrap()])
.await
.unwrap();
} else { } else {
panic!("Invalid decrypted event type"); panic!("Invalid decrypted event type");
} }
@ -1434,10 +1281,7 @@ mod test {
#[async_test] #[async_test]
async fn key_share_cycle_without_session() { async fn key_share_cycle_without_session() {
let alice_machine = get_machine().await; let alice_machine = get_machine().await;
let alice_account = Account { let alice_account = Account { inner: account(), store: alice_machine.store.clone() };
inner: account(),
store: alice_machine.store.clone(),
};
let bob_machine = bob_machine(); let bob_machine = bob_machine();
let bob_account = bob_account(); let bob_account = bob_account();
@ -1447,11 +1291,7 @@ mod test {
// We need a trusted device, otherwise we won't request keys // We need a trusted device, otherwise we won't request keys
alice_device.set_trust_state(LocalTrust::Verified); alice_device.set_trust_state(LocalTrust::Verified);
alice_machine alice_machine.store.save_devices(&[alice_device]).await.unwrap();
.store
.save_devices(&[alice_device])
.await
.unwrap();
// Create Olm sessions for our two accounts. // Create Olm sessions for our two accounts.
let (alice_session, bob_session) = alice_account.create_session_for(&bob_account).await; let (alice_session, bob_session) = alice_account.create_session_for(&bob_account).await;
@ -1461,27 +1301,13 @@ mod test {
// Populate our stores with Olm sessions and a Megolm session. // Populate our stores with Olm sessions and a Megolm session.
alice_machine alice_machine.store.save_devices(&[bob_device]).await.unwrap();
.store bob_machine.store.save_devices(&[alice_device]).await.unwrap();
.save_devices(&[bob_device])
.await
.unwrap();
bob_machine
.store
.save_devices(&[alice_device])
.await
.unwrap();
let (group_session, inbound_group_session) = bob_account let (group_session, inbound_group_session) =
.create_group_session_pair_with_defaults(&room_id()) bob_account.create_group_session_pair_with_defaults(&room_id()).await.unwrap();
.await
.unwrap();
bob_machine bob_machine.store.save_inbound_group_sessions(&[inbound_group_session]).await.unwrap();
.store
.save_inbound_group_sessions(&[inbound_group_session])
.await
.unwrap();
// Alice wants to request the outbound group session from bob. // Alice wants to request the outbound group session from bob.
alice_machine alice_machine
@ -1495,9 +1321,7 @@ mod test {
group_session.mark_shared_with(&alice_id(), &alice_device_id()); group_session.mark_shared_with(&alice_id(), &alice_device_id());
// Put the outbound session into bobs store. // Put the outbound session into bobs store.
bob_machine bob_machine.outbound_group_sessions.insert(group_session.clone());
.outbound_group_sessions
.insert(group_session.clone());
// Get the request and convert it into a event. // Get the request and convert it into a event.
let requests = alice_machine.outgoing_to_device_requests().await.unwrap(); let requests = alice_machine.outgoing_to_device_requests().await.unwrap();
@ -1515,22 +1339,12 @@ mod test {
let content: RoomKeyRequestToDeviceEventContent = let content: RoomKeyRequestToDeviceEventContent =
serde_json::from_str(content.get()).unwrap(); serde_json::from_str(content.get()).unwrap();
alice_machine alice_machine.mark_outgoing_request_as_sent(id).await.unwrap();
.mark_outgoing_request_as_sent(id)
.await
.unwrap();
let event = ToDeviceEvent { let event = ToDeviceEvent { sender: alice_id(), content };
sender: alice_id(),
content,
};
// Bob doesn't have any outgoing requests. // Bob doesn't have any outgoing requests.
assert!(bob_machine assert!(bob_machine.outgoing_to_device_requests().await.unwrap().is_empty());
.outgoing_to_device_requests()
.await
.unwrap()
.is_empty());
assert!(bob_machine.users_for_key_claim.is_empty()); assert!(bob_machine.users_for_key_claim.is_empty());
assert!(bob_machine.wait_queue.is_empty()); assert!(bob_machine.wait_queue.is_empty());
@ -1538,35 +1352,19 @@ mod test {
bob_machine.receive_incoming_key_request(&event); bob_machine.receive_incoming_key_request(&event);
bob_machine.collect_incoming_key_requests().await.unwrap(); bob_machine.collect_incoming_key_requests().await.unwrap();
// Bob doens't have an outgoing requests since we're lacking a session. // Bob doens't have an outgoing requests since we're lacking a session.
assert!(bob_machine assert!(bob_machine.outgoing_to_device_requests().await.unwrap().is_empty());
.outgoing_to_device_requests()
.await
.unwrap()
.is_empty());
assert!(!bob_machine.users_for_key_claim.is_empty()); assert!(!bob_machine.users_for_key_claim.is_empty());
assert!(!bob_machine.wait_queue.is_empty()); assert!(!bob_machine.wait_queue.is_empty());
// We create a session now. // We create a session now.
alice_machine alice_machine.store.save_sessions(&[alice_session]).await.unwrap();
.store bob_machine.store.save_sessions(&[bob_session]).await.unwrap();
.save_sessions(&[alice_session])
.await
.unwrap();
bob_machine
.store
.save_sessions(&[bob_session])
.await
.unwrap();
bob_machine.retry_keyshare(&alice_id(), &alice_device_id()); bob_machine.retry_keyshare(&alice_id(), &alice_device_id());
assert!(bob_machine.users_for_key_claim.is_empty()); assert!(bob_machine.users_for_key_claim.is_empty());
bob_machine.collect_incoming_key_requests().await.unwrap(); bob_machine.collect_incoming_key_requests().await.unwrap();
// Bob now has an outgoing requests. // Bob now has an outgoing requests.
assert!(!bob_machine assert!(!bob_machine.outgoing_to_device_requests().await.unwrap().is_empty());
.outgoing_to_device_requests()
.await
.unwrap()
.is_empty());
assert!(bob_machine.wait_queue.is_empty()); assert!(bob_machine.wait_queue.is_empty());
// Get the request and convert it to a encrypted to-device event. // Get the request and convert it to a encrypted to-device event.
@ -1588,10 +1386,7 @@ mod test {
bob_machine.mark_outgoing_request_as_sent(id).await.unwrap(); bob_machine.mark_outgoing_request_as_sent(id).await.unwrap();
let event = ToDeviceEvent { let event = ToDeviceEvent { sender: bob_id(), content };
sender: bob_id(),
content,
};
// Check that alice doesn't have the session. // Check that alice doesn't have the session.
assert!(alice_machine assert!(alice_machine
@ -1612,11 +1407,7 @@ mod test {
.receive_forwarded_room_key(&decrypted.sender_key, &mut e) .receive_forwarded_room_key(&decrypted.sender_key, &mut e)
.await .await
.unwrap(); .unwrap();
alice_machine alice_machine.store.save_inbound_group_sessions(&[session.unwrap()]).await.unwrap();
.store
.save_inbound_group_sessions(&[session.unwrap()])
.await
.unwrap();
} else { } else {
panic!("Invalid decrypted event type"); panic!("Invalid decrypted event type");
} }

View File

@ -17,8 +17,6 @@ use std::path::Path;
use std::{collections::BTreeMap, mem, sync::Arc}; use std::{collections::BTreeMap, mem, sync::Arc};
use dashmap::DashMap; use dashmap::DashMap;
use tracing::{debug, error, info, trace, warn};
use matrix_sdk_common::{ use matrix_sdk_common::{
api::r0::{ api::r0::{
keys::{ keys::{
@ -43,6 +41,7 @@ use matrix_sdk_common::{
uuid::Uuid, uuid::Uuid,
UInt, UInt,
}; };
use tracing::{debug, error, info, trace, warn};
#[cfg(feature = "sled_cryptostore")] #[cfg(feature = "sled_cryptostore")]
use crate::store::sled::SledStore; use crate::store::sled::SledStore;
@ -148,19 +147,12 @@ impl OlmMachine {
let store = Arc::new(store); let store = Arc::new(store);
let verification_machine = let verification_machine =
VerificationMachine::new(account.clone(), user_identity.clone(), store.clone()); VerificationMachine::new(account.clone(), user_identity.clone(), store.clone());
let store = Store::new( let store =
user_id.clone(), Store::new(user_id.clone(), user_identity.clone(), store, verification_machine.clone());
user_identity.clone(),
store,
verification_machine.clone(),
);
let device_id: Arc<DeviceIdBox> = Arc::new(device_id); let device_id: Arc<DeviceIdBox> = Arc::new(device_id);
let users_for_key_claim = Arc::new(DashMap::new()); let users_for_key_claim = Arc::new(DashMap::new());
let account = Account { let account = Account { inner: account, store: store.clone() };
inner: account,
store: store.clone(),
};
let group_session_manager = GroupSessionManager::new(account.clone(), store.clone()); let group_session_manager = GroupSessionManager::new(account.clone(), store.clone());
@ -244,9 +236,7 @@ impl OlmMachine {
} }
}; };
Ok(OlmMachine::new_helper( Ok(OlmMachine::new_helper(&user_id, device_id, store, account, identity))
&user_id, device_id, store, account, identity,
))
} }
/// Create a new machine with the default crypto store. /// Create a new machine with the default crypto store.
@ -296,19 +286,16 @@ impl OlmMachine {
pub async fn outgoing_requests(&self) -> StoreResult<Vec<OutgoingRequest>> { pub async fn outgoing_requests(&self) -> StoreResult<Vec<OutgoingRequest>> {
let mut requests = Vec::new(); let mut requests = Vec::new();
if let Some(r) = self.keys_for_upload().await.map(|r| OutgoingRequest { if let Some(r) = self
request_id: Uuid::new_v4(), .keys_for_upload()
request: Arc::new(r.into()), .await
}) { .map(|r| OutgoingRequest { request_id: Uuid::new_v4(), request: Arc::new(r.into()) })
{
requests.push(r); requests.push(r);
} }
for request in self for request in
.identity_manager self.identity_manager.users_for_key_query().await.into_iter().map(|r| OutgoingRequest {
.users_for_key_query()
.await
.into_iter()
.map(|r| OutgoingRequest {
request_id: Uuid::new_v4(), request_id: Uuid::new_v4(),
request: Arc::new(r.into()), request: Arc::new(r.into()),
}) })
@ -317,12 +304,7 @@ impl OlmMachine {
} }
requests.append(&mut self.verification_machine.outgoing_messages()); requests.append(&mut self.verification_machine.outgoing_messages());
requests.append( requests.append(&mut self.key_request_machine.outgoing_to_device_requests().await?);
&mut self
.key_request_machine
.outgoing_to_device_requests()
.await?,
);
Ok(requests) Ok(requests)
} }
@ -373,10 +355,7 @@ impl OlmMachine {
let identity = self.user_identity.lock().await; let identity = self.user_identity.lock().await;
identity.mark_as_shared(); identity.mark_as_shared();
let changes = Changes { let changes = Changes { private_identity: Some(identity.clone()), ..Default::default() };
private_identity: Some(identity.clone()),
..Default::default()
};
self.store.save_changes(changes).await self.store.save_changes(changes).await
} }
@ -406,10 +385,7 @@ impl OlmMachine {
); );
let changes = Changes { let changes = Changes {
identities: IdentityChanges { identities: IdentityChanges { new: vec![public.into()], ..Default::default() },
new: vec![public.into()],
..Default::default()
},
private_identity: Some(identity.clone()), private_identity: Some(identity.clone()),
..Default::default() ..Default::default()
}; };
@ -421,10 +397,8 @@ impl OlmMachine {
info!("Trying to upload the existing cross signing identity"); info!("Trying to upload the existing cross signing identity");
let request = identity.as_upload_request().await; let request = identity.as_upload_request().await;
// TODO remove this expect. // TODO remove this expect.
let signature_request = identity let signature_request =
.sign_account(&self.account) identity.sign_account(&self.account).await.expect("Can't sign device keys");
.await
.expect("Can't sign device keys");
Ok((request, signature_request)) Ok((request, signature_request))
} }
} }
@ -518,9 +492,7 @@ impl OlmMachine {
/// ///
/// * `response` - The response containing the claimed one-time keys. /// * `response` - The response containing the claimed one-time keys.
async fn receive_keys_claim_response(&self, response: &KeysClaimResponse) -> OlmResult<()> { async fn receive_keys_claim_response(&self, response: &KeysClaimResponse) -> OlmResult<()> {
self.session_manager self.session_manager.receive_keys_claim_response(response).await
.receive_keys_claim_response(response)
.await
} }
/// Receive a successful keys query response. /// Receive a successful keys query response.
@ -536,9 +508,7 @@ impl OlmMachine {
&self, &self,
response: &KeysQueryResponse, response: &KeysQueryResponse,
) -> OlmResult<(DeviceChanges, IdentityChanges)> { ) -> OlmResult<(DeviceChanges, IdentityChanges)> {
self.identity_manager self.identity_manager.receive_keys_query_response(response).await
.receive_keys_query_response(response)
.await
} }
/// Get a request to upload E2EE keys to the server. /// Get a request to upload E2EE keys to the server.
@ -676,9 +646,7 @@ impl OlmMachine {
/// Returns true if a session was invalidated, false if there was no session /// Returns true if a session was invalidated, false if there was no session
/// to invalidate. /// to invalidate.
pub async fn invalidate_group_session(&self, room_id: &RoomId) -> StoreResult<bool> { pub async fn invalidate_group_session(&self, room_id: &RoomId) -> StoreResult<bool> {
self.group_session_manager self.group_session_manager.invalidate_group_session(room_id).await
.invalidate_group_session(room_id)
.await
} }
/// Get to-device requests to share a group session with users in a room. /// Get to-device requests to share a group session with users in a room.
@ -695,9 +663,7 @@ impl OlmMachine {
users: impl Iterator<Item = &UserId>, users: impl Iterator<Item = &UserId>,
encryption_settings: impl Into<EncryptionSettings>, encryption_settings: impl Into<EncryptionSettings>,
) -> OlmResult<Vec<Arc<ToDeviceRequest>>> { ) -> OlmResult<Vec<Arc<ToDeviceRequest>>> {
self.group_session_manager self.group_session_manager.share_group_session(room_id, users, encryption_settings).await
.share_group_session(room_id, users, encryption_settings)
.await
} }
/// Receive and properly handle a decrypted to-device event. /// Receive and properly handle a decrypted to-device event.
@ -716,18 +682,15 @@ impl OlmMachine {
let event = match decrypted.event.deserialize() { let event = match decrypted.event.deserialize() {
Ok(e) => e, Ok(e) => e,
Err(e) => { Err(e) => {
warn!( warn!("Decrypted to-device event failed to be parsed correctly {:?}", e);
"Decrypted to-device event failed to be parsed correctly {:?}",
e
);
return Ok((None, None)); return Ok((None, None));
} }
}; };
match event { match event {
AnyToDeviceEvent::RoomKey(mut e) => Ok(self AnyToDeviceEvent::RoomKey(mut e) => {
.add_room_key(&decrypted.sender_key, &decrypted.signing_key, &mut e) Ok(self.add_room_key(&decrypted.sender_key, &decrypted.signing_key, &mut e).await?)
.await?), }
AnyToDeviceEvent::ForwardedRoomKey(mut e) => Ok(self AnyToDeviceEvent::ForwardedRoomKey(mut e) => Ok(self
.key_request_machine .key_request_machine
.receive_forwarded_room_key(&decrypted.sender_key, &mut e) .receive_forwarded_room_key(&decrypted.sender_key, &mut e)
@ -748,14 +711,9 @@ impl OlmMachine {
/// Mark an outgoing to-device requests as sent. /// Mark an outgoing to-device requests as sent.
async fn mark_to_device_request_as_sent(&self, request_id: &Uuid) -> StoreResult<()> { async fn mark_to_device_request_as_sent(&self, request_id: &Uuid) -> StoreResult<()> {
self.verification_machine.mark_request_as_sent(request_id); self.verification_machine.mark_request_as_sent(request_id);
self.key_request_machine self.key_request_machine.mark_outgoing_request_as_sent(*request_id).await?;
.mark_outgoing_request_as_sent(*request_id) self.group_session_manager.mark_request_as_sent(request_id).await?;
.await?; self.session_manager.mark_outgoing_request_as_sent(request_id);
self.group_session_manager
.mark_request_as_sent(request_id)
.await?;
self.session_manager
.mark_outgoing_request_as_sent(request_id);
Ok(()) Ok(())
} }
@ -830,10 +788,8 @@ impl OlmMachine {
// Always save the account, a new session might get created which also // Always save the account, a new session might get created which also
// touches the account. // touches the account.
let mut changes = Changes { let mut changes =
account: Some(self.account.inner.clone()), Changes { account: Some(self.account.inner.clone()), ..Default::default() };
..Default::default()
};
self.update_one_time_key_count(one_time_keys_counts).await; self.update_one_time_key_count(one_time_keys_counts).await;
@ -850,10 +806,7 @@ impl OlmMachine {
Ok(e) => e, Ok(e) => e,
Err(e) => { Err(e) => {
// Skip invalid events. // Skip invalid events.
warn!( warn!("Received an invalid to-device event {:?} {:?}", e, raw_event);
"Received an invalid to-device event {:?} {:?}",
e, raw_event
);
continue; continue;
} }
}; };
@ -865,10 +818,7 @@ impl OlmMachine {
let decrypted = match self.decrypt_to_device_event(&e).await { let decrypted = match self.decrypt_to_device_event(&e).await {
Ok(e) => e, Ok(e) => e,
Err(err) => { Err(err) => {
warn!( warn!("Failed to decrypt to-device event from {} {}", e.sender, err);
"Failed to decrypt to-device event from {} {}",
e.sender, err
);
if let OlmError::SessionWedged(sender, curve_key) = err { if let OlmError::SessionWedged(sender, curve_key) = err {
if let Err(e) = self if let Err(e) = self
@ -916,10 +866,7 @@ impl OlmMachine {
events.push(raw_event); events.push(raw_event);
} }
let changed_sessions = self let changed_sessions = self.key_request_machine.collect_incoming_key_requests().await?;
.key_request_machine
.collect_incoming_key_requests()
.await?;
changes.sessions.extend(changed_sessions); changes.sessions.extend(changed_sessions);
@ -1036,25 +983,16 @@ impl OlmMachine {
// TODO check if this is from a verified device. // TODO check if this is from a verified device.
let (decrypted_event, _) = session.decrypt(event).await?; let (decrypted_event, _) = session.decrypt(event).await?;
trace!( trace!("Successfully decrypted a Megolm event {:?}", decrypted_event);
"Successfully decrypted a Megolm event {:?}",
decrypted_event
);
if let Ok(e) = decrypted_event.deserialize() { if let Ok(e) = decrypted_event.deserialize() {
self.verification_machine self.verification_machine.receive_room_event(room_id, &e).await?;
.receive_room_event(room_id, &e)
.await?;
} }
let encryption_info = self let encryption_info =
.get_encryption_info(&session, &event.sender, &content.device_id) self.get_encryption_info(&session, &event.sender, &content.device_id).await?;
.await?;
Ok(SyncRoomEvent { Ok(SyncRoomEvent { encryption_info: Some(encryption_info), event: decrypted_event })
encryption_info: Some(encryption_info),
event: decrypted_event,
})
} }
/// Update the tracked users. /// Update the tracked users.
@ -1210,17 +1148,11 @@ impl OlmMachine {
let num_sessions = sessions.len(); let num_sessions = sessions.len();
let changes = Changes { let changes = Changes { inbound_group_sessions: sessions, ..Default::default() };
inbound_group_sessions: sessions,
..Default::default()
};
self.store.save_changes(changes).await?; self.store.save_changes(changes).await?;
info!( info!("Successfully imported {} inbound group sessions", num_sessions);
"Successfully imported {} inbound group sessions",
num_sessions
);
Ok((num_sessions, total_sessions)) Ok((num_sessions, total_sessions))
} }
@ -1288,15 +1220,6 @@ pub(crate) mod test {
}; };
use http::Response; use http::Response;
use serde_json::json;
use crate::{
machine::OlmMachine,
olm::Utility,
verification::test::{outgoing_request_to_event, request_to_event},
EncryptionSettings, ReadOnlyDevice, ToDeviceRequest,
};
use matrix_sdk_common::{ use matrix_sdk_common::{
api::r0::keys::{claim_keys, get_keys, upload_keys, OneTimeKey}, api::r0::keys::{claim_keys, get_keys, upload_keys, OneTimeKey},
events::{ events::{
@ -1313,6 +1236,14 @@ pub(crate) mod test {
IncomingResponse, Raw, IncomingResponse, Raw,
}; };
use matrix_sdk_test::test_json; use matrix_sdk_test::test_json;
use serde_json::json;
use crate::{
machine::OlmMachine,
olm::Utility,
verification::test::{outgoing_request_to_event, request_to_event},
EncryptionSettings, ReadOnlyDevice, ToDeviceRequest,
};
/// These keys need to be periodically uploaded to the server. /// These keys need to be periodically uploaded to the server.
type OneTimeKeys = BTreeMap<DeviceKeyId, OneTimeKey>; type OneTimeKeys = BTreeMap<DeviceKeyId, OneTimeKey>;
@ -1332,10 +1263,7 @@ pub(crate) mod test {
} }
pub fn response_from_file(json: &serde_json::Value) -> Response<Vec<u8>> { pub fn response_from_file(json: &serde_json::Value) -> Response<Vec<u8>> {
Response::builder() Response::builder().status(200).body(json.to_string().as_bytes().to_vec()).unwrap()
.status(200)
.body(json.to_string().as_bytes().to_vec())
.unwrap()
} }
fn keys_upload_response() -> upload_keys::Response { fn keys_upload_response() -> upload_keys::Response {
@ -1354,15 +1282,7 @@ pub(crate) mod test {
let to_device_request = &requests[0]; let to_device_request = &requests[0];
let content: Raw<EncryptedEventContent> = serde_json::from_str( let content: Raw<EncryptedEventContent> = serde_json::from_str(
to_device_request to_device_request.messages.values().next().unwrap().values().next().unwrap().get(),
.messages
.values()
.next()
.unwrap()
.values()
.next()
.unwrap()
.get(),
) )
.unwrap(); .unwrap();
@ -1372,15 +1292,9 @@ pub(crate) mod test {
pub(crate) async fn get_prepared_machine() -> (OlmMachine, OneTimeKeys) { pub(crate) async fn get_prepared_machine() -> (OlmMachine, OneTimeKeys) {
let machine = OlmMachine::new(&user_id(), &alice_device_id()); let machine = OlmMachine::new(&user_id(), &alice_device_id());
machine.account.inner.update_uploaded_key_count(0); machine.account.inner.update_uploaded_key_count(0);
let request = machine let request = machine.keys_for_upload().await.expect("Can't prepare initial key upload");
.keys_for_upload()
.await
.expect("Can't prepare initial key upload");
let response = keys_upload_response(); let response = keys_upload_response();
machine machine.receive_keys_upload_response(&response).await.unwrap();
.receive_keys_upload_response(&response)
.await
.unwrap();
(machine, request.one_time_keys.unwrap()) (machine, request.one_time_keys.unwrap())
} }
@ -1389,10 +1303,7 @@ pub(crate) mod test {
let (machine, otk) = get_prepared_machine().await; let (machine, otk) = get_prepared_machine().await;
let response = keys_query_response(); let response = keys_query_response();
machine machine.receive_keys_query_response(&response).await.unwrap();
.receive_keys_query_response(&response)
.await
.unwrap();
(machine, otk) (machine, otk)
} }
@ -1435,28 +1346,15 @@ pub(crate) mod test {
async fn get_machine_pair_with_setup_sessions() -> (OlmMachine, OlmMachine) { async fn get_machine_pair_with_setup_sessions() -> (OlmMachine, OlmMachine) {
let (alice, bob) = get_machine_pair_with_session().await; let (alice, bob) = get_machine_pair_with_session().await;
let bob_device = alice let bob_device = alice.get_device(&bob.user_id, &bob.device_id).await.unwrap().unwrap();
.get_device(&bob.user_id, &bob.device_id)
.await
.unwrap()
.unwrap();
let (session, content) = bob_device let (session, content) = bob_device.encrypt(EventType::Dummy, json!({})).await.unwrap();
.encrypt(EventType::Dummy, json!({}))
.await
.unwrap();
alice.store.save_sessions(&[session]).await.unwrap(); alice.store.save_sessions(&[session]).await.unwrap();
let event = ToDeviceEvent { let event = ToDeviceEvent { sender: alice.user_id().clone(), content };
sender: alice.user_id().clone(),
content,
};
let decrypted = bob.decrypt_to_device_event(&event).await.unwrap(); let decrypted = bob.decrypt_to_device_event(&event).await.unwrap();
bob.store bob.store.save_sessions(&[decrypted.session.session()]).await.unwrap();
.save_sessions(&[decrypted.session.session()])
.await
.unwrap();
(alice, bob) (alice, bob)
} }
@ -1472,34 +1370,18 @@ pub(crate) mod test {
let machine = OlmMachine::new(&user_id(), &alice_device_id()); let machine = OlmMachine::new(&user_id(), &alice_device_id());
let mut response = keys_upload_response(); let mut response = keys_upload_response();
response response.one_time_key_counts.remove(&DeviceKeyAlgorithm::SignedCurve25519).unwrap();
.one_time_key_counts
.remove(&DeviceKeyAlgorithm::SignedCurve25519)
.unwrap();
assert!(machine.should_upload_keys().await); assert!(machine.should_upload_keys().await);
machine machine.receive_keys_upload_response(&response).await.unwrap();
.receive_keys_upload_response(&response)
.await
.unwrap();
assert!(machine.should_upload_keys().await); assert!(machine.should_upload_keys().await);
response response.one_time_key_counts.insert(DeviceKeyAlgorithm::SignedCurve25519, uint!(10));
.one_time_key_counts machine.receive_keys_upload_response(&response).await.unwrap();
.insert(DeviceKeyAlgorithm::SignedCurve25519, uint!(10));
machine
.receive_keys_upload_response(&response)
.await
.unwrap();
assert!(machine.should_upload_keys().await); assert!(machine.should_upload_keys().await);
response response.one_time_key_counts.insert(DeviceKeyAlgorithm::SignedCurve25519, uint!(50));
.one_time_key_counts machine.receive_keys_upload_response(&response).await.unwrap();
.insert(DeviceKeyAlgorithm::SignedCurve25519, uint!(50));
machine
.receive_keys_upload_response(&response)
.await
.unwrap();
assert!(!machine.should_upload_keys().await); assert!(!machine.should_upload_keys().await);
} }
@ -1511,20 +1393,12 @@ pub(crate) mod test {
assert!(machine.should_upload_keys().await); assert!(machine.should_upload_keys().await);
machine machine.receive_keys_upload_response(&response).await.unwrap();
.receive_keys_upload_response(&response)
.await
.unwrap();
assert!(machine.should_upload_keys().await); assert!(machine.should_upload_keys().await);
assert!(machine.account.generate_one_time_keys().await.is_ok()); assert!(machine.account.generate_one_time_keys().await.is_ok());
response response.one_time_key_counts.insert(DeviceKeyAlgorithm::SignedCurve25519, uint!(50));
.one_time_key_counts machine.receive_keys_upload_response(&response).await.unwrap();
.insert(DeviceKeyAlgorithm::SignedCurve25519, uint!(50));
machine
.receive_keys_upload_response(&response)
.await
.unwrap();
assert!(machine.account.generate_one_time_keys().await.is_err()); assert!(machine.account.generate_one_time_keys().await.is_err());
} }
@ -1551,14 +1425,8 @@ pub(crate) mod test {
let machine = OlmMachine::new(&user_id(), &alice_device_id()); let machine = OlmMachine::new(&user_id(), &alice_device_id());
let room_id = room_id!("!test:example.org"); let room_id = room_id!("!test:example.org");
machine machine.create_outbound_group_session_with_defaults(&room_id).await.unwrap();
.create_outbound_group_session_with_defaults(&room_id) assert!(machine.group_session_manager.get_outbound_group_session(&room_id).is_some());
.await
.unwrap();
assert!(machine
.group_session_manager
.get_outbound_group_session(&room_id)
.is_some());
machine.invalidate_group_session(&room_id).await.unwrap(); machine.invalidate_group_session(&room_id).await.unwrap();
@ -1614,10 +1482,8 @@ pub(crate) mod test {
let identity_keys = machine.account.identity_keys(); let identity_keys = machine.account.identity_keys();
let ed25519_key = identity_keys.ed25519(); let ed25519_key = identity_keys.ed25519();
let mut request = machine let mut request =
.keys_for_upload() machine.keys_for_upload().await.expect("Can't prepare initial key upload");
.await
.expect("Can't prepare initial key upload");
let utility = Utility::new(); let utility = Utility::new();
let ret = utility.verify_json( let ret = utility.verify_json(
@ -1640,15 +1506,10 @@ pub(crate) mod test {
let mut response = keys_upload_response(); let mut response = keys_upload_response();
response.one_time_key_counts.insert( response.one_time_key_counts.insert(
DeviceKeyAlgorithm::SignedCurve25519, DeviceKeyAlgorithm::SignedCurve25519,
(request.one_time_keys.unwrap().len() as u64) (request.one_time_keys.unwrap().len() as u64).try_into().unwrap(),
.try_into()
.unwrap(),
); );
machine machine.receive_keys_upload_response(&response).await.unwrap();
.receive_keys_upload_response(&response)
.await
.unwrap();
let ret = machine.keys_for_upload().await; let ret = machine.keys_for_upload().await;
assert!(ret.is_none()); assert!(ret.is_none());
@ -1664,17 +1525,9 @@ pub(crate) mod test {
let alice_devices = machine.store.get_user_devices(&alice_id).await.unwrap(); let alice_devices = machine.store.get_user_devices(&alice_id).await.unwrap();
assert!(alice_devices.devices().peekable().peek().is_none()); assert!(alice_devices.devices().peekable().peek().is_none());
machine machine.receive_keys_query_response(&response).await.unwrap();
.receive_keys_query_response(&response)
.await
.unwrap();
let device = machine let device = machine.store.get_device(&alice_id, alice_device_id).await.unwrap().unwrap();
.store
.get_device(&alice_id, alice_device_id)
.await
.unwrap()
.unwrap();
assert_eq!(device.user_id(), &alice_id); assert_eq!(device.user_id(), &alice_id);
assert_eq!(device.device_id(), alice_device_id); assert_eq!(device.device_id(), alice_device_id);
} }
@ -1686,11 +1539,8 @@ pub(crate) mod test {
let alice = alice_id(); let alice = alice_id();
let alice_device = alice_device_id(); let alice_device = alice_device_id();
let (_, missing_sessions) = machine let (_, missing_sessions) =
.get_missing_sessions(&mut [alice.clone()].iter()) machine.get_missing_sessions(&mut [alice.clone()].iter()).await.unwrap().unwrap();
.await
.unwrap()
.unwrap();
assert!(missing_sessions.one_time_keys.contains_key(&alice)); assert!(missing_sessions.one_time_keys.contains_key(&alice));
let user_sessions = missing_sessions.one_time_keys.get(&alice).unwrap(); let user_sessions = missing_sessions.one_time_keys.get(&alice).unwrap();
@ -1713,10 +1563,7 @@ pub(crate) mod test {
let response = claim_keys::Response::new(one_time_keys); let response = claim_keys::Response::new(one_time_keys);
alice_machine alice_machine.receive_keys_claim_response(&response).await.unwrap();
.receive_keys_claim_response(&response)
.await
.unwrap();
let session = alice_machine let session = alice_machine
.store .store
@ -1732,28 +1579,14 @@ pub(crate) mod test {
async fn test_olm_encryption() { async fn test_olm_encryption() {
let (alice, bob) = get_machine_pair_with_session().await; let (alice, bob) = get_machine_pair_with_session().await;
let bob_device = alice let bob_device = alice.get_device(&bob.user_id, &bob.device_id).await.unwrap().unwrap();
.get_device(&bob.user_id, &bob.device_id)
.await
.unwrap()
.unwrap();
let event = ToDeviceEvent { let event = ToDeviceEvent {
sender: alice.user_id().clone(), sender: alice.user_id().clone(),
content: bob_device content: bob_device.encrypt(EventType::Dummy, json!({})).await.unwrap().1,
.encrypt(EventType::Dummy, json!({}))
.await
.unwrap()
.1,
}; };
let event = bob let event = bob.decrypt_to_device_event(&event).await.unwrap().event.deserialize().unwrap();
.decrypt_to_device_event(&event)
.await
.unwrap()
.event
.deserialize()
.unwrap();
if let AnyToDeviceEvent::Dummy(e) = event { if let AnyToDeviceEvent::Dummy(e) = event {
assert_eq!(&e.sender, alice.user_id()); assert_eq!(&e.sender, alice.user_id());
@ -1782,17 +1615,12 @@ pub(crate) mod test {
content: to_device_requests_to_content(to_device_requests), content: to_device_requests_to_content(to_device_requests),
}; };
let alice_session = alice let alice_session =
.group_session_manager alice.group_session_manager.get_outbound_group_session(&room_id).unwrap();
.get_outbound_group_session(&room_id)
.unwrap();
let decrypted = bob.decrypt_to_device_event(&event).await.unwrap(); let decrypted = bob.decrypt_to_device_event(&event).await.unwrap();
bob.store bob.store.save_sessions(&[decrypted.session.session()]).await.unwrap();
.save_sessions(&[decrypted.session.session()])
.await
.unwrap();
bob.store bob.store
.save_inbound_group_sessions(&[decrypted.inbound_group_session.unwrap()]) .save_inbound_group_sessions(&[decrypted.inbound_group_session.unwrap()])
.await .await
@ -1837,25 +1665,16 @@ pub(crate) mod test {
content: to_device_requests_to_content(to_device_requests), content: to_device_requests_to_content(to_device_requests),
}; };
let group_session = bob let group_session =
.decrypt_to_device_event(&event) bob.decrypt_to_device_event(&event).await.unwrap().inbound_group_session;
.await bob.store.save_inbound_group_sessions(&[group_session.unwrap()]).await.unwrap();
.unwrap()
.inbound_group_session;
bob.store
.save_inbound_group_sessions(&[group_session.unwrap()])
.await
.unwrap();
let plaintext = "It is a secret to everybody"; let plaintext = "It is a secret to everybody";
let content = MessageEventContent::text_plain(plaintext); let content = MessageEventContent::text_plain(plaintext);
let encrypted_content = alice let encrypted_content = alice
.encrypt( .encrypt(&room_id, AnyMessageEventContent::RoomMessage(content.clone()))
&room_id,
AnyMessageEventContent::RoomMessage(content.clone()),
)
.await .await
.unwrap(); .unwrap();
@ -1867,13 +1686,8 @@ pub(crate) mod test {
unsigned: Unsigned::default(), unsigned: Unsigned::default(),
}; };
let decrypted_event = bob let decrypted_event =
.decrypt_room_event(&event, &room_id) bob.decrypt_room_event(&event, &room_id).await.unwrap().event.deserialize().unwrap();
.await
.unwrap()
.event
.deserialize()
.unwrap();
if let AnySyncRoomEvent::Message(AnySyncMessageEvent::RoomMessage(SyncMessageEvent { if let AnySyncRoomEvent::Message(AnySyncMessageEvent::RoomMessage(SyncMessageEvent {
sender, sender,
@ -1912,10 +1726,7 @@ pub(crate) mod test {
let device_id = machine.device_id().to_owned(); let device_id = machine.device_id().to_owned();
let ed25519_key = machine.identity_keys().ed25519().to_owned(); let ed25519_key = machine.identity_keys().ed25519().to_owned();
machine machine.receive_keys_upload_response(&keys_upload_response()).await.unwrap();
.receive_keys_upload_response(&keys_upload_response())
.await
.unwrap();
drop(machine); drop(machine);
@ -1937,11 +1748,7 @@ pub(crate) mod test {
async fn interactive_verification() { async fn interactive_verification() {
let (alice, bob) = get_machine_pair_with_setup_sessions().await; let (alice, bob) = get_machine_pair_with_setup_sessions().await;
let bob_device = alice let bob_device = alice.get_device(bob.user_id(), bob.device_id()).await.unwrap().unwrap();
.get_device(bob.user_id(), bob.device_id())
.await
.unwrap()
.unwrap();
assert!(!bob_device.is_trusted()); assert!(!bob_device.is_trusted());
@ -1955,10 +1762,7 @@ pub(crate) mod test {
assert!(alice_sas.emoji().is_none()); assert!(alice_sas.emoji().is_none());
assert!(bob_sas.emoji().is_none()); assert!(bob_sas.emoji().is_none());
let event = bob_sas let event = bob_sas.accept().map(|r| request_to_event(bob.user_id(), &r)).unwrap();
.accept()
.map(|r| request_to_event(bob.user_id(), &r))
.unwrap();
alice.handle_verification_event(&event).await; alice.handle_verification_event(&event).await;
@ -2007,11 +1811,8 @@ pub(crate) mod test {
assert!(alice_sas.is_done()); assert!(alice_sas.is_done());
assert!(bob_device.is_trusted()); assert!(bob_device.is_trusted());
let alice_device = bob let alice_device =
.get_device(alice.user_id(), alice.device_id()) bob.get_device(alice.user_id(), alice.device_id()).await.unwrap().unwrap();
.await
.unwrap()
.unwrap();
assert!(!alice_device.is_trusted()); assert!(!alice_device.is_trusted());
bob.handle_verification_event(&event).await; bob.handle_verification_event(&event).await;

View File

@ -12,10 +12,6 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use matrix_sdk_common::events::ToDeviceEvent;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use sha2::{Digest, Sha256};
use std::{ use std::{
collections::BTreeMap, collections::BTreeMap,
convert::{TryFrom, TryInto}, convert::{TryFrom, TryInto},
@ -26,7 +22,6 @@ use std::{
Arc, Arc,
}, },
}; };
use tracing::{debug, trace, warn};
#[cfg(test)] #[cfg(test)]
use matrix_sdk_common::events::EventType; use matrix_sdk_common::events::EventType;
@ -37,7 +32,7 @@ use matrix_sdk_common::{
encryption::DeviceKeys, encryption::DeviceKeys,
events::{ events::{
room::encrypted::{EncryptedEventContent, EncryptedEventScheme}, room::encrypted::{EncryptedEventContent, EncryptedEventScheme},
AnyToDeviceEvent, AnyToDeviceEvent, ToDeviceEvent,
}, },
identifiers::{ identifiers::{
DeviceId, DeviceIdBox, DeviceKeyAlgorithm, DeviceKeyId, EventEncryptionAlgorithm, RoomId, DeviceId, DeviceIdBox, DeviceKeyAlgorithm, DeviceKeyId, EventEncryptionAlgorithm, RoomId,
@ -53,7 +48,15 @@ use olm_rs::{
session::{OlmMessage, PreKeyMessage}, session::{OlmMessage, PreKeyMessage},
PicklingMode, PicklingMode,
}; };
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use sha2::{Digest, Sha256};
use tracing::{debug, trace, warn};
use super::{
EncryptionSettings, InboundGroupSession, OutboundGroupSession, PrivateCrossSigningIdentity,
Session,
};
use crate::{ use crate::{
error::{EventError, OlmResult, SessionCreationError}, error::{EventError, OlmResult, SessionCreationError},
identities::ReadOnlyDevice, identities::ReadOnlyDevice,
@ -63,11 +66,6 @@ use crate::{
OlmError, OlmError,
}; };
use super::{
EncryptionSettings, InboundGroupSession, OutboundGroupSession, PrivateCrossSigningIdentity,
Session,
};
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Account { pub struct Account {
pub(crate) inner: ReadOnlyAccount, pub(crate) inner: ReadOnlyAccount,
@ -141,10 +139,8 @@ impl Account {
// Try to find a ciphertext that was meant for our device. // Try to find a ciphertext that was meant for our device.
if let Some(ciphertext) = own_ciphertext { if let Some(ciphertext) = own_ciphertext {
let message_type: u8 = ciphertext let message_type: u8 =
.message_type ciphertext.message_type.try_into().map_err(|_| EventError::UnsupportedOlmType)?;
.try_into()
.map_err(|_| EventError::UnsupportedOlmType)?;
let sha = Sha256::new() let sha = Sha256::new()
.chain(&content.sender_key) .chain(&content.sender_key)
@ -162,20 +158,18 @@ impl Account {
.map_err(|_| EventError::UnsupportedOlmType)?; .map_err(|_| EventError::UnsupportedOlmType)?;
// Decrypt the OlmMessage and get a Ruma event out of it. // Decrypt the OlmMessage and get a Ruma event out of it.
let (session, event, signing_key) = match self let (session, event, signing_key) =
.decrypt_olm_message(&event.sender, &content.sender_key, message) match self.decrypt_olm_message(&event.sender, &content.sender_key, message).await {
.await Ok(d) => d,
{ Err(OlmError::SessionWedged(user_id, sender_key)) => {
Ok(d) => d, if self.store.is_message_known(&message_hash).await? {
Err(OlmError::SessionWedged(user_id, sender_key)) => { return Err(OlmError::ReplayedMessage(user_id, sender_key));
if self.store.is_message_known(&message_hash).await? { } else {
return Err(OlmError::ReplayedMessage(user_id, sender_key)); return Err(OlmError::SessionWedged(user_id, sender_key));
} else { }
return Err(OlmError::SessionWedged(user_id, sender_key));
} }
} Err(e) => return Err(e),
Err(e) => return Err(e), };
};
debug!("Decrypted a to-device event {:?}", event); debug!("Decrypted a to-device event {:?}", event);
@ -210,9 +204,8 @@ impl Account {
} }
self.inner.mark_as_shared(); self.inner.mark_as_shared();
let one_time_key_count = response let one_time_key_count =
.one_time_key_counts response.one_time_key_counts.get(&DeviceKeyAlgorithm::SignedCurve25519);
.get(&DeviceKeyAlgorithm::SignedCurve25519);
let count: u64 = one_time_key_count.map_or(0, |c| (*c).into()); let count: u64 = one_time_key_count.map_or(0, |c| (*c).into());
debug!( debug!(
@ -297,9 +290,8 @@ impl Account {
message: OlmMessage, message: OlmMessage,
) -> OlmResult<(SessionType, Raw<AnyToDeviceEvent>, String)> { ) -> OlmResult<(SessionType, Raw<AnyToDeviceEvent>, String)> {
// First try to decrypt using an existing session. // First try to decrypt using an existing session.
let (session, plaintext) = if let Some(d) = self let (session, plaintext) = if let Some(d) =
.try_decrypt_olm_message(sender, sender_key, &message) self.try_decrypt_olm_message(sender, sender_key, &message).await?
.await?
{ {
// Decryption succeeded, de-structure the session/plaintext out of // Decryption succeeded, de-structure the session/plaintext out of
// the Option. // the Option.
@ -316,32 +308,26 @@ impl Account {
available sessions {} {}", available sessions {} {}",
sender, sender_key sender, sender_key
); );
return Err(OlmError::SessionWedged( return Err(OlmError::SessionWedged(sender.to_owned(), sender_key.to_owned()));
sender.to_owned(),
sender_key.to_owned(),
));
} }
OlmMessage::PreKey(m) => { OlmMessage::PreKey(m) => {
// Create the new session. // Create the new session.
let session = match self let session =
.inner match self.inner.create_inbound_session(sender_key, m.clone()).await {
.create_inbound_session(sender_key, m.clone()) Ok(s) => s,
.await Err(e) => {
{ warn!(
Ok(s) => s, "Failed to create a new Olm session for {} {}
Err(e) => {
warn!(
"Failed to create a new Olm session for {} {}
from a prekey message: {}", from a prekey message: {}",
sender, sender_key, e sender, sender_key, e
); );
return Err(OlmError::SessionWedged( return Err(OlmError::SessionWedged(
sender.to_owned(), sender.to_owned(),
sender_key.to_owned(), sender_key.to_owned(),
)); ));
} }
}; };
session session
} }
@ -428,9 +414,8 @@ impl Account {
return Err(EventError::MissmatchedKeys.into()); return Err(EventError::MissmatchedKeys.into());
} }
let signing_key = keys let signing_key =
.get(&DeviceKeyAlgorithm::Ed25519) keys.get(&DeviceKeyAlgorithm::Ed25519).ok_or(EventError::MissingSigningKey)?;
.ok_or(EventError::MissingSigningKey)?;
Ok(( Ok((
Raw::from(serde_json::from_value::<AnyToDeviceEvent>(decrypted_json)?), Raw::from(serde_json::from_value::<AnyToDeviceEvent>(decrypted_json)?),
@ -547,8 +532,7 @@ impl ReadOnlyAccount {
/// * `new_count` - The new count that was reported by the server. /// * `new_count` - The new count that was reported by the server.
pub(crate) fn update_uploaded_key_count(&self, new_count: u64) { pub(crate) fn update_uploaded_key_count(&self, new_count: u64) {
let key_count = i64::try_from(new_count).unwrap_or(i64::MAX); let key_count = i64::try_from(new_count).unwrap_or(i64::MAX);
self.uploaded_signed_key_count self.uploaded_signed_key_count.store(key_count, Ordering::Relaxed);
.store(key_count, Ordering::Relaxed);
} }
/// Get the currently known uploaded key count. /// Get the currently known uploaded key count.
@ -631,19 +615,12 @@ impl ReadOnlyAccount {
/// Returns None if no keys need to be uploaded. /// Returns None if no keys need to be uploaded.
pub(crate) async fn keys_for_upload( pub(crate) async fn keys_for_upload(
&self, &self,
) -> Option<( ) -> Option<(Option<DeviceKeys>, Option<BTreeMap<DeviceKeyId, OneTimeKey>>)> {
Option<DeviceKeys>,
Option<BTreeMap<DeviceKeyId, OneTimeKey>>,
)> {
if !self.should_upload_keys().await { if !self.should_upload_keys().await {
return None; return None;
} }
let device_keys = if !self.shared() { let device_keys = if !self.shared() { Some(self.device_keys().await) } else { None };
Some(self.device_keys().await)
} else {
None
};
let one_time_keys = self.signed_one_time_keys().await.ok(); let one_time_keys = self.signed_one_time_keys().await.ok();
@ -666,7 +643,8 @@ impl ReadOnlyAccount {
/// ///
/// # Arguments /// # Arguments
/// ///
/// * `pickle_mode` - The mode that was used to pickle the account, either an /// * `pickle_mode` - The mode that was used to pickle the account, either
/// an
/// unencrypted mode or an encrypted using passphrase. /// unencrypted mode or an encrypted using passphrase.
pub async fn pickle(&self, pickle_mode: PicklingMode) -> PickledAccount { pub async fn pickle(&self, pickle_mode: PicklingMode) -> PickledAccount {
let pickle = AccountPickle(self.inner.lock().await.pickle(pickle_mode)); let pickle = AccountPickle(self.inner.lock().await.pickle(pickle_mode));
@ -686,7 +664,8 @@ impl ReadOnlyAccount {
/// ///
/// * `pickle` - The pickled version of the Account. /// * `pickle` - The pickled version of the Account.
/// ///
/// * `pickle_mode` - The mode that was used to pickle the account, either an /// * `pickle_mode` - The mode that was used to pickle the account, either
/// an
/// unencrypted mode or an encrypted using passphrase. /// unencrypted mode or an encrypted using passphrase.
pub fn from_pickle( pub fn from_pickle(
pickle: PickledAccount, pickle: PickledAccount,
@ -742,25 +721,17 @@ impl ReadOnlyAccount {
"keys": device_keys.keys, "keys": device_keys.keys,
}); });
device_keys device_keys.signatures.entry(self.user_id().clone()).or_insert_with(BTreeMap::new).insert(
.signatures DeviceKeyId::from_parts(DeviceKeyAlgorithm::Ed25519, &self.device_id),
.entry(self.user_id().clone()) self.sign_json(json_device_keys).await,
.or_insert_with(BTreeMap::new) );
.insert(
DeviceKeyId::from_parts(DeviceKeyAlgorithm::Ed25519, &self.device_id),
self.sign_json(json_device_keys).await,
);
device_keys device_keys
} }
pub(crate) async fn bootstrap_cross_signing( pub(crate) async fn bootstrap_cross_signing(
&self, &self,
) -> ( ) -> (PrivateCrossSigningIdentity, UploadSigningKeysRequest, SignatureUploadRequest) {
PrivateCrossSigningIdentity,
UploadSigningKeysRequest,
SignatureUploadRequest,
) {
PrivateCrossSigningIdentity::new_with_account(self).await PrivateCrossSigningIdentity::new_with_account(self).await
} }
@ -873,8 +844,8 @@ impl ReadOnlyAccount {
/// # Arguments /// # Arguments
/// * `device` - The other account's device. /// * `device` - The other account's device.
/// ///
/// * `key_map` - A map from the algorithm and device id to the one-time /// * `key_map` - A map from the algorithm and device id to the one-time key
/// key that the other account created and shared with us. /// that the other account created and shared with us.
pub(crate) async fn create_outbound_session( pub(crate) async fn create_outbound_session(
&self, &self,
device: ReadOnlyDevice, device: ReadOnlyDevice,
@ -911,24 +882,20 @@ impl ReadOnlyAccount {
) )
})?; })?;
let curve_key = device let curve_key = device.get_key(DeviceKeyAlgorithm::Curve25519).ok_or_else(|| {
.get_key(DeviceKeyAlgorithm::Curve25519) SessionCreationError::DeviceMissingCurveKey(
.ok_or_else(|| { device.user_id().to_owned(),
SessionCreationError::DeviceMissingCurveKey( device.device_id().into(),
device.user_id().to_owned(), )
device.device_id().into(), })?;
)
})?;
self.create_outbound_session_helper(curve_key, &one_time_key) self.create_outbound_session_helper(curve_key, &one_time_key).await.map_err(|e| {
.await SessionCreationError::OlmError(
.map_err(|e| { device.user_id().to_owned(),
SessionCreationError::OlmError( device.device_id().into(),
device.user_id().to_owned(), e,
device.device_id().into(), )
e, })
)
})
} }
/// Create a new session with another account given a pre-key Olm message. /// Create a new session with another account given a pre-key Olm message.
@ -946,17 +913,10 @@ impl ReadOnlyAccount {
their_identity_key: &str, their_identity_key: &str,
message: PreKeyMessage, message: PreKeyMessage,
) -> Result<Session, OlmSessionError> { ) -> Result<Session, OlmSessionError> {
let session = self let session =
.inner self.inner.lock().await.create_inbound_session_from(their_identity_key, message)?;
.lock()
.await
.create_inbound_session_from(their_identity_key, message)?;
self.inner self.inner.lock().await.remove_one_time_keys(&session).expect(
.lock()
.await
.remove_one_time_keys(&session)
.expect(
"Session was successfully created but the account doesn't hold a matching one-time key", "Session was successfully created but the account doesn't hold a matching one-time key",
); );
@ -1028,8 +988,7 @@ impl ReadOnlyAccount {
&self, &self,
room_id: &RoomId, room_id: &RoomId,
) -> Result<(OutboundGroupSession, InboundGroupSession), ()> { ) -> Result<(OutboundGroupSession, InboundGroupSession), ()> {
self.create_group_session_pair(room_id, EncryptionSettings::default()) self.create_group_session_pair(room_id, EncryptionSettings::default()).await
.await
} }
#[cfg(test)] #[cfg(test)]
@ -1039,27 +998,19 @@ impl ReadOnlyAccount {
let device = ReadOnlyDevice::from_account(other).await; let device = ReadOnlyDevice::from_account(other).await;
let mut our_session = self let mut our_session =
.create_outbound_session(device.clone(), &one_time) self.create_outbound_session(device.clone(), &one_time).await.unwrap();
.await
.unwrap();
other.mark_keys_as_published().await; other.mark_keys_as_published().await;
let message = our_session let message = our_session.encrypt(&device, EventType::Dummy, json!({})).await.unwrap();
.encrypt(&device, EventType::Dummy, json!({}))
.await
.unwrap();
let content = if let EncryptedEventScheme::OlmV1Curve25519AesSha2(c) = message.scheme { let content = if let EncryptedEventScheme::OlmV1Curve25519AesSha2(c) = message.scheme {
c c
} else { } else {
panic!("Invalid encrypted event algorithm"); panic!("Invalid encrypted event algorithm");
}; };
let own_ciphertext = content let own_ciphertext = content.ciphertext.get(other.identity_keys.curve25519()).unwrap();
.ciphertext
.get(other.identity_keys.curve25519())
.unwrap();
let message_type: u8 = own_ciphertext.message_type.try_into().unwrap(); let message_type: u8 = own_ciphertext.message_type.try_into().unwrap();
let message = let message =

View File

@ -19,19 +19,6 @@ use std::{
sync::Arc, sync::Arc,
}; };
use olm_rs::{
errors::OlmGroupSessionError, inbound_group_session::OlmInboundGroupSession, PicklingMode,
};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use zeroize::Zeroizing;
pub use olm_rs::{
account::IdentityKeys,
session::{OlmMessage, PreKeyMessage},
utility::OlmUtility,
};
use matrix_sdk_common::{ use matrix_sdk_common::{
events::{ events::{
forwarded_room_key::ForwardedRoomKeyToDeviceEventContent, forwarded_room_key::ForwardedRoomKeyToDeviceEventContent,
@ -45,6 +32,17 @@ use matrix_sdk_common::{
locks::Mutex, locks::Mutex,
Raw, Raw,
}; };
pub use olm_rs::{
account::IdentityKeys,
session::{OlmMessage, PreKeyMessage},
utility::OlmUtility,
};
use olm_rs::{
errors::OlmGroupSessionError, inbound_group_session::OlmInboundGroupSession, PicklingMode,
};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use zeroize::Zeroizing;
use super::{ExportedGroupSessionKey, ExportedRoomKey, GroupSessionKey}; use super::{ExportedGroupSessionKey, ExportedRoomKey, GroupSessionKey};
use crate::error::{EventError, MegolmResult}; use crate::error::{EventError, MegolmResult};
@ -149,10 +147,8 @@ impl InboundGroupSession {
forwarding_chains.push(sender_key.to_owned()); forwarding_chains.push(sender_key.to_owned());
let mut sender_claimed_key = BTreeMap::new(); let mut sender_claimed_key = BTreeMap::new();
sender_claimed_key.insert( sender_claimed_key
DeviceKeyAlgorithm::Ed25519, .insert(DeviceKeyAlgorithm::Ed25519, content.sender_claimed_ed25519_key.to_owned());
content.sender_claimed_ed25519_key.to_owned(),
);
Ok(InboundGroupSession { Ok(InboundGroupSession {
inner: Mutex::new(session).into(), inner: Mutex::new(session).into(),
@ -219,11 +215,7 @@ impl InboundGroupSession {
let message_index = std::cmp::max(self.first_known_index(), message_index); let message_index = std::cmp::max(self.first_known_index(), message_index);
let session_key = ExportedGroupSessionKey( let session_key = ExportedGroupSessionKey(
self.inner self.inner.lock().await.export(message_index).expect("Can't export session"),
.lock()
.await
.export(message_index)
.expect("Can't export session"),
); );
ExportedRoomKey { ExportedRoomKey {
@ -316,9 +308,7 @@ impl InboundGroupSession {
let (plaintext, message_index) = self.decrypt_helper(content.ciphertext.clone()).await?; let (plaintext, message_index) = self.decrypt_helper(content.ciphertext.clone()).await?;
let mut decrypted_value = serde_json::from_str::<Value>(&plaintext)?; let mut decrypted_value = serde_json::from_str::<Value>(&plaintext)?;
let decrypted_object = decrypted_value let decrypted_object = decrypted_value.as_object_mut().ok_or(EventError::NotAnObject)?;
.as_object_mut()
.ok_or(EventError::NotAnObject)?;
// TODO better number conversion here. // TODO better number conversion here.
let server_ts = event let server_ts = event
@ -337,10 +327,8 @@ impl InboundGroupSession {
serde_json::to_value(&event.unsigned).unwrap_or_default(), serde_json::to_value(&event.unsigned).unwrap_or_default(),
); );
if let Some(decrypted_content) = decrypted_object if let Some(decrypted_content) =
.get_mut("content") decrypted_object.get_mut("content").map(|c| c.as_object_mut()).flatten()
.map(|c| c.as_object_mut())
.flatten()
{ {
if !decrypted_content.contains_key("m.relates_to") { if !decrypted_content.contains_key("m.relates_to") {
if let Some(relation) = &event.content.relates_to { if let Some(relation) = &event.content.relates_to {
@ -352,19 +340,14 @@ impl InboundGroupSession {
} }
} }
Ok(( Ok((serde_json::from_value::<Raw<AnySyncRoomEvent>>(decrypted_value)?, message_index))
serde_json::from_value::<Raw<AnySyncRoomEvent>>(decrypted_value)?,
message_index,
))
} }
} }
#[cfg(not(tarpaulin_include))] #[cfg(not(tarpaulin_include))]
impl fmt::Debug for InboundGroupSession { impl fmt::Debug for InboundGroupSession {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("InboundGroupSession") f.debug_struct("InboundGroupSession").field("session_id", &self.session_id()).finish()
.field("session_id", &self.session_id())
.finish()
} }
} }
@ -399,7 +382,8 @@ pub struct PickledInboundGroupSession {
pub history_visibility: Option<HistoryVisibility>, pub history_visibility: Option<HistoryVisibility>,
} }
/// The typed representation of a base64 encoded string of the GroupSession pickle. /// The typed representation of a base64 encoded string of the GroupSession
/// pickle.
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InboundGroupSessionPickle(String); pub struct InboundGroupSessionPickle(String);

View File

@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::{collections::BTreeMap, convert::TryInto};
use matrix_sdk_common::{ use matrix_sdk_common::{
events::forwarded_room_key::{ events::forwarded_room_key::{
ForwardedRoomKeyToDeviceEventContent, ForwardedRoomKeyToDeviceEventContentInit, ForwardedRoomKeyToDeviceEventContent, ForwardedRoomKeyToDeviceEventContentInit,
@ -19,7 +21,6 @@ use matrix_sdk_common::{
identifiers::{DeviceKeyAlgorithm, EventEncryptionAlgorithm, RoomId}, identifiers::{DeviceKeyAlgorithm, EventEncryptionAlgorithm, RoomId},
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::{collections::BTreeMap, convert::TryInto};
use zeroize::Zeroize; use zeroize::Zeroize;
mod inbound; mod inbound;
@ -107,10 +108,8 @@ impl From<ForwardedRoomKeyToDeviceEventContent> for ExportedRoomKey {
/// Convert the content of a forwarded room key into a exported room key. /// Convert the content of a forwarded room key into a exported room key.
fn from(forwarded_key: ForwardedRoomKeyToDeviceEventContent) -> Self { fn from(forwarded_key: ForwardedRoomKeyToDeviceEventContent) -> Self {
let mut sender_claimed_keys: BTreeMap<DeviceKeyAlgorithm, String> = BTreeMap::new(); let mut sender_claimed_keys: BTreeMap<DeviceKeyAlgorithm, String> = BTreeMap::new();
sender_claimed_keys.insert( sender_claimed_keys
DeviceKeyAlgorithm::Ed25519, .insert(DeviceKeyAlgorithm::Ed25519, forwarded_key.sender_claimed_ed25519_key);
forwarded_key.sender_claimed_ed25519_key,
);
Self { Self {
algorithm: forwarded_key.algorithm, algorithm: forwarded_key.algorithm,
@ -142,10 +141,7 @@ mod test {
#[tokio::test] #[tokio::test]
#[cfg(target_os = "linux")] #[cfg(target_os = "linux")]
async fn expiration() { async fn expiration() {
let settings = EncryptionSettings { let settings = EncryptionSettings { rotation_period_msgs: 1, ..Default::default() };
rotation_period_msgs: 1,
..Default::default()
};
let account = ReadOnlyAccount::new(&user_id!("@alice:example.org"), "DEVICEID".into()); let account = ReadOnlyAccount::new(&user_id!("@alice:example.org"), "DEVICEID".into());
let (session, _) = account let (session, _) = account
@ -155,9 +151,9 @@ mod test {
assert!(!session.expired()); assert!(!session.expired());
let _ = session let _ = session
.encrypt(AnyMessageEventContent::RoomMessage( .encrypt(AnyMessageEventContent::RoomMessage(MessageEventContent::text_plain(
MessageEventContent::text_plain("Test message"), "Test message",
)) )))
.await; .await;
assert!(session.expired()); assert!(session.expired());

View File

@ -12,15 +12,6 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use dashmap::DashMap;
use matrix_sdk_common::{
api::r0::to_device::DeviceIdOrAllDevices,
events::room::{
encrypted::MegolmV1AesSha2ContentInit, history_visibility::HistoryVisibility,
message::Relation,
},
uuid::Uuid,
};
use std::{ use std::{
cmp::max, cmp::max,
collections::BTreeMap, collections::BTreeMap,
@ -31,23 +22,24 @@ use std::{
}, },
time::Duration, time::Duration,
}; };
use tracing::{debug, error, trace};
use dashmap::DashMap;
use matrix_sdk_common::{ use matrix_sdk_common::{
api::r0::to_device::DeviceIdOrAllDevices,
events::{ events::{
room::{ room::{
encrypted::{EncryptedEventContent, EncryptedEventScheme}, encrypted::{EncryptedEventContent, EncryptedEventScheme, MegolmV1AesSha2ContentInit},
encryption::EncryptionEventContent, encryption::EncryptionEventContent,
history_visibility::HistoryVisibility,
message::Relation,
}, },
AnyMessageEventContent, EventContent, AnyMessageEventContent, EventContent,
}, },
identifiers::{DeviceId, DeviceIdBox, EventEncryptionAlgorithm, RoomId, UserId}, identifiers::{DeviceId, DeviceIdBox, EventEncryptionAlgorithm, RoomId, UserId},
instant::Instant, instant::Instant,
locks::Mutex, locks::Mutex,
uuid::Uuid,
}; };
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
pub use olm_rs::{ pub use olm_rs::{
account::IdentityKeys, account::IdentityKeys,
session::{OlmMessage, PreKeyMessage}, session::{OlmMessage, PreKeyMessage},
@ -56,13 +48,15 @@ pub use olm_rs::{
use olm_rs::{ use olm_rs::{
errors::OlmGroupSessionError, outbound_group_session::OlmOutboundGroupSession, PicklingMode, errors::OlmGroupSessionError, outbound_group_session::OlmOutboundGroupSession, PicklingMode,
}; };
use serde::{Deserialize, Serialize};
use crate::ToDeviceRequest; use serde_json::{json, Value};
use tracing::{debug, error, trace};
use super::{ use super::{
super::{deserialize_instant, serialize_instant}, super::{deserialize_instant, serialize_instant},
GroupSessionKey, GroupSessionKey,
}; };
use crate::ToDeviceRequest;
const ROTATION_PERIOD: Duration = Duration::from_millis(604800000); const ROTATION_PERIOD: Duration = Duration::from_millis(604800000);
const ROTATION_MESSAGES: u64 = 100; const ROTATION_MESSAGES: u64 = 100;
@ -102,12 +96,10 @@ impl EncryptionSettings {
/// Create new encryption settings using an `EncryptionEventContent` and a /// Create new encryption settings using an `EncryptionEventContent` and a
/// history visibility. /// history visibility.
pub fn new(content: EncryptionEventContent, history_visibility: HistoryVisibility) -> Self { pub fn new(content: EncryptionEventContent, history_visibility: HistoryVisibility) -> Self {
let rotation_period: Duration = content let rotation_period: Duration =
.rotation_period_ms content.rotation_period_ms.map_or(ROTATION_PERIOD, |r| Duration::from_millis(r.into()));
.map_or(ROTATION_PERIOD, |r| Duration::from_millis(r.into())); let rotation_period_msgs: u64 =
let rotation_period_msgs: u64 = content content.rotation_period_msgs.map_or(ROTATION_MESSAGES, Into::into);
.rotation_period_msgs
.map_or(ROTATION_MESSAGES, Into::into);
Self { Self {
algorithm: content.algorithm, algorithm: content.algorithm,
@ -186,8 +178,7 @@ impl OutboundGroupSession {
request: Arc<ToDeviceRequest>, request: Arc<ToDeviceRequest>,
message_index: u32, message_index: u32,
) { ) {
self.to_share_with_set self.to_share_with_set.insert(request_id, (request, message_index));
.insert(request_id, (request, message_index));
} }
/// This should be called if an the user wishes to rotate this session. /// This should be called if an the user wishes to rotate this session.
@ -225,10 +216,7 @@ impl OutboundGroupSession {
}); });
user_pairs.for_each(|(u, d)| { user_pairs.for_each(|(u, d)| {
self.shared_with_set self.shared_with_set.entry(u).or_insert_with(DashMap::new).extend(d);
.entry(u)
.or_insert_with(DashMap::new)
.extend(d);
}); });
if self.to_share_with_set.is_empty() { if self.to_share_with_set.is_empty() {
@ -241,11 +229,8 @@ impl OutboundGroupSession {
self.mark_as_shared(); self.mark_as_shared();
} }
} else { } else {
let request_ids: Vec<String> = self let request_ids: Vec<String> =
.to_share_with_set self.to_share_with_set.iter().map(|e| e.key().to_string()).collect();
.iter()
.map(|e| e.key().to_string())
.collect();
error!( error!(
all_request_ids = ?request_ids, all_request_ids = ?request_ids,
@ -296,11 +281,7 @@ impl OutboundGroupSession {
let relates_to: Option<Relation> = json_content let relates_to: Option<Relation> = json_content
.get("content") .get("content")
.map(|c| { .map(|c| c.get("m.relates_to").cloned().map(|r| serde_json::from_value(r).ok()))
c.get("m.relates_to")
.cloned()
.map(|r| serde_json::from_value(r).ok())
})
.flatten() .flatten()
.flatten(); .flatten();
@ -443,10 +424,7 @@ impl OutboundGroupSession {
/// Get the list of requests that need to be sent out for this session to be /// Get the list of requests that need to be sent out for this session to be
/// marked as shared. /// marked as shared.
pub(crate) fn pending_requests(&self) -> Vec<Arc<ToDeviceRequest>> { pub(crate) fn pending_requests(&self) -> Vec<Arc<ToDeviceRequest>> {
self.to_share_with_set self.to_share_with_set.iter().map(|i| i.value().0.clone()).collect()
.iter()
.map(|i| i.value().0.clone())
.collect()
} }
/// Get the list of request ids this session is waiting for to be sent out. /// Get the list of request ids this session is waiting for to be sent out.
@ -462,10 +440,10 @@ impl OutboundGroupSession {
/// # Arguments /// # Arguments
/// ///
/// * `device_id` - The device id of the device that created this session. /// * `device_id` - The device id of the device that created this session.
/// Put differently, our own device id. /// Put differently, our own device id.
/// ///
/// * `identity_keys` - The identity keys of the device that created this /// * `identity_keys` - The identity keys of the device that created this
/// session, our own identity keys. /// session, our own identity keys.
/// ///
/// * `pickle` - The pickled version of the `OutboundGroupSession`. /// * `pickle` - The pickled version of the `OutboundGroupSession`.
/// ///
@ -507,7 +485,8 @@ impl OutboundGroupSession {
/// ///
/// # Arguments /// # Arguments
/// ///
/// * `pickle_mode` - The mode that should be used to pickle the group session, /// * `pickle_mode` - The mode that should be used to pickle the group
/// session,
/// either an unencrypted mode or an encrypted using passphrase. /// either an unencrypted mode or an encrypted using passphrase.
pub async fn pickle(&self, pickling_mode: PicklingMode) -> PickledOutboundGroupSession { pub async fn pickle(&self, pickling_mode: PicklingMode) -> PickledOutboundGroupSession {
let pickle: OutboundGroupSessionPickle = let pickle: OutboundGroupSessionPickle =
@ -528,10 +507,7 @@ impl OutboundGroupSession {
( (
u.key().clone(), u.key().clone(),
#[allow(clippy::map_clone)] #[allow(clippy::map_clone)]
u.value() u.value().iter().map(|d| (d.key().clone(), *d.value())).collect(),
.iter()
.map(|d| (d.key().clone(), *d.value()))
.collect(),
) )
}) })
.collect(), .collect(),
@ -578,10 +554,7 @@ pub struct PickledOutboundGroupSession {
/// The room id this session is used for. /// The room id this session is used for.
pub room_id: Arc<RoomId>, pub room_id: Arc<RoomId>,
/// The timestamp when this session was created. /// The timestamp when this session was created.
#[serde( #[serde(deserialize_with = "deserialize_instant", serialize_with = "serialize_instant")]
deserialize_with = "deserialize_instant",
serialize_with = "serialize_instant"
)]
pub creation_time: Instant, pub creation_time: Instant,
/// The number of messages this session has already encrypted. /// The number of messages this session has already encrypted.
pub message_count: u64, pub message_count: u64,

View File

@ -30,14 +30,13 @@ pub use group_sessions::{
OutboundGroupSession, PickledInboundGroupSession, PickledOutboundGroupSession, OutboundGroupSession, PickledInboundGroupSession, PickledOutboundGroupSession,
}; };
pub(crate) use group_sessions::{GroupSessionKey, ShareState}; pub(crate) use group_sessions::{GroupSessionKey, ShareState};
use matrix_sdk_common::instant::{Duration, Instant};
pub use olm_rs::{account::IdentityKeys, PicklingMode}; pub use olm_rs::{account::IdentityKeys, PicklingMode};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
pub use session::{PickledSession, Session, SessionPickle}; pub use session::{PickledSession, Session, SessionPickle};
pub use signing::{PickledCrossSigningIdentity, PrivateCrossSigningIdentity}; pub use signing::{PickledCrossSigningIdentity, PrivateCrossSigningIdentity};
pub(crate) use utility::Utility; pub(crate) use utility::Utility;
use matrix_sdk_common::instant::{Duration, Instant};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
pub(crate) fn serialize_instant<S>(instant: &Instant, serializer: S) -> Result<S::Ok, S::Error> pub(crate) fn serialize_instant<S>(instant: &Instant, serializer: S) -> Result<S::Ok, S::Error>
where where
S: Serializer, S: Serializer,
@ -60,14 +59,16 @@ where
#[cfg(test)] #[cfg(test)]
pub(crate) mod test { pub(crate) mod test {
use crate::olm::{InboundGroupSession, ReadOnlyAccount, Session}; use std::{collections::BTreeMap, convert::TryInto};
use matrix_sdk_common::{ use matrix_sdk_common::{
api::r0::keys::SignedKey, api::r0::keys::SignedKey,
events::forwarded_room_key::ForwardedRoomKeyToDeviceEventContent, events::forwarded_room_key::ForwardedRoomKeyToDeviceEventContent,
identifiers::{room_id, user_id, DeviceId, UserId}, identifiers::{room_id, user_id, DeviceId, UserId},
}; };
use olm_rs::session::OlmMessage; use olm_rs::session::OlmMessage;
use std::{collections::BTreeMap, convert::TryInto};
use crate::olm::{InboundGroupSession, ReadOnlyAccount, Session};
fn alice_id() -> UserId { fn alice_id() -> UserId {
user_id!("@alice:example.org") user_id!("@alice:example.org")
@ -90,21 +91,12 @@ pub(crate) mod test {
let bob = ReadOnlyAccount::new(&bob_id(), &bob_device_id()); let bob = ReadOnlyAccount::new(&bob_id(), &bob_device_id());
bob.generate_one_time_keys_helper(1).await; bob.generate_one_time_keys_helper(1).await;
let one_time_key = bob let one_time_key =
.one_time_keys() bob.one_time_keys().await.curve25519().iter().next().unwrap().1.to_owned();
.await
.curve25519()
.iter()
.next()
.unwrap()
.1
.to_owned();
let one_time_key = SignedKey::new(one_time_key, BTreeMap::new()); let one_time_key = SignedKey::new(one_time_key, BTreeMap::new());
let sender_key = bob.identity_keys().curve25519().to_owned(); let sender_key = bob.identity_keys().curve25519().to_owned();
let session = alice let session =
.create_outbound_session_helper(&sender_key, &one_time_key) alice.create_outbound_session_helper(&sender_key, &one_time_key).await.unwrap();
.await
.unwrap();
(alice, session) (alice, session)
} }
@ -120,10 +112,7 @@ pub(crate) mod test {
assert_ne!(identity_keys.keys().len(), 0); assert_ne!(identity_keys.keys().len(), 0);
assert_ne!(identity_keys.iter().len(), 0); assert_ne!(identity_keys.iter().len(), 0);
assert!(identity_keys.contains_key("ed25519")); assert!(identity_keys.contains_key("ed25519"));
assert_eq!( assert_eq!(identity_keys.ed25519(), identity_keys.get("ed25519").unwrap());
identity_keys.ed25519(),
identity_keys.get("ed25519").unwrap()
);
assert!(!identity_keys.curve25519().is_empty()); assert!(!identity_keys.curve25519().is_empty());
account.mark_as_shared(); account.mark_as_shared();
@ -147,10 +136,7 @@ pub(crate) mod test {
assert_ne!(one_time_keys.iter().len(), 0); assert_ne!(one_time_keys.iter().len(), 0);
assert!(one_time_keys.contains_key("curve25519")); assert!(one_time_keys.contains_key("curve25519"));
assert_eq!(one_time_keys.curve25519().keys().len(), 10); assert_eq!(one_time_keys.curve25519().keys().len(), 10);
assert_eq!( assert_eq!(one_time_keys.curve25519(), one_time_keys.get("curve25519").unwrap());
one_time_keys.curve25519(),
one_time_keys.get("curve25519").unwrap()
);
account.mark_keys_as_published().await; account.mark_keys_as_published().await;
let one_time_keys = account.one_time_keys().await; let one_time_keys = account.one_time_keys().await;
@ -166,13 +152,7 @@ pub(crate) mod test {
let one_time_keys = alice.one_time_keys().await; let one_time_keys = alice.one_time_keys().await;
alice.mark_keys_as_published().await; alice.mark_keys_as_published().await;
let one_time_key = one_time_keys let one_time_key = one_time_keys.curve25519().iter().next().unwrap().1.to_owned();
.curve25519()
.iter()
.next()
.unwrap()
.1
.to_owned();
let one_time_key = SignedKey::new(one_time_key, BTreeMap::new()); let one_time_key = SignedKey::new(one_time_key, BTreeMap::new());
@ -196,10 +176,7 @@ pub(crate) mod test {
.await .await
.unwrap(); .unwrap();
assert!(alice_session assert!(alice_session.matches(bob_keys.curve25519(), prekey_message).await.unwrap());
.matches(bob_keys.curve25519(), prekey_message)
.await
.unwrap());
assert_eq!(bob_session.session_id(), alice_session.session_id()); assert_eq!(bob_session.session_id(), alice_session.session_id());
@ -212,10 +189,7 @@ pub(crate) mod test {
let alice = ReadOnlyAccount::new(&alice_id(), &alice_device_id()); let alice = ReadOnlyAccount::new(&alice_id(), &alice_device_id());
let room_id = room_id!("!test:localhost"); let room_id = room_id!("!test:localhost");
let (outbound, _) = alice let (outbound, _) = alice.create_group_session_pair_with_defaults(&room_id).await.unwrap();
.create_group_session_pair_with_defaults(&room_id)
.await
.unwrap();
assert_eq!(0, outbound.message_index().await); assert_eq!(0, outbound.message_index().await);
assert!(!outbound.shared()); assert!(!outbound.shared());
@ -238,10 +212,7 @@ pub(crate) mod test {
let plaintext = "This is a secret to everybody".to_owned(); let plaintext = "This is a secret to everybody".to_owned();
let ciphertext = outbound.encrypt_helper(plaintext.clone()).await; let ciphertext = outbound.encrypt_helper(plaintext.clone()).await;
assert_eq!( assert_eq!(plaintext, inbound.decrypt_helper(ciphertext).await.unwrap().0);
plaintext,
inbound.decrypt_helper(ciphertext).await.unwrap().0
);
} }
#[tokio::test] #[tokio::test]
@ -249,10 +220,7 @@ pub(crate) mod test {
let alice = ReadOnlyAccount::new(&alice_id(), &alice_device_id()); let alice = ReadOnlyAccount::new(&alice_id(), &alice_device_id());
let room_id = room_id!("!test:localhost"); let room_id = room_id!("!test:localhost");
let (_, inbound) = alice let (_, inbound) = alice.create_group_session_pair_with_defaults(&room_id).await.unwrap();
.create_group_session_pair_with_defaults(&room_id)
.await
.unwrap();
let export = inbound.export().await; let export = inbound.export().await;
let export: ForwardedRoomKeyToDeviceEventContent = export.try_into().unwrap(); let export: ForwardedRoomKeyToDeviceEventContent = export.try_into().unwrap();

View File

@ -27,20 +27,18 @@ use matrix_sdk_common::{
locks::Mutex, locks::Mutex,
}; };
use olm_rs::{errors::OlmSessionError, session::OlmSession, PicklingMode}; use olm_rs::{errors::OlmSessionError, session::OlmSession, PicklingMode};
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use crate::{
error::{EventError, OlmResult, SessionUnpicklingError},
ReadOnlyDevice,
};
pub use olm_rs::{ pub use olm_rs::{
session::{OlmMessage, PreKeyMessage}, session::{OlmMessage, PreKeyMessage},
utility::OlmUtility, utility::OlmUtility,
}; };
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use super::{deserialize_instant, serialize_instant, IdentityKeys}; use super::{deserialize_instant, serialize_instant, IdentityKeys};
use crate::{
error::{EventError, OlmResult, SessionUnpicklingError},
ReadOnlyDevice,
};
/// Cryptographic session that enables secure communication between two /// Cryptographic session that enables secure communication between two
/// `Account`s /// `Account`s
@ -105,8 +103,8 @@ impl Session {
/// # Arguments /// # Arguments
/// ///
/// * `recipient_device` - The device for which this message is going to be /// * `recipient_device` - The device for which this message is going to be
/// encrypted, this needs to be the device that was used to create this /// encrypted, this needs to be the device that was used to create this
/// session with. /// session with.
/// ///
/// * `event_type` - The type of the event. /// * `event_type` - The type of the event.
/// ///
@ -121,10 +119,8 @@ impl Session {
.get_key(DeviceKeyAlgorithm::Ed25519) .get_key(DeviceKeyAlgorithm::Ed25519)
.ok_or(EventError::MissingSigningKey)?; .ok_or(EventError::MissingSigningKey)?;
let relates_to = content let relates_to =
.get("m.relates_to") content.get("m.relates_to").cloned().and_then(|v| serde_json::from_value(v).ok());
.cloned()
.and_then(|v| serde_json::from_value(v).ok());
let payload = json!({ let payload = json!({
"sender": self.user_id.as_str(), "sender": self.user_id.as_str(),
@ -174,10 +170,7 @@ impl Session {
their_identity_key: &str, their_identity_key: &str,
message: PreKeyMessage, message: PreKeyMessage,
) -> Result<bool, OlmSessionError> { ) -> Result<bool, OlmSessionError> {
self.inner self.inner.lock().await.matches_inbound_session_from(their_identity_key, message)
.lock()
.await
.matches_inbound_session_from(their_identity_key, message)
} }
/// Returns the unique identifier for this session. /// Returns the unique identifier for this session.
@ -259,20 +252,15 @@ pub struct PickledSession {
/// The curve25519 key of the other user that we share this session with. /// The curve25519 key of the other user that we share this session with.
pub sender_key: String, pub sender_key: String,
/// The relative time elapsed since the session was created. /// The relative time elapsed since the session was created.
#[serde( #[serde(deserialize_with = "deserialize_instant", serialize_with = "serialize_instant")]
deserialize_with = "deserialize_instant",
serialize_with = "serialize_instant"
)]
pub creation_time: Instant, pub creation_time: Instant,
/// The relative time elapsed since the session was last used. /// The relative time elapsed since the session was last used.
#[serde( #[serde(deserialize_with = "deserialize_instant", serialize_with = "serialize_instant")]
deserialize_with = "deserialize_instant",
serialize_with = "serialize_instant"
)]
pub last_use_time: Instant, pub last_use_time: Instant,
} }
/// The typed representation of a base64 encoded string of the Olm Session pickle. /// The typed representation of a base64 encoded string of the Olm Session
/// pickle.
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionPickle(String); pub struct SessionPickle(String);

View File

@ -14,8 +14,6 @@
mod pk_signing; mod pk_signing;
use serde::{Deserialize, Serialize};
use serde_json::Error as JsonError;
use std::{ use std::{
collections::BTreeMap, collections::BTreeMap,
sync::{ sync::{
@ -30,14 +28,15 @@ use matrix_sdk_common::{
identifiers::{DeviceKeyAlgorithm, DeviceKeyId, UserId}, identifiers::{DeviceKeyAlgorithm, DeviceKeyId, UserId},
locks::Mutex, locks::Mutex,
}; };
use pk_signing::{MasterSigning, PickledSignings, SelfSigning, Signing, SigningError, UserSigning};
use serde::{Deserialize, Serialize};
use serde_json::Error as JsonError;
use crate::{ use crate::{
error::SignatureError, requests::UploadSigningKeysRequest, OwnUserIdentity, ReadOnlyAccount, error::SignatureError, requests::UploadSigningKeysRequest, OwnUserIdentity, ReadOnlyAccount,
ReadOnlyDevice, UserIdentity, ReadOnlyDevice, UserIdentity,
}; };
use pk_signing::{MasterSigning, PickledSignings, SelfSigning, Signing, SigningError, UserSigning};
/// Private cross signing identity. /// Private cross signing identity.
/// ///
/// This object holds the private and public ed25519 key triplet that is used /// This object holds the private and public ed25519 key triplet that is used
@ -186,10 +185,7 @@ impl PrivateCrossSigningIdentity {
signed_keys signed_keys
.entry((&*self.user_id).to_owned()) .entry((&*self.user_id).to_owned())
.or_insert_with(BTreeMap::new) .or_insert_with(BTreeMap::new)
.insert( .insert(device_keys.device_id.to_string(), serde_json::to_value(device_keys)?);
device_keys.device_id.to_string(),
serde_json::to_value(device_keys)?,
);
Ok(SignatureUploadRequest::new(signed_keys)) Ok(SignatureUploadRequest::new(signed_keys))
} }
@ -229,10 +225,7 @@ impl PrivateCrossSigningIdentity {
signature, signature,
); );
let master = MasterSigning { let master = MasterSigning { inner: master, public_key: public_key.into() };
inner: master,
public_key: public_key.into(),
};
let identity = Self::new_helper(account.user_id(), master).await; let identity = Self::new_helper(account.user_id(), master).await;
let signature_request = identity let signature_request = identity
@ -250,20 +243,14 @@ impl PrivateCrossSigningIdentity {
let mut public_key = user.cross_signing_key(user_id.to_owned(), KeyUsage::UserSigning); let mut public_key = user.cross_signing_key(user_id.to_owned(), KeyUsage::UserSigning);
master.sign_subkey(&mut public_key).await; master.sign_subkey(&mut public_key).await;
let user = UserSigning { let user = UserSigning { inner: user, public_key: public_key.into() };
inner: user,
public_key: public_key.into(),
};
let self_signing = Signing::new(); let self_signing = Signing::new();
let mut public_key = let mut public_key =
self_signing.cross_signing_key(user_id.to_owned(), KeyUsage::SelfSigning); self_signing.cross_signing_key(user_id.to_owned(), KeyUsage::SelfSigning);
master.sign_subkey(&mut public_key).await; master.sign_subkey(&mut public_key).await;
let self_signing = SelfSigning { let self_signing = SelfSigning { inner: self_signing, public_key: public_key.into() };
inner: self_signing,
public_key: public_key.into(),
};
Self { Self {
user_id: Arc::new(user_id.to_owned()), user_id: Arc::new(user_id.to_owned()),
@ -281,10 +268,7 @@ impl PrivateCrossSigningIdentity {
let master = Signing::new(); let master = Signing::new();
let public_key = master.cross_signing_key(user_id.clone(), KeyUsage::Master); let public_key = master.cross_signing_key(user_id.clone(), KeyUsage::Master);
let master = MasterSigning { let master = MasterSigning { inner: master, public_key: public_key.into() };
inner: master,
public_key: public_key.into(),
};
Self::new_helper(&user_id, master).await Self::new_helper(&user_id, master).await
} }
@ -334,11 +318,7 @@ impl PrivateCrossSigningIdentity {
None None
}; };
let pickle = PickledSignings { let pickle = PickledSignings { master_key, user_signing_key, self_signing_key };
master_key,
user_signing_key,
self_signing_key,
};
let pickle = serde_json::to_string(&pickle)?; let pickle = serde_json::to_string(&pickle)?;
@ -390,54 +370,35 @@ impl PrivateCrossSigningIdentity {
/// Get the upload request that is needed to share the public keys of this /// Get the upload request that is needed to share the public keys of this
/// identity. /// identity.
pub(crate) async fn as_upload_request(&self) -> UploadSigningKeysRequest { pub(crate) async fn as_upload_request(&self) -> UploadSigningKeysRequest {
let master_key = self let master_key =
.master_key self.master_key.lock().await.as_ref().cloned().map(|k| k.public_key.into());
.lock()
.await
.as_ref()
.cloned()
.map(|k| k.public_key.into());
let user_signing_key = self let user_signing_key =
.user_signing_key self.user_signing_key.lock().await.as_ref().cloned().map(|k| k.public_key.into());
.lock()
.await
.as_ref()
.cloned()
.map(|k| k.public_key.into());
let self_signing_key = self let self_signing_key =
.self_signing_key self.self_signing_key.lock().await.as_ref().cloned().map(|k| k.public_key.into());
.lock()
.await
.as_ref()
.cloned()
.map(|k| k.public_key.into());
UploadSigningKeysRequest { UploadSigningKeysRequest { master_key, self_signing_key, user_signing_key }
master_key,
self_signing_key,
user_signing_key,
}
} }
} }
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use crate::{
identities::{ReadOnlyDevice, UserIdentity},
olm::ReadOnlyAccount,
};
use std::{collections::BTreeMap, sync::Arc}; use std::{collections::BTreeMap, sync::Arc};
use super::{PrivateCrossSigningIdentity, Signing};
use matrix_sdk_common::{ use matrix_sdk_common::{
api::r0::keys::CrossSigningKey, api::r0::keys::CrossSigningKey,
identifiers::{user_id, UserId}, identifiers::{user_id, UserId},
}; };
use matrix_sdk_test::async_test; use matrix_sdk_test::async_test;
use super::{PrivateCrossSigningIdentity, Signing};
use crate::{
identities::{ReadOnlyDevice, UserIdentity},
olm::ReadOnlyAccount,
};
fn user_id() -> UserId { fn user_id() -> UserId {
user_id!("@example:localhost") user_id!("@example:localhost")
} }
@ -481,28 +442,12 @@ mod test {
assert!(master_key assert!(master_key
.public_key .public_key
.verify_subkey( .verify_subkey(&identity.self_signing_key.lock().await.as_ref().unwrap().public_key,)
&identity
.self_signing_key
.lock()
.await
.as_ref()
.unwrap()
.public_key,
)
.is_ok()); .is_ok());
assert!(master_key assert!(master_key
.public_key .public_key
.verify_subkey( .verify_subkey(&identity.user_signing_key.lock().await.as_ref().unwrap().public_key,)
&identity
.user_signing_key
.lock()
.await
.as_ref()
.unwrap()
.public_key,
)
.is_ok()); .is_ok());
} }
@ -512,15 +457,11 @@ mod test {
let pickled = identity.pickle(pickle_key()).await.unwrap(); let pickled = identity.pickle(pickle_key()).await.unwrap();
let unpickled = PrivateCrossSigningIdentity::from_pickle(pickled, pickle_key()) let unpickled =
.await PrivateCrossSigningIdentity::from_pickle(pickled, pickle_key()).await.unwrap();
.unwrap();
assert_eq!(identity.user_id, unpickled.user_id); assert_eq!(identity.user_id, unpickled.user_id);
assert_eq!( assert_eq!(&*identity.master_key.lock().await, &*unpickled.master_key.lock().await);
&*identity.master_key.lock().await,
&*unpickled.master_key.lock().await
);
assert_eq!( assert_eq!(
&*identity.user_signing_key.lock().await, &*identity.user_signing_key.lock().await,
&*unpickled.user_signing_key.lock().await &*unpickled.user_signing_key.lock().await
@ -591,9 +532,6 @@ mod test {
bob_public.master_key = master.into(); bob_public.master_key = master.into();
user_signing user_signing.public_key.verify_master_key(bob_public.master_key()).unwrap();
.public_key
.verify_master_key(bob_public.master_key())
.unwrap();
} }
} }

View File

@ -12,32 +12,27 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::{collections::BTreeMap, convert::TryInto, sync::Arc};
use aes_gcm::{ use aes_gcm::{
aead::{generic_array::GenericArray, Aead, NewAead}, aead::{generic_array::GenericArray, Aead, NewAead},
Aes256Gcm, Aes256Gcm,
}; };
use getrandom::getrandom; use getrandom::getrandom;
use matrix_sdk_common::{
encryption::DeviceKeys,
identifiers::{DeviceKeyAlgorithm, DeviceKeyId},
};
use serde::{Deserialize, Serialize};
use serde_json::{json, Error as JsonError, Value};
use std::{collections::BTreeMap, convert::TryInto, sync::Arc};
use thiserror::Error;
use zeroize::Zeroizing;
use olm_rs::pk::OlmPkSigning;
#[cfg(test)]
use olm_rs::{errors::OlmUtilityError, utility::OlmUtility};
use matrix_sdk_common::{ use matrix_sdk_common::{
api::r0::keys::{CrossSigningKey, KeyUsage}, api::r0::keys::{CrossSigningKey, KeyUsage},
identifiers::UserId, encryption::DeviceKeys,
identifiers::{DeviceKeyAlgorithm, DeviceKeyId, UserId},
locks::Mutex, locks::Mutex,
CanonicalJsonValue, CanonicalJsonValue,
}; };
use olm_rs::pk::OlmPkSigning;
#[cfg(test)]
use olm_rs::{errors::OlmUtilityError, utility::OlmUtility};
use serde::{Deserialize, Serialize};
use serde_json::{json, Error as JsonError, Value};
use thiserror::Error;
use zeroize::Zeroizing;
use crate::{ use crate::{
error::SignatureError, error::SignatureError,
@ -73,9 +68,7 @@ pub struct Signing {
impl std::fmt::Debug for Signing { impl std::fmt::Debug for Signing {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Signing") f.debug_struct("Signing").field("public_key", &self.public_key.as_str()).finish()
.field("public_key", &self.public_key.as_str())
.finish()
} }
} }
@ -156,10 +149,7 @@ impl MasterSigning {
) -> Result<Self, SigningError> { ) -> Result<Self, SigningError> {
let inner = Signing::from_pickle(pickle.pickle, pickle_key)?; let inner = Signing::from_pickle(pickle.pickle, pickle_key)?;
Ok(Self { Ok(Self { inner, public_key: pickle.public_key.into() })
inner,
public_key: pickle.public_key.into(),
})
} }
pub async fn sign_subkey<'a>(&self, subkey: &mut CrossSigningKey) { pub async fn sign_subkey<'a>(&self, subkey: &mut CrossSigningKey) {
@ -200,10 +190,7 @@ impl UserSigning {
user: &UserIdentity, user: &UserIdentity,
) -> Result<BTreeMap<UserId, BTreeMap<String, Value>>, SignatureError> { ) -> Result<BTreeMap<UserId, BTreeMap<String, Value>>, SignatureError> {
let user_master: &CrossSigningKey = user.master_key().as_ref(); let user_master: &CrossSigningKey = user.master_key().as_ref();
let signature = self let signature = self.inner.sign_json(serde_json::to_value(user_master)?).await?;
.inner
.sign_json(serde_json::to_value(user_master)?)
.await?;
let mut signatures = BTreeMap::new(); let mut signatures = BTreeMap::new();
@ -228,10 +215,7 @@ impl UserSigning {
) -> Result<Self, SigningError> { ) -> Result<Self, SigningError> {
let inner = Signing::from_pickle(pickle.pickle, pickle_key)?; let inner = Signing::from_pickle(pickle.pickle, pickle_key)?;
Ok(Self { Ok(Self { inner, public_key: pickle.public_key.into() })
inner,
public_key: pickle.public_key.into(),
})
} }
} }
@ -279,10 +263,7 @@ impl SelfSigning {
) -> Result<Self, SigningError> { ) -> Result<Self, SigningError> {
let inner = Signing::from_pickle(pickle.pickle, pickle_key)?; let inner = Signing::from_pickle(pickle.pickle, pickle_key)?;
Ok(Self { Ok(Self { inner, public_key: pickle.public_key.into() })
inner,
public_key: pickle.public_key.into(),
})
} }
} }
@ -353,17 +334,12 @@ impl Signing {
getrandom(&mut nonce).expect("Can't generate nonce to pickle the signing object"); getrandom(&mut nonce).expect("Can't generate nonce to pickle the signing object");
let nonce = GenericArray::from_slice(nonce.as_slice()); let nonce = GenericArray::from_slice(nonce.as_slice());
let ciphertext = cipher let ciphertext =
.encrypt(nonce, self.seed.as_slice()) cipher.encrypt(nonce, self.seed.as_slice()).expect("Can't encrypt signing pickle");
.expect("Can't encrypt signing pickle");
let ciphertext = encode(ciphertext); let ciphertext = encode(ciphertext);
let pickle = InnerPickle { let pickle = InnerPickle { version: 1, nonce: encode(nonce.as_slice()), ciphertext };
version: 1,
nonce: encode(nonce.as_slice()),
ciphertext,
};
PickledSigning(serde_json::to_string(&pickle).expect("Can't encode pickled signing")) PickledSigning(serde_json::to_string(&pickle).expect("Can't encode pickled signing"))
} }
@ -376,11 +352,8 @@ impl Signing {
let mut keys = BTreeMap::new(); let mut keys = BTreeMap::new();
keys.insert( keys.insert(
DeviceKeyId::from_parts( DeviceKeyId::from_parts(DeviceKeyAlgorithm::Ed25519, self.public_key().as_str().into())
DeviceKeyAlgorithm::Ed25519, .to_string(),
self.public_key().as_str().into(),
)
.to_string(),
self.public_key().to_string(), self.public_key().to_string(),
); );

View File

@ -12,14 +12,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use olm_rs::utility::OlmUtility;
use serde_json::Value;
use std::convert::TryInto; use std::convert::TryInto;
use matrix_sdk_common::{ use matrix_sdk_common::{
identifiers::{DeviceKeyAlgorithm, DeviceKeyId, UserId}, identifiers::{DeviceKeyAlgorithm, DeviceKeyId, UserId},
CanonicalJsonValue, CanonicalJsonValue,
}; };
use olm_rs::utility::OlmUtility;
use serde_json::Value;
use crate::error::SignatureError; use crate::error::SignatureError;
@ -29,9 +29,7 @@ pub(crate) struct Utility {
impl Utility { impl Utility {
pub fn new() -> Self { pub fn new() -> Self {
Self { Self { inner: OlmUtility::new() }
inner: OlmUtility::new(),
}
} }
/// Verify a signed JSON object. /// Verify a signed JSON object.
@ -49,7 +47,7 @@ impl Utility {
/// * `key_id` - The id of the key that signed the JSON object. /// * `key_id` - The id of the key that signed the JSON object.
/// ///
/// * `signing_key` - The public ed25519 key which was used to sign the JSON /// * `signing_key` - The public ed25519 key which was used to sign the JSON
/// object. /// object.
/// ///
/// * `json` - The JSON object that should be verified. /// * `json` - The JSON object that should be verified.
pub(crate) fn verify_json( pub(crate) fn verify_json(
@ -67,29 +65,20 @@ impl Utility {
let unsigned = json_object.remove("unsigned"); let unsigned = json_object.remove("unsigned");
let signatures = json_object.remove("signatures"); let signatures = json_object.remove("signatures");
let canonical_json: CanonicalJsonValue = json let canonical_json: CanonicalJsonValue =
.clone() json.clone().try_into().map_err(|_| SignatureError::NotAnObject)?;
.try_into()
.map_err(|_| SignatureError::NotAnObject)?;
let canonical_json: String = canonical_json.to_string(); let canonical_json: String = canonical_json.to_string();
let signatures = signatures.ok_or(SignatureError::NoSignatureFound)?; let signatures = signatures.ok_or(SignatureError::NoSignatureFound)?;
let signature_object = signatures let signature_object = signatures.as_object().ok_or(SignatureError::NoSignatureFound)?;
.as_object() let signature =
.ok_or(SignatureError::NoSignatureFound)?; signature_object.get(user_id.as_str()).ok_or(SignatureError::NoSignatureFound)?;
let signature = signature_object let signature =
.get(user_id.as_str()) signature.get(key_id.to_string()).ok_or(SignatureError::NoSignatureFound)?;
.ok_or(SignatureError::NoSignatureFound)?;
let signature = signature
.get(key_id.to_string())
.ok_or(SignatureError::NoSignatureFound)?;
let signature = signature.as_str().ok_or(SignatureError::NoSignatureFound)?; let signature = signature.as_str().ok_or(SignatureError::NoSignatureFound)?;
let ret = match self let ret = match self.inner.ed25519_verify(signing_key, &canonical_json, signature) {
.inner
.ed25519_verify(signing_key, &canonical_json, signature)
{
Ok(_) => Ok(()), Ok(_) => Ok(()),
Err(_) => Err(SignatureError::VerificationError), Err(_) => Err(SignatureError::VerificationError),
}; };
@ -108,10 +97,11 @@ impl Utility {
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use super::Utility;
use matrix_sdk_common::identifiers::{user_id, DeviceKeyAlgorithm, DeviceKeyId}; use matrix_sdk_common::identifiers::{user_id, DeviceKeyAlgorithm, DeviceKeyId};
use serde_json::json; use serde_json::json;
use super::Utility;
#[test] #[test]
fn signature_test() { fn signature_test() {
let mut device_keys = json!({ let mut device_keys = json!({

View File

@ -35,18 +35,19 @@ use matrix_sdk_common::{
identifiers::{DeviceIdBox, RoomId, UserId}, identifiers::{DeviceIdBox, RoomId, UserId},
uuid::Uuid, uuid::Uuid,
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::value::RawValue as RawJsonValue; use serde_json::value::RawValue as RawJsonValue;
/// Customized version of `ruma_client_api::r0::to_device::send_event_to_device::Request`, using a /// Customized version of
/// `ruma_client_api::r0::to_device::send_event_to_device::Request`, using a
/// UUID for the transaction ID. /// UUID for the transaction ID.
#[derive(Clone, Debug, Deserialize, Serialize)] #[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ToDeviceRequest { pub struct ToDeviceRequest {
/// Type of event being sent to each device. /// Type of event being sent to each device.
pub event_type: EventType, pub event_type: EventType,
/// A request identifier unique to the access token used to send the request. /// A request identifier unique to the access token used to send the
/// request.
pub txn_id: Uuid, pub txn_id: Uuid,
/// A map of users to devices to a content for a message event to be /// A map of users to devices to a content for a message event to be
@ -80,15 +81,18 @@ impl ToDeviceRequest {
pub struct UploadSigningKeysRequest { pub struct UploadSigningKeysRequest {
/// The user's master key. /// The user's master key.
pub master_key: Option<CrossSigningKey>, pub master_key: Option<CrossSigningKey>,
/// The user's self-signing key. Must be signed with the accompanied master, or by the /// The user's self-signing key. Must be signed with the accompanied master,
/// user's most recently uploaded master key if no master key is included in the request. /// or by the user's most recently uploaded master key if no master key
/// is included in the request.
pub self_signing_key: Option<CrossSigningKey>, pub self_signing_key: Option<CrossSigningKey>,
/// The user's user-signing key. Must be signed with the accompanied master, or by the /// The user's user-signing key. Must be signed with the accompanied master,
/// user's most recently uploaded master key if no master key is included in the request. /// or by the user's most recently uploaded master key if no master key
/// is included in the request.
pub user_signing_key: Option<CrossSigningKey>, pub user_signing_key: Option<CrossSigningKey>,
} }
/// Customized version of `ruma_client_api::r0::keys::get_keys::Request`, without any references. /// Customized version of `ruma_client_api::r0::keys::get_keys::Request`,
/// without any references.
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct KeysQueryRequest { pub struct KeysQueryRequest {
/// The time (in milliseconds) to wait when downloading keys from remote /// The time (in milliseconds) to wait when downloading keys from remote
@ -109,11 +113,7 @@ pub struct KeysQueryRequest {
impl KeysQueryRequest { impl KeysQueryRequest {
pub(crate) fn new(device_keys: BTreeMap<UserId, Vec<DeviceIdBox>>) -> Self { pub(crate) fn new(device_keys: BTreeMap<UserId, Vec<DeviceIdBox>>) -> Self {
Self { Self { timeout: None, device_keys, token: None }
timeout: None,
device_keys,
token: None,
}
} }
} }
@ -177,19 +177,13 @@ impl From<SignatureUploadRequest> for OutgoingRequests {
impl From<OutgoingVerificationRequest> for OutgoingRequest { impl From<OutgoingVerificationRequest> for OutgoingRequest {
fn from(r: OutgoingVerificationRequest) -> Self { fn from(r: OutgoingVerificationRequest) -> Self {
Self { Self { request_id: r.request_id(), request: Arc::new(r.into()) }
request_id: r.request_id(),
request: Arc::new(r.into()),
}
} }
} }
impl From<SignatureUploadRequest> for OutgoingRequest { impl From<SignatureUploadRequest> for OutgoingRequest {
fn from(r: SignatureUploadRequest) -> Self { fn from(r: SignatureUploadRequest) -> Self {
Self { Self { request_id: Uuid::new_v4(), request: Arc::new(r.into()) }
request_id: Uuid::new_v4(),
request: Arc::new(r.into()),
}
} }
} }

View File

@ -17,9 +17,8 @@ use std::{
sync::Arc, sync::Arc,
}; };
use futures::future::join_all;
use dashmap::DashMap; use dashmap::DashMap;
use futures::future::join_all;
use matrix_sdk_common::{ use matrix_sdk_common::{
api::r0::to_device::DeviceIdOrAllDevices, api::r0::to_device::DeviceIdOrAllDevices,
events::{ events::{
@ -105,10 +104,7 @@ impl GroupSessionCache {
room_id: &RoomId, room_id: &RoomId,
session_id: &str, session_id: &str,
) -> StoreResult<Option<OutboundGroupSession>> { ) -> StoreResult<Option<OutboundGroupSession>> {
Ok(self Ok(self.get_or_load(room_id).await?.filter(|o| session_id == o.session_id()))
.get_or_load(room_id)
.await?
.filter(|o| session_id == o.session_id()))
} }
} }
@ -127,11 +123,7 @@ impl GroupSessionManager {
const MAX_TO_DEVICE_MESSAGES: usize = 250; const MAX_TO_DEVICE_MESSAGES: usize = 250;
pub(crate) fn new(account: Account, store: Store) -> Self { pub(crate) fn new(account: Account, store: Store) -> Self {
Self { Self { account, store: store.clone(), sessions: GroupSessionCache::new(store) }
account,
store: store.clone(),
sessions: GroupSessionCache::new(store),
}
} }
pub async fn invalidate_group_session(&self, room_id: &RoomId) -> StoreResult<bool> { pub async fn invalidate_group_session(&self, room_id: &RoomId) -> StoreResult<bool> {
@ -231,9 +223,7 @@ impl GroupSessionManager {
Ok((s, None)) Ok((s, None))
} }
} else { } else {
self.create_outbound_group_session(room_id, settings) self.create_outbound_group_session(room_id, settings).await.map(|(o, i)| (o, i.into()))
.await
.map(|(o, i)| (o, i.into()))
} }
} }
@ -253,13 +243,10 @@ impl GroupSessionManager {
let used_session = match encrypted { let used_session = match encrypted {
Ok((session, encrypted)) => { Ok((session, encrypted)) => {
message message.entry(device.user_id().clone()).or_insert_with(BTreeMap::new).insert(
.entry(device.user_id().clone()) DeviceIdOrAllDevices::DeviceId(device.device_id().into()),
.or_insert_with(BTreeMap::new) serde_json::value::to_raw_value(&encrypted)?,
.insert( );
DeviceIdOrAllDevices::DeviceId(device.device_id().into()),
serde_json::value::to_raw_value(&encrypted)?,
);
Some(session) Some(session)
} }
// TODO we'll want to create m.room_key.withheld here. // TODO we'll want to create m.room_key.withheld here.
@ -271,10 +258,8 @@ impl GroupSessionManager {
Ok((used_session, message)) Ok((used_session, message))
}; };
let tasks: Vec<_> = devices let tasks: Vec<_> =
.iter() devices.iter().map(|d| spawn(encrypt(d.clone(), content.clone()))).collect();
.map(|d| spawn(encrypt(d.clone(), content.clone())))
.collect();
let results = join_all(tasks).await; let results = join_all(tasks).await;
@ -286,20 +271,14 @@ impl GroupSessionManager {
} }
for (user, device_messages) in message.into_iter() { for (user, device_messages) in message.into_iter() {
messages messages.entry(user).or_insert_with(BTreeMap::new).extend(device_messages);
.entry(user)
.or_insert_with(BTreeMap::new)
.extend(device_messages);
} }
} }
let id = Uuid::new_v4(); let id = Uuid::new_v4();
let request = ToDeviceRequest { let request =
event_type: EventType::RoomEncrypted, ToDeviceRequest { event_type: EventType::RoomEncrypted, txn_id: id, messages };
txn_id: id,
messages,
};
trace!( trace!(
recipient_count = request.message_count(), recipient_count = request.message_count(),
@ -331,20 +310,14 @@ impl GroupSessionManager {
"Calculating group session recipients" "Calculating group session recipients"
); );
let users_shared_with: HashSet<UserId> = outbound let users_shared_with: HashSet<UserId> =
.shared_with_set outbound.shared_with_set.iter().map(|k| k.key().clone()).collect();
.iter()
.map(|k| k.key().clone())
.collect();
let users_shared_with: HashSet<&UserId> = users_shared_with.iter().collect(); let users_shared_with: HashSet<&UserId> = users_shared_with.iter().collect();
// A user left if a user is missing from the set of users that should // A user left if a user is missing from the set of users that should
// get the session but is in the set of users that received the session. // get the session but is in the set of users that received the session.
let user_left = !users_shared_with let user_left = !users_shared_with.difference(&users).collect::<HashSet<_>>().is_empty();
.difference(&users)
.collect::<HashSet<_>>()
.is_empty();
let visibility_changed = outbound.settings().history_visibility != history_visibility; let visibility_changed = outbound.settings().history_visibility != history_visibility;
@ -359,10 +332,8 @@ impl GroupSessionManager {
for user_id in users { for user_id in users {
let user_devices = self.store.get_user_devices(&user_id).await?; let user_devices = self.store.get_user_devices(&user_id).await?;
let non_blacklisted_devices: Vec<Device> = user_devices let non_blacklisted_devices: Vec<Device> =
.devices() user_devices.devices().filter(|d| !d.is_blacklisted()).collect();
.filter(|d| !d.is_blacklisted())
.collect();
// If we haven't already concluded that the session should be // If we haven't already concluded that the session should be
// rotated for other reasons, we also need to check whether any // rotated for other reasons, we also need to check whether any
@ -370,10 +341,8 @@ impl GroupSessionManager {
// meantime. If so, we should also rotate the session. // meantime. If so, we should also rotate the session.
if !should_rotate { if !should_rotate {
// Device IDs that should receive this session // Device IDs that should receive this session
let non_blacklisted_device_ids: HashSet<&DeviceId> = non_blacklisted_devices let non_blacklisted_device_ids: HashSet<&DeviceId> =
.iter() non_blacklisted_devices.iter().map(|d| d.device_id()).collect();
.map(|d| d.device_id())
.collect();
if let Some(shared) = outbound.shared_with_set.get(user_id) { if let Some(shared) = outbound.shared_with_set.get(user_id) {
#[allow(clippy::map_clone)] #[allow(clippy::map_clone)]
@ -389,9 +358,8 @@ impl GroupSessionManager {
// //
// represents newly deleted or blacklisted devices. If this // represents newly deleted or blacklisted devices. If this
// set is non-empty, we must rotate. // set is non-empty, we must rotate.
let newly_deleted_or_blacklisted = shared let newly_deleted_or_blacklisted =
.difference(&non_blacklisted_device_ids) shared.difference(&non_blacklisted_device_ids).collect::<HashSet<_>>();
.collect::<HashSet<_>>();
if !newly_deleted_or_blacklisted.is_empty() { if !newly_deleted_or_blacklisted.is_empty() {
should_rotate = true; should_rotate = true;
@ -399,10 +367,7 @@ impl GroupSessionManager {
}; };
} }
devices devices.entry(user_id.clone()).or_insert_with(Vec::new).extend(non_blacklisted_devices);
.entry(user_id.clone())
.or_insert_with(Vec::new)
.extend(non_blacklisted_devices);
} }
debug!( debug!(
@ -462,25 +427,22 @@ impl GroupSessionManager {
let history_visibility = encryption_settings.history_visibility.clone(); let history_visibility = encryption_settings.history_visibility.clone();
let mut changes = Changes::default(); let mut changes = Changes::default();
let (outbound, inbound) = self let (outbound, inbound) =
.get_or_create_outbound_session(room_id, encryption_settings.clone()) self.get_or_create_outbound_session(room_id, encryption_settings.clone()).await?;
.await?;
if let Some(inbound) = inbound { if let Some(inbound) = inbound {
changes.outbound_group_sessions.push(outbound.clone()); changes.outbound_group_sessions.push(outbound.clone());
changes.inbound_group_sessions.push(inbound); changes.inbound_group_sessions.push(inbound);
} }
let (should_rotate, devices) = self let (should_rotate, devices) =
.collect_session_recipients(users, history_visibility, &outbound) self.collect_session_recipients(users, history_visibility, &outbound).await?;
.await?;
let outbound = if should_rotate { let outbound = if should_rotate {
let old_session_id = outbound.session_id(); let old_session_id = outbound.session_id();
let (outbound, inbound) = self let (outbound, inbound) =
.create_outbound_group_session(room_id, encryption_settings) self.create_outbound_group_session(room_id, encryption_settings).await?;
.await?;
changes.outbound_group_sessions.push(outbound.clone()); changes.outbound_group_sessions.push(outbound.clone());
changes.inbound_group_sessions.push(inbound); changes.inbound_group_sessions.push(inbound);
@ -515,9 +477,7 @@ impl GroupSessionManager {
if !devices.is_empty() { if !devices.is_empty() {
let users = devices.iter().fold(BTreeMap::new(), |mut acc, d| { let users = devices.iter().fold(BTreeMap::new(), |mut acc, d| {
acc.entry(d.user_id()) acc.entry(d.user_id()).or_insert_with(BTreeSet::new).insert(d.device_id());
.or_insert_with(BTreeSet::new)
.insert(d.device_id());
acc acc
}); });
@ -626,14 +586,8 @@ mod test {
let machine = OlmMachine::new(&alice_id(), &alice_device_id()); let machine = OlmMachine::new(&alice_id(), &alice_device_id());
machine machine.mark_request_as_sent(&uuid, &keys_query).await.unwrap();
.mark_request_as_sent(&uuid, &keys_query) machine.mark_request_as_sent(&uuid, &keys_claim).await.unwrap();
.await
.unwrap();
machine
.mark_request_as_sent(&uuid, &keys_claim)
.await
.unwrap();
machine machine
} }
@ -647,11 +601,7 @@ mod test {
let users: Vec<_> = keys_claim.one_time_keys.keys().collect(); let users: Vec<_> = keys_claim.one_time_keys.keys().collect();
let requests = machine let requests = machine
.share_group_session( .share_group_session(&room_id, users.clone().into_iter(), EncryptionSettings::default())
&room_id,
users.clone().into_iter(),
EncryptionSettings::default(),
)
.await .await
.unwrap(); .unwrap();

View File

@ -77,11 +77,7 @@ impl SessionManager {
} }
pub async fn mark_device_as_wedged(&self, sender: &UserId, curve_key: &str) -> StoreResult<()> { pub async fn mark_device_as_wedged(&self, sender: &UserId, curve_key: &str) -> StoreResult<()> {
if let Some(device) = self if let Some(device) = self.store.get_device_from_curve_key(sender, curve_key).await? {
.store
.get_device_from_curve_key(sender, curve_key)
.await?
{
let sessions = device.get_sessions().await?; let sessions = device.get_sessions().await?;
if let Some(sessions) = sessions { if let Some(sessions) = sessions {
@ -120,25 +116,16 @@ impl SessionManager {
/// ///
/// If the device was wedged this will queue up a dummy to-device message. /// If the device was wedged this will queue up a dummy to-device message.
async fn check_if_unwedged(&self, user_id: &UserId, device_id: &DeviceId) -> OlmResult<()> { async fn check_if_unwedged(&self, user_id: &UserId, device_id: &DeviceId) -> OlmResult<()> {
if self if self.wedged_devices.get(user_id).map(|d| d.remove(device_id)).flatten().is_some() {
.wedged_devices
.get(user_id)
.map(|d| d.remove(device_id))
.flatten()
.is_some()
{
if let Some(device) = self.store.get_device(user_id, device_id).await? { if let Some(device) = self.store.get_device(user_id, device_id).await? {
let (_, content) = device.encrypt(EventType::Dummy, json!({})).await?; let (_, content) = device.encrypt(EventType::Dummy, json!({})).await?;
let id = Uuid::new_v4(); let id = Uuid::new_v4();
let mut messages = BTreeMap::new(); let mut messages = BTreeMap::new();
messages messages.entry(device.user_id().to_owned()).or_insert_with(BTreeMap::new).insert(
.entry(device.user_id().to_owned()) DeviceIdOrAllDevices::DeviceId(device.device_id().into()),
.or_insert_with(BTreeMap::new) to_raw_value(&content)?,
.insert( );
DeviceIdOrAllDevices::DeviceId(device.device_id().into()),
to_raw_value(&content)?,
);
let request = OutgoingRequest { let request = OutgoingRequest {
request_id: id, request_id: id,
@ -307,13 +294,13 @@ impl SessionManager {
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use dashmap::DashMap;
use matrix_sdk_common::locks::Mutex;
use std::{collections::BTreeMap, sync::Arc}; use std::{collections::BTreeMap, sync::Arc};
use dashmap::DashMap;
use matrix_sdk_common::{ use matrix_sdk_common::{
api::r0::keys::claim_keys::Response as KeyClaimResponse, api::r0::keys::claim_keys::Response as KeyClaimResponse,
identifiers::{user_id, DeviceIdBox, UserId}, identifiers::{user_id, DeviceIdBox, UserId},
locks::Mutex,
}; };
use matrix_sdk_test::async_test; use matrix_sdk_test::async_test;
@ -347,9 +334,7 @@ mod test {
let account = ReadOnlyAccount::new(&user_id, &device_id); let account = ReadOnlyAccount::new(&user_id, &device_id);
let store: Arc<Box<dyn CryptoStore>> = Arc::new(Box::new(MemoryStore::new())); let store: Arc<Box<dyn CryptoStore>> = Arc::new(Box::new(MemoryStore::new()));
store.save_account(account.clone()).await.unwrap(); store.save_account(account.clone()).await.unwrap();
let identity = Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty( let identity = Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty(user_id.clone())));
user_id.clone(),
)));
let verification = let verification =
VerificationMachine::new(account.clone(), identity.clone(), store.clone()); VerificationMachine::new(account.clone(), identity.clone(), store.clone());
@ -358,10 +343,7 @@ mod test {
let store = Store::new(user_id.clone(), identity, store, verification); let store = Store::new(user_id.clone(), identity, store, verification);
let account = Account { let account = Account { inner: account, store: store.clone() };
inner: account,
store: store.clone(),
};
let session_cache = GroupSessionCache::new(store.clone()); let session_cache = GroupSessionCache::new(store.clone());
@ -405,10 +387,7 @@ mod test {
let response = KeyClaimResponse::new(one_time_keys); let response = KeyClaimResponse::new(one_time_keys);
manager manager.receive_keys_claim_response(&response).await.unwrap();
.receive_keys_claim_response(&response)
.await
.unwrap();
assert!(manager assert!(manager
.get_missing_sessions(&mut [bob.user_id().clone()].iter()) .get_missing_sessions(&mut [bob.user_id().clone()].iter())
@ -434,11 +413,7 @@ mod test {
let bob_device = ReadOnlyDevice::from_account(&bob).await; let bob_device = ReadOnlyDevice::from_account(&bob).await;
session.creation_time = Arc::new(Instant::now() - Duration::from_secs(3601)); session.creation_time = Arc::new(Instant::now() - Duration::from_secs(3601));
manager manager.store.save_devices(&[bob_device.clone()]).await.unwrap();
.store
.save_devices(&[bob_device.clone()])
.await
.unwrap();
manager.store.save_sessions(&[session]).await.unwrap(); manager.store.save_sessions(&[session]).await.unwrap();
assert!(manager assert!(manager
@ -451,10 +426,7 @@ mod test {
assert!(!manager.users_for_key_claim.contains_key(bob.user_id())); assert!(!manager.users_for_key_claim.contains_key(bob.user_id()));
assert!(!manager.is_device_wedged(&bob_device)); assert!(!manager.is_device_wedged(&bob_device));
manager manager.mark_device_as_wedged(bob_device.user_id(), &curve_key).await.unwrap();
.mark_device_as_wedged(bob_device.user_id(), &curve_key)
.await
.unwrap();
assert!(manager.is_device_wedged(&bob_device)); assert!(manager.is_device_wedged(&bob_device));
assert!(manager.users_for_key_claim.contains_key(bob.user_id())); assert!(manager.users_for_key_claim.contains_key(bob.user_id()));
@ -480,10 +452,7 @@ mod test {
assert!(manager.outgoing_to_device_requests.is_empty()); assert!(manager.outgoing_to_device_requests.is_empty());
manager manager.receive_keys_claim_response(&response).await.unwrap();
.receive_keys_claim_response(&response)
.await
.unwrap();
assert!(!manager.is_device_wedged(&bob_device)); assert!(!manager.is_device_wedged(&bob_device));
assert!(manager assert!(manager

View File

@ -39,9 +39,7 @@ pub struct SessionStore {
impl SessionStore { impl SessionStore {
/// Create a new empty Session store. /// Create a new empty Session store.
pub fn new() -> Self { pub fn new() -> Self {
SessionStore { SessionStore { entries: Arc::new(DashMap::new()) }
entries: Arc::new(DashMap::new()),
}
} }
/// Add a session to the store. /// Add a session to the store.
@ -72,8 +70,7 @@ impl SessionStore {
/// Add a list of sessions belonging to the sender key. /// Add a list of sessions belonging to the sender key.
pub fn set_for_sender(&self, sender_key: &str, sessions: Vec<Session>) { pub fn set_for_sender(&self, sender_key: &str, sessions: Vec<Session>) {
self.entries self.entries.insert(sender_key.to_owned(), Arc::new(Mutex::new(sessions)));
.insert(sender_key.to_owned(), Arc::new(Mutex::new(sessions)));
} }
} }
@ -87,9 +84,7 @@ pub struct GroupSessionStore {
impl GroupSessionStore { impl GroupSessionStore {
/// Create a new empty store. /// Create a new empty store.
pub fn new() -> Self { pub fn new() -> Self {
GroupSessionStore { GroupSessionStore { entries: Arc::new(DashMap::new()) }
entries: Arc::new(DashMap::new()),
}
} }
/// Add an inbound group session to the store. /// Add an inbound group session to the store.
@ -148,9 +143,7 @@ pub struct DeviceStore {
impl DeviceStore { impl DeviceStore {
/// Create a new empty device store. /// Create a new empty device store.
pub fn new() -> Self { pub fn new() -> Self {
DeviceStore { DeviceStore { entries: Arc::new(DashMap::new()) }
entries: Arc::new(DashMap::new()),
}
} }
/// Add a device to the store. /// Add a device to the store.
@ -167,19 +160,15 @@ impl DeviceStore {
/// Get the device with the given device_id and belonging to the given user. /// Get the device with the given device_id and belonging to the given user.
pub fn get(&self, user_id: &UserId, device_id: &DeviceId) -> Option<ReadOnlyDevice> { pub fn get(&self, user_id: &UserId, device_id: &DeviceId) -> Option<ReadOnlyDevice> {
self.entries self.entries.get(user_id).and_then(|m| m.get(device_id).map(|d| d.value().clone()))
.get(user_id)
.and_then(|m| m.get(device_id).map(|d| d.value().clone()))
} }
/// Remove the device with the given device_id and belonging to the given user. /// Remove the device with the given device_id and belonging to the given
/// user.
/// ///
/// Returns the device if it was removed, None if it wasn't in the store. /// Returns the device if it was removed, None if it wasn't in the store.
pub fn remove(&self, user_id: &UserId, device_id: &DeviceId) -> Option<ReadOnlyDevice> { pub fn remove(&self, user_id: &UserId, device_id: &DeviceId) -> Option<ReadOnlyDevice> {
self.entries self.entries.get(user_id).and_then(|m| m.remove(device_id)).map(|(_, d)| d)
.get(user_id)
.and_then(|m| m.remove(device_id))
.map(|(_, d)| d)
} }
/// Get a read-only view over all devices of the given user. /// Get a read-only view over all devices of the given user.
@ -195,12 +184,13 @@ impl DeviceStore {
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use matrix_sdk_common::identifiers::room_id;
use crate::{ use crate::{
identities::device::test::get_device, identities::device::test::get_device,
olm::{test::get_account_and_session, InboundGroupSession}, olm::{test::get_account_and_session, InboundGroupSession},
store::caches::{DeviceStore, GroupSessionStore, SessionStore}, store::caches::{DeviceStore, GroupSessionStore, SessionStore},
}; };
use matrix_sdk_common::identifiers::room_id;
#[tokio::test] #[tokio::test]
async fn test_session_store() { async fn test_session_store() {
@ -239,10 +229,8 @@ mod test {
let (account, _) = get_account_and_session().await; let (account, _) = get_account_and_session().await;
let room_id = room_id!("!test:localhost"); let room_id = room_id!("!test:localhost");
let (outbound, _) = account let (outbound, _) =
.create_group_session_pair_with_defaults(&room_id) account.create_group_session_pair_with_defaults(&room_id).await.unwrap();
.await
.unwrap();
assert_eq!(0, outbound.message_index().await); assert_eq!(0, outbound.message_index().await);
assert!(!outbound.shared()); assert!(!outbound.shared());
@ -261,9 +249,7 @@ mod test {
let store = GroupSessionStore::new(); let store = GroupSessionStore::new();
store.add(inbound.clone()); store.add(inbound.clone());
let loaded_session = store let loaded_session = store.get(&room_id, "test_key", outbound.session_id()).unwrap();
.get(&room_id, "test_key", outbound.session_id())
.unwrap();
assert_eq!(inbound, loaded_session); assert_eq!(inbound, loaded_session);
} }

View File

@ -37,10 +37,7 @@ use crate::{
}; };
fn encode_key_info(info: &RequestedKeyInfo) -> String { fn encode_key_info(info: &RequestedKeyInfo) -> String {
format!( format!("{}{}{}{}", info.room_id, info.sender_key, info.algorithm, info.session_id)
"{}{}{}{}",
info.room_id, info.sender_key, info.algorithm, info.session_id
)
} }
/// An in-memory only store that will forget all the E2EE key once it's dropped. /// An in-memory only store that will forget all the E2EE key once it's dropped.
@ -121,22 +118,14 @@ impl CryptoStore for MemoryStore {
async fn save_changes(&self, mut changes: Changes) -> Result<()> { async fn save_changes(&self, mut changes: Changes) -> Result<()> {
self.save_sessions(changes.sessions).await; self.save_sessions(changes.sessions).await;
self.save_inbound_group_sessions(changes.inbound_group_sessions) self.save_inbound_group_sessions(changes.inbound_group_sessions).await;
.await;
self.save_devices(changes.devices.new).await; self.save_devices(changes.devices.new).await;
self.save_devices(changes.devices.changed).await; self.save_devices(changes.devices.changed).await;
self.delete_devices(changes.devices.deleted).await; self.delete_devices(changes.devices.deleted).await;
for identity in changes for identity in changes.identities.new.drain(..).chain(changes.identities.changed) {
.identities let _ = self.identities.insert(identity.user_id().to_owned(), identity.clone());
.new
.drain(..)
.chain(changes.identities.changed)
{
let _ = self
.identities
.insert(identity.user_id().to_owned(), identity.clone());
} }
for hash in changes.message_hashes { for hash in changes.message_hashes {
@ -167,9 +156,7 @@ impl CryptoStore for MemoryStore {
sender_key: &str, sender_key: &str,
session_id: &str, session_id: &str,
) -> Result<Option<InboundGroupSession>> { ) -> Result<Option<InboundGroupSession>> {
Ok(self Ok(self.inbound_group_sessions.get(room_id, sender_key, session_id))
.inbound_group_sessions
.get(room_id, sender_key, session_id))
} }
async fn get_inbound_group_sessions(&self) -> Result<Vec<InboundGroupSession>> { async fn get_inbound_group_sessions(&self) -> Result<Vec<InboundGroupSession>> {
@ -250,10 +237,7 @@ impl CryptoStore for MemoryStore {
&self, &self,
request_id: Uuid, request_id: Uuid,
) -> Result<Option<OutgoingKeyRequest>> { ) -> Result<Option<OutgoingKeyRequest>> {
Ok(self Ok(self.outgoing_key_requests.get(&request_id).map(|r| r.clone()))
.outgoing_key_requests
.get(&request_id)
.map(|r| r.clone()))
} }
async fn get_key_request_by_info( async fn get_key_request_by_info(
@ -278,12 +262,10 @@ impl CryptoStore for MemoryStore {
} }
async fn delete_outgoing_key_request(&self, request_id: Uuid) -> Result<()> { async fn delete_outgoing_key_request(&self, request_id: Uuid) -> Result<()> {
self.outgoing_key_requests self.outgoing_key_requests.remove(&request_id).and_then(|(_, i)| {
.remove(&request_id) let key_info_string = encode_key_info(&i.info);
.and_then(|(_, i)| { self.key_requests_by_info.remove(&key_info_string)
let key_info_string = encode_key_info(&i.info); });
self.key_requests_by_info.remove(&key_info_string)
});
Ok(()) Ok(())
} }
@ -291,12 +273,13 @@ impl CryptoStore for MemoryStore {
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use matrix_sdk_common::identifiers::room_id;
use crate::{ use crate::{
identities::device::test::get_device, identities::device::test::get_device,
olm::{test::get_account_and_session, InboundGroupSession, OlmMessageHash}, olm::{test::get_account_and_session, InboundGroupSession, OlmMessageHash},
store::{memorystore::MemoryStore, Changes, CryptoStore}, store::{memorystore::MemoryStore, Changes, CryptoStore},
}; };
use matrix_sdk_common::identifiers::room_id;
#[tokio::test] #[tokio::test]
async fn test_session_store() { async fn test_session_store() {
@ -308,11 +291,7 @@ mod test {
store.save_sessions(vec![session.clone()]).await; store.save_sessions(vec![session.clone()]).await;
let sessions = store let sessions = store.get_sessions(&session.sender_key).await.unwrap().unwrap();
.get_sessions(&session.sender_key)
.await
.unwrap()
.unwrap();
let sessions = sessions.lock().await; let sessions = sessions.lock().await;
let loaded_session = &sessions[0]; let loaded_session = &sessions[0];
@ -325,10 +304,8 @@ mod test {
let (account, _) = get_account_and_session().await; let (account, _) = get_account_and_session().await;
let room_id = room_id!("!test:localhost"); let room_id = room_id!("!test:localhost");
let (outbound, _) = account let (outbound, _) =
.create_group_session_pair_with_defaults(&room_id) account.create_group_session_pair_with_defaults(&room_id).await.unwrap();
.await
.unwrap();
let inbound = InboundGroupSession::new( let inbound = InboundGroupSession::new(
"test_key", "test_key",
"test_key", "test_key",
@ -339,9 +316,7 @@ mod test {
.unwrap(); .unwrap();
let store = MemoryStore::new(); let store = MemoryStore::new();
let _ = store let _ = store.save_inbound_group_sessions(vec![inbound.clone()]).await;
.save_inbound_group_sessions(vec![inbound.clone()])
.await;
let loaded_session = store let loaded_session = store
.get_inbound_group_session(&room_id, "test_key", outbound.session_id()) .get_inbound_group_session(&room_id, "test_key", outbound.session_id())
@ -358,11 +333,8 @@ mod test {
store.save_devices(vec![device.clone()]).await; store.save_devices(vec![device.clone()]).await;
let loaded_device = store let loaded_device =
.get_device(device.user_id(), device.device_id()) store.get_device(device.user_id(), device.device_id()).await.unwrap().unwrap();
.await
.unwrap()
.unwrap();
assert_eq!(device, loaded_device); assert_eq!(device, loaded_device);
@ -376,11 +348,7 @@ mod test {
assert_eq!(&device, loaded_device); assert_eq!(&device, loaded_device);
store.delete_devices(vec![device.clone()]).await; store.delete_devices(vec![device.clone()]).await;
assert!(store assert!(store.get_device(device.user_id(), device.device_id()).await.unwrap().is_none());
.get_device(device.user_id(), device.device_id())
.await
.unwrap()
.is_none());
} }
#[tokio::test] #[tokio::test]
@ -388,14 +356,8 @@ mod test {
let device = get_device(); let device = get_device();
let store = MemoryStore::new(); let store = MemoryStore::new();
assert!(store assert!(store.update_tracked_user(device.user_id(), false).await.unwrap());
.update_tracked_user(device.user_id(), false) assert!(!store.update_tracked_user(device.user_id(), false).await.unwrap());
.await
.unwrap());
assert!(!store
.update_tracked_user(device.user_id(), false)
.await
.unwrap());
assert!(store.is_user_tracked(device.user_id())); assert!(store.is_user_tracked(device.user_id()));
} }
@ -404,10 +366,8 @@ mod test {
async fn test_message_hash() { async fn test_message_hash() {
let store = MemoryStore::new(); let store = MemoryStore::new();
let hash = OlmMessageHash { let hash =
sender_key: "test_sender".to_owned(), OlmMessageHash { sender_key: "test_sender".to_owned(), hash: "test_hash".to_owned() };
hash: "test_hash".to_owned(),
};
let mut changes = Changes::default(); let mut changes = Changes::default();
changes.message_hashes.push(hash.clone()); changes.message_hashes.push(hash.clone());

View File

@ -43,11 +43,6 @@ mod pickle_key;
#[cfg(feature = "sled_cryptostore")] #[cfg(feature = "sled_cryptostore")]
pub(crate) mod sled; pub(crate) mod sled;
#[cfg(feature = "sled_cryptostore")]
pub use self::sled::SledStore;
pub use memorystore::MemoryStore;
pub use pickle_key::{EncryptedPickleKey, PickleKey};
use std::{ use std::{
collections::{HashMap, HashSet}, collections::{HashMap, HashSet},
fmt::Debug, fmt::Debug,
@ -56,10 +51,6 @@ use std::{
sync::Arc, sync::Arc,
}; };
use olm_rs::errors::{OlmAccountError, OlmGroupSessionError, OlmSessionError};
use serde_json::Error as SerdeError;
use thiserror::Error;
use matrix_sdk_common::{ use matrix_sdk_common::{
async_trait, async_trait,
events::room_key_request::RequestedKeyInfo, events::room_key_request::RequestedKeyInfo,
@ -71,7 +62,14 @@ use matrix_sdk_common::{
uuid::Uuid, uuid::Uuid,
AsyncTraitDeps, AsyncTraitDeps,
}; };
pub use memorystore::MemoryStore;
use olm_rs::errors::{OlmAccountError, OlmGroupSessionError, OlmSessionError};
pub use pickle_key::{EncryptedPickleKey, PickleKey};
use serde_json::Error as SerdeError;
use thiserror::Error;
#[cfg(feature = "sled_cryptostore")]
pub use self::sled::SledStore;
use crate::{ use crate::{
error::SessionUnpicklingError, error::SessionUnpicklingError,
identities::{Device, ReadOnlyDevice, UserDevices, UserIdentities}, identities::{Device, ReadOnlyDevice, UserDevices, UserIdentities},
@ -145,12 +143,7 @@ impl Store {
store: Arc<Box<dyn CryptoStore>>, store: Arc<Box<dyn CryptoStore>>,
verification_machine: VerificationMachine, verification_machine: VerificationMachine,
) -> Self { ) -> Self {
Self { Self { user_id, identity, inner: store, verification_machine }
user_id,
identity,
inner: store,
verification_machine,
}
} }
pub async fn get_readonly_device( pub async fn get_readonly_device(
@ -162,10 +155,7 @@ impl Store {
} }
pub async fn save_sessions(&self, sessions: &[Session]) -> Result<()> { pub async fn save_sessions(&self, sessions: &[Session]) -> Result<()> {
let changes = Changes { let changes = Changes { sessions: sessions.to_vec(), ..Default::default() };
sessions: sessions.to_vec(),
..Default::default()
};
self.save_changes(changes).await self.save_changes(changes).await
} }
@ -173,10 +163,7 @@ impl Store {
#[cfg(test)] #[cfg(test)]
pub async fn save_devices(&self, devices: &[ReadOnlyDevice]) -> Result<()> { pub async fn save_devices(&self, devices: &[ReadOnlyDevice]) -> Result<()> {
let changes = Changes { let changes = Changes {
devices: DeviceChanges { devices: DeviceChanges { changed: devices.to_vec(), ..Default::default() },
changed: devices.to_vec(),
..Default::default()
},
..Default::default() ..Default::default()
}; };
@ -188,10 +175,7 @@ impl Store {
&self, &self,
sessions: &[InboundGroupSession], sessions: &[InboundGroupSession],
) -> Result<()> { ) -> Result<()> {
let changes = Changes { let changes = Changes { inbound_group_sessions: sessions.to_vec(), ..Default::default() };
inbound_group_sessions: sessions.to_vec(),
..Default::default()
};
self.save_changes(changes).await self.save_changes(changes).await
} }
@ -210,8 +194,7 @@ impl Store {
) -> Result<Option<Device>> { ) -> Result<Option<Device>> {
self.get_user_devices(user_id).await.map(|d| { self.get_user_devices(user_id).await.map(|d| {
d.devices().find(|d| { d.devices().find(|d| {
d.get_key(DeviceKeyAlgorithm::Curve25519) d.get_key(DeviceKeyAlgorithm::Curve25519).map_or(false, |k| k == curve_key)
.map_or(false, |k| k == curve_key)
}) })
}) })
} }
@ -219,12 +202,8 @@ impl Store {
pub async fn get_user_devices(&self, user_id: &UserId) -> Result<UserDevices> { pub async fn get_user_devices(&self, user_id: &UserId) -> Result<UserDevices> {
let devices = self.inner.get_user_devices(user_id).await?; let devices = self.inner.get_user_devices(user_id).await?;
let own_identity = self let own_identity =
.inner self.inner.get_user_identity(&self.user_id).await?.map(|i| i.own().cloned()).flatten();
.get_user_identity(&self.user_id)
.await?
.map(|i| i.own().cloned())
.flatten();
let device_owner_identity = self.inner.get_user_identity(user_id).await.ok().flatten(); let device_owner_identity = self.inner.get_user_identity(user_id).await.ok().flatten();
Ok(UserDevices { Ok(UserDevices {
@ -241,24 +220,17 @@ impl Store {
user_id: &UserId, user_id: &UserId,
device_id: &DeviceId, device_id: &DeviceId,
) -> Result<Option<Device>> { ) -> Result<Option<Device>> {
let own_identity = self let own_identity =
.get_user_identity(&self.user_id) self.get_user_identity(&self.user_id).await?.map(|i| i.own().cloned()).flatten();
.await?
.map(|i| i.own().cloned())
.flatten();
let device_owner_identity = self.get_user_identity(user_id).await?; let device_owner_identity = self.get_user_identity(user_id).await?;
Ok(self Ok(self.inner.get_device(user_id, device_id).await?.map(|d| Device {
.inner inner: d,
.get_device(user_id, device_id) private_identity: self.identity.clone(),
.await? verification_machine: self.verification_machine.clone(),
.map(|d| Device { own_identity,
inner: d, device_owner_identity,
private_identity: self.identity.clone(), }))
verification_machine: self.verification_machine.clone(),
own_identity,
device_owner_identity,
}))
} }
} }
@ -366,7 +338,8 @@ pub trait CryptoStore: AsyncTraitDeps {
/// Get all the inbound group sessions we have stored. /// Get all the inbound group sessions we have stored.
async fn get_inbound_group_sessions(&self) -> Result<Vec<InboundGroupSession>>; async fn get_inbound_group_sessions(&self) -> Result<Vec<InboundGroupSession>>;
/// Get the outobund group sessions we have stored that is used for the given room. /// Get the outobund group sessions we have stored that is used for the
/// given room.
async fn get_outbound_group_sessions( async fn get_outbound_group_sessions(
&self, &self,
room_id: &RoomId, room_id: &RoomId,

View File

@ -22,11 +22,10 @@ use getrandom::getrandom;
use hmac::Hmac; use hmac::Hmac;
use olm_rs::PicklingMode; use olm_rs::PicklingMode;
use pbkdf2::pbkdf2; use pbkdf2::pbkdf2;
use serde::{Deserialize, Serialize};
use sha2::Sha256; use sha2::Sha256;
use zeroize::{Zeroize, Zeroizing}; use zeroize::{Zeroize, Zeroizing};
use serde::{Deserialize, Serialize};
const KEY_SIZE: usize = 32; const KEY_SIZE: usize = 32;
const NONCE_SIZE: usize = 12; const NONCE_SIZE: usize = 12;
const KDF_SALT_SIZE: usize = 32; const KDF_SALT_SIZE: usize = 32;
@ -114,9 +113,7 @@ impl PickleKey {
/// Get a `PicklingMode` version of this pickle key. /// Get a `PicklingMode` version of this pickle key.
pub fn pickle_mode(&self) -> PicklingMode { pub fn pickle_mode(&self) -> PicklingMode {
PicklingMode::Encrypted { PicklingMode::Encrypted { key: self.aes256_key.clone() }
key: self.aes256_key.clone(),
}
} }
/// Get the raw AES256 key. /// Get the raw AES256 key.
@ -142,10 +139,7 @@ impl PickleKey {
getrandom(&mut nonce).expect("Can't generate new random nonce for the pickle key"); getrandom(&mut nonce).expect("Can't generate new random nonce for the pickle key");
let ciphertext = cipher let ciphertext = cipher
.encrypt( .encrypt(&GenericArray::from_slice(nonce.as_ref()), self.aes256_key.as_slice())
&GenericArray::from_slice(nonce.as_ref()),
self.aes256_key.as_slice(),
)
.expect("Can't encrypt pickle key"); .expect("Can't encrypt pickle key");
EncryptedPickleKey { EncryptedPickleKey {
@ -181,9 +175,7 @@ impl PickleKey {
} }
}; };
Ok(Self { Ok(Self { aes256_key: decrypted })
aes256_key: decrypted,
})
} }
} }

View File

@ -20,13 +20,6 @@ use std::{
}; };
use dashmap::DashSet; use dashmap::DashSet;
use olm_rs::{account::IdentityKeys, PicklingMode};
pub use sled::Error;
use sled::{
transaction::{ConflictableTransactionError, TransactionError},
Config, Db, Transactional, Tree,
};
use matrix_sdk_common::{ use matrix_sdk_common::{
async_trait, async_trait,
events::room_key_request::RequestedKeyInfo, events::room_key_request::RequestedKeyInfo,
@ -34,6 +27,12 @@ use matrix_sdk_common::{
locks::Mutex, locks::Mutex,
uuid, uuid,
}; };
use olm_rs::{account::IdentityKeys, PicklingMode};
pub use sled::Error;
use sled::{
transaction::{ConflictableTransactionError, TransactionError},
Config, Db, Transactional, Tree,
};
use uuid::Uuid; use uuid::Uuid;
use super::{ use super::{
@ -97,13 +96,7 @@ impl EncodeKey for &str {
impl EncodeKey for (&str, &str) { impl EncodeKey for (&str, &str) {
fn encode(&self) -> Vec<u8> { fn encode(&self) -> Vec<u8> {
[ [self.0.as_bytes(), &[Self::SEPARATOR], self.1.as_bytes(), &[Self::SEPARATOR]].concat()
self.0.as_bytes(),
&[Self::SEPARATOR],
self.1.as_bytes(),
&[Self::SEPARATOR],
]
.concat()
} }
} }
@ -164,9 +157,7 @@ impl std::fmt::Debug for SledStore {
if let Some(path) = &self.path { if let Some(path) = &self.path {
f.debug_struct("SledStore").field("path", &path).finish() f.debug_struct("SledStore").field("path", &path).finish()
} else { } else {
f.debug_struct("SledStore") f.debug_struct("SledStore").field("path", &"memory store").finish()
.field("path", &"memory store")
.finish()
} }
} }
} }
@ -253,9 +244,8 @@ impl SledStore {
} }
fn get_or_create_pickle_key(passphrase: &str, database: &Db) -> Result<PickleKey> { fn get_or_create_pickle_key(passphrase: &str, database: &Db) -> Result<PickleKey> {
let key = if let Some(key) = database let key = if let Some(key) =
.get("pickle_key".encode())? database.get("pickle_key".encode())?.map(|v| serde_json::from_slice(&v))
.map(|v| serde_json::from_slice(&v))
{ {
PickleKey::from_encrypted(passphrase, key?) PickleKey::from_encrypted(passphrase, key?)
.map_err(|_| CryptoStoreError::UnpicklingError)? .map_err(|_| CryptoStoreError::UnpicklingError)?
@ -297,9 +287,7 @@ impl SledStore {
&self, &self,
room_id: &RoomId, room_id: &RoomId,
) -> Result<Option<OutboundGroupSession>> { ) -> Result<Option<OutboundGroupSession>> {
let account_info = self let account_info = self.get_account_info().ok_or(CryptoStoreError::AccountUnset)?;
.get_account_info()
.ok_or(CryptoStoreError::AccountUnset)?;
self.outbound_group_sessions self.outbound_group_sessions
.get(room_id.encode())? .get(room_id.encode())?
@ -501,17 +489,11 @@ impl SledStore {
&self, &self,
id: &[u8], id: &[u8],
) -> Result<Option<OutgoingKeyRequest>> { ) -> Result<Option<OutgoingKeyRequest>> {
let request = self let request =
.outgoing_key_requests self.outgoing_key_requests.get(id)?.map(|r| serde_json::from_slice(&r)).transpose()?;
.get(id)?
.map(|r| serde_json::from_slice(&r))
.transpose()?;
let request = if request.is_none() { let request = if request.is_none() {
self.unsent_key_requests self.unsent_key_requests.get(id)?.map(|r| serde_json::from_slice(&r)).transpose()?
.get(id)?
.map(|r| serde_json::from_slice(&r))
.transpose()?
} else { } else {
request request
}; };
@ -553,10 +535,7 @@ impl CryptoStore for SledStore {
*self.account_info.write().unwrap() = Some(account_info); *self.account_info.write().unwrap() = Some(account_info);
let changes = Changes { let changes = Changes { account: Some(account), ..Default::default() };
account: Some(account),
..Default::default()
};
self.save_changes(changes).await self.save_changes(changes).await
} }
@ -579,9 +558,7 @@ impl CryptoStore for SledStore {
} }
async fn get_sessions(&self, sender_key: &str) -> Result<Option<Arc<Mutex<Vec<Session>>>>> { async fn get_sessions(&self, sender_key: &str) -> Result<Option<Arc<Mutex<Vec<Session>>>>> {
let account_info = self let account_info = self.get_account_info().ok_or(CryptoStoreError::AccountUnset)?;
.get_account_info()
.ok_or(CryptoStoreError::AccountUnset)?;
if self.session_cache.get(sender_key).is_none() { if self.session_cache.get(sender_key).is_none() {
let sessions: Result<Vec<Session>> = self let sessions: Result<Vec<Session>> = self
@ -613,16 +590,10 @@ impl CryptoStore for SledStore {
session_id: &str, session_id: &str,
) -> Result<Option<InboundGroupSession>> { ) -> Result<Option<InboundGroupSession>> {
let key = (room_id.as_str(), sender_key, session_id).encode(); let key = (room_id.as_str(), sender_key, session_id).encode();
let pickle = self let pickle = self.inbound_group_sessions.get(&key)?.map(|p| serde_json::from_slice(&p));
.inbound_group_sessions
.get(&key)?
.map(|p| serde_json::from_slice(&p));
if let Some(pickle) = pickle { if let Some(pickle) = pickle {
Ok(Some(InboundGroupSession::from_pickle( Ok(Some(InboundGroupSession::from_pickle(pickle?, self.get_pickle_mode())?))
pickle?,
self.get_pickle_mode(),
)?))
} else { } else {
Ok(None) Ok(None)
} }
@ -658,10 +629,7 @@ impl CryptoStore for SledStore {
fn users_for_key_query(&self) -> HashSet<UserId> { fn users_for_key_query(&self) -> HashSet<UserId> {
#[allow(clippy::map_clone)] #[allow(clippy::map_clone)]
self.users_for_key_query_cache self.users_for_key_query_cache.iter().map(|u| u.clone()).collect()
.iter()
.map(|u| u.clone())
.collect()
} }
async fn update_tracked_user(&self, user: &UserId, dirty: bool) -> Result<bool> { async fn update_tracked_user(&self, user: &UserId, dirty: bool) -> Result<bool> {
@ -715,9 +683,7 @@ impl CryptoStore for SledStore {
} }
async fn is_message_known(&self, message_hash: &crate::olm::OlmMessageHash) -> Result<bool> { async fn is_message_known(&self, message_hash: &crate::olm::OlmMessageHash) -> Result<bool> {
Ok(self Ok(self.olm_hashes.contains_key(serde_json::to_vec(message_hash)?)?)
.olm_hashes
.contains_key(serde_json::to_vec(message_hash)?)?)
} }
async fn get_outgoing_key_request( async fn get_outgoing_key_request(
@ -753,36 +719,33 @@ impl CryptoStore for SledStore {
} }
async fn delete_outgoing_key_request(&self, request_id: Uuid) -> Result<()> { async fn delete_outgoing_key_request(&self, request_id: Uuid) -> Result<()> {
let ret: Result<(), TransactionError<serde_json::Error>> = ( let ret: Result<(), TransactionError<serde_json::Error>> =
&self.outgoing_key_requests, (&self.outgoing_key_requests, &self.unsent_key_requests, &self.key_requests_by_info)
&self.unsent_key_requests, .transaction(
&self.key_requests_by_info, |(outgoing_key_requests, unsent_key_requests, key_requests_by_info)| {
) let sent_request: Option<OutgoingKeyRequest> = outgoing_key_requests
.transaction( .remove(request_id.encode())?
|(outgoing_key_requests, unsent_key_requests, key_requests_by_info)| { .map(|r| serde_json::from_slice(&r))
let sent_request: Option<OutgoingKeyRequest> = outgoing_key_requests .transpose()
.remove(request_id.encode())? .map_err(ConflictableTransactionError::Abort)?;
.map(|r| serde_json::from_slice(&r))
.transpose()
.map_err(ConflictableTransactionError::Abort)?;
let unsent_request: Option<OutgoingKeyRequest> = unsent_key_requests let unsent_request: Option<OutgoingKeyRequest> = unsent_key_requests
.remove(request_id.encode())? .remove(request_id.encode())?
.map(|r| serde_json::from_slice(&r)) .map(|r| serde_json::from_slice(&r))
.transpose() .transpose()
.map_err(ConflictableTransactionError::Abort)?; .map_err(ConflictableTransactionError::Abort)?;
if let Some(request) = sent_request { if let Some(request) = sent_request {
key_requests_by_info.remove((&request.info).encode())?; key_requests_by_info.remove((&request.info).encode())?;
} }
if let Some(request) = unsent_request { if let Some(request) = unsent_request {
key_requests_by_info.remove((&request.info).encode())?; key_requests_by_info.remove((&request.info).encode())?;
} }
Ok(()) Ok(())
}, },
); );
ret?; ret?;
self.inner.flush_async().await?; self.inner.flush_async().await?;
@ -793,6 +756,19 @@ impl CryptoStore for SledStore {
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use std::collections::BTreeMap;
use matrix_sdk_common::{
api::r0::keys::SignedKey,
events::room_key_request::RequestedKeyInfo,
identifiers::{room_id, user_id, DeviceId, EventEncryptionAlgorithm, UserId},
uuid::Uuid,
};
use matrix_sdk_test::async_test;
use olm_rs::outbound_group_session::OlmOutboundGroupSession;
use tempfile::tempdir;
use super::{CryptoStore, OutgoingKeyRequest, SledStore};
use crate::{ use crate::{
identities::{ identities::{
device::test::get_device, device::test::get_device,
@ -804,18 +780,6 @@ mod test {
}, },
store::{Changes, DeviceChanges, IdentityChanges}, store::{Changes, DeviceChanges, IdentityChanges},
}; };
use matrix_sdk_common::{
api::r0::keys::SignedKey,
events::room_key_request::RequestedKeyInfo,
identifiers::{room_id, user_id, DeviceId, EventEncryptionAlgorithm, UserId},
uuid::Uuid,
};
use matrix_sdk_test::async_test;
use olm_rs::outbound_group_session::OlmOutboundGroupSession;
use std::collections::BTreeMap;
use tempfile::tempdir;
use super::{CryptoStore, OutgoingKeyRequest, SledStore};
fn alice_id() -> UserId { fn alice_id() -> UserId {
user_id!("@alice:example.org") user_id!("@alice:example.org")
@ -846,10 +810,7 @@ mod test {
async fn get_loaded_store() -> (ReadOnlyAccount, SledStore, tempfile::TempDir) { async fn get_loaded_store() -> (ReadOnlyAccount, SledStore, tempfile::TempDir) {
let (store, dir) = get_store(None).await; let (store, dir) = get_store(None).await;
let account = get_account(); let account = get_account();
store store.save_account(account.clone()).await.expect("Can't save account");
.save_account(account.clone())
.await
.expect("Can't save account");
(account, store, dir) (account, store, dir)
} }
@ -863,21 +824,12 @@ mod test {
let bob = ReadOnlyAccount::new(&bob_id(), &bob_device_id()); let bob = ReadOnlyAccount::new(&bob_id(), &bob_device_id());
bob.generate_one_time_keys_helper(1).await; bob.generate_one_time_keys_helper(1).await;
let one_time_key = bob let one_time_key =
.one_time_keys() bob.one_time_keys().await.curve25519().iter().next().unwrap().1.to_owned();
.await
.curve25519()
.iter()
.next()
.unwrap()
.1
.to_owned();
let one_time_key = SignedKey::new(one_time_key, BTreeMap::new()); let one_time_key = SignedKey::new(one_time_key, BTreeMap::new());
let sender_key = bob.identity_keys().curve25519().to_owned(); let sender_key = bob.identity_keys().curve25519().to_owned();
let session = alice let session =
.create_outbound_session_helper(&sender_key, &one_time_key) alice.create_outbound_session_helper(&sender_key, &one_time_key).await.unwrap();
.await
.unwrap();
(alice, session) (alice, session)
} }
@ -895,10 +847,7 @@ mod test {
assert!(store.load_account().await.unwrap().is_none()); assert!(store.load_account().await.unwrap().is_none());
let account = get_account(); let account = get_account();
store store.save_account(account).await.expect("Can't save account");
.save_account(account)
.await
.expect("Can't save account");
} }
#[async_test] #[async_test]
@ -906,10 +855,7 @@ mod test {
let (store, _dir) = get_store(None).await; let (store, _dir) = get_store(None).await;
let account = get_account(); let account = get_account();
store store.save_account(account.clone()).await.expect("Can't save account");
.save_account(account.clone())
.await
.expect("Can't save account");
let loaded_account = store.load_account().await.expect("Can't load account"); let loaded_account = store.load_account().await.expect("Can't load account");
let loaded_account = loaded_account.unwrap(); let loaded_account = loaded_account.unwrap();
@ -922,10 +868,7 @@ mod test {
let (store, _dir) = get_store(Some("secret_passphrase")).await; let (store, _dir) = get_store(Some("secret_passphrase")).await;
let account = get_account(); let account = get_account();
store store.save_account(account.clone()).await.expect("Can't save account");
.save_account(account.clone())
.await
.expect("Can't save account");
let loaded_account = store.load_account().await.expect("Can't load account"); let loaded_account = store.load_account().await.expect("Can't load account");
let loaded_account = loaded_account.unwrap(); let loaded_account = loaded_account.unwrap();
@ -938,50 +881,32 @@ mod test {
let (store, _dir) = get_store(None).await; let (store, _dir) = get_store(None).await;
let account = get_account(); let account = get_account();
store store.save_account(account.clone()).await.expect("Can't save account");
.save_account(account.clone())
.await
.expect("Can't save account");
account.mark_as_shared(); account.mark_as_shared();
account.update_uploaded_key_count(50); account.update_uploaded_key_count(50);
store store.save_account(account.clone()).await.expect("Can't save account");
.save_account(account.clone())
.await
.expect("Can't save account");
let loaded_account = store.load_account().await.expect("Can't load account"); let loaded_account = store.load_account().await.expect("Can't load account");
let loaded_account = loaded_account.unwrap(); let loaded_account = loaded_account.unwrap();
assert_eq!(account, loaded_account); assert_eq!(account, loaded_account);
assert_eq!( assert_eq!(account.uploaded_key_count(), loaded_account.uploaded_key_count());
account.uploaded_key_count(),
loaded_account.uploaded_key_count()
);
} }
#[async_test] #[async_test]
async fn load_sessions() { async fn load_sessions() {
let (store, _dir) = get_store(None).await; let (store, _dir) = get_store(None).await;
let (account, session) = get_account_and_session().await; let (account, session) = get_account_and_session().await;
store store.save_account(account.clone()).await.expect("Can't save account");
.save_account(account.clone())
.await
.expect("Can't save account");
let changes = Changes { let changes = Changes { sessions: vec![session.clone()], ..Default::default() };
sessions: vec![session.clone()],
..Default::default()
};
store.save_changes(changes).await.unwrap(); store.save_changes(changes).await.unwrap();
let sessions = store let sessions =
.get_sessions(&session.sender_key) store.get_sessions(&session.sender_key).await.expect("Can't load sessions").unwrap();
.await
.expect("Can't load sessions")
.unwrap();
let loaded_session = sessions.lock().await.get(0).cloned().unwrap(); let loaded_session = sessions.lock().await.get(0).cloned().unwrap();
assert_eq!(&session, &loaded_session); assert_eq!(&session, &loaded_session);
@ -994,15 +919,9 @@ mod test {
let sender_key = session.sender_key.to_owned(); let sender_key = session.sender_key.to_owned();
let session_id = session.session_id().to_owned(); let session_id = session.session_id().to_owned();
store store.save_account(account.clone()).await.expect("Can't save account");
.save_account(account.clone())
.await
.expect("Can't save account");
let changes = Changes { let changes = Changes { sessions: vec![session.clone()], ..Default::default() };
sessions: vec![session.clone()],
..Default::default()
};
store.save_changes(changes).await.unwrap(); store.save_changes(changes).await.unwrap();
let sessions = store.get_sessions(&sender_key).await.unwrap().unwrap(); let sessions = store.get_sessions(&sender_key).await.unwrap().unwrap();
@ -1040,15 +959,9 @@ mod test {
) )
.expect("Can't create session"); .expect("Can't create session");
let changes = Changes { let changes = Changes { inbound_group_sessions: vec![session], ..Default::default() };
inbound_group_sessions: vec![session],
..Default::default()
};
store store.save_changes(changes).await.expect("Can't save group session");
.save_changes(changes)
.await
.expect("Can't save group session");
} }
#[async_test] #[async_test]
@ -1072,15 +985,10 @@ mod test {
let session = InboundGroupSession::from_export(export).unwrap(); let session = InboundGroupSession::from_export(export).unwrap();
let changes = Changes { let changes =
inbound_group_sessions: vec![session.clone()], Changes { inbound_group_sessions: vec![session.clone()], ..Default::default() };
..Default::default()
};
store store.save_changes(changes).await.expect("Can't save group session");
.save_changes(changes)
.await
.expect("Can't save group session");
drop(store); drop(store);
@ -1103,21 +1011,12 @@ mod test {
let (_account, store, dir) = get_loaded_store().await; let (_account, store, dir) = get_loaded_store().await;
let device = get_device(); let device = get_device();
assert!(store assert!(store.update_tracked_user(device.user_id(), false).await.unwrap());
.update_tracked_user(device.user_id(), false) assert!(!store.update_tracked_user(device.user_id(), false).await.unwrap());
.await
.unwrap());
assert!(!store
.update_tracked_user(device.user_id(), false)
.await
.unwrap());
assert!(store.is_user_tracked(device.user_id())); assert!(store.is_user_tracked(device.user_id()));
assert!(!store.users_for_key_query().contains(device.user_id())); assert!(!store.users_for_key_query().contains(device.user_id()));
assert!(!store assert!(!store.update_tracked_user(device.user_id(), true).await.unwrap());
.update_tracked_user(device.user_id(), true)
.await
.unwrap());
assert!(store.users_for_key_query().contains(device.user_id())); assert!(store.users_for_key_query().contains(device.user_id()));
drop(store); drop(store);
@ -1128,10 +1027,7 @@ mod test {
assert!(store.is_user_tracked(device.user_id())); assert!(store.is_user_tracked(device.user_id()));
assert!(store.users_for_key_query().contains(device.user_id())); assert!(store.users_for_key_query().contains(device.user_id()));
store store.update_tracked_user(device.user_id(), false).await.unwrap();
.update_tracked_user(device.user_id(), false)
.await
.unwrap();
assert!(!store.users_for_key_query().contains(device.user_id())); assert!(!store.users_for_key_query().contains(device.user_id()));
drop(store); drop(store);
@ -1148,10 +1044,7 @@ mod test {
let device = get_device(); let device = get_device();
let changes = Changes { let changes = Changes {
devices: DeviceChanges { devices: DeviceChanges { changed: vec![device.clone()], ..Default::default() },
changed: vec![device.clone()],
..Default::default()
},
..Default::default() ..Default::default()
}; };
@ -1163,11 +1056,8 @@ mod test {
store.load_account().await.unwrap(); store.load_account().await.unwrap();
let loaded_device = store let loaded_device =
.get_device(device.user_id(), device.device_id()) store.get_device(device.user_id(), device.device_id()).await.unwrap().unwrap();
.await
.unwrap()
.unwrap();
assert_eq!(device, loaded_device); assert_eq!(device, loaded_device);
@ -1188,20 +1078,14 @@ mod test {
let device = get_device(); let device = get_device();
let changes = Changes { let changes = Changes {
devices: DeviceChanges { devices: DeviceChanges { changed: vec![device.clone()], ..Default::default() },
changed: vec![device.clone()],
..Default::default()
},
..Default::default() ..Default::default()
}; };
store.save_changes(changes).await.unwrap(); store.save_changes(changes).await.unwrap();
let changes = Changes { let changes = Changes {
devices: DeviceChanges { devices: DeviceChanges { deleted: vec![device.clone()], ..Default::default() },
deleted: vec![device.clone()],
..Default::default()
},
..Default::default() ..Default::default()
}; };
@ -1212,10 +1096,7 @@ mod test {
store.load_account().await.unwrap(); store.load_account().await.unwrap();
let loaded_device = store let loaded_device = store.get_device(device.user_id(), device.device_id()).await.unwrap();
.get_device(device.user_id(), device.device_id())
.await
.unwrap();
assert!(loaded_device.is_none()); assert!(loaded_device.is_none());
} }
@ -1232,10 +1113,7 @@ mod test {
let account = ReadOnlyAccount::new(&user_id, &device_id); let account = ReadOnlyAccount::new(&user_id, &device_id);
store store.save_account(account.clone()).await.expect("Can't save account");
.save_account(account.clone())
.await
.expect("Can't save account");
let own_identity = get_own_identity(); let own_identity = get_own_identity();
@ -1247,10 +1125,7 @@ mod test {
..Default::default() ..Default::default()
}; };
store store.save_changes(changes).await.expect("Can't save identity");
.save_changes(changes)
.await
.expect("Can't save identity");
drop(store); drop(store);
@ -1258,17 +1133,10 @@ mod test {
store.load_account().await.unwrap(); store.load_account().await.unwrap();
let loaded_user = store let loaded_user = store.get_user_identity(own_identity.user_id()).await.unwrap().unwrap();
.get_user_identity(own_identity.user_id())
.await
.unwrap()
.unwrap();
assert_eq!(loaded_user.master_key(), own_identity.master_key()); assert_eq!(loaded_user.master_key(), own_identity.master_key());
assert_eq!( assert_eq!(loaded_user.self_signing_key(), own_identity.self_signing_key());
loaded_user.self_signing_key(),
own_identity.self_signing_key()
);
assert_eq!(loaded_user, own_identity.clone().into()); assert_eq!(loaded_user, own_identity.clone().into());
let other_identity = get_other_identity(); let other_identity = get_other_identity();
@ -1283,17 +1151,10 @@ mod test {
store.save_changes(changes).await.unwrap(); store.save_changes(changes).await.unwrap();
let loaded_user = store let loaded_user = store.get_user_identity(other_identity.user_id()).await.unwrap().unwrap();
.get_user_identity(other_identity.user_id())
.await
.unwrap()
.unwrap();
assert_eq!(loaded_user.master_key(), other_identity.master_key()); assert_eq!(loaded_user.master_key(), other_identity.master_key());
assert_eq!( assert_eq!(loaded_user.self_signing_key(), other_identity.self_signing_key());
loaded_user.self_signing_key(),
other_identity.self_signing_key()
);
assert_eq!(loaded_user, other_identity.into()); assert_eq!(loaded_user, other_identity.into());
own_identity.mark_as_verified(); own_identity.mark_as_verified();
@ -1317,10 +1178,7 @@ mod test {
assert!(store.load_identity().await.unwrap().is_none()); assert!(store.load_identity().await.unwrap().is_none());
let identity = PrivateCrossSigningIdentity::new(alice_id()).await; let identity = PrivateCrossSigningIdentity::new(alice_id()).await;
let changes = Changes { let changes = Changes { private_identity: Some(identity.clone()), ..Default::default() };
private_identity: Some(identity.clone()),
..Default::default()
};
store.save_changes(changes).await.unwrap(); store.save_changes(changes).await.unwrap();
let loaded_identity = store.load_identity().await.unwrap().unwrap(); let loaded_identity = store.load_identity().await.unwrap().unwrap();
@ -1331,10 +1189,8 @@ mod test {
async fn olm_hash_saving() { async fn olm_hash_saving() {
let (_, store, _dir) = get_loaded_store().await; let (_, store, _dir) = get_loaded_store().await;
let hash = OlmMessageHash { let hash =
sender_key: "test_sender".to_owned(), OlmMessageHash { sender_key: "test_sender".to_owned(), hash: "test_hash".to_owned() };
hash: "test_hash".to_owned(),
};
let mut changes = Changes::default(); let mut changes = Changes::default();
changes.message_hashes.push(hash.clone()); changes.message_hashes.push(hash.clone());

View File

@ -15,9 +15,6 @@
use std::{convert::TryFrom, sync::Arc}; use std::{convert::TryFrom, sync::Arc};
use dashmap::DashMap; use dashmap::DashMap;
use tracing::{info, trace, warn};
use matrix_sdk_common::{ use matrix_sdk_common::{
events::{ events::{
room::message::MessageType, AnyMessageEvent, AnySyncMessageEvent, AnySyncRoomEvent, room::message::MessageType, AnyMessageEvent, AnySyncMessageEvent, AnySyncRoomEvent,
@ -27,12 +24,12 @@ use matrix_sdk_common::{
locks::Mutex, locks::Mutex,
uuid::Uuid, uuid::Uuid,
}; };
use tracing::{info, trace, warn};
use super::{ use super::{
requests::VerificationRequest, requests::VerificationRequest,
sas::{content_to_request, OutgoingContent, Sas, VerificationResult}, sas::{content_to_request, OutgoingContent, Sas, VerificationResult},
}; };
use crate::{ use crate::{
olm::PrivateCrossSigningIdentity, olm::PrivateCrossSigningIdentity,
requests::OutgoingRequest, requests::OutgoingRequest,
@ -85,18 +82,14 @@ impl VerificationMachine {
); );
let request = match content.into() { let request = match content.into() {
OutgoingContent::Room(r, c) => RoomMessageRequest { OutgoingContent::Room(r, c) => {
room_id: r, RoomMessageRequest { room_id: r, txn_id: Uuid::new_v4(), content: c }.into()
txn_id: Uuid::new_v4(),
content: c,
} }
.into(),
OutgoingContent::ToDevice(c) => { OutgoingContent::ToDevice(c) => {
let request = let request =
content_to_request(device.user_id(), device.device_id().to_owned(), c); content_to_request(device.user_id(), device.device_id().to_owned(), c);
self.verifications self.verifications.insert(sas.flow_id().as_str().to_owned(), sas.clone());
.insert(sas.flow_id().as_str().to_owned(), sas.clone());
request.into() request.into()
} }
@ -136,10 +129,7 @@ impl VerificationMachine {
let request = content_to_request(recipient, recipient_device.to_owned(), c); let request = content_to_request(recipient, recipient_device.to_owned(), c);
let request_id = request.txn_id; let request_id = request.txn_id;
let request = OutgoingRequest { let request = OutgoingRequest { request_id, request: Arc::new(request.into()) };
request_id,
request: Arc::new(request.into()),
};
self.outgoing_messages.insert(request_id, request); self.outgoing_messages.insert(request_id, request);
} }
@ -149,12 +139,7 @@ impl VerificationMachine {
let request = OutgoingRequest { let request = OutgoingRequest {
request: Arc::new( request: Arc::new(
RoomMessageRequest { RoomMessageRequest { room_id: r, txn_id: request_id, content: c }.into(),
room_id: r,
txn_id: request_id,
content: c,
}
.into(),
), ),
request_id, request_id,
}; };
@ -181,24 +166,17 @@ impl VerificationMachine {
} }
pub fn outgoing_messages(&self) -> Vec<OutgoingRequest> { pub fn outgoing_messages(&self) -> Vec<OutgoingRequest> {
self.outgoing_messages self.outgoing_messages.iter().map(|r| (*r).clone()).collect()
.iter()
.map(|r| (*r).clone())
.collect()
} }
pub fn garbage_collect(&self) { pub fn garbage_collect(&self) {
self.verifications self.verifications.retain(|_, s| !(s.is_done() || s.is_canceled()));
.retain(|_, s| !(s.is_done() || s.is_canceled()));
for sas in self.verifications.iter() { for sas in self.verifications.iter() {
if let Some(r) = sas.cancel_if_timed_out() { if let Some(r) = sas.cancel_if_timed_out() {
self.outgoing_messages.insert( self.outgoing_messages.insert(
r.request_id(), r.request_id(),
OutgoingRequest { OutgoingRequest { request_id: r.request_id(), request: Arc::new(r.into()) },
request_id: r.request_id(),
request: Arc::new(r.into()),
},
); );
} }
} }
@ -239,8 +217,7 @@ impl VerificationMachine {
r, r,
); );
self.requests self.requests.insert(request.flow_id().as_str().to_owned(), request);
.insert(request.flow_id().as_str().to_owned(), request);
} }
} }
} }
@ -261,10 +238,8 @@ impl VerificationMachine {
if let Some((_, request)) = if let Some((_, request)) =
self.requests.remove(e.content.relation.event_id.as_str()) self.requests.remove(e.content.relation.event_id.as_str())
{ {
if let Some(d) = self if let Some(d) =
.store self.store.get_device(&e.sender, &e.content.from_device).await?
.get_device(&e.sender, &e.content.from_device)
.await?
{ {
match request.into_started_sas( match request.into_started_sas(
e, e,
@ -370,8 +345,7 @@ impl VerificationMachine {
&e.content, &e.content,
); );
self.requests self.requests.insert(request.flow_id().as_str().to_string(), request);
.insert(request.flow_id().as_str().to_string(), request);
} }
AnyToDeviceEvent::KeyVerificationReady(e) => { AnyToDeviceEvent::KeyVerificationReady(e) => {
if let Some(request) = self.requests.get(&e.content.transaction_id) { if let Some(request) = self.requests.get(&e.content.transaction_id) {
@ -388,11 +362,7 @@ impl VerificationMachine {
e.content.from_device e.content.from_device
); );
if let Some(d) = self if let Some(d) = self.store.get_device(&e.sender, &e.content.from_device).await? {
.store
.get_device(&e.sender, &e.content.from_device)
.await?
{
let private_identity = self.private_identity.lock().await.clone(); let private_identity = self.private_identity.lock().await.clone();
match Sas::from_start_event( match Sas::from_start_event(
self.account.clone(), self.account.clone(),
@ -403,8 +373,7 @@ impl VerificationMachine {
self.store.get_user_identity(&e.sender).await?, self.store.get_user_identity(&e.sender).await?,
) { ) {
Ok(s) => { Ok(s) => {
self.verifications self.verifications.insert(e.content.transaction_id.clone(), s);
.insert(e.content.transaction_id.clone(), s);
} }
Err(c) => { Err(c) => {
warn!( warn!(
@ -455,10 +424,7 @@ impl VerificationMachine {
self.outgoing_messages.insert( self.outgoing_messages.insert(
request_id, request_id,
OutgoingRequest { OutgoingRequest { request_id, request: Arc::new(r.into()) },
request_id,
request: Arc::new(r.into()),
},
); );
} }
} }
@ -535,10 +501,7 @@ mod test {
); );
machine machine
.receive_event(&wrap_any_to_device_content( .receive_event(&wrap_any_to_device_content(bob_sas.user_id(), start_content.into()))
bob_sas.user_id(),
start_content.into(),
))
.await .await
.unwrap(); .unwrap();

View File

@ -17,11 +17,10 @@ mod requests;
mod sas; mod sas;
pub use machine::VerificationMachine; pub use machine::VerificationMachine;
use matrix_sdk_common::identifiers::{EventId, RoomId};
pub use requests::VerificationRequest; pub use requests::VerificationRequest;
pub use sas::{AcceptSettings, Sas, VerificationResult}; pub use sas::{AcceptSettings, Sas, VerificationResult};
use matrix_sdk_common::identifiers::{EventId, RoomId};
#[derive(Clone, Debug, Hash, PartialEq, PartialOrd)] #[derive(Clone, Debug, Hash, PartialEq, PartialOrd)]
pub enum FlowId { pub enum FlowId {
ToDevice(String), ToDevice(String),
@ -59,18 +58,17 @@ impl From<(RoomId, EventId)> for FlowId {
#[cfg(test)] #[cfg(test)]
pub(crate) mod test { pub(crate) mod test {
use crate::{
requests::{OutgoingRequest, OutgoingRequests},
OutgoingVerificationRequest,
};
use serde_json::Value;
use matrix_sdk_common::{ use matrix_sdk_common::{
events::{AnyToDeviceEvent, AnyToDeviceEventContent, EventType, ToDeviceEvent}, events::{AnyToDeviceEvent, AnyToDeviceEventContent, EventType, ToDeviceEvent},
identifiers::UserId, identifiers::UserId,
}; };
use serde_json::Value;
use super::sas::OutgoingContent; use super::sas::OutgoingContent;
use crate::{
requests::{OutgoingRequest, OutgoingRequests},
OutgoingVerificationRequest,
};
pub(crate) fn request_to_event( pub(crate) fn request_to_event(
sender: &UserId, sender: &UserId,
@ -94,11 +92,7 @@ pub(crate) mod test {
sender: &UserId, sender: &UserId,
content: OutgoingContent, content: OutgoingContent,
) -> AnyToDeviceEvent { ) -> AnyToDeviceEvent {
let content = if let OutgoingContent::ToDevice(c) = content { let content = if let OutgoingContent::ToDevice(c) = content { c } else { unreachable!() };
c
} else {
unreachable!()
};
match content { match content {
AnyToDeviceEventContent::KeyVerificationKey(c) => { AnyToDeviceEventContent::KeyVerificationKey(c) => {
@ -133,22 +127,11 @@ pub(crate) mod test {
pub(crate) fn get_content_from_request( pub(crate) fn get_content_from_request(
request: &OutgoingVerificationRequest, request: &OutgoingVerificationRequest,
) -> OutgoingContent { ) -> OutgoingContent {
let request = if let OutgoingVerificationRequest::ToDevice(r) = request { let request =
r if let OutgoingVerificationRequest::ToDevice(r) = request { r } else { unreachable!() };
} else {
unreachable!()
};
let json: Value = serde_json::from_str( let json: Value = serde_json::from_str(
request request.messages.values().next().unwrap().values().next().unwrap().get(),
.messages
.values()
.next()
.unwrap()
.values()
.next()
.unwrap()
.get(),
) )
.unwrap(); .unwrap();

View File

@ -35,6 +35,10 @@ use matrix_sdk_common::{
uuid::Uuid, uuid::Uuid,
}; };
use super::{
sas::{content_to_request, OutgoingContent, StartContent},
FlowId,
};
use crate::{ use crate::{
olm::{PrivateCrossSigningIdentity, ReadOnlyAccount}, olm::{PrivateCrossSigningIdentity, ReadOnlyAccount},
store::CryptoStore, store::CryptoStore,
@ -42,11 +46,6 @@ use crate::{
UserIdentities, UserIdentities,
}; };
use super::{
sas::{content_to_request, OutgoingContent, StartContent},
FlowId,
};
const SUPPORTED_METHODS: &[VerificationMethod] = &[VerificationMethod::MSasV1]; const SUPPORTED_METHODS: &[VerificationMethod] = &[VerificationMethod::MSasV1];
pub enum RequestContent<'a> { pub enum RequestContent<'a> {
@ -256,15 +255,13 @@ impl VerificationRequest {
content: RequestContent, content: RequestContent,
) -> Self { ) -> Self {
Self { Self {
inner: Arc::new(Mutex::new(InnerRequest::Requested( inner: Arc::new(Mutex::new(InnerRequest::Requested(RequestState::from_request_event(
RequestState::from_request_event( account.user_id(),
account.user_id(), account.device_id(),
account.device_id(), sender,
sender, &flow_id,
&flow_id, content,
content, )))),
),
))),
account, account,
other_user_id: sender.clone().into(), other_user_id: sender.clone().into(),
private_cross_signing_identity, private_cross_signing_identity,
@ -278,15 +275,12 @@ impl VerificationRequest {
let mut inner = self.inner.lock().unwrap(); let mut inner = self.inner.lock().unwrap();
inner.accept().map(|c| match c { inner.accept().map(|c| match c {
OutgoingContent::ToDevice(content) => self OutgoingContent::ToDevice(content) => {
.content_to_request(inner.other_device_id(), content) self.content_to_request(inner.other_device_id(), content).into()
.into(), }
OutgoingContent::Room(room_id, content) => RoomMessageRequest { OutgoingContent::Room(room_id, content) => {
room_id, RoomMessageRequest { room_id, txn_id: Uuid::new_v4(), content }.into()
txn_id: Uuid::new_v4(),
content,
} }
.into(),
}) })
} }
@ -445,10 +439,7 @@ impl RequestState<Created> {
own_user_id: own_user_id.to_owned(), own_user_id: own_user_id.to_owned(),
own_device_id: own_device_id.to_owned(), own_device_id: own_device_id.to_owned(),
other_user_id: other_user_id.to_owned(), other_user_id: other_user_id.to_owned(),
state: Created { state: Created { methods: SUPPORTED_METHODS.to_vec(), flow_id: flow_id.to_owned() },
methods: SUPPORTED_METHODS.to_vec(),
flow_id: flow_id.to_owned(),
},
} }
} }
@ -589,14 +580,9 @@ impl RequestState<Ready> {
other_identity: Option<UserIdentities>, other_identity: Option<UserIdentities>,
) -> (Sas, StartContent) { ) -> (Sas, StartContent) {
match self.state.flow_id { match self.state.flow_id {
FlowId::ToDevice(t) => Sas::start( FlowId::ToDevice(t) => {
account, Sas::start(account, private_identity, other_device, store, other_identity, Some(t))
private_identity, }
other_device,
store,
other_identity,
Some(t),
),
FlowId::InRoom(r, e) => Sas::start_in_room( FlowId::InRoom(r, e) => Sas::start_in_room(
e, e,
r, r,
@ -630,6 +616,7 @@ mod test {
}; };
use matrix_sdk_test::async_test; use matrix_sdk_test::async_test;
use super::VerificationRequest;
use crate::{ use crate::{
olm::{PrivateCrossSigningIdentity, ReadOnlyAccount}, olm::{PrivateCrossSigningIdentity, ReadOnlyAccount},
store::{CryptoStore, MemoryStore}, store::{CryptoStore, MemoryStore},
@ -640,8 +627,6 @@ mod test {
ReadOnlyDevice, ReadOnlyDevice,
}; };
use super::VerificationRequest;
fn alice_id() -> UserId { fn alice_id() -> UserId {
UserId::try_from("@alice:example.org").unwrap() UserId::try_from("@alice:example.org").unwrap()
} }
@ -760,9 +745,7 @@ mod test {
panic!("Invalid start event content type"); panic!("Invalid start event content type");
}; };
let alice_sas = alice_request let alice_sas = alice_request.into_started_sas(&event, bob_device, None).unwrap();
.into_started_sas(&event, bob_device, None)
.unwrap();
assert!(!bob_sas.is_canceled()); assert!(!bob_sas.is_canceled());
assert!(!alice_sas.is_canceled()); assert!(!alice_sas.is_canceled());

View File

@ -59,10 +59,7 @@ impl StartContent {
StartContent::Room(_, c) => serde_json::to_value(c), StartContent::Room(_, c) => serde_json::to_value(c),
}; };
content content.expect("Can't serialize content").try_into().expect("Can't canonicalize content")
.expect("Can't serialize content")
.try_into()
.expect("Can't canonicalize content")
} }
} }
@ -287,14 +284,7 @@ impl From<OutgoingVerificationRequest> for OutgoingContent {
match request { match request {
OutgoingVerificationRequest::ToDevice(r) => { OutgoingVerificationRequest::ToDevice(r) => {
let json: Value = serde_json::from_str( let json: Value = serde_json::from_str(
r.messages r.messages.values().next().unwrap().values().next().unwrap().get(),
.values()
.next()
.unwrap()
.values()
.next()
.unwrap()
.get(),
) )
.unwrap(); .unwrap();

View File

@ -14,11 +14,6 @@
use std::{collections::BTreeMap, convert::TryInto}; use std::{collections::BTreeMap, convert::TryInto};
use sha2::{Digest, Sha256};
use tracing::{trace, warn};
use olm_rs::sas::OlmSas;
use matrix_sdk_common::{ use matrix_sdk_common::{
api::r0::to_device::DeviceIdOrAllDevices, api::r0::to_device::DeviceIdOrAllDevices,
events::{ events::{
@ -32,17 +27,19 @@ use matrix_sdk_common::{
identifiers::{DeviceKeyAlgorithm, DeviceKeyId, UserId}, identifiers::{DeviceKeyAlgorithm, DeviceKeyId, UserId},
uuid::Uuid, uuid::Uuid,
}; };
use olm_rs::sas::OlmSas;
use crate::{ use sha2::{Digest, Sha256};
identities::{ReadOnlyDevice, UserIdentities}, use tracing::{trace, warn};
utilities::encode,
ReadOnlyAccount, ToDeviceRequest,
};
use super::{ use super::{
event_enums::{MacContent, StartContent}, event_enums::{MacContent, StartContent},
FlowId, FlowId,
}; };
use crate::{
identities::{ReadOnlyDevice, UserIdentities},
utilities::encode,
ReadOnlyAccount, ToDeviceRequest,
};
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct SasIds { pub struct SasIds {
@ -65,12 +62,7 @@ pub fn calculate_commitment(public_key: &str, content: impl Into<StartContent>)
let content = content.into().canonical_json(); let content = content.into().canonical_json();
let content_string = content.to_string(); let content_string = content.to_string();
encode( encode(Sha256::new().chain(&public_key).chain(&content_string).finalize())
Sha256::new()
.chain(&public_key)
.chain(&content_string)
.finalize(),
)
} }
/// Get a tuple of an emoji and a description of the emoji using a number. /// Get a tuple of an emoji and a description of the emoji using a number.
@ -234,11 +226,7 @@ pub fn receive_mac_event(
.calculate_mac(key, &format!("{}{}", info, key_id)) .calculate_mac(key, &format!("{}{}", info, key_id))
.expect("Can't calculate SAS MAC") .expect("Can't calculate SAS MAC")
{ {
trace!( trace!("Successfully verified the device key {} from {}", key_id, sender);
"Successfully verified the device key {} from {}",
key_id,
sender
);
verified_devices.push(ids.other_device.clone()); verified_devices.push(ids.other_device.clone());
} else { } else {
@ -253,11 +241,7 @@ pub fn receive_mac_event(
.calculate_mac(key, &format!("{}{}", info, key_id)) .calculate_mac(key, &format!("{}{}", info, key_id))
.expect("Can't calculate SAS MAC") .expect("Can't calculate SAS MAC")
{ {
trace!( trace!("Successfully verified the master key {} from {}", key_id, sender);
"Successfully verified the master key {} from {}",
key_id,
sender
);
verified_identities.push(identity.clone()) verified_identities.push(identity.clone())
} else { } else {
return Err(CancelCode::KeyMismatch); return Err(CancelCode::KeyMismatch);
@ -319,8 +303,7 @@ pub fn get_mac_content(sas: &OlmSas, ids: &SasIds, flow_id: &FlowId) -> MacConte
mac.insert( mac.insert(
key_id.to_string(), key_id.to_string(),
sas.calculate_mac(key, &format!("{}{}", info, key_id)) sas.calculate_mac(key, &format!("{}{}", info, key_id)).expect("Can't calculate SAS MAC"),
.expect("Can't calculate SAS MAC"),
); );
// TODO Add the cross signing master key here if we trust/have it. // TODO Add the cross signing master key here if we trust/have it.
@ -332,23 +315,13 @@ pub fn get_mac_content(sas: &OlmSas, ids: &SasIds, flow_id: &FlowId) -> MacConte
.expect("Can't calculate SAS MAC"); .expect("Can't calculate SAS MAC");
match flow_id { match flow_id {
FlowId::ToDevice(s) => MacToDeviceEventContent { FlowId::ToDevice(s) => {
transaction_id: s.to_string(), MacToDeviceEventContent { transaction_id: s.to_string(), keys, mac }.into()
keys, }
mac, FlowId::InRoom(r, e) => {
(r.clone(), MacEventContent { mac, keys, relation: Relation { event_id: e.clone() } })
.into()
} }
.into(),
FlowId::InRoom(r, e) => (
r.clone(),
MacEventContent {
mac,
keys,
relation: Relation {
event_id: e.clone(),
},
},
)
.into(),
} }
} }
@ -369,24 +342,12 @@ fn extra_info_sas(
flow_id: &str, flow_id: &str,
we_started: bool, we_started: bool,
) -> String { ) -> String {
let our_info = format!( let our_info = format!("{}|{}|{}", ids.account.user_id(), ids.account.device_id(), own_pubkey);
"{}|{}|{}", let their_info =
ids.account.user_id(), format!("{}|{}|{}", ids.other_device.user_id(), ids.other_device.device_id(), their_pubkey);
ids.account.device_id(),
own_pubkey
);
let their_info = format!(
"{}|{}|{}",
ids.other_device.user_id(),
ids.other_device.device_id(),
their_pubkey
);
let (first_info, second_info) = if we_started { let (first_info, second_info) =
(our_info, their_info) if we_started { (our_info, their_info) } else { (their_info, our_info) };
} else {
(their_info, our_info)
};
let info = format!( let info = format!(
"MATRIX_KEY_VERIFICATION_SAS|{first_info}|{second_info}|{flow_id}", "MATRIX_KEY_VERIFICATION_SAS|{first_info}|{second_info}|{flow_id}",
@ -585,11 +546,7 @@ pub fn content_to_request(
_ => unreachable!(), _ => unreachable!(),
}; };
ToDeviceRequest { ToDeviceRequest { txn_id: Uuid::new_v4(), event_type, messages }
txn_id: Uuid::new_v4(),
event_type,
messages,
}
} }
#[cfg(test)] #[cfg(test)]
@ -627,18 +584,14 @@ mod test {
#[test] #[test]
fn emoji_generation() { fn emoji_generation() {
let bytes = vec![0, 0, 0, 0, 0, 0]; let bytes = vec![0, 0, 0, 0, 0, 0];
let index: Vec<(&'static str, &'static str)> = vec![0, 0, 0, 0, 0, 0, 0] let index: Vec<(&'static str, &'static str)> =
.into_iter() vec![0, 0, 0, 0, 0, 0, 0].into_iter().map(emoji_from_index).collect();
.map(emoji_from_index)
.collect();
assert_eq!(bytes_to_emoji(bytes), index.as_ref()); assert_eq!(bytes_to_emoji(bytes), index.as_ref());
let bytes = vec![0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF]; let bytes = vec![0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF];
let index: Vec<(&'static str, &'static str)> = vec![63, 63, 63, 63, 63, 63, 63] let index: Vec<(&'static str, &'static str)> =
.into_iter() vec![63, 63, 63, 63, 63, 63, 63].into_iter().map(emoji_from_index).collect();
.map(emoji_from_index)
.collect();
assert_eq!(bytes_to_emoji(bytes), index.as_ref()); assert_eq!(bytes_to_emoji(bytes), index.as_ref());
} }

View File

@ -12,11 +12,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::sync::Arc;
#[cfg(test)] #[cfg(test)]
use std::time::Instant; use std::time::Instant;
use std::sync::Arc;
use matrix_sdk_common::{ use matrix_sdk_common::{
events::{ events::{
key::verification::{cancel::CancelCode, ShortAuthenticationString}, key::verification::{cancel::CancelCode, ShortAuthenticationString},
@ -25,11 +24,6 @@ use matrix_sdk_common::{
identifiers::{EventId, RoomId}, identifiers::{EventId, RoomId},
}; };
use crate::{
identities::{ReadOnlyDevice, UserIdentities},
ReadOnlyAccount,
};
use super::{ use super::{
event_enums::{AcceptContent, CancelContent, MacContent, OutgoingContent}, event_enums::{AcceptContent, CancelContent, MacContent, OutgoingContent},
sas_state::{ sas_state::{
@ -38,6 +32,10 @@ use super::{
}, },
FlowId, StartContent, FlowId, StartContent,
}; };
use crate::{
identities::{ReadOnlyDevice, UserIdentities},
ReadOnlyAccount,
};
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub enum InnerSas { pub enum InnerSas {
@ -315,14 +313,15 @@ impl InnerSas {
_ => (self, None), _ => (self, None),
}, },
AnyToDeviceEvent::KeyVerificationMac(e) => match self { AnyToDeviceEvent::KeyVerificationMac(e) => match self {
InnerSas::KeyRecieved(s) => match s.into_mac_received(&e.sender, e.content.clone()) InnerSas::KeyRecieved(s) => {
{ match s.into_mac_received(&e.sender, e.content.clone()) {
Ok(s) => (InnerSas::MacReceived(s), None), Ok(s) => (InnerSas::MacReceived(s), None),
Err(s) => { Err(s) => {
let content = s.as_content(); let content = s.as_content();
(InnerSas::Canceled(s), Some(content.into())) (InnerSas::Canceled(s), Some(content.into()))
}
} }
}, }
InnerSas::Confirmed(s) => match s.into_done(&e.sender, e.content.clone()) { InnerSas::Confirmed(s) => match s.into_done(&e.sender, e.content.clone()) {
Ok(s) => (InnerSas::Done(s), None), Ok(s) => (InnerSas::Done(s), None),
Err(s) => { Err(s) => {

View File

@ -17,13 +17,14 @@ mod helpers;
mod inner_sas; mod inner_sas;
mod sas_state; mod sas_state;
use std::sync::{Arc, Mutex};
#[cfg(test)] #[cfg(test)]
use std::time::Instant; use std::time::Instant;
use event_enums::AcceptContent; use event_enums::AcceptContent;
use std::sync::{Arc, Mutex}; pub use event_enums::{CancelContent, OutgoingContent, StartContent};
use tracing::{error, info, trace, warn}; pub use helpers::content_to_request;
use inner_sas::InnerSas;
use matrix_sdk_common::{ use matrix_sdk_common::{
api::r0::keys::upload_signatures::Request as SignatureUploadRequest, api::r0::keys::upload_signatures::Request as SignatureUploadRequest,
events::{ events::{
@ -37,7 +38,9 @@ use matrix_sdk_common::{
identifiers::{DeviceId, EventId, RoomId, UserId}, identifiers::{DeviceId, EventId, RoomId, UserId},
uuid::Uuid, uuid::Uuid,
}; };
use tracing::{error, info, trace, warn};
use super::FlowId;
use crate::{ use crate::{
error::SignatureError, error::SignatureError,
identities::{LocalTrust, ReadOnlyDevice, UserIdentities}, identities::{LocalTrust, ReadOnlyDevice, UserIdentities},
@ -47,12 +50,6 @@ use crate::{
ReadOnlyAccount, ToDeviceRequest, ReadOnlyAccount, ToDeviceRequest,
}; };
use super::FlowId;
pub use event_enums::{CancelContent, OutgoingContent, StartContent};
pub use helpers::content_to_request;
use inner_sas::InnerSas;
#[derive(Debug)] #[derive(Debug)]
/// A result of a verification flow. /// A result of a verification flow.
#[allow(clippy::large_enum_variant)] #[allow(clippy::large_enum_variant)]
@ -275,22 +272,18 @@ impl Sas {
&self, &self,
settings: AcceptSettings, settings: AcceptSettings,
) -> Option<OutgoingVerificationRequest> { ) -> Option<OutgoingVerificationRequest> {
self.inner self.inner.lock().unwrap().accept().map(|c| match settings.apply(c) {
.lock() AcceptContent::ToDevice(c) => {
.unwrap() let content = AnyToDeviceEventContent::KeyVerificationAccept(c);
.accept() self.content_to_request(content).into()
.map(|c| match settings.apply(c) { }
AcceptContent::ToDevice(c) => { AcceptContent::Room(room_id, content) => RoomMessageRequest {
let content = AnyToDeviceEventContent::KeyVerificationAccept(c); room_id,
self.content_to_request(content).into() txn_id: Uuid::new_v4(),
} content: AnyMessageEventContent::KeyVerificationAccept(content),
AcceptContent::Room(room_id, content) => RoomMessageRequest { }
room_id, .into(),
txn_id: Uuid::new_v4(), })
content: AnyMessageEventContent::KeyVerificationAccept(content),
}
.into(),
})
} }
/// Confirm the Sas verification. /// Confirm the Sas verification.
@ -303,10 +296,7 @@ impl Sas {
pub async fn confirm( pub async fn confirm(
&self, &self,
) -> Result< ) -> Result<
( (Option<OutgoingVerificationRequest>, Option<SignatureUploadRequest>),
Option<OutgoingVerificationRequest>,
Option<SignatureUploadRequest>,
),
CryptoStoreError, CryptoStoreError,
> { > {
let (content, done) = { let (content, done) = {
@ -319,9 +309,9 @@ impl Sas {
}; };
let mac_request = content.map(|c| match c { let mac_request = content.map(|c| match c {
event_enums::MacContent::ToDevice(c) => self event_enums::MacContent::ToDevice(c) => {
.content_to_request(AnyToDeviceEventContent::KeyVerificationMac(c)) self.content_to_request(AnyToDeviceEventContent::KeyVerificationMac(c)).into()
.into(), }
event_enums::MacContent::Room(r, c) => RoomMessageRequest { event_enums::MacContent::Room(r, c) => RoomMessageRequest {
room_id: r, room_id: r,
txn_id: Uuid::new_v4(), txn_id: Uuid::new_v4(),
@ -374,10 +364,7 @@ impl Sas {
}; };
let mut changes = Changes { let mut changes = Changes {
devices: DeviceChanges { devices: DeviceChanges { changed: vec![device], ..Default::default() },
changed: vec![device],
..Default::default()
},
..Default::default() ..Default::default()
}; };
@ -437,10 +424,7 @@ impl Sas {
.map(VerificationResult::SignatureUpload) .map(VerificationResult::SignatureUpload)
.unwrap_or(VerificationResult::Ok)) .unwrap_or(VerificationResult::Ok))
} else { } else {
Ok(self Ok(self.cancel().map(VerificationResult::Cancel).unwrap_or(VerificationResult::Ok))
.cancel()
.map(VerificationResult::Cancel)
.unwrap_or(VerificationResult::Ok))
} }
} }
@ -463,14 +447,8 @@ impl Sas {
.as_ref() .as_ref()
.map_or(false, |i| i.master_key() == identity.master_key()) .map_or(false, |i| i.master_key() == identity.master_key())
{ {
if self if self.verified_identities().map_or(false, |i| i.contains(&identity)) {
.verified_identities() trace!("Marking user identity of {} as verified.", identity.user_id(),);
.map_or(false, |i| i.contains(&identity))
{
trace!(
"Marking user identity of {} as verified.",
identity.user_id(),
);
if let UserIdentities::Own(i) = &identity { if let UserIdentities::Own(i) = &identity {
i.mark_as_verified(); i.mark_as_verified();
@ -509,17 +487,11 @@ impl Sas {
pub(crate) async fn mark_device_as_verified( pub(crate) async fn mark_device_as_verified(
&self, &self,
) -> Result<Option<ReadOnlyDevice>, CryptoStoreError> { ) -> Result<Option<ReadOnlyDevice>, CryptoStoreError> {
let device = self let device = self.store.get_device(self.other_user_id(), self.other_device_id()).await?;
.store
.get_device(self.other_user_id(), self.other_device_id())
.await?;
if let Some(device) = device { if let Some(device) = device {
if device.keys() == self.other_device.keys() { if device.keys() == self.other_device.keys() {
if self if self.verified_devices().map_or(false, |v| v.contains(&device)) {
.verified_devices()
.map_or(false, |v| v.contains(&device))
{
trace!( trace!(
"Marking device {} {} as verified.", "Marking device {} {} as verified.",
device.user_id(), device.user_id(),
@ -580,9 +552,9 @@ impl Sas {
content: AnyMessageEventContent::KeyVerificationCancel(content), content: AnyMessageEventContent::KeyVerificationCancel(content),
} }
.into(), .into(),
CancelContent::ToDevice(c) => self CancelContent::ToDevice(c) => {
.content_to_request(AnyToDeviceEventContent::KeyVerificationCancel(c)) self.content_to_request(AnyToDeviceEventContent::KeyVerificationCancel(c)).into()
.into(), }
}) })
} }
@ -684,11 +656,7 @@ impl Sas {
} }
pub(crate) fn content_to_request(&self, content: AnyToDeviceEventContent) -> ToDeviceRequest { pub(crate) fn content_to_request(&self, content: AnyToDeviceEventContent) -> ToDeviceRequest {
content_to_request( content_to_request(self.other_user_id(), self.other_device_id().to_owned(), content)
self.other_user_id(),
self.other_device_id().to_owned(),
content,
)
} }
} }
@ -717,9 +685,7 @@ impl AcceptSettings {
/// ///
/// * `methods` - The methods this client allows at most /// * `methods` - The methods this client allows at most
pub fn with_allowed_methods(methods: Vec<ShortAuthenticationString>) -> Self { pub fn with_allowed_methods(methods: Vec<ShortAuthenticationString>) -> Self {
Self { Self { allowed_methods: methods }
allowed_methods: methods,
}
} }
fn apply(self, mut content: AcceptContent) -> AcceptContent { fn apply(self, mut content: AcceptContent) -> AcceptContent {
@ -728,15 +694,8 @@ impl AcceptSettings {
method: AcceptMethod::MSasV1(c), method: AcceptMethod::MSasV1(c),
.. ..
}) })
| AcceptContent::Room( | AcceptContent::Room(_, AcceptEventContent { method: AcceptMethod::MSasV1(c), .. }) => {
_, c.short_authentication_string.retain(|sas| self.allowed_methods.contains(sas));
AcceptEventContent {
method: AcceptMethod::MSasV1(c),
..
},
) => {
c.short_authentication_string
.retain(|sas| self.allowed_methods.contains(sas));
content content
} }
_ => content, _ => content,
@ -750,6 +709,7 @@ mod test {
use matrix_sdk_common::identifiers::{DeviceId, UserId}; use matrix_sdk_common::identifiers::{DeviceId, UserId};
use super::Sas;
use crate::{ use crate::{
olm::PrivateCrossSigningIdentity, olm::PrivateCrossSigningIdentity,
store::{CryptoStore, MemoryStore}, store::{CryptoStore, MemoryStore},
@ -757,8 +717,6 @@ mod test {
ReadOnlyAccount, ReadOnlyDevice, ReadOnlyAccount, ReadOnlyDevice,
}; };
use super::Sas;
fn alice_id() -> UserId { fn alice_id() -> UserId {
UserId::try_from("@alice:example.org").unwrap() UserId::try_from("@alice:example.org").unwrap()
} }
@ -841,13 +799,7 @@ mod test {
); );
alice.receive_event(&event); alice.receive_event(&event);
assert!(alice assert!(alice.verified_devices().unwrap().contains(&alice.other_device()));
.verified_devices() assert!(bob.verified_devices().unwrap().contains(&bob.other_device()));
.unwrap()
.contains(&alice.other_device()));
assert!(bob
.verified_devices()
.unwrap()
.contains(&bob.other_device()));
} }
} }

View File

@ -19,8 +19,6 @@ use std::{
time::{Duration, Instant}, time::{Duration, Instant},
}; };
use olm_rs::sas::OlmSas;
use matrix_sdk_common::{ use matrix_sdk_common::{
events::key::verification::{ events::key::verification::{
accept::{ accept::{
@ -40,6 +38,7 @@ use matrix_sdk_common::{
identifiers::{DeviceId, EventId, RoomId, UserId}, identifiers::{DeviceId, EventId, RoomId, UserId},
uuid::Uuid, uuid::Uuid,
}; };
use olm_rs::sas::OlmSas;
use tracing::info; use tracing::info;
use super::{ use super::{
@ -51,7 +50,6 @@ use super::{
receive_mac_event, SasIds, receive_mac_event, SasIds,
}, },
}; };
use crate::{ use crate::{
identities::{ReadOnlyDevice, UserIdentities}, identities::{ReadOnlyDevice, UserIdentities},
verification::FlowId, verification::FlowId,
@ -62,10 +60,8 @@ const KEY_AGREEMENT_PROTOCOLS: &[KeyAgreementProtocol] =
&[KeyAgreementProtocol::Curve25519HkdfSha256]; &[KeyAgreementProtocol::Curve25519HkdfSha256];
const HASHES: &[HashAlgorithm] = &[HashAlgorithm::Sha256]; const HASHES: &[HashAlgorithm] = &[HashAlgorithm::Sha256];
const MACS: &[MessageAuthenticationCode] = &[MessageAuthenticationCode::HkdfHmacSha256]; const MACS: &[MessageAuthenticationCode] = &[MessageAuthenticationCode::HkdfHmacSha256];
const STRINGS: &[ShortAuthenticationString] = &[ const STRINGS: &[ShortAuthenticationString] =
ShortAuthenticationString::Decimal, &[ShortAuthenticationString::Decimal, ShortAuthenticationString::Emoji];
ShortAuthenticationString::Emoji,
];
// The max time a SAS flow can take from start to done. // The max time a SAS flow can take from start to done.
const MAX_AGE: Duration = Duration::from_secs(60 * 5); const MAX_AGE: Duration = Duration::from_secs(60 * 5);
@ -91,9 +87,7 @@ impl TryFrom<AcceptV1Content> for AcceptedProtocols {
if !KEY_AGREEMENT_PROTOCOLS.contains(&content.key_agreement_protocol) if !KEY_AGREEMENT_PROTOCOLS.contains(&content.key_agreement_protocol)
|| !HASHES.contains(&content.hash) || !HASHES.contains(&content.hash)
|| !MACS.contains(&content.message_authentication_code) || !MACS.contains(&content.message_authentication_code)
|| (!content || (!content.short_authentication_string.contains(&ShortAuthenticationString::Emoji)
.short_authentication_string
.contains(&ShortAuthenticationString::Emoji)
&& !content && !content
.short_authentication_string .short_authentication_string
.contains(&ShortAuthenticationString::Decimal)) .contains(&ShortAuthenticationString::Decimal))
@ -402,11 +396,7 @@ impl SasState<Created> {
) -> SasState<Created> { ) -> SasState<Created> {
SasState { SasState {
inner: Arc::new(Mutex::new(OlmSas::new())), inner: Arc::new(Mutex::new(OlmSas::new())),
ids: SasIds { ids: SasIds { account, other_device, other_identity },
account,
other_device,
other_identity,
},
verification_flow_id: flow_id.into(), verification_flow_id: flow_id.into(),
creation_time: Arc::new(Instant::now()), creation_time: Arc::new(Instant::now()),
@ -441,9 +431,7 @@ impl SasState<Created> {
MSasV1Content::new(self.state.protocol_definitions.clone()) MSasV1Content::new(self.state.protocol_definitions.clone())
.expect("Invalid initial protocol definitions."), .expect("Invalid initial protocol definitions."),
), ),
relation: Relation { relation: Relation { event_id: e.clone() },
event_id: e.clone(),
},
}, },
), ),
} }
@ -490,8 +478,8 @@ impl SasState<Created> {
} }
impl SasState<Started> { impl SasState<Started> {
/// Create a new SAS verification flow from an in-room m.key.verification.start /// Create a new SAS verification flow from an in-room
/// event. /// m.key.verification.start event.
/// ///
/// This will put us in the `started` state. /// This will put us in the `started` state.
/// ///
@ -549,11 +537,7 @@ impl SasState<Started> {
Ok(SasState { Ok(SasState {
inner: Arc::new(Mutex::new(sas)), inner: Arc::new(Mutex::new(sas)),
ids: SasIds { ids: SasIds { account, other_device, other_identity },
account,
other_device,
other_identity,
},
creation_time: Arc::new(Instant::now()), creation_time: Arc::new(Instant::now()),
last_event_time: Arc::new(Instant::now()), last_event_time: Arc::new(Instant::now()),
@ -605,19 +589,12 @@ impl SasState<Started> {
); );
match self.verification_flow_id.as_ref() { match self.verification_flow_id.as_ref() {
FlowId::ToDevice(s) => AcceptToDeviceEventContent { FlowId::ToDevice(s) => {
transaction_id: s.to_string(), AcceptToDeviceEventContent { transaction_id: s.to_string(), method }.into()
method,
} }
.into(),
FlowId::InRoom(r, e) => ( FlowId::InRoom(r, e) => (
r.clone(), r.clone(),
AcceptEventContent { AcceptEventContent { method, relation: Relation { event_id: e.clone() } },
method,
relation: Relation {
event_id: e.clone(),
},
},
) )
.into(), .into(),
} }
@ -683,10 +660,8 @@ impl SasState<Accepted> {
self.check_event(&sender, content.flow_id().as_str()) self.check_event(&sender, content.flow_id().as_str())
.map_err(|c| self.clone().cancel(c))?; .map_err(|c| self.clone().cancel(c))?;
let commitment = calculate_commitment( let commitment =
content.public_key(), calculate_commitment(content.public_key(), self.state.start_content.as_ref().clone());
self.state.start_content.as_ref().clone(),
);
if self.state.commitment != commitment { if self.state.commitment != commitment {
Err(self.cancel(CancelCode::InvalidMessage)) Err(self.cancel(CancelCode::InvalidMessage))
@ -728,9 +703,7 @@ impl SasState<Accepted> {
r.clone(), r.clone(),
KeyEventContent { KeyEventContent {
key: self.inner.lock().unwrap().public_key(), key: self.inner.lock().unwrap().public_key(),
relation: Relation { relation: Relation { event_id: e.clone() },
event_id: e.clone(),
},
}, },
) )
.into(), .into(),
@ -754,9 +727,7 @@ impl SasState<KeyReceived> {
r.clone(), r.clone(),
KeyEventContent { KeyEventContent {
key: self.inner.lock().unwrap().public_key(), key: self.inner.lock().unwrap().public_key(),
relation: Relation { relation: Relation { event_id: e.clone() },
event_id: e.clone(),
},
}, },
) )
.into(), .into(),
@ -779,8 +750,8 @@ impl SasState<KeyReceived> {
/// Get the index of the emoji of the short authentication string. /// Get the index of the emoji of the short authentication string.
/// ///
/// Returns seven u8 numbers in the range from 0 to 63 inclusive, those numbers /// Returns seven u8 numbers in the range from 0 to 63 inclusive, those
/// can be converted to a unique emoji defined by the spec. /// numbers can be converted to a unique emoji defined by the spec.
pub fn get_emoji_index(&self) -> [u8; 7] { pub fn get_emoji_index(&self) -> [u8; 7] {
get_emoji_index( get_emoji_index(
&self.inner.lock().unwrap(), &self.inner.lock().unwrap(),
@ -952,11 +923,7 @@ impl SasState<Confirmed> {
/// ///
/// The content needs to be automatically sent to the other side. /// The content needs to be automatically sent to the other side.
pub fn as_content(&self) -> MacContent { pub fn as_content(&self) -> MacContent {
get_mac_content( get_mac_content(&self.inner.lock().unwrap(), &self.ids, &self.verification_flow_id)
&self.inner.lock().unwrap(),
&self.ids,
&self.verification_flow_id,
)
} }
} }
@ -1015,8 +982,8 @@ impl SasState<MacReceived> {
/// Get the index of the emoji of the short authentication string. /// Get the index of the emoji of the short authentication string.
/// ///
/// Returns seven u8 numbers in the range from 0 to 63 inclusive, those numbers /// Returns seven u8 numbers in the range from 0 to 63 inclusive, those
/// can be converted to a unique emoji defined by the spec. /// numbers can be converted to a unique emoji defined by the spec.
pub fn get_emoji_index(&self) -> [u8; 7] { pub fn get_emoji_index(&self) -> [u8; 7] {
get_emoji_index( get_emoji_index(
&self.inner.lock().unwrap(), &self.inner.lock().unwrap(),
@ -1048,11 +1015,7 @@ impl SasState<WaitingForDone> {
/// The content needs to be automatically sent to the other side if it /// The content needs to be automatically sent to the other side if it
/// wasn't already sent. /// wasn't already sent.
pub fn as_content(&self) -> MacContent { pub fn as_content(&self) -> MacContent {
get_mac_content( get_mac_content(&self.inner.lock().unwrap(), &self.ids, &self.verification_flow_id)
&self.inner.lock().unwrap(),
&self.ids,
&self.verification_flow_id,
)
} }
pub fn done_content(&self) -> DoneContent { pub fn done_content(&self) -> DoneContent {
@ -1060,15 +1023,9 @@ impl SasState<WaitingForDone> {
FlowId::ToDevice(_) => { FlowId::ToDevice(_) => {
unreachable!("The done content isn't supported yet for to-device verifications") unreachable!("The done content isn't supported yet for to-device verifications")
} }
FlowId::InRoom(r, e) => ( FlowId::InRoom(r, e) => {
r.clone(), (r.clone(), DoneEventContent { relation: Relation { event_id: e.clone() } }).into()
DoneEventContent { }
relation: Relation {
event_id: e.clone(),
},
},
)
.into(),
} }
} }
@ -1110,11 +1067,7 @@ impl SasState<Done> {
/// The content needs to be automatically sent to the other side if it /// The content needs to be automatically sent to the other side if it
/// wasn't already sent. /// wasn't already sent.
pub fn as_content(&self) -> MacContent { pub fn as_content(&self) -> MacContent {
get_mac_content( get_mac_content(&self.inner.lock().unwrap(), &self.ids, &self.verification_flow_id)
&self.inner.lock().unwrap(),
&self.ids,
&self.verification_flow_id,
)
} }
pub fn done_content(&self) -> DoneContent { pub fn done_content(&self) -> DoneContent {
@ -1122,15 +1075,9 @@ impl SasState<Done> {
FlowId::ToDevice(_) => { FlowId::ToDevice(_) => {
unreachable!("The done content isn't supported yet for to-device verifications") unreachable!("The done content isn't supported yet for to-device verifications")
} }
FlowId::InRoom(r, e) => ( FlowId::InRoom(r, e) => {
r.clone(), (r.clone(), DoneEventContent { relation: Relation { event_id: e.clone() } }).into()
DoneEventContent { }
relation: Relation {
event_id: e.clone(),
},
},
)
.into(),
} }
} }
@ -1166,10 +1113,7 @@ impl Canceled {
_ => unimplemented!(), _ => unimplemented!(),
}; };
Canceled { Canceled { cancel_code: code, reason }
cancel_code: code,
reason,
}
} }
} }
@ -1188,9 +1132,7 @@ impl SasState<Canceled> {
CancelEventContent { CancelEventContent {
reason: self.state.reason.to_string(), reason: self.state.reason.to_string(),
code: self.state.cancel_code.clone(), code: self.state.cancel_code.clone(),
relation: Relation { relation: Relation { event_id: e.clone() },
event_id: e.clone(),
},
}, },
) )
.into(), .into(),
@ -1202,10 +1144,6 @@ impl SasState<Canceled> {
mod test { mod test {
use std::convert::TryFrom; use std::convert::TryFrom;
use crate::{
verification::sas::{event_enums::AcceptContent, StartContent},
ReadOnlyAccount, ReadOnlyDevice,
};
use matrix_sdk_common::{ use matrix_sdk_common::{
events::key::verification::{ events::key::verification::{
accept::{AcceptMethod, CustomContent}, accept::{AcceptMethod, CustomContent},
@ -1215,6 +1153,10 @@ mod test {
}; };
use super::{Accepted, Created, SasState, Started}; use super::{Accepted, Created, SasState, Started};
use crate::{
verification::sas::{event_enums::AcceptContent, StartContent},
ReadOnlyAccount, ReadOnlyDevice,
};
fn alice_id() -> UserId { fn alice_id() -> UserId {
UserId::try_from("@alice:example.org").unwrap() UserId::try_from("@alice:example.org").unwrap()
@ -1353,9 +1295,7 @@ mod test {
let content = bob.as_content(); let content = bob.as_content();
let sender = UserId::try_from("@malory:example.org").unwrap(); let sender = UserId::try_from("@malory:example.org").unwrap();
alice alice.into_accepted(&sender, content).expect_err("Didn't cancel on a invalid sender");
.into_accepted(&sender, content)
.expect_err("Didn't cancel on a invalid sender");
} }
#[tokio::test] #[tokio::test]

View File

@ -1,7 +1,6 @@
use std::{collections::HashMap, panic}; use std::{collections::HashMap, panic};
use http::Response; use http::Response;
use matrix_sdk_common::{ use matrix_sdk_common::{
api::r0::sync::sync_events::Response as SyncResponse, api::r0::sync::sync_events::Response as SyncResponse,
events::{ events::{
@ -11,9 +10,8 @@ use matrix_sdk_common::{
identifiers::{room_id, RoomId}, identifiers::{room_id, RoomId},
IncomingResponse, IncomingResponse,
}; };
use serde_json::Value as JsonValue;
pub use matrix_sdk_test_macros::async_test; pub use matrix_sdk_test_macros::async_test;
use serde_json::Value as JsonValue;
pub mod test_json; pub mod test_json;
@ -44,16 +42,17 @@ pub enum EventsJson {
Typing, Typing,
} }
/// The `EventBuilder` struct can be used to easily generate valid sync responses for testing. /// The `EventBuilder` struct can be used to easily generate valid sync
/// These can be then fed into either `Client` or `Room`. /// responses for testing. These can be then fed into either `Client` or `Room`.
/// ///
/// It supports generated a number of canned events, such as a member entering a room, his power /// It supports generated a number of canned events, such as a member entering a
/// level and display name changing and similar. It also supports insertion of custom events in the /// room, his power level and display name changing and similar. It also
/// form of `EventsJson` values. /// supports insertion of custom events in the form of `EventsJson` values.
/// ///
/// **Important** You *must* use the *same* builder when sending multiple sync responses to /// **Important** You *must* use the *same* builder when sending multiple sync
/// a single client. Otherwise, the subsequent responses will be *ignored* by the client because /// responses to a single client. Otherwise, the subsequent responses will be
/// the `next_batch` sync token will not be rotated properly. /// *ignored* by the client because the `next_batch` sync token will not be
/// rotated properly.
/// ///
/// # Example usage /// # Example usage
/// ///
@ -94,7 +93,8 @@ pub struct EventBuilder {
ephemeral: Vec<AnySyncEphemeralRoomEvent>, ephemeral: Vec<AnySyncEphemeralRoomEvent>,
/// The account data events that determine the state of a `Room`. /// The account data events that determine the state of a `Room`.
account_data: Vec<AnyGlobalAccountDataEvent>, account_data: Vec<AnyGlobalAccountDataEvent>,
/// Internal counter to enable the `prev_batch` and `next_batch` of each sync response to vary. /// Internal counter to enable the `prev_batch` and `next_batch` of each
/// sync response to vary.
batch_counter: i64, batch_counter: i64,
} }
@ -154,10 +154,7 @@ impl EventBuilder {
} }
fn add_joined_event(&mut self, room_id: &RoomId, event: AnySyncRoomEvent) { fn add_joined_event(&mut self, room_id: &RoomId, event: AnySyncRoomEvent) {
self.joined_room_events self.joined_room_events.entry(room_id.clone()).or_insert_with(Vec::new).push(event);
.entry(room_id.clone())
.or_insert_with(Vec::new)
.push(event);
} }
pub fn add_custom_invited_event( pub fn add_custom_invited_event(
@ -166,10 +163,7 @@ impl EventBuilder {
event: serde_json::Value, event: serde_json::Value,
) -> &mut Self { ) -> &mut Self {
let event = serde_json::from_value::<AnySyncStateEvent>(event).unwrap(); let event = serde_json::from_value::<AnySyncStateEvent>(event).unwrap();
self.invited_room_events self.invited_room_events.entry(room_id.clone()).or_insert_with(Vec::new).push(event);
.entry(room_id.clone())
.or_insert_with(Vec::new)
.push(event);
self self
} }
@ -179,10 +173,7 @@ impl EventBuilder {
event: serde_json::Value, event: serde_json::Value,
) -> &mut Self { ) -> &mut Self {
let event = serde_json::from_value::<AnySyncRoomEvent>(event).unwrap(); let event = serde_json::from_value::<AnySyncRoomEvent>(event).unwrap();
self.left_room_events self.left_room_events.entry(room_id.clone()).or_insert_with(Vec::new).push(event);
.entry(room_id.clone())
.or_insert_with(Vec::new)
.push(event);
self self
} }
@ -227,7 +218,8 @@ impl EventBuilder {
pub fn build_json_sync_response(&mut self) -> JsonValue { pub fn build_json_sync_response(&mut self) -> JsonValue {
let main_room_id = room_id!("!SVkFJHzfwvuaIEawgC:localhost"); let main_room_id = room_id!("!SVkFJHzfwvuaIEawgC:localhost");
// First time building a sync response, so initialize the `prev_batch` to a default one. // First time building a sync response, so initialize the `prev_batch` to a
// default one.
let prev_batch = self.generate_sync_token(); let prev_batch = self.generate_sync_token();
self.batch_counter += 1; self.batch_counter += 1;
let next_batch = self.generate_sync_token(); let next_batch = self.generate_sync_token();
@ -352,9 +344,7 @@ impl EventBuilder {
pub fn build_sync_response(&mut self) -> SyncResponse { pub fn build_sync_response(&mut self) -> SyncResponse {
let body = self.build_json_sync_response(); let body = self.build_json_sync_response();
let response = Response::builder() let response = Response::builder().body(serde_json::to_vec(&body).unwrap()).unwrap();
.body(serde_json::to_vec(&body).unwrap())
.unwrap();
SyncResponse::try_from_http_response(response).unwrap() SyncResponse::try_from_http_response(response).unwrap()
} }
@ -395,15 +385,10 @@ pub fn sync_response(kind: SyncResponseFile) -> SyncResponse {
SyncResponseFile::Voip => &test_json::VOIP_SYNC, SyncResponseFile::Voip => &test_json::VOIP_SYNC,
}; };
let response = Response::builder() let response = Response::builder().body(data.to_string().as_bytes().to_vec()).unwrap();
.body(data.to_string().as_bytes().to_vec())
.unwrap();
SyncResponse::try_from_http_response(response).unwrap() SyncResponse::try_from_http_response(response).unwrap()
} }
pub fn response_from_file(json: &serde_json::Value) -> Response<Vec<u8>> { pub fn response_from_file(json: &serde_json::Value) -> Response<Vec<u8>> {
Response::builder() Response::builder().status(200).body(json.to_string().as_bytes().to_vec()).unwrap()
.status(200)
.body(json.to_string().as_bytes().to_vec())
.unwrap()
} }

View File

@ -1,8 +1,8 @@
//! Test data for the matrix-sdk crates. //! Test data for the matrix-sdk crates.
//! //!
//! Exporting each const allows all the test data to have a single source of truth. //! Exporting each const allows all the test data to have a single source of
//! When running `cargo publish` no external folders are allowed so all the //! truth. When running `cargo publish` no external folders are allowed so all
//! test data needs to be contained within this crate. //! the test data needs to be contained within this crate.
use lazy_static::lazy_static; use lazy_static::lazy_static;
use serde_json::{json, Value as JsonValue}; use serde_json::{json, Value as JsonValue};
@ -17,12 +17,11 @@ pub use events::{
PUBLIC_ROOMS, REACTION, REDACTED, REDACTED_INVALID, REDACTED_STATE, REDACTION, PUBLIC_ROOMS, REACTION, REDACTED, REDACTED_INVALID, REDACTED_STATE, REDACTION,
REGISTRATION_RESPONSE_ERR, ROOM_ID, ROOM_MESSAGES, TYPING, REGISTRATION_RESPONSE_ERR, ROOM_ID, ROOM_MESSAGES, TYPING,
}; };
pub use members::MEMBERS;
pub use sync::{ pub use sync::{
DEFAULT_SYNC_SUMMARY, INVITE_SYNC, LEAVE_SYNC, LEAVE_SYNC_EVENT, MORE_SYNC, SYNC, VOIP_SYNC, DEFAULT_SYNC_SUMMARY, INVITE_SYNC, LEAVE_SYNC, LEAVE_SYNC_EVENT, MORE_SYNC, SYNC, VOIP_SYNC,
}; };
pub use members::MEMBERS;
lazy_static! { lazy_static! {
pub static ref DEVICES: JsonValue = json!({ pub static ref DEVICES: JsonValue = json!({
"devices": [ "devices": [

View File

@ -1,6 +1,7 @@
use proc_macro::TokenStream; use proc_macro::TokenStream;
/// Attribute to use `wasm_bindgen_test` for wasm32 targets and `tokio::test` for everything else /// Attribute to use `wasm_bindgen_test` for wasm32 targets and `tokio::test`
/// for everything else
#[proc_macro_attribute] #[proc_macro_attribute]
pub fn async_test(_attr: TokenStream, item: TokenStream) -> TokenStream { pub fn async_test(_attr: TokenStream, item: TokenStream) -> TokenStream {
let attrs = r#" let attrs = r#"