Merge branch 'master' into sas-longer-flow

This commit is contained in:
Damir Jelić 2021-05-13 11:26:40 +02:00
commit 3f57a2a9f2
79 changed files with 2020 additions and 4180 deletions

View file

@ -20,7 +20,7 @@ jobs:
- name: Install rust
uses: actions-rs/toolchain@v1
with:
toolchain: stable
toolchain: nightly
components: rustfmt
profile: minimal
override: true

6
.rustfmt.toml Normal file
View file

@ -0,0 +1,6 @@
max_width = 100
comment_width = 80
wrap_comments = true
imports_granularity = "Crate"
use_small_heuristics = "Max"
group_imports = "StdExternalCrate"

View file

@ -1,5 +1,4 @@
use std::{env, process::exit};
use tokio::time::{sleep, Duration};
use matrix_sdk::{
self, async_trait,
@ -7,6 +6,7 @@ use matrix_sdk::{
room::Room,
Client, ClientConfig, EventHandler, SyncSettings,
};
use tokio::time::{sleep, Duration};
use url::Url;
struct AutoJoinBot {
@ -72,15 +72,11 @@ async fn login_and_sync(
let homeserver_url = Url::parse(&homeserver_url).expect("Couldn't parse the homeserver URL");
let client = Client::new_with_config(homeserver_url, client_config).unwrap();
client
.login(username, password, None, Some("autojoin bot"))
.await?;
client.login(username, password, None, Some("autojoin bot")).await?;
println!("logged in as {}", username);
client
.set_event_handler(Box::new(AutoJoinBot::new(client.clone())))
.await;
client.set_event_handler(Box::new(AutoJoinBot::new(client.clone()))).await;
client.sync(SyncSettings::default()).await;

View file

@ -69,24 +69,23 @@ async fn login_and_sync(
// create a new Client with the given homeserver url and config
let client = Client::new_with_config(homeserver_url, client_config).unwrap();
client
.login(&username, &password, None, Some("command bot"))
.await?;
client.login(&username, &password, None, Some("command bot")).await?;
println!("logged in as {}", username);
// An initial sync to set up state and so our bot doesn't respond to old messages.
// If the `StateStore` finds saved state in the location given the initial sync will
// be skipped in favor of loading state from the store
// An initial sync to set up state and so our bot doesn't respond to old
// messages. If the `StateStore` finds saved state in the location given the
// initial sync will be skipped in favor of loading state from the store
client.sync_once(SyncSettings::default()).await.unwrap();
// add our CommandBot to be notified of incoming messages, we do this after the initial
// sync to avoid responding to messages before the bot was running.
// add our CommandBot to be notified of incoming messages, we do this after the
// initial sync to avoid responding to messages before the bot was running.
client.set_event_handler(Box::new(CommandBot::new())).await;
// since we called `sync_once` before we entered our sync loop we must pass
// that sync token to `sync`
let settings = SyncSettings::default().token(client.sync_token().await.unwrap());
// this keeps state from the server streaming in to CommandBot via the EventHandler trait
// this keeps state from the server streaming in to CommandBot via the
// EventHandler trait
client.sync(settings).await;
Ok(())

View file

@ -5,12 +5,11 @@ use std::{
sync::atomic::{AtomicBool, Ordering},
};
use serde_json::json;
use url::Url;
use matrix_sdk::{
self, api::r0::uiaa::AuthData, identifiers::UserId, Client, LoopCtrl, SyncSettings,
};
use serde_json::json;
use url::Url;
fn auth_data<'a>(user: &UserId, password: &str, session: Option<&'a str>) -> AuthData<'a> {
let mut auth_parameters = BTreeMap::new();
@ -22,11 +21,7 @@ fn auth_data<'a>(user: &UserId, password: &str, session: Option<&'a str>) -> Aut
auth_parameters.insert("identifier".to_owned(), identifier);
auth_parameters.insert("password".to_owned(), password.to_owned().into());
AuthData::DirectRequest {
kind: "m.login.password",
auth_parameters,
session,
}
AuthData::DirectRequest { kind: "m.login.password", auth_parameters, session }
}
async fn bootstrap(client: Client, user_id: UserId, password: String) {
@ -34,9 +29,7 @@ async fn bootstrap(client: Client, user_id: UserId, password: String) {
let mut input = String::new();
io::stdin()
.read_line(&mut input)
.expect("error: unable to read user input");
io::stdin().read_line(&mut input).expect("error: unable to read user input");
#[cfg(feature = "encryption")]
if let Err(e) = client.bootstrap_cross_signing(None).await {
@ -63,9 +56,7 @@ async fn login(
let homeserver_url = Url::parse(&homeserver_url).expect("Couldn't parse the homeserver URL");
let client = Client::new(homeserver_url).unwrap();
let response = client
.login(username, password, None, Some("rust-sdk"))
.await?;
let response = client.login(username, password, None, Some("rust-sdk")).await?;
let user_id = &response.user_id;
let client_ref = &client;

View file

@ -6,7 +6,6 @@ use std::{
Arc,
},
};
use url::Url;
use matrix_sdk::{
self,
@ -14,14 +13,13 @@ use matrix_sdk::{
identifiers::UserId,
Client, LoopCtrl, Sas, SyncSettings,
};
use url::Url;
async fn wait_for_confirmation(client: Client, sas: Sas) {
println!("Does the emoji match: {:?}", sas.emoji());
let mut input = String::new();
io::stdin()
.read_line(&mut input)
.expect("error: unable to read user input");
io::stdin().read_line(&mut input).expect("error: unable to read user input");
match input.trim().to_lowercase().as_ref() {
"yes" | "true" | "ok" => {
@ -68,9 +66,7 @@ async fn login(
let homeserver_url = Url::parse(&homeserver_url).expect("Couldn't parse the homeserver URL");
let client = Client::new(homeserver_url).unwrap();
client
.login(username, password, None, Some("rust-sdk"))
.await?;
client.login(username, password, None, Some("rust-sdk")).await?;
let client_ref = &client;
let initial_sync = Arc::new(AtomicBool::from(true));
@ -81,12 +77,7 @@ async fn login(
let client = &client_ref;
let initial = &initial_ref;
for event in response
.to_device
.events
.iter()
.filter_map(|e| e.deserialize().ok())
{
for event in response.to_device.events.iter().filter_map(|e| e.deserialize().ok()) {
match event {
AnyToDeviceEvent::KeyVerificationStart(e) => {
let sas = client
@ -129,11 +120,8 @@ async fn login(
if !initial.load(Ordering::SeqCst) {
for (_room_id, room_info) in response.rooms.join {
for event in room_info
.timeline
.events
.iter()
.filter_map(|e| e.event.deserialize().ok())
for event in
room_info.timeline.events.iter().filter_map(|e| e.event.deserialize().ok())
{
if let AnySyncRoomEvent::Message(event) = event {
match event {

View file

@ -1,13 +1,12 @@
use std::{convert::TryFrom, env, process::exit};
use url::Url;
use matrix_sdk::{
self,
api::r0::profile,
identifiers::{MxcUri, UserId},
Client, Result as MatrixResult,
};
use url::Url;
#[derive(Debug)]
struct UserProfile {
@ -29,10 +28,7 @@ async fn get_profile(client: Client, mxid: &UserId) -> MatrixResult<UserProfile>
// Use the response and construct a UserProfile struct.
// See https://docs.rs/ruma-client-api/0.9.0/ruma_client_api/r0/profile/get_profile/struct.Response.html
// for details on the Response for this Request
let user_profile = UserProfile {
avatar_url: resp.avatar_url,
displayname: resp.displayname,
};
let user_profile = UserProfile { avatar_url: resp.avatar_url, displayname: resp.displayname };
Ok(user_profile)
}
@ -44,9 +40,7 @@ async fn login(
let homeserver_url = Url::parse(&homeserver_url).expect("Couldn't parse the homeserver URL");
let client = Client::new(homeserver_url).unwrap();
client
.login(username, password, None, Some("rust-sdk"))
.await?;
client.login(username, password, None, Some("rust-sdk")).await?;
Ok(client)
}

View file

@ -6,7 +6,6 @@ use std::{
process::exit,
sync::Arc,
};
use tokio::sync::Mutex;
use matrix_sdk::{
self, async_trait,
@ -17,6 +16,7 @@ use matrix_sdk::{
room::Room,
Client, EventHandler, SyncSettings,
};
use tokio::sync::Mutex;
use url::Url;
struct ImageBot {
@ -52,9 +52,7 @@ impl EventHandler for ImageBot {
println!("sending image");
let mut image = self.image.lock().await;
room.send_attachment("cat", &mime::IMAGE_JPEG, &mut *image, None)
.await
.unwrap();
room.send_attachment("cat", &mime::IMAGE_JPEG, &mut *image, None).await.unwrap();
image.seek(SeekFrom::Start(0)).unwrap();
@ -73,14 +71,10 @@ async fn login_and_sync(
let homeserver_url = Url::parse(&homeserver_url).expect("Couldn't parse the homeserver URL");
let client = Client::new(homeserver_url).unwrap();
client
.login(&username, &password, None, Some("command bot"))
.await?;
client.login(&username, &password, None, Some("command bot")).await?;
client.sync_once(SyncSettings::default()).await.unwrap();
client
.set_event_handler(Box::new(ImageBot::new(image)))
.await;
client.set_event_handler(Box::new(ImageBot::new(image))).await;
let settings = SyncSettings::default().token(client.sync_token().await.unwrap());
client.sync(settings).await;
@ -91,26 +85,19 @@ async fn login_and_sync(
#[tokio::main]
async fn main() -> Result<(), matrix_sdk::Error> {
tracing_subscriber::fmt::init();
let (homeserver_url, username, password, image_path) = match (
env::args().nth(1),
env::args().nth(2),
env::args().nth(3),
env::args().nth(4),
) {
(Some(a), Some(b), Some(c), Some(d)) => (a, b, c, d),
_ => {
eprintln!(
"Usage: {} <homeserver_url> <username> <password> <image>",
env::args().next().unwrap()
);
exit(1)
}
};
let (homeserver_url, username, password, image_path) =
match (env::args().nth(1), env::args().nth(2), env::args().nth(3), env::args().nth(4)) {
(Some(a), Some(b), Some(c), Some(d)) => (a, b, c, d),
_ => {
eprintln!(
"Usage: {} <homeserver_url> <username> <password> <image>",
env::args().next().unwrap()
);
exit(1)
}
};
println!(
"helloooo {} {} {} {:#?}",
homeserver_url, username, password, image_path
);
println!("helloooo {} {} {} {:#?}", homeserver_url, username, password, image_path);
let path = PathBuf::from(image_path);
let image = File::open(path).expect("Can't open image file.");

View file

@ -1,5 +1,4 @@
use std::{env, process::exit};
use url::Url;
use matrix_sdk::{
self, async_trait,
@ -10,6 +9,7 @@ use matrix_sdk::{
room::Room,
Client, EventHandler, SyncSettings,
};
use url::Url;
struct EventCallback;
@ -28,9 +28,7 @@ impl EventHandler for EventCallback {
} = event
{
let member = room.get_member(&sender).await.unwrap().unwrap();
let name = member
.display_name()
.unwrap_or_else(|| member.user_id().as_str());
let name = member.display_name().unwrap_or_else(|| member.user_id().as_str());
println!("{}: {}", name, msg_body);
}
}
@ -47,9 +45,7 @@ async fn login(
client.set_event_handler(Box::new(EventCallback)).await;
client
.login(username, password, None, Some("rust-sdk"))
.await?;
client.login(username, password, None, Some("rust-sdk")).await?;
client.sync(SyncSettings::new()).await;
Ok(())

File diff suppressed because it is too large Load diff

View file

@ -40,7 +40,8 @@ impl Deref for Device {
impl Device {
/// Start a interactive verification with this `Device`
///
/// Returns a `Sas` object that represents the interactive verification flow.
/// Returns a `Sas` object that represents the interactive verification
/// flow.
///
/// # Example
///
@ -65,10 +66,7 @@ impl Device {
let (sas, request) = self.inner.start_verification().await?;
self.client.send_to_device(&request).await?;
Ok(Sas {
inner: sas,
client: self.client.clone(),
})
Ok(Sas { inner: sas, client: self.client.clone() })
}
/// Is the device trusted.
@ -102,10 +100,7 @@ pub struct UserDevices {
impl UserDevices {
/// Get the specific device with the given device id.
pub fn get(&self, device_id: &DeviceId) -> Option<Device> {
self.inner.get(device_id).map(|d| Device {
inner: d,
client: self.client.clone(),
})
self.inner.get(device_id).map(|d| Device { inner: d, client: self.client.clone() })
}
/// Iterator over all the device ids of the user devices.
@ -117,9 +112,6 @@ impl UserDevices {
pub fn devices(&self) -> impl Iterator<Item = Device> + '_ {
let client = self.client.clone();
self.inner.devices().map(move |d| Device {
inner: d,
client: client.clone(),
})
self.inner.devices().map(move |d| Device { inner: d, client: client.clone() })
}
}

View file

@ -14,7 +14,11 @@
//! Error conditions.
use std::io::Error as IoError;
use http::StatusCode;
#[cfg(feature = "encryption")]
use matrix_sdk_base::crypto::store::CryptoStoreError;
use matrix_sdk_base::{Error as MatrixError, StoreError};
use matrix_sdk_common::{
api::{
@ -26,12 +30,8 @@ use matrix_sdk_common::{
};
use reqwest::Error as ReqwestError;
use serde_json::Error as JsonError;
use std::io::Error as IoError;
use thiserror::Error;
#[cfg(feature = "encryption")]
use matrix_sdk_base::crypto::store::CryptoStoreError;
/// Result type of the rust-sdk.
pub type Result<T> = std::result::Result<T, Error>;
@ -43,11 +43,13 @@ pub enum HttpError {
#[error(transparent)]
Reqwest(#[from] ReqwestError),
/// Queried endpoint requires authentication but was called on an anonymous client.
/// Queried endpoint requires authentication but was called on an anonymous
/// client.
#[error("the queried endpoint requires authentication but was called before logging in")]
AuthenticationRequired,
/// Client tried to force authentication but did not provide an access token.
/// Client tried to force authentication but did not provide an access
/// token.
#[error("tried to force authentication but no access token was provided")]
ForcedAuthenticationWithoutAccessToken,
@ -69,9 +71,10 @@ pub enum HttpError {
/// An error occurred while authenticating.
///
/// When registering or authenticating the Matrix server can send a `UiaaResponse`
/// as the error type, this is a User-Interactive Authentication API response. This
/// represents an error with information about how to authenticate the user.
/// When registering or authenticating the Matrix server can send a
/// `UiaaResponse` as the error type, this is a User-Interactive
/// Authentication API response. This represents an error with
/// information about how to authenticate the user.
#[error(transparent)]
UiaaError(#[from] FromHttpResponseError<UiaaError>),
@ -96,7 +99,8 @@ pub enum Error {
#[error(transparent)]
Http(#[from] HttpError),
/// Queried endpoint requires authentication but was called on an anonymous client.
/// Queried endpoint requires authentication but was called on an anonymous
/// client.
#[error("the queried endpoint requires authentication but was called before logging in")]
AuthenticationRequired,

View file

@ -16,6 +16,7 @@ use std::ops::Deref;
use matrix_sdk_common::{
api::r0::push::get_notifications::Notification,
async_trait,
events::{
fully_read::FullyReadEventContent, AnySyncRoomEvent, GlobalAccountDataEvent,
RoomAccountDataEvent,
@ -56,7 +57,6 @@ use crate::{
room::Room,
Client,
};
use matrix_sdk_common::async_trait;
pub(crate) struct Handler {
pub(crate) inner: Box<dyn EventHandler>,
@ -77,50 +77,29 @@ impl Handler {
}
pub(crate) async fn handle_sync(&self, response: &SyncResponse) {
for event in response
.account_data
.events
.iter()
.filter_map(|e| e.deserialize().ok())
{
for event in response.account_data.events.iter().filter_map(|e| e.deserialize().ok()) {
self.handle_account_data_event(&event).await;
}
for (room_id, room_info) in &response.rooms.join {
if let Some(room) = self.get_room(room_id) {
for event in room_info
.ephemeral
.events
.iter()
.filter_map(|e| e.deserialize().ok())
for event in room_info.ephemeral.events.iter().filter_map(|e| e.deserialize().ok())
{
self.handle_ephemeral_event(room.clone(), &event).await;
}
for event in room_info
.account_data
.events
.iter()
.filter_map(|e| e.deserialize().ok())
for event in
room_info.account_data.events.iter().filter_map(|e| e.deserialize().ok())
{
self.handle_room_account_data_event(room.clone(), &event)
.await;
self.handle_room_account_data_event(room.clone(), &event).await;
}
for event in room_info
.state
.events
.iter()
.filter_map(|e| e.deserialize().ok())
{
for event in room_info.state.events.iter().filter_map(|e| e.deserialize().ok()) {
self.handle_state_event(room.clone(), &event).await;
}
for event in room_info
.timeline
.events
.iter()
.filter_map(|e| e.event.deserialize().ok())
for event in
room_info.timeline.events.iter().filter_map(|e| e.event.deserialize().ok())
{
self.handle_timeline_event(room.clone(), &event).await;
}
@ -129,30 +108,18 @@ impl Handler {
for (room_id, room_info) in &response.rooms.leave {
if let Some(room) = self.get_room(room_id) {
for event in room_info
.account_data
.events
.iter()
.filter_map(|e| e.deserialize().ok())
for event in
room_info.account_data.events.iter().filter_map(|e| e.deserialize().ok())
{
self.handle_room_account_data_event(room.clone(), &event)
.await;
self.handle_room_account_data_event(room.clone(), &event).await;
}
for event in room_info
.state
.events
.iter()
.filter_map(|e| e.deserialize().ok())
{
for event in room_info.state.events.iter().filter_map(|e| e.deserialize().ok()) {
self.handle_state_event(room.clone(), &event).await;
}
for event in room_info
.timeline
.events
.iter()
.filter_map(|e| e.event.deserialize().ok())
for event in
room_info.timeline.events.iter().filter_map(|e| e.event.deserialize().ok())
{
self.handle_timeline_event(room.clone(), &event).await;
}
@ -161,31 +128,22 @@ impl Handler {
for (room_id, room_info) in &response.rooms.invite {
if let Some(room) = self.get_room(room_id) {
for event in room_info
.invite_state
.events
.iter()
.filter_map(|e| e.deserialize().ok())
for event in
room_info.invite_state.events.iter().filter_map(|e| e.deserialize().ok())
{
self.handle_stripped_state_event(room.clone(), &event).await;
}
}
}
for event in response
.presence
.events
.iter()
.filter_map(|e| e.deserialize().ok())
{
for event in response.presence.events.iter().filter_map(|e| e.deserialize().ok()) {
self.on_presence_event(&event).await;
}
for (room_id, notifications) in &response.notifications {
if let Some(room) = self.get_room(&room_id) {
for notification in notifications {
self.on_room_notification(room.clone(), notification.clone())
.await;
self.on_room_notification(room.clone(), notification.clone()).await;
}
}
}
@ -249,8 +207,7 @@ impl Handler {
self.on_room_tombstone(room, &tomb).await
}
AnySyncStateEvent::Custom(custom) => {
self.on_custom_event(room, &CustomEvent::State(custom))
.await
self.on_custom_event(room, &CustomEvent::State(custom)).await
}
_ => {}
}
@ -268,8 +225,7 @@ impl Handler {
}
AnyStrippedStateEvent::RoomName(name) => self.on_stripped_state_name(room, &name).await,
AnyStrippedStateEvent::RoomCanonicalAlias(canonical) => {
self.on_stripped_state_canonical_alias(room, &canonical)
.await
self.on_stripped_state_canonical_alias(room, &canonical).await
}
AnyStrippedStateEvent::RoomAliases(aliases) => {
self.on_stripped_state_aliases(room, &aliases).await
@ -341,8 +297,9 @@ pub enum CustomEvent<'c> {
StrippedState(&'c StrippedStateEvent<CustomEventContent>),
}
/// This trait allows any type implementing `EventHandler` to specify event callbacks for each event.
/// The `Client` calls each method when the corresponding event is received.
/// This trait allows any type implementing `EventHandler` to specify event
/// callbacks for each event. The `Client` calls each method when the
/// corresponding event is received.
///
/// # Examples
/// ```
@ -427,8 +384,8 @@ pub trait EventHandler: Send + Sync {
/// Fires when `Client` receives a `RoomEvent::Tombstone` event.
async fn on_room_tombstone(&self, _: Room, _: &SyncStateEvent<TombstoneEventContent>) {}
/// Fires when `Client` receives room events that trigger notifications according to
/// the push rules of the user.
/// Fires when `Client` receives room events that trigger notifications
/// according to the push rules of the user.
async fn on_room_notification(&self, _: Room, _: Notification) {}
// `RoomEvent`s from `IncomingState`
@ -453,7 +410,8 @@ pub trait EventHandler: Send + Sync {
async fn on_state_join_rules(&self, _: Room, _: &SyncStateEvent<JoinRulesEventContent>) {}
// `AnyStrippedStateEvent`s
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomMember` event.
/// Fires when `Client` receives a
/// `AnyStrippedStateEvent::StrippedRoomMember` event.
async fn on_stripped_state_member(
&self,
_: Room,
@ -461,32 +419,38 @@ pub trait EventHandler: Send + Sync {
_: Option<MemberEventContent>,
) {
}
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomName` event.
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomName`
/// event.
async fn on_stripped_state_name(&self, _: Room, _: &StrippedStateEvent<NameEventContent>) {}
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomCanonicalAlias` event.
/// Fires when `Client` receives a
/// `AnyStrippedStateEvent::StrippedRoomCanonicalAlias` event.
async fn on_stripped_state_canonical_alias(
&self,
_: Room,
_: &StrippedStateEvent<CanonicalAliasEventContent>,
) {
}
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomAliases` event.
/// Fires when `Client` receives a
/// `AnyStrippedStateEvent::StrippedRoomAliases` event.
async fn on_stripped_state_aliases(
&self,
_: Room,
_: &StrippedStateEvent<AliasesEventContent>,
) {
}
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomAvatar` event.
/// Fires when `Client` receives a
/// `AnyStrippedStateEvent::StrippedRoomAvatar` event.
async fn on_stripped_state_avatar(&self, _: Room, _: &StrippedStateEvent<AvatarEventContent>) {}
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomPowerLevels` event.
/// Fires when `Client` receives a
/// `AnyStrippedStateEvent::StrippedRoomPowerLevels` event.
async fn on_stripped_state_power_levels(
&self,
_: Room,
_: &StrippedStateEvent<PowerLevelsEventContent>,
) {
}
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomJoinRules` event.
/// Fires when `Client` receives a
/// `AnyStrippedStateEvent::StrippedRoomJoinRules` event.
async fn on_stripped_state_join_rules(
&self,
_: Room,
@ -523,31 +487,33 @@ pub trait EventHandler: Send + Sync {
/// Fires when `Client` receives a `NonRoomEvent::RoomAliases` event.
async fn on_presence_event(&self, _: &PresenceEvent) {}
/// Fires when `Client` receives a `Event::Custom` event or if deserialization fails
/// because the event was unknown to ruma.
/// Fires when `Client` receives a `Event::Custom` event or if
/// deserialization fails because the event was unknown to ruma.
///
/// The only guarantee this method can give about the event is that it is valid JSON.
/// The only guarantee this method can give about the event is that it is
/// valid JSON.
async fn on_unrecognized_event(&self, _: Room, _: &RawJsonValue) {}
/// Fires when `Client` receives a `Event::Custom` event or if deserialization fails
/// because the event was unknown to ruma.
/// Fires when `Client` receives a `Event::Custom` event or if
/// deserialization fails because the event was unknown to ruma.
///
/// The only guarantee this method can give about the event is that it is in the
/// shape of a valid matrix event.
/// The only guarantee this method can give about the event is that it is in
/// the shape of a valid matrix event.
async fn on_custom_event(&self, _: Room, _: &CustomEvent<'_>) {}
}
#[cfg(test)]
mod test {
use super::*;
use std::{sync::Arc, time::Duration};
use matrix_sdk_common::{async_trait, locks::Mutex};
use matrix_sdk_test::{async_test, test_json};
use mockito::{mock, Matcher};
use std::{sync::Arc, time::Duration};
#[cfg(target_arch = "wasm32")]
pub use wasm_bindgen_test::*;
use super::*;
#[derive(Clone)]
pub struct EvHandlerTest(Arc<Mutex<Vec<String>>>);
@ -640,56 +606,50 @@ mod test {
}
// `AnyStrippedStateEvent`s
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomMember` event.
/// Fires when `Client` receives a
/// `AnyStrippedStateEvent::StrippedRoomMember` event.
async fn on_stripped_state_member(
&self,
_: Room,
_: &StrippedStateEvent<MemberEventContent>,
_: Option<MemberEventContent>,
) {
self.0
.lock()
.await
.push("stripped state member".to_string())
self.0.lock().await.push("stripped state member".to_string())
}
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomName` event.
/// Fires when `Client` receives a
/// `AnyStrippedStateEvent::StrippedRoomName` event.
async fn on_stripped_state_name(&self, _: Room, _: &StrippedStateEvent<NameEventContent>) {
self.0.lock().await.push("stripped state name".to_string())
}
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomCanonicalAlias` event.
/// Fires when `Client` receives a
/// `AnyStrippedStateEvent::StrippedRoomCanonicalAlias` event.
async fn on_stripped_state_canonical_alias(
&self,
_: Room,
_: &StrippedStateEvent<CanonicalAliasEventContent>,
) {
self.0
.lock()
.await
.push("stripped state canonical".to_string())
self.0.lock().await.push("stripped state canonical".to_string())
}
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomAliases` event.
/// Fires when `Client` receives a
/// `AnyStrippedStateEvent::StrippedRoomAliases` event.
async fn on_stripped_state_aliases(
&self,
_: Room,
_: &StrippedStateEvent<AliasesEventContent>,
) {
self.0
.lock()
.await
.push("stripped state aliases".to_string())
self.0.lock().await.push("stripped state aliases".to_string())
}
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomAvatar` event.
/// Fires when `Client` receives a
/// `AnyStrippedStateEvent::StrippedRoomAvatar` event.
async fn on_stripped_state_avatar(
&self,
_: Room,
_: &StrippedStateEvent<AvatarEventContent>,
) {
self.0
.lock()
.await
.push("stripped state avatar".to_string())
self.0.lock().await.push("stripped state avatar".to_string())
}
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomPowerLevels` event.
/// Fires when `Client` receives a
/// `AnyStrippedStateEvent::StrippedRoomPowerLevels` event.
async fn on_stripped_state_power_levels(
&self,
_: Room,
@ -697,7 +657,8 @@ mod test {
) {
self.0.lock().await.push("stripped state power".to_string())
}
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomJoinRules` event.
/// Fires when `Client` receives a
/// `AnyStrippedStateEvent::StrippedRoomJoinRules` event.
async fn on_stripped_state_join_rules(
&self,
_: Room,
@ -768,14 +729,11 @@ mod test {
}
async fn mock_sync(client: &Client, response: String) {
let _m = mock(
"GET",
Matcher::Regex(r"^/_matrix/client/r0/sync\?.*$".to_string()),
)
.with_status(200)
.match_header("authorization", "Bearer 1234")
.with_body(response)
.create();
let _m = mock("GET", Matcher::Regex(r"^/_matrix/client/r0/sync\?.*$".to_string()))
.with_status(200)
.match_header("authorization", "Bearer 1234")
.with_body(response)
.create();
let sync_settings = SyncSettings::new().timeout(Duration::from_millis(3000));
let _response = client.sync_once(sync_settings).await.unwrap();
@ -823,14 +781,7 @@ mod test {
mock_sync(&client, test_json::INVITE_SYNC.to_string()).await;
let v = test_vec.lock().await;
assert_eq!(
v.as_slice(),
[
"stripped state name",
"stripped state member",
"presence event"
],
)
assert_eq!(v.as_slice(), ["stripped state name", "stripped state member", "presence event"],)
}
#[async_test]
@ -897,15 +848,7 @@ mod test {
mock_sync(&client, test_json::VOIP_SYNC.to_string()).await;
let v = test_vec.lock().await;
assert_eq!(
v.as_slice(),
[
"call invite",
"call answer",
"call candidates",
"call hangup",
],
)
assert_eq!(v.as_slice(), ["call invite", "call answer", "call candidates", "call hangup",],)
}
#[async_test]

View file

@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#[cfg(all(not(target_arch = "wasm32")))]
use std::sync::atomic::{AtomicU64, Ordering};
use std::{convert::TryFrom, fmt::Debug, sync::Arc};
#[cfg(all(not(target_arch = "wasm32")))]
@ -19,16 +21,13 @@ use backoff::{future::retry, Error as RetryError, ExponentialBackoff};
#[cfg(all(not(target_arch = "wasm32")))]
use http::StatusCode;
use http::{HeaderValue, Response as HttpResponse};
use reqwest::{Client, Response};
#[cfg(all(not(target_arch = "wasm32")))]
use std::sync::atomic::{AtomicU64, Ordering};
use tracing::trace;
use url::Url;
use matrix_sdk_common::{
api::r0::media::create_content, async_trait, locks::RwLock, AsyncTraitDeps, AuthScheme,
FromHttpResponseError, IncomingResponse, SendAccessToken,
};
use reqwest::{Client, Response};
use tracing::trace;
use url::Url;
use crate::{
error::HttpError, Bytes, BytesMut, ClientConfig, OutgoingRequest, RequestConfig, Session,
@ -39,13 +38,16 @@ use crate::{
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
pub trait HttpSend: AsyncTraitDeps {
/// The method abstracting sending request types and receiving response types.
/// The method abstracting sending request types and receiving response
/// types.
///
/// This is called by the client every time it wants to send anything to a homeserver.
/// This is called by the client every time it wants to send anything to a
/// homeserver.
///
/// # Arguments
///
/// * `request` - The http request that has been converted from a ruma `Request`.
/// * `request` - The http request that has been converted from a ruma
/// `Request`.
///
/// * `request_config` - The config used for this request.
///
@ -122,8 +124,7 @@ impl HttpClient {
let request = if !self.request_config.assert_identity {
self.try_into_http_request(request, session, config).await?
} else {
self.try_into_http_request_with_identy_assertion(request, session, config)
.await?
self.try_into_http_request_with_identy_assertion(request, session, config).await?
};
self.inner.send_request(request, config).await
@ -202,9 +203,7 @@ impl HttpClient {
request: create_content::Request<'_>,
config: Option<RequestConfig>,
) -> Result<create_content::Response, HttpError> {
let response = self
.send_request(request, self.session.clone(), config)
.await?;
let response = self.send_request(request, self.session.clone(), config).await?;
Ok(create_content::Response::try_from_http_response(response)?)
}
@ -217,9 +216,7 @@ impl HttpClient {
Request: OutgoingRequest + Debug,
HttpError: From<FromHttpResponseError<Request::EndpointError>>,
{
let response = self
.send_request(request, self.session.clone(), config)
.await?;
let response = self.send_request(request, self.session.clone(), config).await?;
trace!("Got response: {:?}", response);
@ -256,9 +253,7 @@ pub(crate) fn client_with_config(config: &ClientConfig) -> Result<Client, HttpEr
headers.insert(reqwest::header::USER_AGENT, user_agent);
http_client
.default_headers(headers)
.timeout(config.request_config.timeout)
http_client.default_headers(headers).timeout(config.request_config.timeout)
};
#[cfg(target_arch = "wasm32")]
@ -274,9 +269,7 @@ async fn response_to_http_response(
let status = response.status();
let mut http_builder = HttpResponse::builder().status(status);
let headers = http_builder
.headers_mut()
.expect("Can't get the response builder headers");
let headers = http_builder.headers_mut().expect("Can't get the response builder headers");
for (k, v) in response.headers_mut().drain() {
if let Some(key) = k {
@ -286,9 +279,7 @@ async fn response_to_http_response(
let body = response.bytes().await?;
Ok(http_builder
.body(body)
.expect("Can't construct a response using the given body"))
Ok(http_builder.body(body).expect("Can't construct a response using the given body"))
}
#[cfg(any(target_arch = "wasm32"))]
@ -329,18 +320,12 @@ async fn send_request(
};
// Turn errors into permanent errors when the retry limit is reached
let error_type = if stop {
RetryError::Permanent
} else {
RetryError::Transient
};
let error_type = if stop { RetryError::Permanent } else { RetryError::Transient };
let request = request.try_clone().ok_or(HttpError::UnableToCloneRequest)?;
let response = client
.execute(request)
.await
.map_err(|e| error_type(HttpError::Reqwest(e)))?;
let response =
client.execute(request).await.map_err(|e| error_type(HttpError::Reqwest(e)))?;
let status_code = response.status();
// TODO TOO_MANY_REQUESTS will have a retry timeout which we should

View file

@ -17,19 +17,21 @@
//!
//! # Enabling logging
//!
//! Users of the matrix-sdk crate can enable log output by depending on the `tracing-subscriber`
//! crate and including the following line in their application (e.g. at the start of `main`):
//! Users of the matrix-sdk crate can enable log output by depending on the
//! `tracing-subscriber` crate and including the following line in their
//! application (e.g. at the start of `main`):
//!
//! ```rust
//! tracing_subscriber::fmt::init();
//! ```
//!
//! The log output is controlled via the `RUST_LOG` environment variable by setting it to one of
//! the `error`, `warn`, `info`, `debug` or `trace` levels. The output is printed to stdout.
//! The log output is controlled via the `RUST_LOG` environment variable by
//! setting it to one of the `error`, `warn`, `info`, `debug` or `trace` levels.
//! The output is printed to stdout.
//!
//! The `RUST_LOG` variable also supports a more advanced syntax for filtering log output more
//! precisely, for instance with crate-level granularity. For more information on this, check out
//! the [tracing_subscriber
//! The `RUST_LOG` variable also supports a more advanced syntax for filtering
//! log output more precisely, for instance with crate-level granularity. For
//! more information on this, check out the [tracing_subscriber
//! documentation](https://tracing.rs/tracing_subscriber/filter/struct.envfilter).
//!
//! # Crate Feature Flags
@ -44,10 +46,13 @@
//! * `markdown`: Support for sending markdown formatted messages.
//! * `socks`: Enables SOCKS support in reqwest, the default HTTP client.
//! * `sso_login`: Enables SSO login with a local http server.
//! * `require_auth_for_profile_requests`: Whether to send the access token in the authentication
//! header when calling endpoints that retrieve profile data. This matches the synapse
//! configuration `require_auth_for_profile_requests`. Enabled by default.
//! * `appservice`: Enables low-level appservice functionality. For an high-level API there's the
//! * `require_auth_for_profile_requests`: Whether to send the access token in
//! the authentication
//! header when calling endpoints that retrieve profile data. This matches the
//! synapse configuration `require_auth_for_profile_requests`. Enabled by
//! default.
//! * `appservice`: Enables low-level appservice functionality. For an
//! high-level API there's the
//! `matrix-sdk-appservice` crate
#![deny(
@ -71,6 +76,7 @@ compile_error!("only one of 'native-tls' or 'rustls-tls' features can be enabled
#[cfg(all(feature = "sso_login", target_arch = "wasm32"))]
compile_error!("'sso_login' cannot be enabled on 'wasm32' arch");
pub use bytes::{Bytes, BytesMut};
#[cfg(feature = "encryption")]
#[cfg_attr(feature = "docs", doc(cfg(encryption)))]
pub use matrix_sdk_base::crypto::{EncryptionInfo, LocalTrust};
@ -78,8 +84,6 @@ pub use matrix_sdk_base::{
Error as BaseError, Room as BaseRoom, RoomInfo, RoomMember as BaseRoomMember, RoomType,
Session, StateChanges, StoreError,
};
pub use bytes::{Bytes, BytesMut};
pub use matrix_sdk_common::*;
pub use reqwest;

View file

@ -1,3 +1,5 @@
use std::{ops::Deref, sync::Arc};
use matrix_sdk_base::{deserialized_responses::MembersResponse, identifiers::UserId};
use matrix_sdk_common::{
api::r0::{
@ -8,11 +10,10 @@ use matrix_sdk_common::{
locks::Mutex,
};
use std::{ops::Deref, sync::Arc};
use crate::{BaseRoom, Client, Result, RoomMember};
/// A struct containing methodes that are common for Joined, Invited and Left Rooms
/// A struct containing methodes that are common for Joined, Invited and Left
/// Rooms
#[derive(Debug, Clone)]
pub struct Common {
inner: BaseRoom,
@ -36,10 +37,7 @@ impl Common {
/// * `room` - The underlaying room.
pub fn new(client: Client, room: BaseRoom) -> Self {
// TODO: Make this private
Self {
inner: room,
client,
}
Self { inner: room, client }
}
/// Leave this room.
@ -111,9 +109,9 @@ impl Common {
}
}
/// Sends a request to `/_matrix/client/r0/rooms/{room_id}/messages` and returns
/// a `get_message_events::Response` that contains a chunk of room and state events
/// (`AnyRoomEvent` and `AnyStateEvent`).
/// Sends a request to `/_matrix/client/r0/rooms/{room_id}/messages` and
/// returns a `get_message_events::Response` that contains a chunk of
/// room and state events (`AnyRoomEvent` and `AnyStateEvent`).
///
/// # Arguments
///
@ -152,35 +150,25 @@ impl Common {
pub(crate) async fn request_members(&self) -> Result<Option<MembersResponse>> {
#[allow(clippy::map_clone)]
if let Some(mutex) = self
.client
.members_request_locks
.get(self.inner.room_id())
.map(|m| m.clone())
if let Some(mutex) =
self.client.members_request_locks.get(self.inner.room_id()).map(|m| m.clone())
{
mutex.lock().await;
Ok(None)
} else {
let mutex = Arc::new(Mutex::new(()));
self.client
.members_request_locks
.insert(self.inner.room_id().clone(), mutex.clone());
self.client.members_request_locks.insert(self.inner.room_id().clone(), mutex.clone());
let _guard = mutex.lock().await;
let request = get_member_events::Request::new(self.inner.room_id());
let response = self.client.send(request, None).await?;
let response = self
.client
.base_client
.receive_members(self.inner.room_id(), &response)
.await?;
let response =
self.client.base_client.receive_members(self.inner.room_id(), &response).await?;
self.client
.members_request_locks
.remove(self.inner.room_id());
self.client.members_request_locks.remove(self.inner.room_id());
Ok(Some(response))
}
@ -248,9 +236,9 @@ impl Common {
/// Get all the joined members of this room.
///
/// *Note*: This method will not fetch the members from the homeserver if the
/// member list isn't synchronized due to member lazy loading. Thus, members
/// could be missing from the list.
/// *Note*: This method will not fetch the members from the homeserver if
/// the member list isn't synchronized due to member lazy loading. Thus,
/// members could be missing from the list.
///
/// Use [joined_members()](#method.joined_members) if you want to ensure to
/// always get the full member list.
@ -284,9 +272,9 @@ impl Common {
/// Get a specific member of this room.
///
/// *Note*: This method will not fetch the members from the homeserver if the
/// member list isn't synchronized due to member lazy loading. Thus, members
/// could be missing.
/// *Note*: This method will not fetch the members from the homeserver if
/// the member list isn't synchronized due to member lazy loading. Thus,
/// members could be missing.
///
/// Use [get_member()](#method.get_member) if you want to ensure to always
/// have the full member list to chose from.
@ -295,7 +283,6 @@ impl Common {
///
/// * `user_id` - The ID of the user that should be fetched out of the
/// store.
///
pub async fn get_member_no_sync(&self, user_id: &UserId) -> Result<Option<RoomMember>> {
Ok(self
.inner
@ -304,7 +291,8 @@ impl Common {
.map(|member| RoomMember::new(self.client.clone(), member)))
}
/// Get all members for this room, includes invited, joined and left members.
/// Get all members for this room, includes invited, joined and left
/// members.
///
/// *Note*: This method will fetch the members from the homeserver if the
/// member list isn't synchronized due to member lazy loading. Because of
@ -317,11 +305,12 @@ impl Common {
self.members_no_sync().await
}
/// Get all members for this room, includes invited, joined and left members.
/// Get all members for this room, includes invited, joined and left
/// members.
///
/// *Note*: This method will not fetch the members from the homeserver if the
/// member list isn't synchronized due to member lazy loading. Thus, members
/// could be missing.
/// *Note*: This method will not fetch the members from the homeserver if
/// the member list isn't synchronized due to member lazy loading. Thus,
/// members could be missing.
///
/// Use [members()](#method.members) if you want to ensure to always get
/// the full member list.

View file

@ -1,17 +1,20 @@
use crate::{room::Common, BaseRoom, Client, Result, RoomType};
use std::ops::Deref;
use crate::{room::Common, BaseRoom, Client, Result, RoomType};
/// A room in the invited state.
///
/// This struct contains all methodes specific to a `Room` with type `RoomType::Invited`.
/// Operations may fail once the underlaying `Room` changes `RoomType`.
/// This struct contains all methodes specific to a `Room` with type
/// `RoomType::Invited`. Operations may fail once the underlaying `Room` changes
/// `RoomType`.
#[derive(Debug, Clone)]
pub struct Invited {
pub(crate) inner: Common,
}
impl Invited {
/// Create a new `room::Invited` if the underlaying `Room` has type `RoomType::Invited`.
/// Create a new `room::Invited` if the underlaying `Room` has type
/// `RoomType::Invited`.
///
/// # Arguments
/// * `client` - The client used to make requests.
@ -20,9 +23,7 @@ impl Invited {
pub fn new(client: Client, room: BaseRoom) -> Option<Self> {
// TODO: Make this private
if room.room_type() == RoomType::Invited {
Some(Self {
inner: Common::new(client, room),
})
Some(Self { inner: Common::new(client, room) })
} else {
None
}

View file

@ -1,9 +1,11 @@
use crate::{room::Common, BaseRoom, Client, Result, RoomType};
#[cfg(feature = "encryption")]
use std::sync::Arc;
use std::{io::Read, ops::Deref};
#[cfg(feature = "encryption")]
use std::sync::Arc;
use matrix_sdk_base::crypto::AttachmentEncryptor;
#[cfg(feature = "encryption")]
use matrix_sdk_common::locks::Mutex;
use matrix_sdk_common::{
api::r0::{
membership::{
@ -34,25 +36,20 @@ use matrix_sdk_common::{
receipt::ReceiptType,
uuid::Uuid,
};
use mime::{self, Mime};
#[cfg(feature = "encryption")]
use matrix_sdk_common::locks::Mutex;
#[cfg(feature = "encryption")]
use matrix_sdk_base::crypto::AttachmentEncryptor;
#[cfg(feature = "encryption")]
use tracing::instrument;
use crate::{room::Common, BaseRoom, Client, Result, RoomType};
const TYPING_NOTICE_TIMEOUT: Duration = Duration::from_secs(4);
const TYPING_NOTICE_RESEND_TIMEOUT: Duration = Duration::from_secs(3);
/// A room in the joined state.
///
/// The `JoinedRoom` contains all methodes specific to a `Room` with type `RoomType::Joined`.
/// Operations may fail once the underlaying `Room` changes `RoomType`.
/// The `JoinedRoom` contains all methodes specific to a `Room` with type
/// `RoomType::Joined`. Operations may fail once the underlaying `Room` changes
/// `RoomType`.
#[derive(Debug, Clone)]
pub struct Joined {
pub(crate) inner: Common,
@ -67,7 +64,8 @@ impl Deref for Joined {
}
impl Joined {
/// Create a new `room::Joined` if the underlaying `BaseRoom` has type `RoomType::Joined`.
/// Create a new `room::Joined` if the underlaying `BaseRoom` has type
/// `RoomType::Joined`.
///
/// # Arguments
/// * `client` - The client used to make requests.
@ -76,9 +74,7 @@ impl Joined {
pub fn new(client: Client, room: BaseRoom) -> Option<Self> {
// TODO: Make this private
if room.room_type() == RoomType::Joined {
Some(Self {
inner: Common::new(client, room),
})
Some(Self { inner: Common::new(client, room) })
} else {
None
}
@ -97,9 +93,7 @@ impl Joined {
///
/// * `reason` - The reason for banning this user.
pub async fn ban_user(&self, user_id: &UserId, reason: Option<&str>) -> Result<()> {
let request = assign!(ban_user::Request::new(self.inner.room_id(), user_id), {
reason
});
let request = assign!(ban_user::Request::new(self.inner.room_id(), user_id), { reason });
self.client.send(request, None).await?;
Ok(())
}
@ -108,13 +102,12 @@ impl Joined {
///
/// # Arguments
///
/// * `user_id` - The `UserId` of the user that should be kicked out of the room.
/// * `user_id` - The `UserId` of the user that should be kicked out of the
/// room.
///
/// * `reason` - Optional reason why the room member is being kicked out.
pub async fn kick_user(&self, user_id: &UserId, reason: Option<&str>) -> Result<()> {
let request = assign!(kick_user::Request::new(self.inner.room_id(), user_id), {
reason
});
let request = assign!(kick_user::Request::new(self.inner.room_id(), user_id), { reason });
self.client.send(request, None).await?;
Ok(())
}
@ -148,9 +141,10 @@ impl Joined {
/// Activate typing notice for this room.
///
/// The typing notice remains active for 4s. It can be deactivate at any point by setting
/// typing to `false`. If this method is called while the typing notice is active nothing will happen.
/// This method can be called on every key stroke, since it will do nothing while typing is
/// The typing notice remains active for 4s. It can be deactivate at any
/// point by setting typing to `false`. If this method is called while
/// the typing notice is active nothing will happen. This method can be
/// called on every key stroke, since it will do nothing while typing is
/// active.
///
/// # Arguments
@ -183,21 +177,23 @@ impl Joined {
/// # });
/// ```
pub async fn typing_notice(&self, typing: bool) -> Result<()> {
// Only send a request to the homeserver if the old timeout has elapsed or the typing
// notice changed state within the TYPING_NOTICE_TIMEOUT
// Only send a request to the homeserver if the old timeout has elapsed
// or the typing notice changed state within the
// TYPING_NOTICE_TIMEOUT
let send =
if let Some(typing_time) = self.client.typing_notice_times.get(self.inner.room_id()) {
if typing_time.elapsed() > TYPING_NOTICE_RESEND_TIMEOUT {
// We always reactivate the typing notice if typing is true or we may need to
// deactivate it if it's currently active if typing is false
// We always reactivate the typing notice if typing is true or
// we may need to deactivate it if it's
// currently active if typing is false
typing || typing_time.elapsed() <= TYPING_NOTICE_TIMEOUT
} else {
// Only send a request when we need to deactivate typing
!typing
}
} else {
// Typing notice is currently deactivated, therefore, send a request only when it's
// about to be activated
// Typing notice is currently deactivated, therefore, send a request
// only when it's about to be activated
typing
};
@ -220,11 +216,13 @@ impl Joined {
Ok(())
}
/// Send a request to notify this room that the user has read specific event.
/// Send a request to notify this room that the user has read specific
/// event.
///
/// # Arguments
///
/// * `event_id` - The `EventId` specifies the event to set the read receipt on.
/// * `event_id` - The `EventId` specifies the event to set the read receipt
/// on.
pub async fn read_receipt(&self, event_id: &EventId) -> Result<()> {
let request =
create_receipt::Request::new(self.inner.room_id(), ReceiptType::Read, event_id);
@ -233,22 +231,23 @@ impl Joined {
Ok(())
}
/// Send a request to notify this room that the user has read up to specific event.
/// Send a request to notify this room that the user has read up to specific
/// event.
///
/// # Arguments
///
/// * fully_read - The `EventId` of the event the user has read to.
///
/// * read_receipt - An `EventId` to specify the event to set the read receipt on.
/// * read_receipt - An `EventId` to specify the event to set the read
/// receipt on.
pub async fn read_marker(
&self,
fully_read: &EventId,
read_receipt: Option<&EventId>,
) -> Result<()> {
let request = assign!(
set_read_marker::Request::new(self.inner.room_id(), fully_read),
{ read_receipt }
);
let request = assign!(set_read_marker::Request::new(self.inner.room_id(), fully_read), {
read_receipt
});
self.client.send(request, None).await?;
Ok(())
@ -266,11 +265,8 @@ impl Joined {
// TODO expose this publicly so people can pre-share a group session if
// e.g. a user starts to type a message for a room.
#[allow(clippy::map_clone)]
if let Some(mutex) = self
.client
.group_session_locks
.get(self.inner.room_id())
.map(|m| m.clone())
if let Some(mutex) =
self.client.group_session_locks.get(self.inner.room_id()).map(|m| m.clone())
{
// If a group session share request is already going on,
// await the release of the lock.
@ -279,23 +275,14 @@ impl Joined {
// Otherwise create a new lock and share the group
// session.
let mutex = Arc::new(Mutex::new(()));
self.client
.group_session_locks
.insert(self.inner.room_id().clone(), mutex.clone());
self.client.group_session_locks.insert(self.inner.room_id().clone(), mutex.clone());
let _guard = mutex.lock().await;
{
let joined = self
.client
.store()
.get_joined_user_ids(self.inner.room_id())
.await?;
let invited = self
.client
.store()
.get_invited_user_ids(self.inner.room_id())
.await?;
let joined = self.client.store().get_joined_user_ids(self.inner.room_id()).await?;
let invited =
self.client.store().get_invited_user_ids(self.inner.room_id()).await?;
let members = joined.iter().chain(&invited);
self.client.claim_one_time_keys(members).await?;
};
@ -308,10 +295,7 @@ impl Joined {
// session as using it would end up in undecryptable
// messages.
if let Err(r) = response {
self.client
.base_client
.invalidate_group_session(self.inner.room_id())
.await?;
self.client.base_client.invalidate_group_session(self.inner.room_id()).await?;
return Err(r);
}
}
@ -328,19 +312,13 @@ impl Joined {
#[cfg_attr(feature = "docs", doc(cfg(encryption)))]
#[instrument]
async fn share_group_session(&self) -> Result<()> {
let mut requests = self
.client
.base_client
.share_group_session(self.inner.room_id())
.await?;
let mut requests =
self.client.base_client.share_group_session(self.inner.room_id()).await?;
for request in requests.drain(..) {
let response = self.client.send_to_device(&request).await?;
self.client
.base_client
.mark_request_as_sent(&request.txn_id, &response)
.await?;
self.client.base_client.mark_request_as_sent(&request.txn_id, &response).await?;
}
Ok(())
@ -407,10 +385,7 @@ impl Joined {
self.preshare_group_session().await?;
AnyMessageEventContent::RoomEncrypted(
self.client
.base_client
.encrypt(self.inner.room_id(), content)
.await?,
self.client.base_client.encrypt(self.inner.room_id(), content).await?,
)
} else {
content.into()
@ -430,8 +405,9 @@ impl Joined {
/// If the room is encrypted and the encryption feature is enabled the
/// upload will be encrypted.
///
/// This is a convenience method that calls the [`Client::upload()`](#Client::method.upload)
/// and afterwards the [`send()`](#method.send).
/// This is a convenience method that calls the
/// [`Client::upload()`](#Client::method.upload) and afterwards the
/// [`send()`](#method.send).
///
/// # Arguments
/// * `body` - A textual representation of the media that is going to be
@ -538,11 +514,8 @@ impl Joined {
}),
};
self.send(
AnyMessageEventContent::RoomMessage(MessageEventContent::new(content)),
txn_id,
)
.await
self.send(AnyMessageEventContent::RoomMessage(MessageEventContent::new(content)), txn_id)
.await
}
/// Send a room state event to the homeserver.
@ -639,10 +612,10 @@ impl Joined {
txn_id: Option<Uuid>,
) -> Result<redact_event::Response> {
let txn_id = txn_id.unwrap_or_else(Uuid::new_v4).to_string();
let request = assign!(
redact_event::Request::new(self.inner.room_id(), event_id, &txn_id),
{ reason }
);
let request =
assign!(redact_event::Request::new(self.inner.room_id(), event_id, &txn_id), {
reason
});
self.client.send(request, None).await
}

View file

@ -1,19 +1,22 @@
use crate::{room::Common, BaseRoom, Client, Result, RoomType};
use std::ops::Deref;
use matrix_sdk_common::api::r0::membership::forget_room;
use crate::{room::Common, BaseRoom, Client, Result, RoomType};
/// A room in the left state.
///
/// This struct contains all methodes specific to a `Room` with type `RoomType::Left`.
/// Operations may fail once the underlaying `Room` changes `RoomType`.
/// This struct contains all methodes specific to a `Room` with type
/// `RoomType::Left`. Operations may fail once the underlaying `Room` changes
/// `RoomType`.
#[derive(Debug, Clone)]
pub struct Left {
pub(crate) inner: Common,
}
impl Left {
/// Create a new `room::Left` if the underlaying `Room` has type `RoomType::Left`.
/// Create a new `room::Left` if the underlaying `Room` has type
/// `RoomType::Left`.
///
/// # Arguments
/// * `client` - The client used to make requests.
@ -22,9 +25,7 @@ impl Left {
pub fn new(client: Client, room: BaseRoom) -> Option<Self> {
// TODO: Make this private
if room.room_type() == RoomType::Left {
Some(Self {
inner: Common::new(client, room),
})
Some(Self { inner: Common::new(client, room) })
} else {
None
}

View file

@ -1,7 +1,7 @@
use matrix_sdk_common::api::r0::media::{get_content, get_content_thumbnail};
use std::ops::Deref;
use matrix_sdk_common::api::r0::media::{get_content, get_content_thumbnail};
use crate::{BaseRoomMember, Client, Result};
/// The high-level `RoomMember` representation
@ -21,10 +21,7 @@ impl Deref for RoomMember {
impl RoomMember {
pub(crate) fn new(client: Client, member: BaseRoomMember) -> Self {
Self {
inner: member,
client,
}
Self { inner: member, client }
}
/// Gets the avatar of this member, if set.

View file

@ -26,11 +26,7 @@ impl AppserviceEventHandler {
#[async_trait]
impl EventHandler for AppserviceEventHandler {
async fn on_room_member(&self, room: Room, event: &SyncStateEvent<MemberEventContent>) {
if !self
.appservice
.user_id_is_in_namespace(&event.state_key)
.unwrap()
{
if !self.appservice.user_id_is_in_namespace(&event.state_key).unwrap() {
dbg!("not an appservice user");
return;
}
@ -38,11 +34,7 @@ impl EventHandler for AppserviceEventHandler {
if let MembershipState::Invite = event.content.membership {
let user_id = UserId::try_from(event.state_key.clone()).unwrap();
let client = self
.appservice
.client_with_localpart(user_id.localpart())
.await
.unwrap();
let client = self.appservice.client_with_localpart(user_id.localpart()).await.unwrap();
client.join_room_by_id(room.room_id()).await.unwrap();
}
@ -51,10 +43,7 @@ impl EventHandler for AppserviceEventHandler {
#[actix_web::main]
pub async fn main() -> std::io::Result<()> {
env::set_var(
"RUST_LOG",
"actix_web=debug,actix_server=info,matrix_sdk=debug",
);
env::set_var("RUST_LOG", "actix_web=debug,actix_server=info,matrix_sdk=debug");
tracing_subscriber::fmt::init();
let homeserver_url = "http://localhost:8008";
@ -62,16 +51,11 @@ pub async fn main() -> std::io::Result<()> {
let registration =
AppserviceRegistration::try_from_yaml_file("./tests/registration.yaml").unwrap();
let appservice = Appservice::new(homeserver_url, server_name, registration)
.await
.unwrap();
let appservice = Appservice::new(homeserver_url, server_name, registration).await.unwrap();
let event_handler = AppserviceEventHandler::new(appservice.clone());
appservice
.client()
.set_event_handler(Box::new(event_handler))
.await;
appservice.client().set_event_handler(Box::new(event_handler)).await;
HttpServer::new(move || App::new().service(appservice.actix_service()))
.bind(("0.0.0.0", 8090))?

View file

@ -17,6 +17,7 @@ use std::{
pin::Pin,
};
pub use actix_web::Scope;
use actix_web::{
dev::Payload,
error::PayloadError,
@ -30,8 +31,6 @@ use futures::Future;
use futures_util::{TryFutureExt, TryStreamExt};
use matrix_sdk::api_appservice as api;
pub use actix_web::Scope;
use crate::{error::Error, Appservice};
pub async fn run_server(
@ -53,10 +52,7 @@ pub fn get_scope() -> Scope {
}
fn gen_scope(scope: &str) -> Scope {
web::scope(scope)
.service(push_transactions)
.service(query_user_id)
.service(query_room_alias)
web::scope(scope).service(push_transactions).service(query_user_id).service(query_room_alias)
}
#[tracing::instrument]
@ -69,11 +65,7 @@ async fn push_transactions(
return Ok(HttpResponse::Unauthorized().finish());
}
appservice
.client()
.receive_transaction(request.incoming)
.await
.unwrap();
appservice.client().receive_transaction(request.incoming).await.unwrap();
Ok(HttpResponse::Ok().json("{}"))
}
@ -136,13 +128,9 @@ impl<T: matrix_sdk::IncomingRequest> FromRequest for IncomingRequest<T> {
uri
};
let mut builder = http::request::Builder::new()
.method(request.method())
.uri(uri);
let mut builder = http::request::Builder::new().method(request.method()).uri(uri);
let headers = builder
.headers_mut()
.ok_or(Error::UnknownHttpRequestBuilder)?;
let headers = builder.headers_mut().ok_or(Error::UnknownHttpRequestBuilder)?;
for (key, value) in request.headers().iter() {
headers.append(key, value.to_owned());
}
@ -158,10 +146,7 @@ impl<T: matrix_sdk::IncomingRequest> FromRequest for IncomingRequest<T> {
let access_token = match request.uri().query() {
Some(query) => {
let query: Vec<(String, String)> = matrix_sdk::urlencoded::from_str(query)?;
query
.into_iter()
.find(|(key, _)| key == "access_token")
.map(|(_, value)| value)
query.into_iter().find(|(key, _)| key == "access_token").map(|(_, value)| value)
}
None => None,
};

View file

@ -14,11 +14,14 @@
//! Matrix [Application Service] library
//!
//! The appservice crate aims to provide a batteries-included experience. That means that we
//! * ship with functionality to configure your webserver crate or simply run the webserver for you
//! The appservice crate aims to provide a batteries-included experience. That
//! means that we
//! * ship with functionality to configure your webserver crate or simply run
//! the webserver for you
//! * receive and validate requests from the homeserver correctly
//! * allow calling the homeserver with proper virtual user identity assertion
//! * have the goal to have a consistent room state available by leveraging the stores that the matrix-sdk provides
//! * have the goal to have a consistent room state available by leveraging the
//! stores that the matrix-sdk provides
//!
//! # Quickstart
//!
@ -62,6 +65,8 @@ use std::{
};
use http::Uri;
#[doc(inline)]
pub use matrix_sdk::api_appservice as api;
use matrix_sdk::{
api::{
error::ErrorKind,
@ -81,9 +86,6 @@ use regex::Regex;
use tracing::error;
use tracing::warn;
#[doc(inline)]
pub use matrix_sdk::api_appservice as api;
#[cfg(feature = "actix")]
mod actix;
mod error;
@ -104,9 +106,7 @@ impl AppserviceRegistration {
///
/// See the fields of [`Registration`] for the required format
pub fn try_from_yaml_str(value: impl AsRef<str>) -> Result<Self> {
Ok(Self {
inner: serde_yaml::from_str(value.as_ref())?,
})
Ok(Self { inner: serde_yaml::from_str(value.as_ref())? })
}
/// Try to load registration from yaml file
@ -115,9 +115,7 @@ impl AppserviceRegistration {
pub fn try_from_yaml_file(path: impl Into<PathBuf>) -> Result<Self> {
let file = File::open(path.into())?;
Ok(Self {
inner: serde_yaml::from_reader(file)?,
})
Ok(Self { inner: serde_yaml::from_reader(file)? })
}
}
@ -177,8 +175,10 @@ impl Appservice {
/// # Arguments
///
/// * `homeserver_url` - The homeserver that the client should connect to.
/// * `server_name` - The server name to use when constructing user ids from the localpart.
/// * `registration` - The [Appservice Registration] to use when interacting with the homserver.
/// * `server_name` - The server name to use when constructing user ids from
/// the localpart.
/// * `registration` - The [Appservice Registration] to use when interacting
/// with the homserver.
///
/// [Appservice Registration]: https://matrix.org/docs/spec/application_service/r0.1.2#registration
pub async fn new(
@ -209,8 +209,9 @@ impl Appservice {
/// Get `Client` for the given `localpart`
///
/// If the `localpart` is covered by the `namespaces` in the [registration] all requests to the
/// homeserver will [assert the identity] to the according virtual user.
/// If the `localpart` is covered by the `namespaces` in the [registration]
/// all requests to the homeserver will [assert the identity] to the
/// according virtual user.
///
/// [registration]: https://matrix.org/docs/spec/application_service/r0.1.2#registration
/// [assert the identity]:
@ -291,7 +292,8 @@ impl Appservice {
/// Get the host and port from the registration URL
///
/// If no port is found it falls back to scheme defaults: 80 for http and 443 for https
/// If no port is found it falls back to scheme defaults: 80 for http and
/// 443 for https
pub fn get_host_and_port_from_registration(&self) -> Result<(Host, Port)> {
let uri = Uri::try_from(&self.registration.url)?;
@ -315,9 +317,11 @@ impl Appservice {
actix::get_scope().data(self.clone())
}
/// Convenience method that runs an http server depending on the selected server feature
/// Convenience method that runs an http server depending on the selected
/// server feature
///
/// This is a blocking call that tries to listen on the provided host and port
/// This is a blocking call that tries to listen on the provided host and
/// port
pub async fn run(&self, host: impl AsRef<str>, port: impl Into<u16>) -> Result<()> {
#[cfg(feature = "actix")]
{

View file

@ -1,14 +1,12 @@
#[cfg(feature = "actix")]
mod actix {
use actix_web::{test, App};
use matrix_sdk_appservice::*;
use std::env;
use actix_web::{test, App};
use matrix_sdk_appservice::*;
async fn appservice() -> Appservice {
env::set_var(
"RUST_LOG",
"mockito=debug,matrix_sdk=debug,ruma=debug,actix_web=debug",
);
env::set_var("RUST_LOG", "mockito=debug,matrix_sdk=debug,ruma=debug,actix_web=debug");
let _ = tracing_subscriber::fmt::try_init();
Appservice::new(
@ -109,7 +107,8 @@ mod actix {
let resp = test::call_service(&app, req).await;
// TODO: this should actually return a 401 but is 500 because something in the extractor fails
// TODO: this should actually return a 401 but is 500 because something in the
// extractor fails
assert_eq!(resp.status(), 500);
}
}

View file

@ -76,10 +76,7 @@ async fn test_event_handler() -> Result<()> {
}
}
appservice
.client()
.set_event_handler(Box::new(Example::new()))
.await;
appservice.client().set_event_handler(Box::new(Example::new())).await;
let event = serde_json::from_value::<AnyStateEvent>(member_json()).unwrap();
let event: Raw<AnyRoomEvent> = AnyRoomEvent::State(event).into();

View file

@ -1,12 +1,15 @@
use std::{convert::TryFrom, fmt::Debug, sync::Arc};
use futures::executor::block_on;
use serde::Serialize;
#[cfg(not(target_arch = "wasm32"))]
use atty::Stream;
#[cfg(not(target_arch = "wasm32"))]
use clap::{App as Argparse, AppSettings as ArgParseSettings, Arg, ArgMatches, SubCommand};
use futures::executor::block_on;
use matrix_sdk_base::{
events::EventType,
identifiers::{RoomId, UserId},
RoomInfo, Store,
};
#[cfg(not(target_arch = "wasm32"))]
use rustyline::{
completion::{Completer, Pair},
@ -18,7 +21,7 @@ use rustyline::{
};
#[cfg(not(target_arch = "wasm32"))]
use rustyline_derive::Helper;
use serde::Serialize;
#[cfg(not(target_arch = "wasm32"))]
use syntect::{
dumps::from_binary,
@ -28,12 +31,6 @@ use syntect::{
util::{as_24_bit_terminal_escaped, LinesWithEndings},
};
use matrix_sdk_base::{
events::EventType,
identifiers::{RoomId, UserId},
RoomInfo, Store,
};
#[derive(Clone)]
#[cfg(not(target_arch = "wasm32"))]
struct Inspector {
@ -79,17 +76,8 @@ impl InspectorHelper {
fn complete_event_types(&self, arg: Option<&&str>) -> Vec<Pair> {
Self::EVENT_TYPES
.iter()
.map(|t| Pair {
display: t.to_string(),
replacement: format!("{} ", t),
})
.filter(|r| {
if let Some(arg) = arg {
r.replacement.starts_with(arg)
} else {
true
}
})
.map(|t| Pair { display: t.to_string(), replacement: format!("{} ", t) })
.filter(|r| if let Some(arg) = arg { r.replacement.starts_with(arg) } else { true })
.collect()
}
@ -102,13 +90,7 @@ impl InspectorHelper {
display: r.room_id.to_string(),
replacement: format!("{} ", r.room_id.to_string()),
})
.filter(|r| {
if let Some(arg) = arg {
r.replacement.starts_with(arg)
} else {
true
}
})
.filter(|r| if let Some(arg) = arg { r.replacement.starts_with(arg) } else { true })
.collect()
}
}
@ -127,15 +109,9 @@ impl Completer for InspectorHelper {
let commands = vec![
("get-state", "get a state event in the given room"),
(
"get-profiles",
"get all the stored profiles in the given room",
),
("get-profiles", "get all the stored profiles in the given room"),
("list-rooms", "list all rooms"),
(
"get-members",
"get all the membership events in the given room",
),
("get-members", "get all the membership events in the given room"),
]
.iter()
.map(|(r, d)| Pair {
@ -154,19 +130,13 @@ impl Completer for InspectorHelper {
} else {
Ok((
0,
commands
.into_iter()
.filter(|c| c.replacement.starts_with(args[0]))
.collect(),
commands.into_iter().filter(|c| c.replacement.starts_with(args[0])).collect(),
))
}
} else if args.len() == 2 {
if args[0] == "get-state" {
if line.ends_with(' ') {
Ok((
args[0].len() + args[1].len() + 2,
self.complete_event_types(args.get(2)),
))
Ok((args[0].len() + args[1].len() + 2, self.complete_event_types(args.get(2))))
} else {
Ok((args[0].len() + 1, self.complete_rooms(args.get(1))))
}
@ -177,10 +147,7 @@ impl Completer for InspectorHelper {
}
} else if args.len() == 3 {
if args[0] == "get-state" {
Ok((
args[0].len() + args[1].len() + 2,
self.complete_event_types(args.get(2)),
))
Ok((args[0].len() + args[1].len() + 2, self.complete_event_types(args.get(2))))
} else {
Ok((pos, vec![]))
}
@ -216,12 +183,7 @@ impl Printer {
let syntax_set: SyntaxSet = from_binary(include_bytes!("./syntaxes.bin"));
let themes: ThemeSet = from_binary(include_bytes!("./themes.bin"));
Self {
ps: syntax_set.into(),
ts: themes.into(),
json,
color,
}
Self { ps: syntax_set.into(), ts: themes.into(), json, color }
}
fn pretty_print_struct<T: Debug + Serialize>(&self, data: &T) {
@ -232,13 +194,9 @@ impl Printer {
};
let syntax = if self.json {
self.ps
.find_syntax_by_extension("rs")
.expect("Can't find rust syntax extension")
self.ps.find_syntax_by_extension("rs").expect("Can't find rust syntax extension")
} else {
self.ps
.find_syntax_by_extension("json")
.expect("Can't find json syntax extension")
self.ps.find_syntax_by_extension("json").expect("Can't find json syntax extension")
};
if self.color {
@ -305,11 +263,7 @@ impl Inspector {
}
async fn get_display_name_owners(&self, room_id: RoomId, display_name: String) {
let users = self
.store
.get_users_with_display_name(&room_id, &display_name)
.await
.unwrap();
let users = self.store.get_users_with_display_name(&room_id, &display_name).await.unwrap();
self.printer.pretty_print_struct(&users);
}
@ -326,22 +280,14 @@ impl Inspector {
let joined: Vec<UserId> = self.store.get_joined_user_ids(&room_id).await.unwrap();
for member in joined {
let event = self
.store
.get_member_event(&room_id, &member)
.await
.unwrap();
let event = self.store.get_member_event(&room_id, &member).await.unwrap();
self.printer.pretty_print_struct(&event);
}
}
async fn get_state(&self, room_id: RoomId, event_type: EventType) {
self.printer.pretty_print_struct(
&self
.store
.get_state_event(&room_id, event_type, "")
.await
.unwrap(),
&self.store.get_state_event(&room_id, event_type, "").await.unwrap(),
);
}
@ -350,35 +296,25 @@ impl Inspector {
SubCommand::with_name("list-rooms"),
SubCommand::with_name("get-members").arg(
Arg::with_name("room-id").required(true).validator(|r| {
RoomId::try_from(r)
.map(|_| ())
.map_err(|_| "Invalid room id given".to_owned())
RoomId::try_from(r).map(|_| ()).map_err(|_| "Invalid room id given".to_owned())
}),
),
SubCommand::with_name("get-profiles").arg(
Arg::with_name("room-id").required(true).validator(|r| {
RoomId::try_from(r)
.map(|_| ())
.map_err(|_| "Invalid room id given".to_owned())
RoomId::try_from(r).map(|_| ()).map_err(|_| "Invalid room id given".to_owned())
}),
),
SubCommand::with_name("get-display-names")
.arg(Arg::with_name("room-id").required(true).validator(|r| {
RoomId::try_from(r)
.map(|_| ())
.map_err(|_| "Invalid room id given".to_owned())
RoomId::try_from(r).map(|_| ()).map_err(|_| "Invalid room id given".to_owned())
}))
.arg(Arg::with_name("display-name").required(true)),
SubCommand::with_name("get-state")
.arg(Arg::with_name("room-id").required(true).validator(|r| {
RoomId::try_from(r)
.map(|_| ())
.map_err(|_| "Invalid room id given".to_owned())
RoomId::try_from(r).map(|_| ()).map_err(|_| "Invalid room id given".to_owned())
}))
.arg(Arg::with_name("event-type").required(true).validator(|e| {
EventType::try_from(e)
.map(|_| ())
.map_err(|_| "Invalid event type".to_string())
EventType::try_from(e).map(|_| ()).map_err(|_| "Invalid event type".to_string())
})),
]
}

View file

@ -87,20 +87,21 @@ pub struct AdditionalUnsignedData {
pub prev_content: Option<Raw<MemberEventContent>>,
}
/// Transform state event by hoisting `prev_content` field from `unsigned` to the top level.
/// Transform state event by hoisting `prev_content` field from `unsigned` to
/// the top level.
///
/// Due to a [bug in synapse][synapse-bug], `prev_content` often ends up in `unsigned` contrary to
/// the C2S spec. Some more discussion can be found [here][discussion]. Until this is fixed in
/// synapse or handled in Ruma, we use this to hoist up `prev_content` to the top level.
/// Due to a [bug in synapse][synapse-bug], `prev_content` often ends up in
/// `unsigned` contrary to the C2S spec. Some more discussion can be found
/// [here][discussion]. Until this is fixed in synapse or handled in Ruma, we
/// use this to hoist up `prev_content` to the top level.
///
/// [synapse-bug]: <https://github.com/matrix-org/matrix-doc/issues/684#issuecomment-641182668>
/// [discussion]: <https://github.com/matrix-org/matrix-doc/issues/684#issuecomment-641182668>
pub fn hoist_and_deserialize_state_event(
event: &Raw<AnySyncStateEvent>,
) -> StdResult<AnySyncStateEvent, serde_json::Error> {
let prev_content = serde_json::from_str::<AdditionalEventData>(event.json().get())?
.unsigned
.prev_content;
let prev_content =
serde_json::from_str::<AdditionalEventData>(event.json().get())?.unsigned.prev_content;
let mut ev = event.deserialize()?;
@ -116,9 +117,8 @@ pub fn hoist_and_deserialize_state_event(
fn hoist_member_event(
event: &Raw<StateEvent<MemberEventContent>>,
) -> StdResult<StateEvent<MemberEventContent>, serde_json::Error> {
let prev_content = serde_json::from_str::<AdditionalEventData>(event.json().get())?
.unsigned
.prev_content;
let prev_content =
serde_json::from_str::<AdditionalEventData>(event.json().get())?.unsigned.prev_content;
let mut e = event.deserialize()?;
@ -340,7 +340,8 @@ impl BaseClient {
///
/// # Arguments
///
/// * `response` - A successful login response that contains our access token
/// * `response` - A successful login response that contains our access
/// token
/// and device id.
pub async fn receive_login_response(
&self,
@ -440,9 +441,7 @@ impl BaseClient {
AnySyncRoomEvent::State(s) => match s {
AnySyncStateEvent::RoomMember(member) => {
if let Ok(member) = MemberEvent::try_from(member.clone()) {
ambiguity_cache
.handle_event(changes, room_id, &member)
.await?;
ambiguity_cache.handle_event(changes, room_id, &member).await?;
match member.content.membership {
MembershipState::Join | MembershipState::Invite => {
@ -500,8 +499,7 @@ impl BaseClient {
}
if let Some(context) = &mut push_context {
self.update_push_room_context(context, user_id, room_info, changes)
.await;
self.update_push_room_context(context, user_id, room_info, changes).await;
} else {
push_context = self.get_push_room_context(room, room_info, changes).await?;
}
@ -521,10 +519,13 @@ impl BaseClient {
),
);
}
// TODO if there is an Action::SetTweak(Tweak::Highlight) we need to store
// its value with the event so a client can show if the event is highlighted
// TODO if there is an
// Action::SetTweak(Tweak::Highlight) we need to store
// its value with the event so a client can show if the
// event is highlighted
// in the UI.
// Requires the possibility to associate custom data with events and to
// Requires the possibility to associate custom data
// with events and to
// store them.
}
}
@ -762,18 +763,14 @@ impl BaseClient {
let mut changes = StateChanges::new(next_batch.clone());
let mut ambiguity_cache = AmbiguityCache::new(self.store.clone());
self.handle_account_data(&account_data.events, &mut changes)
.await;
self.handle_account_data(&account_data.events, &mut changes).await;
let push_rules = self.get_push_rules(&changes).await?;
let mut new_rooms = Rooms::default();
for (room_id, new_info) in rooms.join {
let room = self
.store
.get_or_create_room(&room_id, RoomType::Joined)
.await;
let room = self.store.get_or_create_room(&room_id, RoomType::Joined).await;
let mut room_info = room.clone_info();
room_info.mark_as_joined();
@ -844,10 +841,7 @@ impl BaseClient {
}
for (room_id, new_info) in rooms.leave {
let room = self
.store
.get_or_create_room(&room_id, RoomType::Left)
.await;
let room = self.store.get_or_create_room(&room_id, RoomType::Left).await;
let mut room_info = room.clone_info();
room_info.mark_as_left();
@ -876,18 +870,14 @@ impl BaseClient {
.await;
changes.add_room(room_info);
new_rooms.leave.insert(
room_id,
LeftRoom::new(timeline, new_info.state, new_info.account_data),
);
new_rooms
.leave
.insert(room_id, LeftRoom::new(timeline, new_info.state, new_info.account_data));
}
for (room_id, new_info) in rooms.invite {
{
let room = self
.store
.get_or_create_room(&room_id, RoomType::Invited)
.await;
let room = self.store.get_or_create_room(&room_id, RoomType::Invited).await;
let mut room_info = room.clone_info();
room_info.mark_as_invited();
changes.add_room(room_info);
@ -934,9 +924,7 @@ impl BaseClient {
.into_iter()
.map(|(k, v)| (k, v.into()))
.collect(),
ambiguity_changes: AmbiguityChanges {
changes: ambiguity_cache.changes,
},
ambiguity_changes: AmbiguityChanges { changes: ambiguity_cache.changes },
notifications: changes.notifications,
};
@ -968,11 +956,7 @@ impl BaseClient {
let members: Vec<MemberEvent> = response
.chunk
.iter()
.filter_map(|e| {
hoist_member_event(e)
.ok()
.and_then(|e| MemberEvent::try_from(e).ok())
})
.filter_map(|e| hoist_member_event(e).ok().and_then(|e| MemberEvent::try_from(e).ok()))
.collect();
let mut ambiguity_cache = AmbiguityCache::new(self.store.clone());
@ -986,12 +970,7 @@ impl BaseClient {
let mut user_ids = BTreeSet::new();
for member in &members {
if self
.store
.get_member_event(&room_id, &member.state_key)
.await?
.is_none()
{
if self.store.get_member_event(&room_id, &member.state_key).await?.is_none() {
#[cfg(feature = "encryption")]
match member.content.membership {
MembershipState::Join | MembershipState::Invite => {
@ -1000,9 +979,7 @@ impl BaseClient {
_ => (),
}
ambiguity_cache
.handle_event(&changes, room_id, &member)
.await?;
ambiguity_cache.handle_event(&changes, room_id, &member).await?;
if member.state_key == member.sender {
changes
@ -1036,9 +1013,7 @@ impl BaseClient {
Ok(MembersResponse {
chunk: members,
ambiguity_changes: AmbiguityChanges {
changes: ambiguity_cache.changes,
},
ambiguity_changes: AmbiguityChanges { changes: ambiguity_cache.changes },
})
}
@ -1050,7 +1025,8 @@ impl BaseClient {
///
/// # Arguments
///
/// * `filter_name` - The name that should be used to persist the filter id in
/// * `filter_name` - The name that should be used to persist the filter id
/// in
/// the store.
///
/// * `response` - The successful filter upload response containing the
@ -1062,10 +1038,7 @@ impl BaseClient {
filter_name: &str,
response: &api::filter::create_filter::Response,
) -> Result<()> {
Ok(self
.store
.save_filter(filter_name, &response.filter_id)
.await?)
Ok(self.store.save_filter(filter_name, &response.filter_id).await?)
}
/// Get the filter id of a previously uploaded filter.
@ -1224,18 +1197,14 @@ impl BaseClient {
/// # Arguments
///
/// * `flow_id` - The unique id that identifies a interactive verification
/// flow. For in-room verifications this will be the event id of the
/// *m.key.verification.request* event that started the flow, for the
/// to-device verification flows this will be the transaction id of the
/// *m.key.verification.start* event.
/// flow. For in-room verifications this will be the event id of the
/// *m.key.verification.request* event that started the flow, for the
/// to-device verification flows this will be the transaction id of the
/// *m.key.verification.start* event.
#[cfg(feature = "encryption")]
#[cfg_attr(feature = "docs", doc(cfg(encryption)))]
pub async fn get_verification(&self, flow_id: &str) -> Option<Sas> {
self.olm
.lock()
.await
.as_ref()
.and_then(|o| o.get_verification(flow_id))
self.olm.lock().await.as_ref().and_then(|o| o.get_verification(flow_id))
}
/// Get a specific device of a user.
@ -1284,10 +1253,12 @@ impl BaseClient {
/// Get the user login session.
///
/// If the client is currently logged in, this will return a `matrix_sdk::Session` object which
/// can later be given to `restore_login`.
/// If the client is currently logged in, this will return a
/// `matrix_sdk::Session` object which can later be given to
/// `restore_login`.
///
/// Returns a session object if the client is logged in. Otherwise returns `None`.
/// Returns a session object if the client is logged in. Otherwise returns
/// `None`.
pub async fn get_session(&self) -> Option<Session> {
self.session.read().await.clone()
}
@ -1349,8 +1320,9 @@ impl BaseClient {
/// Get the push rules.
///
/// Gets the push rules from `changes` if they have been updated, otherwise get them from the
/// store. As a fallback, uses `Ruleset::server_default` if the user is logged in.
/// Gets the push rules from `changes` if they have been updated, otherwise
/// get them from the store. As a fallback, uses
/// `Ruleset::server_default` if the user is logged in.
pub async fn get_push_rules(&self, changes: &StateChanges) -> Result<Ruleset> {
if let Some(AnyGlobalAccountDataEvent::PushRules(event)) = changes
.account_data
@ -1374,11 +1346,11 @@ impl BaseClient {
/// Get the push context for the given room.
///
/// Tries to get the data from `changes` or the up to date `room_info`. Loads the data from the
/// store otherwise.
/// Tries to get the data from `changes` or the up to date `room_info`.
/// Loads the data from the store otherwise.
///
/// Returns `None` if some data couldn't be found. This should only happen in brand new rooms,
/// while we process its state.
/// Returns `None` if some data couldn't be found. This should only happen
/// in brand new rooms, while we process its state.
pub async fn get_push_room_context(
&self,
room: &Room,
@ -1390,16 +1362,10 @@ impl BaseClient {
let member_count = room_info.active_members_count();
let user_display_name = if let Some(member) = changes
.members
.get(room_id)
.and_then(|members| members.get(user_id))
let user_display_name = if let Some(member) =
changes.members.get(room_id).and_then(|members| members.get(user_id))
{
member
.content
.displayname
.clone()
.unwrap_or_else(|| user_id.localpart().to_owned())
member.content.displayname.clone().unwrap_or_else(|| user_id.localpart().to_owned())
} else if let Some(member) = room.get_member(user_id).await? {
member.name().to_owned()
} else {
@ -1449,16 +1415,10 @@ impl BaseClient {
push_rules.member_count = UInt::new(room_info.active_members_count()).unwrap_or(UInt::MAX);
if let Some(member) = changes
.members
.get(room_id)
.and_then(|members| members.get(user_id))
if let Some(member) = changes.members.get(room_id).and_then(|members| members.get(user_id))
{
push_rules.user_display_name = member
.content
.displayname
.clone()
.unwrap_or_else(|| user_id.localpart().to_owned())
push_rules.user_display_name =
member.content.displayname.clone().unwrap_or_else(|| user_id.localpart().to_owned())
}
if let Some(AnySyncStateEvent::RoomPowerLevels(event)) = changes

View file

@ -15,12 +15,12 @@
//! Error conditions.
use serde_json::Error as JsonError;
use std::io::Error as IoError;
use thiserror::Error;
#[cfg(feature = "encryption")]
use matrix_sdk_crypto::{CryptoStoreError, MegolmError, OlmError};
use serde_json::Error as JsonError;
use thiserror::Error;
/// Result type of the rust-sdk.
pub type Result<T, E = Error> = std::result::Result<T, E>;
@ -28,7 +28,8 @@ pub type Result<T, E = Error> = std::result::Result<T, E>;
/// Internal representation of errors.
#[derive(Error, Debug)]
pub enum Error {
/// Queried endpoint requires authentication but was called on an anonymous client.
/// Queried endpoint requires authentication but was called on an anonymous
/// client.
#[error("the queried endpoint requires authentication but was called before logging in")]
AuthenticationRequired,

View file

@ -36,11 +36,12 @@
)]
#![cfg_attr(feature = "docs", feature(doc_cfg))]
pub use matrix_sdk_common::*;
pub use crate::{
error::{Error, Result},
session::Session,
};
pub use matrix_sdk_common::*;
mod client;
mod error;
@ -48,11 +49,9 @@ mod rooms;
mod session;
mod store;
pub use rooms::{Room, RoomInfo, RoomMember, RoomType};
pub use store::{StateChanges, StateStore, Store, StoreError};
pub use client::{BaseClient, BaseClientConfig};
#[cfg(feature = "encryption")]
#[cfg_attr(feature = "docs", doc(cfg(encryption)))]
pub use matrix_sdk_crypto as crypto;
pub use rooms::{Room, RoomInfo, RoomMember, RoomType};
pub use store::{StateChanges, StateStore, Store, StoreError};

View file

@ -1,27 +1,22 @@
mod members;
mod normal;
use matrix_sdk_common::{
events::room::{
create::CreateEventContent, guest_access::GuestAccess,
history_visibility::HistoryVisibility, join_rules::JoinRule,
},
identifiers::{MxcUri, UserId},
};
pub use normal::{Room, RoomInfo, RoomType};
pub use members::RoomMember;
use serde::{Deserialize, Serialize};
use std::cmp::max;
use matrix_sdk_common::{
events::{
room::{encryption::EncryptionEventContent, tombstone::TombstoneEventContent},
room::{
create::CreateEventContent, encryption::EncryptionEventContent,
guest_access::GuestAccess, history_visibility::HistoryVisibility, join_rules::JoinRule,
tombstone::TombstoneEventContent,
},
AnyStateEventContent,
},
identifiers::RoomAliasId,
identifiers::{MxcUri, RoomAliasId, UserId},
};
pub use members::RoomMember;
pub use normal::{Room, RoomInfo, RoomType};
use serde::{Deserialize, Serialize};
/// A base room info struct that is the backbone of normal as well as stripped
/// rooms. Holds all the state events that are important to present a room to
@ -71,20 +66,12 @@ impl BaseRoomInfo {
let invited_joined = (invited_member_count + joined_member_count).saturating_sub(1);
if heroes_count >= invited_joined {
let mut names = heroes
.iter()
.take(3)
.map(|mem| mem.name())
.collect::<Vec<&str>>();
let mut names = heroes.iter().take(3).map(|mem| mem.name()).collect::<Vec<&str>>();
// stabilize ordering
names.sort_unstable();
names.join(", ")
} else if heroes_count < invited_joined && invited_joined > 1 {
let mut names = heroes
.iter()
.take(3)
.map(|mem| mem.name())
.collect::<Vec<&str>>();
let mut names = heroes.iter().take(3).map(|mem| mem.name()).collect::<Vec<&str>>();
names.sort_unstable();
// TODO: What length does the spec want us to use here and in
@ -149,10 +136,8 @@ impl BaseRoomInfo {
true
}
AnyStateEventContent::RoomPowerLevels(p) => {
let max_power_level = p
.users
.values()
.fold(self.max_power_level, |acc, p| max(acc, (*p).into()));
let max_power_level =
p.users.values().fold(self.max_power_level, |acc, p| max(acc, (*p).into()));
self.max_power_level = max_power_level;
true
}

View file

@ -37,14 +37,14 @@ use matrix_sdk_common::{
use serde::{Deserialize, Serialize};
use tracing::info;
use super::{BaseRoomInfo, RoomMember};
use crate::{
deserialized_responses::UnreadNotificationsCount,
store::{Result as StoreResult, StateStore},
};
use super::{BaseRoomInfo, RoomMember};
/// The underlying room data structure collecting state for joined, left and invtied rooms.
/// The underlying room data structure collecting state for joined, left and
/// invtied rooms.
#[derive(Debug, Clone)]
pub struct Room {
room_id: Arc<RoomId>,
@ -135,7 +135,8 @@ impl Room {
/// Check if the room has it's members fully synced.
///
/// Members might be missing if lazy member loading was enabled for the sync.
/// Members might be missing if lazy member loading was enabled for the
/// sync.
///
/// Returns true if no members are missing, false otherwise.
pub fn are_members_synced(&self) -> bool {
@ -200,12 +201,7 @@ impl Room {
/// Get the history visibility policy of this room.
pub fn history_visibility(&self) -> HistoryVisibility {
self.inner
.read()
.unwrap()
.base_info
.history_visibility
.clone()
self.inner.read().unwrap().base_info.history_visibility.clone()
}
/// Is the room considered to be public.
@ -367,9 +363,7 @@ impl Room {
);
let inner = self.inner.read().unwrap();
Ok(inner
.base_info
.calculate_room_name(joined, invited, members))
Ok(inner.base_info.calculate_room_name(joined, invited, members))
}
pub(crate) fn clone_info(&self) -> RoomInfo {
@ -394,11 +388,8 @@ impl Room {
return Ok(None);
};
let presence = self
.store
.get_presence_event(user_id)
.await?
.and_then(|e| e.deserialize().ok());
let presence =
self.store.get_presence_event(user_id).await?.and_then(|e| e.deserialize().ok());
let profile = self.store.get_profile(self.room_id(), user_id).await?;
let max_power_level = self.max_power_level();
let is_room_creator = self
@ -411,28 +402,24 @@ impl Room {
.map(|c| &c.creator == user_id)
.unwrap_or(false);
let power = self
.store
.get_state_event(self.room_id(), EventType::RoomPowerLevels, "")
.await?
.and_then(|e| e.deserialize().ok())
.and_then(|e| {
if let AnySyncStateEvent::RoomPowerLevels(e) = e {
Some(e)
} else {
None
}
});
let power =
self.store
.get_state_event(self.room_id(), EventType::RoomPowerLevels, "")
.await?
.and_then(|e| e.deserialize().ok())
.and_then(|e| {
if let AnySyncStateEvent::RoomPowerLevels(e) = e {
Some(e)
} else {
None
}
});
let ambiguous = self
.store
.get_users_with_display_name(
self.room_id(),
member_event
.content
.displayname
.as_deref()
.unwrap_or_else(|| user_id.localpart()),
member_event.content.displayname.as_deref().unwrap_or_else(|| user_id.localpart()),
)
.await?
.len()
@ -558,8 +545,6 @@ impl RoomInfo {
///
/// The return value is saturated at `u64::MAX`.
pub fn active_members_count(&self) -> u64 {
self.summary
.joined_member_count
.saturating_add(self.summary.invited_member_count)
self.summary.joined_member_count.saturating_add(self.summary.invited_member_count)
}
}

View file

@ -15,9 +15,8 @@
//! User sessions.
use serde::{Deserialize, Serialize};
use matrix_sdk_common::identifiers::{DeviceId, UserId};
use serde::{Deserialize, Serialize};
/// A user session, containing an access token and information about the
/// associated user account.

View file

@ -19,12 +19,10 @@ use matrix_sdk_common::{
events::room::member::MembershipState,
identifiers::{EventId, RoomId, UserId},
};
use tracing::trace;
use crate::Store;
use super::{Result, StateChanges};
use crate::Store;
#[derive(Clone, Debug)]
pub struct AmbiguityCache {
@ -51,11 +49,8 @@ impl AmbiguityMap {
}
fn add(&mut self, user_id: UserId) -> Option<UserId> {
let ambiguous_user = if self.user_count() == 1 {
self.users.iter().next().cloned()
} else {
None
};
let ambiguous_user =
if self.user_count() == 1 { self.users.iter().next().cloned() } else { None };
self.users.insert(user_id);
@ -73,11 +68,7 @@ impl AmbiguityMap {
impl AmbiguityCache {
pub fn new(store: Store) -> Self {
Self {
store,
cache: BTreeMap::new(),
changes: BTreeMap::new(),
}
Self { store, cache: BTreeMap::new(), changes: BTreeMap::new() }
}
pub async fn handle_event(
@ -115,12 +106,9 @@ impl AmbiguityCache {
return Ok(());
}
let disambiguated_member = old_map
.as_mut()
.and_then(|o| o.remove(&member_event.state_key));
let ambiguated_member = new_map
.as_mut()
.and_then(|n| n.add(member_event.state_key.clone()));
let disambiguated_member = old_map.as_mut().and_then(|o| o.remove(&member_event.state_key));
let ambiguated_member =
new_map.as_mut().and_then(|n| n.add(member_event.state_key.clone()));
let ambiguous = new_map.as_ref().map(|n| n.is_ambiguous()).unwrap_or(false);
self.update(room_id, old_map, new_map);
@ -131,11 +119,7 @@ impl AmbiguityCache {
member_ambiguous: ambiguous,
};
trace!(
"Handling display name ambiguity for {}: {:#?}",
member_event.state_key,
change
);
trace!("Handling display name ambiguity for {}: {:#?}", member_event.state_key, change);
self.add_change(room_id, member_event.event_id.clone(), change);
@ -148,10 +132,7 @@ impl AmbiguityCache {
old_map: Option<AmbiguityMap>,
new_map: Option<AmbiguityMap>,
) {
let entry = self
.cache
.entry(room_id.clone())
.or_insert_with(BTreeMap::new);
let entry = self.cache.entry(room_id.clone()).or_insert_with(BTreeMap::new);
if let Some(old) = old_map {
entry.insert(old.display_name, old.users);
@ -163,10 +144,7 @@ impl AmbiguityCache {
}
fn add_change(&mut self, room_id: &RoomId, event_id: EventId, change: AmbiguityChange) {
self.changes
.entry(room_id.clone())
.or_insert_with(BTreeMap::new)
.insert(event_id, change);
self.changes.entry(room_id.clone()).or_insert_with(BTreeMap::new).insert(event_id, change);
}
async fn get(
@ -177,16 +155,12 @@ impl AmbiguityCache {
) -> Result<(Option<AmbiguityMap>, Option<AmbiguityMap>)> {
use MembershipState::*;
let old_event = if let Some(m) = changes
.members
.get(room_id)
.and_then(|m| m.get(&member_event.state_key))
let old_event = if let Some(m) =
changes.members.get(room_id).and_then(|m| m.get(&member_event.state_key))
{
Some(m.clone())
} else {
self.store
.get_member_event(room_id, &member_event.state_key)
.await?
self.store.get_member_event(room_id, &member_event.state_key).await?
};
let old_display_name = if let Some(event) = old_event {
@ -218,23 +192,15 @@ impl AmbiguityCache {
};
let old_map = if let Some(old_name) = old_display_name.as_deref() {
let old_display_name_map = if let Some(u) = self
.cache
.entry(room_id.clone())
.or_insert_with(BTreeMap::new)
.get(old_name)
let old_display_name_map = if let Some(u) =
self.cache.entry(room_id.clone()).or_insert_with(BTreeMap::new).get(old_name)
{
u.clone()
} else {
self.store
.get_users_with_display_name(&room_id, &old_name)
.await?
self.store.get_users_with_display_name(&room_id, &old_name).await?
};
Some(AmbiguityMap {
display_name: old_name.to_string(),
users: old_display_name_map,
})
Some(AmbiguityMap { display_name: old_name.to_string(), users: old_display_name_map })
} else {
None
};
@ -246,8 +212,9 @@ impl AmbiguityCache {
.as_deref()
.unwrap_or_else(|| member_event.state_key.localpart());
// We don't allow other users to set the display name, so if we have
// a more trusted version of the display name use that.
// We don't allow other users to set the display name, so if we
// have a more trusted version of the display
// name use that.
let new_display_name = if member_event.sender.as_str() == member_event.state_key {
new
} else if let Some(old) = old_display_name.as_deref() {
@ -264,9 +231,7 @@ impl AmbiguityCache {
{
u.clone()
} else {
self.store
.get_users_with_display_name(&room_id, &new_display_name)
.await?
self.store.get_users_with_display_name(&room_id, &new_display_name).await?
};
Some(AmbiguityMap {

View file

@ -30,12 +30,10 @@ use matrix_sdk_common::{
instant::Instant,
Raw,
};
use tracing::info;
use crate::deserialized_responses::{MemberEvent, StrippedMemberEvent};
use super::{Result, RoomInfo, StateChanges, StateStore};
use crate::deserialized_responses::{MemberEvent, StrippedMemberEvent};
#[derive(Debug, Clone)]
pub struct MemoryStore {
@ -82,8 +80,7 @@ impl MemoryStore {
}
async fn save_filter(&self, filter_name: &str, filter_id: &str) -> Result<()> {
self.filters
.insert(filter_name.to_string(), filter_id.to_string());
self.filters.insert(filter_name.to_string(), filter_id.to_string());
Ok(())
}
@ -164,8 +161,7 @@ impl MemoryStore {
}
for (event_type, event) in &changes.account_data {
self.account_data
.insert(event_type.to_string(), event.clone());
self.account_data.insert(event_type.to_string(), event.clone());
}
for (room, events) in &changes.room_account_data {
@ -199,8 +195,7 @@ impl MemoryStore {
}
for (room_id, info) in &changes.invited_room_info {
self.stripped_room_info
.insert(room_id.clone(), info.clone());
self.stripped_room_info.insert(room_id.clone(), info.clone());
}
for (room, events) in &changes.stripped_members {
@ -243,8 +238,7 @@ impl MemoryStore {
) -> Result<Option<Raw<AnySyncStateEvent>>> {
#[allow(clippy::map_clone)]
Ok(self.room_state.get(room_id).and_then(|e| {
e.get(event_type.as_ref())
.and_then(|s| s.get(state_key).map(|e| e.clone()))
e.get(event_type.as_ref()).and_then(|s| s.get(state_key).map(|e| e.clone()))
}))
}
@ -254,10 +248,7 @@ impl MemoryStore {
user_id: &UserId,
) -> Result<Option<MemberEventContent>> {
#[allow(clippy::map_clone)]
Ok(self
.profiles
.get(room_id)
.and_then(|p| p.get(user_id).map(|p| p.clone())))
Ok(self.profiles.get(room_id).and_then(|p| p.get(user_id).map(|p| p.clone())))
}
async fn get_member_event(
@ -266,10 +257,7 @@ impl MemoryStore {
state_key: &UserId,
) -> Result<Option<MemberEvent>> {
#[allow(clippy::map_clone)]
Ok(self
.members
.get(room_id)
.and_then(|m| m.get(state_key).map(|m| m.clone())))
Ok(self.members.get(room_id).and_then(|m| m.get(state_key).map(|m| m.clone())))
}
fn get_user_ids(&self, room_id: &RoomId) -> Vec<UserId> {
@ -310,10 +298,7 @@ impl MemoryStore {
&self,
event_type: EventType,
) -> Result<Option<Raw<AnyGlobalAccountDataEvent>>> {
Ok(self
.account_data
.get(event_type.as_ref())
.map(|e| e.clone()))
Ok(self.account_data.get(event_type.as_ref()).map(|e| e.clone()))
}
async fn get_room_account_data_event(

View file

@ -12,15 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#[cfg(feature = "sled_state_store")]
use std::path::Path;
use std::{
collections::{BTreeMap, BTreeSet},
ops::Deref,
sync::Arc,
};
#[cfg(feature = "sled_state_store")]
use std::path::Path;
use dashmap::DashMap;
use matrix_sdk_common::{
api::r0::push::get_notifications::Notification,
@ -201,7 +200,8 @@ pub trait StateStore: AsyncTraitDeps {
///
/// # Arguments
///
/// * `room_id` - The id of the room for which the room account data event should
/// * `room_id` - The id of the room for which the room account data event
/// should
/// be fetched.
///
/// * `event_type` - The event type of the room account data event.
@ -298,20 +298,16 @@ impl Store {
/// Get all the rooms this store knows about.
pub fn get_rooms(&self) -> Vec<Room> {
self.rooms
.iter()
.filter_map(|r| self.get_room(r.key()))
.collect()
self.rooms.iter().filter_map(|r| self.get_room(r.key())).collect()
}
/// Get the room with the given room id.
pub fn get_room(&self, room_id: &RoomId) -> Option<Room> {
self.get_bare_room(room_id)
.and_then(|r| match r.room_type() {
RoomType::Joined => Some(r),
RoomType::Left => Some(r),
RoomType::Invited => self.get_stripped_room(room_id),
})
self.get_bare_room(room_id).and_then(|r| match r.room_type() {
RoomType::Joined => Some(r),
RoomType::Left => Some(r),
RoomType::Invited => self.get_stripped_room(room_id),
})
}
fn get_stripped_room(&self, room_id: &RoomId) -> Option<Room> {
@ -321,10 +317,7 @@ impl Store {
pub(crate) async fn get_or_create_stripped_room(&self, room_id: &RoomId) -> Room {
let session = self.session.read().await;
let user_id = &session
.as_ref()
.expect("Creating room while not being logged in")
.user_id;
let user_id = &session.as_ref().expect("Creating room while not being logged in").user_id;
self.stripped_rooms
.entry(room_id.clone())
@ -334,10 +327,7 @@ impl Store {
pub(crate) async fn get_or_create_room(&self, room_id: &RoomId, room_type: RoomType) -> Room {
let session = self.session.read().await;
let user_id = &session
.as_ref()
.expect("Creating room while not being logged in")
.user_id;
let user_id = &session.as_ref().expect("Creating room while not being logged in").user_id;
self.rooms
.entry(room_id.clone())
@ -359,7 +349,8 @@ impl Deref for Store {
pub struct StateChanges {
/// The sync token that relates to this update.
pub sync_token: Option<String>,
/// A user session, containing an access token and information about the associated user account.
/// A user session, containing an access token and information about the
/// associated user account.
pub session: Option<Session>,
/// A mapping of event type string to `AnyBasicEvent`.
pub account_data: BTreeMap<String, Raw<AnyGlobalAccountDataEvent>>,
@ -371,14 +362,16 @@ pub struct StateChanges {
/// A mapping of `RoomId` to a map of users and their `MemberEventContent`.
pub profiles: BTreeMap<RoomId, BTreeMap<UserId, MemberEventContent>>,
/// A mapping of `RoomId` to a map of event type string to a state key and `AnySyncStateEvent`.
/// A mapping of `RoomId` to a map of event type string to a state key and
/// `AnySyncStateEvent`.
pub state: BTreeMap<RoomId, BTreeMap<String, BTreeMap<String, Raw<AnySyncStateEvent>>>>,
/// A mapping of `RoomId` to a map of event type string to `AnyBasicEvent`.
pub room_account_data: BTreeMap<RoomId, BTreeMap<String, Raw<AnyRoomAccountDataEvent>>>,
/// A map of `RoomId` to `RoomInfo`.
pub room_infos: BTreeMap<RoomId, RoomInfo>,
/// A mapping of `RoomId` to a map of event type to a map of state key to `AnyStrippedStateEvent`.
/// A mapping of `RoomId` to a map of event type to a map of state key to
/// `AnyStrippedStateEvent`.
pub stripped_state:
BTreeMap<RoomId, BTreeMap<String, BTreeMap<String, Raw<AnyStrippedStateEvent>>>>,
/// A mapping of `RoomId` to a map of users and their `StrippedMemberEvent`.
@ -396,10 +389,7 @@ pub struct StateChanges {
impl StateChanges {
/// Create a new `StateChanges` struct with the given sync_token.
pub fn new(sync_token: String) -> Self {
Self {
sync_token: Some(sync_token),
..Default::default()
}
Self { sync_token: Some(sync_token), ..Default::default() }
}
/// Update the `StateChanges` struct with the given `PresenceEvent`.
@ -409,14 +399,12 @@ impl StateChanges {
/// Update the `StateChanges` struct with the given `RoomInfo`.
pub fn add_room(&mut self, room: RoomInfo) {
self.room_infos
.insert(room.room_id.as_ref().to_owned(), room);
self.room_infos.insert(room.room_id.as_ref().to_owned(), room);
}
/// Update the `StateChanges` struct with the given `RoomInfo`.
pub fn add_stripped_room(&mut self, room: RoomInfo) {
self.invited_room_info
.insert(room.room_id.as_ref().to_owned(), room);
self.invited_room_info.insert(room.room_id.as_ref().to_owned(), room);
}
/// Update the `StateChanges` struct with the given `AnyBasicEvent`.
@ -425,11 +413,11 @@ impl StateChanges {
event: AnyGlobalAccountDataEvent,
raw_event: Raw<AnyGlobalAccountDataEvent>,
) {
self.account_data
.insert(event.content().event_type().to_owned(), raw_event);
self.account_data.insert(event.content().event_type().to_owned(), raw_event);
}
/// Update the `StateChanges` struct with the given room with a new `AnyBasicEvent`.
/// Update the `StateChanges` struct with the given room with a new
/// `AnyBasicEvent`.
pub fn add_room_account_data(
&mut self,
room_id: &RoomId,
@ -442,7 +430,8 @@ impl StateChanges {
.insert(event.content().event_type().to_owned(), raw_event);
}
/// Update the `StateChanges` struct with the given room with a new `StrippedMemberEvent`.
/// Update the `StateChanges` struct with the given room with a new
/// `StrippedMemberEvent`.
pub fn add_stripped_member(&mut self, room_id: &RoomId, event: StrippedMemberEvent) {
let user_id = event.state_key.clone();
@ -452,7 +441,8 @@ impl StateChanges {
.insert(user_id, event);
}
/// Update the `StateChanges` struct with the given room with a new `AnySyncStateEvent`.
/// Update the `StateChanges` struct with the given room with a new
/// `AnySyncStateEvent`.
pub fn add_state_event(
&mut self,
room_id: &RoomId,
@ -467,11 +457,9 @@ impl StateChanges {
.insert(event.state_key().to_string(), raw_event);
}
/// Update the `StateChanges` struct with the given room with a new `Notification`.
/// Update the `StateChanges` struct with the given room with a new
/// `Notification`.
pub fn add_notification(&mut self, room_id: &RoomId, notification: Notification) {
self.notifications
.entry(room_id.to_owned())
.or_insert_with(Vec::new)
.push(notification);
self.notifications.entry(room_id.to_owned()).or_insert_with(Vec::new).push(notification);
}
}

View file

@ -37,18 +37,15 @@ use matrix_sdk_common::{
Raw,
};
use serde::{Deserialize, Serialize};
use sled::{
transaction::{ConflictableTransactionError, TransactionError},
Config, Db, Transactional, Tree,
};
use tracing::info;
use crate::deserialized_responses::MemberEvent;
use self::store_key::{EncryptedEvent, StoreKey};
use super::{Result, RoomInfo, StateChanges, StateStore, StoreError};
use crate::deserialized_responses::MemberEvent;
#[derive(Debug, Serialize, Deserialize)]
pub enum DatabaseType {
@ -111,13 +108,7 @@ impl EncodeKey for &str {
impl EncodeKey for (&str, &str) {
fn encode(&self) -> Vec<u8> {
[
self.0.as_bytes(),
&[Self::SEPARATOR],
self.1.as_bytes(),
&[Self::SEPARATOR],
]
.concat()
[self.0.as_bytes(), &[Self::SEPARATOR], self.1.as_bytes(), &[Self::SEPARATOR]].concat()
}
}
@ -167,9 +158,7 @@ impl std::fmt::Debug for SledStore {
if let Some(path) = &self.path {
f.debug_struct("SledStore").field("path", &path).finish()
} else {
f.debug_struct("SledStore")
.field("path", &"memory store")
.finish()
f.debug_struct("SledStore").field("path", &"memory store").finish()
}
}
}
@ -239,8 +228,7 @@ impl SledStore {
} else {
let key = StoreKey::new().map_err::<StoreError, _>(|e| e.into())?;
let encrypted_key = DatabaseType::Encrypted(
key.export(passphrase)
.map_err::<StoreError, _>(|e| e.into())?,
key.export(passphrase).map_err::<StoreError, _>(|e| e.into())?,
);
db.insert("store_key".encode(), serde_json::to_vec(&encrypted_key)?)?;
key
@ -278,8 +266,7 @@ impl SledStore {
}
pub async fn save_filter(&self, filter_name: &str, filter_id: &str) -> Result<()> {
self.session
.insert(("filter", filter_name).encode(), filter_id)?;
self.session.insert(("filter", filter_name).encode(), filter_id)?;
Ok(())
}
@ -479,11 +466,7 @@ impl SledStore {
}
pub async fn get_presence_event(&self, user_id: &UserId) -> Result<Option<Raw<PresenceEvent>>> {
Ok(self
.presence
.get(user_id.encode())?
.map(|e| self.deserialize_event(&e))
.transpose()?)
Ok(self.presence.get(user_id.encode())?.map(|e| self.deserialize_event(&e)).transpose()?)
}
pub async fn get_state_event(
@ -534,14 +517,10 @@ impl SledStore {
&self,
room_id: &RoomId,
) -> impl Stream<Item = Result<UserId>> {
stream::iter(
self.invited_user_ids
.scan_prefix(room_id.encode())
.map(|u| {
UserId::try_from(String::from_utf8_lossy(&u?.1).to_string())
.map_err(StoreError::Identifier)
}),
)
stream::iter(self.invited_user_ids.scan_prefix(room_id.encode()).map(|u| {
UserId::try_from(String::from_utf8_lossy(&u?.1).to_string())
.map_err(StoreError::Identifier)
}))
}
pub async fn get_joined_user_ids(
@ -557,9 +536,7 @@ impl SledStore {
pub async fn get_room_infos(&self) -> impl Stream<Item = Result<RoomInfo>> {
let db = self.clone();
stream::iter(
self.room_info
.iter()
.map(move |r| db.deserialize_event(&r?.1).map_err(|e| e.into())),
self.room_info.iter().map(move |r| db.deserialize_event(&r?.1).map_err(|e| e.into())),
)
}
@ -683,8 +660,7 @@ impl StateStore for SledStore {
room_id: &RoomId,
display_name: &str,
) -> Result<BTreeSet<UserId>> {
self.get_users_with_display_name(room_id, display_name)
.await
self.get_users_with_display_name(room_id, display_name).await
}
async fn get_account_data_event(
@ -770,11 +746,7 @@ mod test {
let room_id = room_id!("!test:localhost");
let user_id = user_id();
assert!(store
.get_member_event(&room_id, &user_id)
.await
.unwrap()
.is_none());
assert!(store.get_member_event(&room_id, &user_id).await.unwrap().is_none());
let mut changes = StateChanges::default();
changes
.members
@ -783,11 +755,7 @@ mod test {
.insert(user_id.clone(), membership_event());
store.save_changes(&changes).await.unwrap();
assert!(store
.get_member_event(&room_id, &user_id)
.await
.unwrap()
.is_some());
assert!(store.get_member_event(&room_id, &user_id).await.unwrap().is_some());
}
#[async_test]

View file

@ -21,11 +21,10 @@ use chacha20poly1305::{
use hmac::Hmac;
use pbkdf2::pbkdf2;
use rand::{thread_rng, Error as RngError, Fill};
use serde::{Deserialize, Serialize};
use sha2::Sha256;
use zeroize::{Zeroize, Zeroizing};
use serde::{Deserialize, Serialize};
use crate::StoreError;
const VERSION: u8 = 1;
@ -76,9 +75,11 @@ pub struct EncryptedEvent {
#[derive(Debug, Serialize, Deserialize, PartialEq)]
pub enum KdfInfo {
Pbkdf2ToChaCha20Poly1305 {
/// The number of PBKDF rounds that were used when deriving the store key.
/// The number of PBKDF rounds that were used when deriving the store
/// key.
rounds: u32,
/// The salt that was used when the passphrase was expanded into a store key.
/// The salt that was used when the passphrase was expanded into a store
/// key.
kdf_salt: Vec<u8>,
},
}
@ -170,10 +171,7 @@ impl StoreKey {
cipher.encrypt(Nonce::from_slice(nonce.as_ref()), self.inner.as_slice())?;
Ok(EncryptedStoreKey {
kdf_info: KdfInfo::Pbkdf2ToChaCha20Poly1305 {
rounds: KDF_ROUNDS,
kdf_salt: salt,
},
kdf_info: KdfInfo::Pbkdf2ToChaCha20Poly1305 { rounds: KDF_ROUNDS, kdf_salt: salt },
ciphertext_info: CipherTextInfo::ChaCha20Poly1305 { nonce, ciphertext },
})
}
@ -196,11 +194,7 @@ impl StoreKey {
let ciphertext = cipher.encrypt(xnonce, event.as_ref())?;
Ok(EncryptedEvent {
version: VERSION,
ciphertext,
nonce,
})
Ok(EncryptedEvent { version: VERSION, ciphertext, nonce })
}
pub fn decrypt<T: for<'b> Deserialize<'b>>(&self, event: EncryptedEvent) -> Result<T, Error> {
@ -248,9 +242,10 @@ impl StoreKey {
#[cfg(test)]
mod test {
use super::StoreKey;
use serde_json::{json, Value};
use super::StoreKey;
#[test]
fn generating() {
StoreKey::new().unwrap();

View file

@ -1,3 +1,5 @@
use std::{collections::BTreeMap, convert::TryFrom, time::SystemTime};
use ruma::{
api::client::r0::sync::sync_events::{
Ephemeral, InvitedRoom, Presence, RoomAccountData, State, ToDevice,
@ -6,7 +8,6 @@ use ruma::{
DeviceIdBox,
};
use serde::{Deserialize, Serialize};
use std::{collections::BTreeMap, convert::TryFrom, time::SystemTime};
use super::{
api::r0::{
@ -103,16 +104,14 @@ pub struct SyncRoomEvent {
impl From<Raw<AnySyncRoomEvent>> for SyncRoomEvent {
fn from(inner: Raw<AnySyncRoomEvent>) -> Self {
Self {
encryption_info: None,
event: inner,
}
Self { encryption_info: None, event: inner }
}
}
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
pub struct SyncResponse {
/// The batch token to supply in the `since` param of the next `/sync` request.
/// The batch token to supply in the `since` param of the next `/sync`
/// request.
pub next_batch: String,
/// Updates to rooms.
pub rooms: Rooms,
@ -137,10 +136,7 @@ pub struct SyncResponse {
impl SyncResponse {
pub fn new(next_batch: String) -> Self {
Self {
next_batch,
..Default::default()
}
Self { next_batch, ..Default::default() }
}
}
@ -161,14 +157,15 @@ pub struct JoinedRoom {
pub unread_notifications: UnreadNotificationsCount,
/// The timeline of messages and state changes in the room.
pub timeline: Timeline,
/// Updates to the state, between the time indicated by the `since` parameter, and the start
/// of the `timeline` (or all state up to the start of the `timeline`, if `since` is not
/// given, or `full_state` is true).
/// Updates to the state, between the time indicated by the `since`
/// parameter, and the start of the `timeline` (or all state up to the
/// start of the `timeline`, if `since` is not given, or `full_state` is
/// true).
pub state: State,
/// The private data that this user has attached to this room.
pub account_data: RoomAccountData,
/// The ephemeral events in the room that aren't recorded in the timeline or state of the
/// room. e.g. typing.
/// The ephemeral events in the room that aren't recorded in the timeline or
/// state of the room. e.g. typing.
pub ephemeral: Ephemeral,
}
@ -180,20 +177,15 @@ impl JoinedRoom {
ephemeral: Ephemeral,
unread_notifications: UnreadNotificationsCount,
) -> Self {
Self {
unread_notifications,
timeline,
state,
account_data,
ephemeral,
}
Self { unread_notifications, timeline, state, account_data, ephemeral }
}
}
/// Counts of unread notifications for a room.
#[derive(Copy, Clone, Debug, Default, Deserialize, Serialize)]
pub struct UnreadNotificationsCount {
/// The number of unread notifications for this room with the highlight flag set.
/// The number of unread notifications for this room with the highlight flag
/// set.
pub highlight_count: u64,
/// The total number of unread notifications for this room.
pub notification_count: u64,
@ -203,10 +195,7 @@ impl From<RumaUnreadNotificationsCount> for UnreadNotificationsCount {
fn from(notifications: RumaUnreadNotificationsCount) -> Self {
Self {
highlight_count: notifications.highlight_count.map(|c| c.into()).unwrap_or(0),
notification_count: notifications
.notification_count
.map(|c| c.into())
.unwrap_or(0),
notification_count: notifications.notification_count.map(|c| c.into()).unwrap_or(0),
}
}
}
@ -216,9 +205,10 @@ pub struct LeftRoom {
/// The timeline of messages and state changes in the room up to the point
/// when the user left.
pub timeline: Timeline,
/// Updates to the state, between the time indicated by the `since` parameter, and the start
/// of the `timeline` (or all state up to the start of the `timeline`, if `since` is not
/// given, or `full_state` is true).
/// Updates to the state, between the time indicated by the `since`
/// parameter, and the start of the `timeline` (or all state up to the
/// start of the `timeline`, if `since` is not given, or `full_state` is
/// true).
pub state: State,
/// The private data that this user has attached to this room.
pub account_data: RoomAccountData,
@ -226,18 +216,15 @@ pub struct LeftRoom {
impl LeftRoom {
pub fn new(timeline: Timeline, state: State, account_data: RoomAccountData) -> Self {
Self {
timeline,
state,
account_data,
}
Self { timeline, state, account_data }
}
}
/// Events in the room.
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
pub struct Timeline {
/// True if the number of events returned was limited by the `limit` on the filter.
/// True if the number of events returned was limited by the `limit` on the
/// filter.
pub limited: bool,
/// A token that can be supplied to to the `from` parameter of the
@ -250,11 +237,7 @@ pub struct Timeline {
impl Timeline {
pub fn new(limited: bool, prev_batch: Option<String>) -> Self {
Self {
limited,
prev_batch,
..Default::default()
}
Self { limited, prev_batch, ..Default::default() }
}
}

View file

@ -6,14 +6,12 @@ use std::{
task::{Context, Poll},
};
#[cfg(not(target_arch = "wasm32"))]
pub use tokio::spawn;
#[cfg(target_arch = "wasm32")]
use wasm_bindgen_futures::spawn_local;
#[cfg(target_arch = "wasm32")]
use futures::{future::RemoteHandle, Future, FutureExt};
#[cfg(not(target_arch = "wasm32"))]
pub use tokio::spawn;
#[cfg(target_arch = "wasm32")]
use wasm_bindgen_futures::spawn_local;
#[cfg(target_arch = "wasm32")]
pub fn spawn<F, T>(future: F) -> JoinHandle<T>

View file

@ -17,7 +17,6 @@ pub use ruma::{
serde::{CanonicalJsonValue, Raw},
thirdparty, uint, Int, Outgoing, UInt,
};
pub use uuid;
pub mod deserialized_responses;

View file

@ -4,6 +4,5 @@
#[cfg(target_arch = "wasm32")]
pub use futures_locks::{Mutex, MutexGuard, RwLock, RwLockReadGuard, RwLockWriteGuard};
#[cfg(not(target_arch = "wasm32"))]
pub use tokio::sync::{Mutex, MutexGuard, RwLock, RwLockReadGuard, RwLockWriteGuard};

View file

@ -4,7 +4,6 @@ mod perf;
use std::sync::Arc;
use criterion::*;
use matrix_sdk_common::{
api::r0::{
keys::{claim_keys, get_keys},
@ -50,17 +49,12 @@ fn huge_keys_query_resopnse() -> get_keys::Response {
}
pub fn keys_query(c: &mut Criterion) {
let runtime = Builder::new_multi_thread()
.build()
.expect("Can't create runtime");
let runtime = Builder::new_multi_thread().build().expect("Can't create runtime");
let machine = OlmMachine::new(&alice_id(), &alice_device_id());
let response = keys_query_response();
let uuid = Uuid::new_v4();
let count = response
.device_keys
.values()
.fold(0, |acc, d| acc + d.len())
let count = response.device_keys.values().fold(0, |acc, d| acc + d.len())
+ response.master_keys.len()
+ response.self_signing_keys.len()
+ response.user_signing_keys.len();
@ -70,14 +64,10 @@ pub fn keys_query(c: &mut Criterion) {
let name = format!("{} device and cross signing keys", count);
group.bench_with_input(
BenchmarkId::new("memory store", &name),
&response,
|b, response| {
b.to_async(&runtime)
.iter(|| async { machine.mark_request_as_sent(&uuid, response).await.unwrap() })
},
);
group.bench_with_input(BenchmarkId::new("memory store", &name), &response, |b, response| {
b.to_async(&runtime)
.iter(|| async { machine.mark_request_as_sent(&uuid, response).await.unwrap() })
});
let dir = tempfile::tempdir().unwrap();
let machine = runtime
@ -89,99 +79,74 @@ pub fn keys_query(c: &mut Criterion) {
))
.unwrap();
group.bench_with_input(
BenchmarkId::new("sled store", &name),
&response,
|b, response| {
b.to_async(&runtime)
.iter(|| async { machine.mark_request_as_sent(&uuid, response).await.unwrap() })
},
);
group.bench_with_input(BenchmarkId::new("sled store", &name), &response, |b, response| {
b.to_async(&runtime)
.iter(|| async { machine.mark_request_as_sent(&uuid, response).await.unwrap() })
});
group.finish()
}
pub fn keys_claiming(c: &mut Criterion) {
let runtime = Arc::new(
Builder::new_multi_thread()
.build()
.expect("Can't create runtime"),
);
let runtime = Arc::new(Builder::new_multi_thread().build().expect("Can't create runtime"));
let keys_query_response = keys_query_response();
let uuid = Uuid::new_v4();
let response = keys_claim_response();
let count = response
.one_time_keys
.values()
.fold(0, |acc, d| acc + d.len());
let count = response.one_time_keys.values().fold(0, |acc, d| acc + d.len());
let mut group = c.benchmark_group("Olm session creation");
group.throughput(Throughput::Elements(count as u64));
let name = format!("{} one-time keys", count);
group.bench_with_input(
BenchmarkId::new("memory store", &name),
&response,
|b, response| {
b.iter_batched(
|| {
let machine = OlmMachine::new(&alice_id(), &alice_device_id());
runtime
.block_on(machine.mark_request_as_sent(&uuid, &keys_query_response))
.unwrap();
(machine, runtime.clone())
},
move |(machine, runtime)| {
runtime
.block_on(machine.mark_request_as_sent(&uuid, response))
.unwrap()
},
BatchSize::SmallInput,
)
},
);
group.bench_with_input(BenchmarkId::new("memory store", &name), &response, |b, response| {
b.iter_batched(
|| {
let machine = OlmMachine::new(&alice_id(), &alice_device_id());
runtime
.block_on(machine.mark_request_as_sent(&uuid, &keys_query_response))
.unwrap();
(machine, runtime.clone())
},
move |(machine, runtime)| {
runtime.block_on(machine.mark_request_as_sent(&uuid, response)).unwrap()
},
BatchSize::SmallInput,
)
});
group.bench_with_input(
BenchmarkId::new("sled store", &name),
&response,
|b, response| {
b.iter_batched(
|| {
let dir = tempfile::tempdir().unwrap();
let machine = runtime
.block_on(OlmMachine::new_with_default_store(
&alice_id(),
&alice_device_id(),
dir.path(),
None,
))
.unwrap();
runtime
.block_on(machine.mark_request_as_sent(&uuid, &keys_query_response))
.unwrap();
(machine, runtime.clone())
},
move |(machine, runtime)| {
runtime
.block_on(machine.mark_request_as_sent(&uuid, response))
.unwrap()
},
BatchSize::SmallInput,
)
},
);
group.bench_with_input(BenchmarkId::new("sled store", &name), &response, |b, response| {
b.iter_batched(
|| {
let dir = tempfile::tempdir().unwrap();
let machine = runtime
.block_on(OlmMachine::new_with_default_store(
&alice_id(),
&alice_device_id(),
dir.path(),
None,
))
.unwrap();
runtime
.block_on(machine.mark_request_as_sent(&uuid, &keys_query_response))
.unwrap();
(machine, runtime.clone())
},
move |(machine, runtime)| {
runtime.block_on(machine.mark_request_as_sent(&uuid, response)).unwrap()
},
BatchSize::SmallInput,
)
});
group.finish()
}
pub fn room_key_sharing(c: &mut Criterion) {
let runtime = Builder::new_multi_thread()
.build()
.expect("Can't create runtime");
let runtime = Builder::new_multi_thread().build().expect("Can't create runtime");
let keys_query_response = keys_query_response();
let uuid = Uuid::new_v4();
@ -191,18 +156,11 @@ pub fn room_key_sharing(c: &mut Criterion) {
let to_device_response = ToDeviceResponse::new();
let users: Vec<UserId> = keys_query_response.device_keys.keys().cloned().collect();
let count = response
.one_time_keys
.values()
.fold(0, |acc, d| acc + d.len());
let count = response.one_time_keys.values().fold(0, |acc, d| acc + d.len());
let machine = OlmMachine::new(&alice_id(), &alice_device_id());
runtime
.block_on(machine.mark_request_as_sent(&uuid, &keys_query_response))
.unwrap();
runtime
.block_on(machine.mark_request_as_sent(&uuid, &response))
.unwrap();
runtime.block_on(machine.mark_request_as_sent(&uuid, &keys_query_response)).unwrap();
runtime.block_on(machine.mark_request_as_sent(&uuid, &response)).unwrap();
let mut group = c.benchmark_group("Room key sharing");
group.throughput(Throughput::Elements(count as u64));
@ -218,10 +176,7 @@ pub fn room_key_sharing(c: &mut Criterion) {
assert!(!requests.is_empty());
for request in requests {
machine
.mark_request_as_sent(&request.txn_id, &to_device_response)
.await
.unwrap();
machine.mark_request_as_sent(&request.txn_id, &to_device_response).await.unwrap();
}
machine.invalidate_group_session(&room_id).await.unwrap();
@ -237,12 +192,8 @@ pub fn room_key_sharing(c: &mut Criterion) {
None,
))
.unwrap();
runtime
.block_on(machine.mark_request_as_sent(&uuid, &keys_query_response))
.unwrap();
runtime
.block_on(machine.mark_request_as_sent(&uuid, &response))
.unwrap();
runtime.block_on(machine.mark_request_as_sent(&uuid, &keys_query_response)).unwrap();
runtime.block_on(machine.mark_request_as_sent(&uuid, &response)).unwrap();
group.bench_function(BenchmarkId::new("sled store", &name), |b| {
b.to_async(&runtime).iter(|| async {
@ -254,10 +205,7 @@ pub fn room_key_sharing(c: &mut Criterion) {
assert!(!requests.is_empty());
for request in requests {
machine
.mark_request_as_sent(&request.txn_id, &to_device_response)
.await
.unwrap();
machine.mark_request_as_sent(&request.txn_id, &to_device_response).await.unwrap();
}
machine.invalidate_group_session(&room_id).await.unwrap();
@ -268,28 +216,21 @@ pub fn room_key_sharing(c: &mut Criterion) {
}
pub fn devices_missing_sessions_collecting(c: &mut Criterion) {
let runtime = Builder::new_multi_thread()
.build()
.expect("Can't create runtime");
let runtime = Builder::new_multi_thread().build().expect("Can't create runtime");
let machine = OlmMachine::new(&alice_id(), &alice_device_id());
let response = huge_keys_query_resopnse();
let uuid = Uuid::new_v4();
let users: Vec<UserId> = response.device_keys.keys().cloned().collect();
let count = response
.device_keys
.values()
.fold(0, |acc, d| acc + d.len());
let count = response.device_keys.values().fold(0, |acc, d| acc + d.len());
let mut group = c.benchmark_group("Devices missing sessions collecting");
group.throughput(Throughput::Elements(count as u64));
let name = format!("{} devices", count);
runtime
.block_on(machine.mark_request_as_sent(&uuid, &response))
.unwrap();
runtime.block_on(machine.mark_request_as_sent(&uuid, &response)).unwrap();
group.bench_function(BenchmarkId::new("memory store", &name), |b| {
b.to_async(&runtime).iter_with_large_drop(|| async {
@ -307,9 +248,7 @@ pub fn devices_missing_sessions_collecting(c: &mut Criterion) {
))
.unwrap();
runtime
.block_on(machine.mark_request_as_sent(&uuid, &response))
.unwrap();
runtime.block_on(machine.mark_request_as_sent(&uuid, &response)).unwrap();
group.bench_function(BenchmarkId::new("sled store", &name), |b| {
b.to_async(&runtime)

View file

@ -6,8 +6,9 @@ use std::{fs::File, os::raw::c_int, path::Path};
use criterion::profiler::Profiler;
use pprof::ProfilerGuard;
/// Small custom profiler that can be used with Criterion to create a flamegraph for benchmarks.
/// Also see [the Criterion documentation on this][custom-profiler].
/// Small custom profiler that can be used with Criterion to create a flamegraph
/// for benchmarks. Also see [the Criterion documentation on
/// this][custom-profiler].
///
/// ## Example on how to enable the custom profiler:
///
@ -30,12 +31,12 @@ use pprof::ProfilerGuard;
/// }
/// ```
///
/// The neat thing about this is that it will sample _only_ the benchmark, and not other stuff like
/// the setup process.
/// The neat thing about this is that it will sample _only_ the benchmark, and
/// not other stuff like the setup process.
///
/// Further, it will only kick in if `--profile-time <time>` is passed to the benchmark binary.
/// A flamegraph will be created for each individual benchmark in its report directory under
/// `profile/flamegraph.svg`.
/// Further, it will only kick in if `--profile-time <time>` is passed to the
/// benchmark binary. A flamegraph will be created for each individual benchmark
/// in its report directory under `profile/flamegraph.svg`.
///
/// [custom-profiler]: https://bheisler.github.io/criterion.rs/book/user_guide/profiling.html#implementing-in-process-profiling-hooks
pub struct FlamegraphProfiler<'a> {
@ -45,10 +46,7 @@ pub struct FlamegraphProfiler<'a> {
impl<'a> FlamegraphProfiler<'a> {
pub fn new(frequency: c_int) -> Self {
FlamegraphProfiler {
frequency,
active_profiler: None,
}
FlamegraphProfiler { frequency, active_profiler: None }
}
}

View file

@ -17,21 +17,17 @@ use std::{
io::{Error as IoError, ErrorKind, Read},
};
use thiserror::Error;
use zeroize::Zeroizing;
use serde::{Deserialize, Serialize};
use matrix_sdk_common::events::room::JsonWebKey;
use getrandom::getrandom;
use aes_ctr::{
cipher::{NewStreamCipher, SyncStreamCipher},
Aes256Ctr,
};
use base64::DecodeError;
use getrandom::getrandom;
use matrix_sdk_common::events::room::JsonWebKey;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use thiserror::Error;
use zeroize::Zeroizing;
use crate::utilities::{decode, decode_url_safe, encode, encode_url_safe};
@ -59,10 +55,7 @@ impl<'a, R: Read> Read for AttachmentDecryptor<'a, R> {
if hash.as_slice() == self.expected_hash.as_slice() {
Ok(0)
} else {
Err(IoError::new(
ErrorKind::Other,
"Hash missmatch while decrypting",
))
Err(IoError::new(ErrorKind::Other, "Hash missmatch while decrypting"))
}
} else {
self.sha.update(&buf[0..read_bytes]);
@ -130,23 +123,14 @@ impl<'a, R: Read + 'a> AttachmentDecryptor<'a, R> {
return Err(DecryptorError::UnknownVersion);
}
let hash = decode(
info.hashes
.get("sha256")
.ok_or(DecryptorError::MissingHash)?,
)?;
let hash = decode(info.hashes.get("sha256").ok_or(DecryptorError::MissingHash)?)?;
let key = Zeroizing::from(decode_url_safe(info.web_key.k)?);
let iv = decode(info.iv)?;
let sha = Sha256::default();
let aes = Aes256Ctr::new_var(&key, &iv).map_err(|_| DecryptorError::KeyNonceLength)?;
Ok(AttachmentDecryptor {
inner_reader: input,
expected_hash: hash,
sha,
aes,
})
Ok(AttachmentDecryptor { inner_reader: input, expected_hash: hash, sha, aes })
}
}
@ -168,9 +152,7 @@ impl<'a, R: Read + 'a> Read for AttachmentEncryptor<'a, R> {
if read_bytes == 0 {
let hash = self.sha.finalize_reset();
self.hashes
.entry("sha256".to_owned())
.or_insert_with(|| encode(hash));
self.hashes.entry("sha256".to_owned()).or_insert_with(|| encode(hash));
Ok(0)
} else {
self.aes.apply_keystream(&mut buf[0..read_bytes]);
@ -244,9 +226,7 @@ impl<'a, R: Read + 'a> AttachmentEncryptor<'a, R> {
/// Consume the encryptor and get the encryption key.
pub fn finish(mut self) -> EncryptionInfo {
let hash = self.sha.finalize();
self.hashes
.entry("sha256".to_owned())
.or_insert_with(|| encode(hash));
self.hashes.entry("sha256".to_owned()).or_insert_with(|| encode(hash));
EncryptionInfo {
version: VERSION.to_string(),
@ -274,10 +254,12 @@ pub struct EncryptionInfo {
#[cfg(test)]
mod test {
use super::{AttachmentDecryptor, AttachmentEncryptor, EncryptionInfo};
use serde_json::json;
use std::io::{Cursor, Read};
use serde_json::json;
use super::{AttachmentDecryptor, AttachmentEncryptor, EncryptionInfo};
const EXAMPLE_DATA: &[u8] = &[
179, 154, 118, 127, 186, 127, 110, 33, 203, 33, 33, 134, 67, 100, 173, 46, 235, 27, 215,
172, 36, 26, 75, 47, 33, 160,

View file

@ -12,20 +12,19 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use serde_json::Error as SerdeError;
use std::io::{Cursor, Read, Seek, SeekFrom};
use thiserror::Error;
use byteorder::{BigEndian, ReadBytesExt};
use getrandom::getrandom;
use aes_ctr::{
cipher::{NewStreamCipher, SyncStreamCipher},
Aes256Ctr,
};
use byteorder::{BigEndian, ReadBytesExt};
use getrandom::getrandom;
use hmac::{Hmac, Mac, NewMac};
use pbkdf2::pbkdf2;
use serde_json::Error as SerdeError;
use sha2::{Sha256, Sha512};
use thiserror::Error;
use crate::{
olm::ExportedRoomKey,
@ -99,14 +98,10 @@ pub fn decrypt_key_export(
return Err(KeyExportError::InvalidHeaders);
}
let payload: String = x
.lines()
.filter(|l| !(l.starts_with(HEADER) || l.starts_with(FOOTER)))
.collect();
let payload: String =
x.lines().filter(|l| !(l.starts_with(HEADER) || l.starts_with(FOOTER))).collect();
Ok(serde_json::from_str(&decrypt_helper(
&payload, passphrase,
)?)?)
Ok(serde_json::from_str(&decrypt_helper(&payload, passphrase)?)?)
}
/// Encrypt the list of exported room keys using the given passphrase.
@ -231,12 +226,12 @@ fn decrypt_helper(ciphertext: &str, passphrase: &str) -> Result<String, KeyExpor
#[cfg(test)]
mod test {
use indoc::indoc;
use proptest::prelude::*;
use std::io::Cursor;
use indoc::indoc;
use matrix_sdk_common::identifiers::room_id;
use matrix_sdk_test::async_test;
use proptest::prelude::*;
use super::{decode, decrypt_helper, decrypt_key_export, encrypt_helper, encrypt_key_export};
use crate::machine::test::get_prepared_machine;
@ -261,10 +256,7 @@ mod test {
"};
fn export_wihtout_headers() -> String {
TEST_EXPORT
.lines()
.filter(|l| !l.starts_with("-----"))
.collect()
TEST_EXPORT.lines().filter(|l| !l.starts_with("-----")).collect()
}
#[test]
@ -301,14 +293,8 @@ mod test {
let (machine, _) = get_prepared_machine().await;
let room_id = room_id!("!test:localhost");
machine
.create_outbound_group_session_with_defaults(&room_id)
.await
.unwrap();
let export = machine
.export_keys(|s| s.room_id() == &room_id)
.await
.unwrap();
machine.create_outbound_group_session_with_defaults(&room_id).await.unwrap();
let export = machine.export_keys(|s| s.room_id() == &room_id).await.unwrap();
assert!(!export.is_empty());
@ -316,10 +302,7 @@ mod test {
let decrypted = decrypt_key_export(Cursor::new(encrypted), "1234").unwrap();
assert_eq!(export, decrypted);
assert_eq!(
machine.import_keys(decrypted, |_, _| {}).await.unwrap(),
(0, 1)
);
assert_eq!(machine.import_keys(decrypted, |_, _| {}).await.unwrap(), (0, 1));
}
#[test]

View file

@ -39,24 +39,17 @@ use serde::{Deserialize, Deserializer, Serialize, Serializer};
use serde_json::{json, Value};
use tracing::warn;
use crate::{
olm::{InboundGroupSession, PrivateCrossSigningIdentity, Session},
store::{Changes, DeviceChanges},
OutgoingVerificationRequest,
};
#[cfg(test)]
use crate::{OlmMachine, ReadOnlyAccount};
use super::{atomic_bool_deserializer, atomic_bool_serializer};
use crate::{
error::{EventError, OlmError, OlmResult, SignatureError},
identities::{OwnUserIdentity, UserIdentities},
olm::Utility,
store::{CryptoStore, Result as StoreResult},
olm::{InboundGroupSession, PrivateCrossSigningIdentity, Session, Utility},
store::{Changes, CryptoStore, DeviceChanges, Result as StoreResult},
verification::VerificationMachine,
Sas, ToDeviceRequest,
OutgoingVerificationRequest, Sas, ToDeviceRequest,
};
use super::{atomic_bool_deserializer, atomic_bool_serializer};
#[cfg(test)]
use crate::{OlmMachine, ReadOnlyAccount};
/// A read-only version of a `Device`.
#[derive(Clone, Serialize, Deserialize)]
@ -120,9 +113,7 @@ pub struct Device {
impl std::fmt::Debug for Device {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Device")
.field("device", &self.inner)
.finish()
f.debug_struct("Device").field("device", &self.inner).finish()
}
}
@ -139,10 +130,7 @@ impl Device {
///
/// Returns a `Sas` object and to-device request that needs to be sent out.
pub async fn start_verification(&self) -> StoreResult<(Sas, ToDeviceRequest)> {
let (sas, request) = self
.verification_machine
.start_sas(self.inner.clone())
.await?;
let (sas, request) = self.verification_machine.start_sas(self.inner.clone()).await?;
if let OutgoingVerificationRequest::ToDevice(r) = request {
Ok((sas, r))
@ -162,8 +150,7 @@ impl Device {
/// Get the trust state of the device.
pub fn trust_state(&self) -> bool {
self.inner
.trust_state(&self.own_identity, &self.device_owner_identity)
self.inner.trust_state(&self.own_identity, &self.device_owner_identity)
}
/// Set the local trust state of the device to the given state.
@ -178,10 +165,7 @@ impl Device {
self.inner.set_trust_state(trust_state);
let changes = Changes {
devices: DeviceChanges {
changed: vec![self.inner.clone()],
..Default::default()
},
devices: DeviceChanges { changed: vec![self.inner.clone()], ..Default::default() },
..Default::default()
};
@ -200,9 +184,7 @@ impl Device {
event_type: EventType,
content: Value,
) -> OlmResult<(Session, EncryptedEventContent)> {
self.inner
.encrypt(&**self.verification_machine.store, event_type, content)
.await
self.inner.encrypt(&**self.verification_machine.store, event_type, content).await
}
/// Encrypt the given inbound group session as a forwarded room key for this
@ -261,9 +243,7 @@ impl UserDevices {
/// Returns true if there is at least one devices of this user that is
/// considered to be verified, false otherwise.
pub fn is_any_verified(&self) -> bool {
self.inner
.values()
.any(|d| d.trust_state(&self.own_identity, &self.device_owner_identity))
self.inner.values().any(|d| d.trust_state(&self.own_identity, &self.device_owner_identity))
}
/// Iterator over all the device ids of the user devices.
@ -348,8 +328,7 @@ impl ReadOnlyDevice {
/// Get the key of the given key algorithm belonging to this device.
pub fn get_key(&self, algorithm: DeviceKeyAlgorithm) -> Option<&String> {
self.keys
.get(&DeviceKeyId::from_parts(algorithm, &self.device_id))
self.keys.get(&DeviceKeyId::from_parts(algorithm, &self.device_id))
}
/// Get a map containing all the device keys.
@ -496,9 +475,8 @@ impl ReadOnlyDevice {
}
fn is_signed_by_device(&self, json: &mut Value) -> Result<(), SignatureError> {
let signing_key = self
.get_key(DeviceKeyAlgorithm::Ed25519)
.ok_or(SignatureError::MissingSigningKey)?;
let signing_key =
self.get_key(DeviceKeyAlgorithm::Ed25519).ok_or(SignatureError::MissingSigningKey)?;
let utility = Utility::new();
@ -590,14 +568,15 @@ impl PartialEq for ReadOnlyDevice {
#[cfg(test)]
pub(crate) mod test {
use serde_json::json;
use std::convert::TryFrom;
use crate::identities::{LocalTrust, ReadOnlyDevice};
use matrix_sdk_common::{
encryption::DeviceKeys,
identifiers::{user_id, DeviceKeyAlgorithm},
};
use serde_json::json;
use crate::identities::{LocalTrust, ReadOnlyDevice};
fn device_keys() -> DeviceKeys {
let device_keys = json!({
@ -640,10 +619,7 @@ pub(crate) mod test {
assert_eq!(device_id, device.device_id());
assert_eq!(device.algorithms.len(), 2);
assert_eq!(LocalTrust::Unset, device.local_trust_state());
assert_eq!(
"Alice's mobile phone",
device.display_name().as_ref().unwrap()
);
assert_eq!("Alice's mobile phone", device.display_name().as_ref().unwrap());
assert_eq!(
device.get_key(DeviceKeyAlgorithm::Curve25519).unwrap(),
"xfgbLIC5WAl1OIkpOzoxpCe8FsRDT6nch7NQsOb15nc"
@ -658,10 +634,7 @@ pub(crate) mod test {
fn update_a_device() {
let mut device = get_device();
assert_eq!(
"Alice's mobile phone",
device.display_name().as_ref().unwrap()
);
assert_eq!("Alice's mobile phone", device.display_name().as_ref().unwrap());
let display_name = "Alice's work computer".to_owned();

View file

@ -12,20 +12,20 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use futures::future::join_all;
use std::{
collections::{BTreeMap, HashSet},
convert::TryFrom,
sync::Arc,
};
use tracing::{trace, warn};
use futures::future::join_all;
use matrix_sdk_common::{
api::r0::keys::get_keys::Response as KeysQueryResponse,
encryption::DeviceKeys,
executor::spawn,
identifiers::{DeviceIdBox, UserId},
};
use tracing::{trace, warn};
use crate::{
error::OlmResult,
@ -54,11 +54,7 @@ impl IdentityManager {
const MAX_KEY_QUERY_USERS: usize = 250;
pub fn new(user_id: Arc<UserId>, device_id: Arc<DeviceIdBox>, store: Store) -> Self {
IdentityManager {
user_id,
device_id,
store,
}
IdentityManager { user_id, device_id, store }
}
fn user_id(&self) -> &UserId {
@ -78,9 +74,8 @@ impl IdentityManager {
&self,
response: &KeysQueryResponse,
) -> OlmResult<(DeviceChanges, IdentityChanges)> {
let changed_devices = self
.handle_devices_from_key_query(response.device_keys.clone())
.await?;
let changed_devices =
self.handle_devices_from_key_query(response.device_keys.clone()).await?;
let changed_identities = self.handle_cross_singing_keys(response).await?;
let changes = Changes {
@ -104,9 +99,8 @@ impl IdentityManager {
store: Store,
device_keys: DeviceKeys,
) -> StoreResult<DeviceChange> {
let old_device = store
.get_readonly_device(&device_keys.user_id, &device_keys.device_id)
.await?;
let old_device =
store.get_readonly_device(&device_keys.user_id, &device_keys.device_id).await?;
if let Some(mut device) = old_device {
if let Err(e) = device.update_device(&device_keys) {
@ -148,25 +142,20 @@ impl IdentityManager {
let current_devices: HashSet<DeviceIdBox> = device_map.keys().cloned().collect();
let tasks = device_map
.into_iter()
.filter_map(|(device_id, device_keys)| {
// We don't need our own device in the device store.
if user_id == *own_user_id && device_id == *own_device_id {
None
} else if user_id != device_keys.user_id || device_id != device_keys.device_id {
warn!(
"Mismatch in device keys payload of device {}|{} from user {}|{}",
device_id, device_keys.device_id, user_id, device_keys.user_id
);
None
} else {
Some(spawn(Self::update_or_create_device(
store.clone(),
device_keys,
)))
}
});
let tasks = device_map.into_iter().filter_map(|(device_id, device_keys)| {
// We don't need our own device in the device store.
if user_id == *own_user_id && device_id == *own_device_id {
None
} else if user_id != device_keys.user_id || device_id != device_keys.device_id {
warn!(
"Mismatch in device keys payload of device {}|{} from user {}|{}",
device_id, device_keys.device_id, user_id, device_keys.user_id
);
None
} else {
Some(spawn(Self::update_or_create_device(store.clone(), device_keys)))
}
});
let results = join_all(tasks).await;
@ -211,17 +200,15 @@ impl IdentityManager {
) -> StoreResult<DeviceChanges> {
let mut changes = DeviceChanges::default();
let tasks = device_keys_map
.into_iter()
.map(|(user_id, device_keys_map)| {
spawn(Self::update_user_devices(
self.store.clone(),
self.user_id.clone(),
self.device_id.clone(),
user_id,
device_keys_map,
))
});
let tasks = device_keys_map.into_iter().map(|(user_id, device_keys_map)| {
spawn(Self::update_user_devices(
self.store.clone(),
self.user_id.clone(),
self.device_id.clone(),
user_id,
device_keys_map,
))
});
let results = join_all(tasks).await;
@ -254,10 +241,7 @@ impl IdentityManager {
let self_signing = if let Some(s) = response.self_signing_keys.get(user_id) {
SelfSigningPubkey::from(s)
} else {
warn!(
"User identity for user {} didn't contain a self signing pubkey",
user_id
);
warn!("User identity for user {} didn't contain a self signing pubkey", user_id);
continue;
};
@ -276,13 +260,11 @@ impl IdentityManager {
continue;
};
identity
.update(master_key, self_signing, user_signing)
.map(|_| (i, false))
identity.update(master_key, self_signing, user_signing).map(|_| (i, false))
}
UserIdentities::Other(ref mut identity) => {
identity.update(master_key, self_signing).map(|_| (i, false))
}
UserIdentities::Other(ref mut identity) => identity
.update(master_key, self_signing)
.map(|_| (i, false)),
}
} else if user_id == self.user_id() {
if let Some(s) = response.user_signing_keys.get(user_id) {
@ -310,10 +292,7 @@ impl IdentityManager {
continue;
}
} else if master_key.user_id() != user_id || self_signing.user_id() != user_id {
warn!(
"User id mismatch in one of the cross signing keys for user {}",
user_id
);
warn!("User id mismatch in one of the cross signing keys for user {}", user_id);
continue;
} else {
UserIdentity::new(master_key, self_signing)
@ -322,11 +301,7 @@ impl IdentityManager {
match result {
Ok((i, new)) => {
trace!(
"Updated or created new user identity for {}: {:?}",
user_id,
i
);
trace!("Updated or created new user identity for {}: {:?}", user_id, i);
if new {
changes.new.push(i);
} else {
@ -334,10 +309,7 @@ impl IdentityManager {
}
}
Err(e) => {
warn!(
"Couldn't update or create new user identity for {}: {:?}",
user_id, e
);
warn!("Couldn't update or create new user identity for {}: {:?}", user_id, e);
continue;
}
}
@ -424,9 +396,7 @@ pub(crate) mod test {
locks::Mutex,
IncomingResponse,
};
use matrix_sdk_test::async_test;
use serde_json::json;
use crate::{
@ -637,10 +607,7 @@ pub(crate) mod test {
let devices = manager.store.get_user_devices(&other_user).await.unwrap();
assert_eq!(devices.devices().count(), 0);
manager
.receive_keys_query_response(&other_key_query())
.await
.unwrap();
manager.receive_keys_query_response(&other_key_query()).await.unwrap();
let devices = manager.store.get_user_devices(&other_user).await.unwrap();
assert_eq!(devices.devices().count(), 1);
@ -651,12 +618,7 @@ pub(crate) mod test {
.await
.unwrap()
.unwrap();
let identity = manager
.store
.get_user_identity(&other_user)
.await
.unwrap()
.unwrap();
let identity = manager.store.get_user_identity(&other_user).await.unwrap().unwrap();
let identity = identity.other().unwrap();
assert!(identity.is_device_signed(&device).is_ok())
@ -669,10 +631,7 @@ pub(crate) mod test {
let devices = manager.store.get_user_devices(&other_user).await.unwrap();
assert_eq!(devices.devices().count(), 0);
manager
.receive_keys_query_response(&other_key_query())
.await
.unwrap();
manager.receive_keys_query_response(&other_key_query()).await.unwrap();
let devices = manager.store.get_user_devices(&other_user).await.unwrap();
assert_eq!(devices.devices().count(), 1);
@ -683,12 +642,7 @@ pub(crate) mod test {
.await
.unwrap()
.unwrap();
let identity = manager
.store
.get_user_identity(&other_user)
.await
.unwrap()
.unwrap();
let identity = manager.store.get_user_identity(&other_user).await.unwrap().unwrap();
let identity = identity.other().unwrap();
assert!(identity.is_device_signed(&device).is_ok())

View file

@ -29,10 +29,10 @@
//!
//! ## User
//!
//! Cross-signing capable devices will upload 3 additional (master, self-signing,
//! user-signing) public keys which represent the user identity owning all the
//! devices. This is represented in two ways, as a `UserIdentity` for other
//! users and as `OwnUserIdentity` for our own user.
//! Cross-signing capable devices will upload 3 additional (master,
//! self-signing, user-signing) public keys which represent the user identity
//! owning all the devices. This is represented in two ways, as a `UserIdentity`
//! for other users and as `OwnUserIdentity` for our own user.
//!
//! This is done because the server will only give us access to 2 of the 3
//! additional public keys for other users, while it will give us access to all
@ -44,19 +44,19 @@ pub(crate) mod device;
mod manager;
pub(crate) mod user;
pub use device::{Device, LocalTrust, ReadOnlyDevice, UserDevices};
pub(crate) use manager::IdentityManager;
pub use user::{
MasterPubkey, OwnUserIdentity, SelfSigningPubkey, UserIdentities, UserIdentity,
UserSigningPubkey,
};
use serde::{Deserialize, Deserializer, Serializer};
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};
pub use device::{Device, LocalTrust, ReadOnlyDevice, UserDevices};
pub(crate) use manager::IdentityManager;
use serde::{Deserialize, Deserializer, Serializer};
pub use user::{
MasterPubkey, OwnUserIdentity, SelfSigningPubkey, UserIdentities, UserIdentity,
UserSigningPubkey,
};
// These methods are only here because Serialize and Deserialize don't seem to
// be implemented for WASM.
fn atomic_bool_serializer<S>(x: &AtomicBool, s: S) -> Result<S::Ok, S::Error>

View file

@ -21,20 +21,18 @@ use std::{
},
};
use serde::{Deserialize, Serialize};
use serde_json::to_value;
use matrix_sdk_common::{
api::r0::keys::{CrossSigningKey, KeyUsage},
identifiers::{DeviceKeyId, UserId},
};
use serde::{Deserialize, Serialize};
use serde_json::to_value;
use super::{atomic_bool_deserializer, atomic_bool_serializer};
#[cfg(test)]
use crate::olm::PrivateCrossSigningIdentity;
use crate::{error::SignatureError, olm::Utility, ReadOnlyDevice};
use super::{atomic_bool_deserializer, atomic_bool_serializer};
/// Wrapper for a cross signing key marking it as the master key.
///
/// Master keys are used to sign other cross signing keys, the self signing and
@ -227,12 +225,7 @@ impl MasterPubkey {
&self,
subkey: impl Into<CrossSigningSubKeys<'a>>,
) -> Result<(), SignatureError> {
let (key_id, key) = self
.0
.keys
.iter()
.next()
.ok_or(SignatureError::MissingSigningKey)?;
let (key_id, key) = self.0.keys.iter().next().ok_or(SignatureError::MissingSigningKey)?;
let key_id = DeviceKeyId::try_from(key_id.as_str())?;
@ -289,12 +282,7 @@ impl UserSigningPubkey {
&self,
master_key: &MasterPubkey,
) -> Result<(), SignatureError> {
let (key_id, key) = self
.0
.keys
.iter()
.next()
.ok_or(SignatureError::MissingSigningKey)?;
let (key_id, key) = self.0.keys.iter().next().ok_or(SignatureError::MissingSigningKey)?;
// TODO check that the usage is OK.
@ -337,12 +325,7 @@ impl SelfSigningPubkey {
/// Returns an empty result if the signature check succeeded, otherwise a
/// SignatureError indicating why the check failed.
pub(crate) fn verify_device(&self, device: &ReadOnlyDevice) -> Result<(), SignatureError> {
let (key_id, key) = self
.0
.keys
.iter()
.next()
.ok_or(SignatureError::MissingSigningKey)?;
let (key_id, key) = self.0.keys.iter().next().ok_or(SignatureError::MissingSigningKey)?;
// TODO check that the usage is OK.
@ -474,37 +457,16 @@ impl UserIdentity {
) -> Result<Self, SignatureError> {
master_key.verify_subkey(&self_signing_key)?;
Ok(Self {
user_id: Arc::new(master_key.0.user_id.clone()),
master_key,
self_signing_key,
})
Ok(Self { user_id: Arc::new(master_key.0.user_id.clone()), master_key, self_signing_key })
}
#[cfg(test)]
pub async fn from_private(identity: &PrivateCrossSigningIdentity) -> Self {
let master_key = identity
.master_key
.lock()
.await
.as_ref()
.unwrap()
.public_key
.clone();
let self_signing_key = identity
.self_signing_key
.lock()
.await
.as_ref()
.unwrap()
.public_key
.clone();
let master_key = identity.master_key.lock().await.as_ref().unwrap().public_key.clone();
let self_signing_key =
identity.self_signing_key.lock().await.as_ref().unwrap().public_key.clone();
Self {
user_id: Arc::new(identity.user_id().clone()),
master_key,
self_signing_key,
}
Self { user_id: Arc::new(identity.user_id().clone()), master_key, self_signing_key }
}
/// Get the user id of this identity.
@ -646,8 +608,7 @@ impl OwnUserIdentity {
/// Returns an empty result if the signature check succeeded, otherwise a
/// SignatureError indicating why the check failed.
pub fn is_identity_signed(&self, identity: &UserIdentity) -> Result<(), SignatureError> {
self.user_signing_key
.verify_master_key(&identity.master_key)
self.user_signing_key.verify_master_key(&identity.master_key)
}
/// Check if the given device has been signed by this identity.
@ -719,6 +680,12 @@ impl OwnUserIdentity {
pub(crate) mod test {
use std::{convert::TryFrom, sync::Arc};
use matrix_sdk_common::{
api::r0::keys::get_keys::Response as KeyQueryResponse, identifiers::user_id, locks::Mutex,
};
use matrix_sdk_test::async_test;
use super::{OwnUserIdentity, UserIdentities, UserIdentity};
use crate::{
identities::{
manager::test::{other_key_query, own_key_query},
@ -729,13 +696,6 @@ pub(crate) mod test {
verification::VerificationMachine,
};
use matrix_sdk_common::{
api::r0::keys::get_keys::Response as KeyQueryResponse, identifiers::user_id, locks::Mutex,
};
use matrix_sdk_test::async_test;
use super::{OwnUserIdentity, UserIdentities, UserIdentity};
fn device(response: &KeyQueryResponse) -> (ReadOnlyDevice, ReadOnlyDevice) {
let mut devices = response.device_keys.values().next().unwrap().values();
let first = ReadOnlyDevice::try_from(devices.next().unwrap()).unwrap();
@ -793,9 +753,8 @@ pub(crate) mod test {
assert!(identity.is_device_signed(&first).is_err());
assert!(identity.is_device_signed(&second).is_ok());
let private_identity = Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty(
second.user_id().clone(),
)));
let private_identity =
Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty(second.user_id().clone())));
let verification_machine = VerificationMachine::new(
ReadOnlyAccount::new(second.user_id(), second.device_id()),
private_identity.clone(),

View file

@ -20,13 +20,9 @@
// If we don't trust the device store an object that remembers the request and
// let the users introspect that object.
use dashmap::{mapref::entry::Entry, DashMap, DashSet};
use serde::{Deserialize, Serialize};
use serde_json::value::to_raw_value;
use std::{collections::BTreeMap, sync::Arc};
use thiserror::Error;
use tracing::{error, info, trace, warn};
use dashmap::{mapref::entry::Entry, DashMap, DashSet};
use matrix_sdk_common::{
api::r0::to_device::DeviceIdOrAllDevices,
events::{
@ -37,6 +33,10 @@ use matrix_sdk_common::{
identifiers::{DeviceId, DeviceIdBox, EventEncryptionAlgorithm, RoomId, UserId},
uuid::Uuid,
};
use serde::{Deserialize, Serialize};
use serde_json::value::to_raw_value;
use thiserror::Error;
use tracing::{error, info, trace, warn};
use crate::{
error::{OlmError, OlmResult},
@ -105,10 +105,8 @@ impl WaitQueue {
&self,
user_id: &UserId,
device_id: &DeviceId,
) -> Vec<(
(UserId, DeviceIdBox, String),
ToDeviceEvent<RoomKeyRequestToDeviceEventContent>,
)> {
) -> Vec<((UserId, DeviceIdBox, String), ToDeviceEvent<RoomKeyRequestToDeviceEventContent>)>
{
self.requests_ids_waiting
.remove(&(user_id.to_owned(), device_id.into()))
.map(|(_, request_ids)| {
@ -204,12 +202,7 @@ fn wrap_key_request_content(
Ok(OutgoingRequest {
request_id: id,
request: Arc::new(
ToDeviceRequest {
event_type: EventType::RoomKeyRequest,
txn_id: id,
messages,
}
.into(),
ToDeviceRequest { event_type: EventType::RoomKeyRequest, txn_id: id, messages }.into(),
),
})
}
@ -241,10 +234,7 @@ impl KeyRequestMachine {
.await?
.into_iter()
.filter(|i| !i.sent_out)
.map(|info| {
info.to_request(self.device_id())
.map_err(CryptoStoreError::from)
})
.map(|info| info.to_request(self.device_id()).map_err(CryptoStoreError::from))
.collect()
}
@ -262,11 +252,8 @@ impl KeyRequestMachine {
&self,
) -> Result<Vec<OutgoingRequest>, CryptoStoreError> {
let mut key_requests = self.load_outgoing_requests().await?;
let key_forwards: Vec<OutgoingRequest> = self
.outgoing_to_device_requests
.iter()
.map(|i| i.value().clone())
.collect();
let key_forwards: Vec<OutgoingRequest> =
self.outgoing_to_device_requests.iter().map(|i| i.value().clone()).collect();
key_requests.extend(key_forwards);
Ok(key_requests)
@ -281,8 +268,7 @@ impl KeyRequestMachine {
let device_id = event.content.requesting_device_id.clone();
let request_id = event.content.request_id.clone();
self.incoming_key_requests
.insert((sender, device_id, request_id), event.clone());
self.incoming_key_requests.insert((sender, device_id, request_id), event.clone());
}
/// Handle all the incoming key requests that are queued up and empty our
@ -401,10 +387,8 @@ impl KeyRequestMachine {
return Ok(None);
};
let device = self
.store
.get_device(&event.sender, &event.content.requesting_device_id)
.await?;
let device =
self.store.get_device(&event.sender, &event.content.requesting_device_id).await?;
if let Some(device) = device {
match self.should_share_key(&device, &session).await {
@ -461,30 +445,22 @@ impl KeyRequestMachine {
device: &Device,
message_index: Option<u32>,
) -> OlmResult<Session> {
let (used_session, content) = device
.encrypt_session(session.clone(), message_index)
.await?;
let (used_session, content) =
device.encrypt_session(session.clone(), message_index).await?;
let id = Uuid::new_v4();
let mut messages = BTreeMap::new();
messages
.entry(device.user_id().to_owned())
.or_insert_with(BTreeMap::new)
.insert(
DeviceIdOrAllDevices::DeviceId(device.device_id().into()),
to_raw_value(&content)?,
);
messages.entry(device.user_id().to_owned()).or_insert_with(BTreeMap::new).insert(
DeviceIdOrAllDevices::DeviceId(device.device_id().into()),
to_raw_value(&content)?,
);
let request = OutgoingRequest {
request_id: id,
request: Arc::new(
ToDeviceRequest {
event_type: EventType::RoomEncrypted,
txn_id: id,
messages,
}
.into(),
ToDeviceRequest { event_type: EventType::RoomEncrypted, txn_id: id, messages }
.into(),
),
};
@ -542,8 +518,8 @@ impl KeyRequestMachine {
} else {
Err(KeyshareDecision::OutboundSessionNotShared)
}
// Else just check if it's one of our own devices that requested the key and
// check if the device is trusted.
// Else just check if it's one of our own devices that requested the key
// and check if the device is trusted.
} else if device.user_id() == self.user_id() {
own_device_check()
// Otherwise, there's not enough info to decide if we can safely share
@ -711,9 +687,7 @@ impl KeyRequestMachine {
/// Delete the given outgoing key info.
async fn delete_key_info(&self, info: &OutgoingKeyRequest) -> Result<(), CryptoStoreError> {
self.store
.delete_outgoing_key_request(info.request_id)
.await
self.store.delete_outgoing_key_request(info.request_id).await
}
/// Mark the outgoing request as sent.
@ -736,20 +710,15 @@ impl KeyRequestMachine {
/// This will queue up a request cancelation.
async fn mark_as_done(&self, key_info: OutgoingKeyRequest) -> Result<(), CryptoStoreError> {
// TODO perhaps only remove the key info if the first known index is 0.
trace!(
"Successfully received a forwarded room key for {:#?}",
key_info
);
trace!("Successfully received a forwarded room key for {:#?}", key_info);
self.outgoing_to_device_requests
.remove(&key_info.request_id);
self.outgoing_to_device_requests.remove(&key_info.request_id);
// TODO return the key info instead of deleting it so the sync handler
// can delete it in one transaction.
self.delete_key_info(&key_info).await?;
let request = key_info.to_cancelation(self.device_id())?;
self.outgoing_to_device_requests
.insert(request.request_id, request);
self.outgoing_to_device_requests.insert(request.request_id, request);
Ok(())
}
@ -801,10 +770,7 @@ impl KeyRequestMachine {
);
}
Ok((
Some(AnyToDeviceEvent::ForwardedRoomKey(event.clone())),
session,
))
Ok((Some(AnyToDeviceEvent::ForwardedRoomKey(event.clone())), session))
} else {
info!(
"Received a forwarded room key from {}, but no key info was found.",
@ -817,6 +783,8 @@ impl KeyRequestMachine {
#[cfg(test)]
mod test {
use std::{convert::TryInto, sync::Arc};
use dashmap::DashMap;
use matrix_sdk_common::{
api::r0::to_device::DeviceIdOrAllDevices,
@ -829,8 +797,8 @@ mod test {
locks::Mutex,
};
use matrix_sdk_test::async_test;
use std::{convert::TryInto, sync::Arc};
use super::{KeyRequestMachine, KeyshareDecision};
use crate::{
identities::{LocalTrust, ReadOnlyDevice},
olm::{Account, PrivateCrossSigningIdentity, ReadOnlyAccount},
@ -839,8 +807,6 @@ mod test {
verification::VerificationMachine,
};
use super::{KeyRequestMachine, KeyshareDecision};
fn alice_id() -> UserId {
user_id!("@alice:example.org")
}
@ -919,11 +885,7 @@ mod test {
async fn create_machine() {
let machine = get_machine().await;
assert!(machine
.outgoing_to_device_requests()
.await
.unwrap()
.is_empty());
assert!(machine.outgoing_to_device_requests().await.unwrap().is_empty());
}
#[async_test]
@ -931,16 +893,10 @@ mod test {
let machine = get_machine().await;
let account = account();
let (_, session) = account
.create_group_session_pair_with_defaults(&room_id())
.await
.unwrap();
let (_, session) =
account.create_group_session_pair_with_defaults(&room_id()).await.unwrap();
assert!(machine
.outgoing_to_device_requests()
.await
.unwrap()
.is_empty());
assert!(machine.outgoing_to_device_requests().await.unwrap().is_empty());
let (cancel, request) = machine
.request_key(session.room_id(), &session.sender_key, session.session_id())
.await
@ -948,10 +904,7 @@ mod test {
assert!(cancel.is_none());
machine
.mark_outgoing_request_as_sent(request.request_id)
.await
.unwrap();
machine.mark_outgoing_request_as_sent(request.request_id).await.unwrap();
let (cancel, _) = machine
.request_key(session.room_id(), &session.sender_key, session.session_id())
@ -972,16 +925,10 @@ mod test {
alice_device.set_trust_state(LocalTrust::Verified);
machine.store.save_devices(&[alice_device]).await.unwrap();
let (_, session) = account
.create_group_session_pair_with_defaults(&room_id())
.await
.unwrap();
let (_, session) =
account.create_group_session_pair_with_defaults(&room_id()).await.unwrap();
assert!(machine
.outgoing_to_device_requests()
.await
.unwrap()
.is_empty());
assert!(machine.outgoing_to_device_requests().await.unwrap().is_empty());
machine
.create_outgoing_key_request(
session.room_id(),
@ -990,15 +937,8 @@ mod test {
)
.await
.unwrap();
assert!(!machine
.outgoing_to_device_requests()
.await
.unwrap()
.is_empty());
assert_eq!(
machine.outgoing_to_device_requests().await.unwrap().len(),
1
);
assert!(!machine.outgoing_to_device_requests().await.unwrap().is_empty());
assert_eq!(machine.outgoing_to_device_requests().await.unwrap().len(), 1);
machine
.create_outgoing_key_request(
@ -1014,15 +954,8 @@ mod test {
let request = requests.get(0).unwrap();
machine
.mark_outgoing_request_as_sent(request.request_id)
.await
.unwrap();
assert!(machine
.outgoing_to_device_requests()
.await
.unwrap()
.is_empty());
machine.mark_outgoing_request_as_sent(request.request_id).await.unwrap();
assert!(machine.outgoing_to_device_requests().await.unwrap().is_empty());
}
#[async_test]
@ -1037,10 +970,8 @@ mod test {
alice_device.set_trust_state(LocalTrust::Verified);
machine.store.save_devices(&[alice_device]).await.unwrap();
let (_, session) = account
.create_group_session_pair_with_defaults(&room_id())
.await
.unwrap();
let (_, session) =
account.create_group_session_pair_with_defaults(&room_id()).await.unwrap();
machine
.create_outgoing_key_request(
session.room_id(),
@ -1060,10 +991,7 @@ mod test {
let content: ForwardedRoomKeyToDeviceEventContent = export.try_into().unwrap();
let mut event = ToDeviceEvent {
sender: alice_id(),
content,
};
let mut event = ToDeviceEvent { sender: alice_id(), content };
assert!(
machine
@ -1078,19 +1006,13 @@ mod test {
.is_none()
);
let (_, first_session) = machine
.receive_forwarded_room_key(&session.sender_key, &mut event)
.await
.unwrap();
let (_, first_session) =
machine.receive_forwarded_room_key(&session.sender_key, &mut event).await.unwrap();
let first_session = first_session.unwrap();
assert_eq!(first_session.first_known_index(), 10);
machine
.store
.save_inbound_group_sessions(&[first_session.clone()])
.await
.unwrap();
machine.store.save_inbound_group_sessions(&[first_session.clone()]).await.unwrap();
// Get the cancel request.
let request = machine.outgoing_to_device_requests.iter().next().unwrap();
@ -1110,24 +1032,16 @@ mod test {
let requests = machine.outgoing_to_device_requests().await.unwrap();
let request = &requests[0];
machine
.mark_outgoing_request_as_sent(request.request_id)
.await
.unwrap();
machine.mark_outgoing_request_as_sent(request.request_id).await.unwrap();
let export = session.export_at_index(15).await;
let content: ForwardedRoomKeyToDeviceEventContent = export.try_into().unwrap();
let mut event = ToDeviceEvent {
sender: alice_id(),
content,
};
let mut event = ToDeviceEvent { sender: alice_id(), content };
let (_, second_session) = machine
.receive_forwarded_room_key(&session.sender_key, &mut event)
.await
.unwrap();
let (_, second_session) =
machine.receive_forwarded_room_key(&session.sender_key, &mut event).await.unwrap();
assert!(second_session.is_none());
@ -1135,15 +1049,10 @@ mod test {
let content: ForwardedRoomKeyToDeviceEventContent = export.try_into().unwrap();
let mut event = ToDeviceEvent {
sender: alice_id(),
content,
};
let mut event = ToDeviceEvent { sender: alice_id(), content };
let (_, second_session) = machine
.receive_forwarded_room_key(&session.sender_key, &mut event)
.await
.unwrap();
let (_, second_session) =
machine.receive_forwarded_room_key(&session.sender_key, &mut event).await.unwrap();
assert_eq!(second_session.unwrap().first_known_index(), 0);
}
@ -1153,17 +1062,11 @@ mod test {
let machine = get_machine().await;
let account = account();
let own_device = machine
.store
.get_device(&alice_id(), &alice_device_id())
.await
.unwrap()
.unwrap();
let own_device =
machine.store.get_device(&alice_id(), &alice_device_id()).await.unwrap().unwrap();
let (outbound, inbound) = account
.create_group_session_pair_with_defaults(&room_id())
.await
.unwrap();
let (outbound, inbound) =
account.create_group_session_pair_with_defaults(&room_id()).await.unwrap();
// We don't share keys with untrusted devices.
assert_eq!(
@ -1175,20 +1078,13 @@ mod test {
);
own_device.set_trust_state(LocalTrust::Verified);
// Now we do want to share the keys.
assert!(machine
.should_share_key(&own_device, &inbound)
.await
.is_ok());
assert!(machine.should_share_key(&own_device, &inbound).await.is_ok());
let bob_device = ReadOnlyDevice::from_account(&bob_account()).await;
machine.store.save_devices(&[bob_device]).await.unwrap();
let bob_device = machine
.store
.get_device(&bob_id(), &bob_device_id())
.await
.unwrap()
.unwrap();
let bob_device =
machine.store.get_device(&bob_id(), &bob_device_id()).await.unwrap().unwrap();
// We don't share sessions with other user's devices if no outbound
// session was provided.
@ -1231,17 +1127,12 @@ mod test {
// We now share the session, since it was shared before.
outbound.mark_shared_with(bob_device.user_id(), bob_device.device_id());
assert!(machine
.should_share_key(&bob_device, &inbound)
.await
.is_ok());
assert!(machine.should_share_key(&bob_device, &inbound).await.is_ok());
// But we don't share some other session that doesn't match our outbound
// session
let (_, other_inbound) = account
.create_group_session_pair_with_defaults(&room_id())
.await
.unwrap();
let (_, other_inbound) =
account.create_group_session_pair_with_defaults(&room_id()).await.unwrap();
assert_eq!(
machine
@ -1255,10 +1146,7 @@ mod test {
#[async_test]
async fn key_share_cycle() {
let alice_machine = get_machine().await;
let alice_account = Account {
inner: account(),
store: alice_machine.store.clone(),
};
let alice_account = Account { inner: account(), store: alice_machine.store.clone() };
let bob_machine = bob_machine();
let bob_account = bob_account();
@ -1268,11 +1156,7 @@ mod test {
// We need a trusted device, otherwise we won't request keys
alice_device.set_trust_state(LocalTrust::Verified);
alice_machine
.store
.save_devices(&[alice_device])
.await
.unwrap();
alice_machine.store.save_devices(&[alice_device]).await.unwrap();
// Create Olm sessions for our two accounts.
let (alice_session, bob_session) = alice_account.create_session_for(&bob_account).await;
@ -1282,37 +1166,15 @@ mod test {
// Populate our stores with Olm sessions and a Megolm session.
alice_machine
.store
.save_sessions(&[alice_session])
.await
.unwrap();
alice_machine
.store
.save_devices(&[bob_device])
.await
.unwrap();
bob_machine
.store
.save_sessions(&[bob_session])
.await
.unwrap();
bob_machine
.store
.save_devices(&[alice_device])
.await
.unwrap();
alice_machine.store.save_sessions(&[alice_session]).await.unwrap();
alice_machine.store.save_devices(&[bob_device]).await.unwrap();
bob_machine.store.save_sessions(&[bob_session]).await.unwrap();
bob_machine.store.save_devices(&[alice_device]).await.unwrap();
let (group_session, inbound_group_session) = bob_account
.create_group_session_pair_with_defaults(&room_id())
.await
.unwrap();
let (group_session, inbound_group_session) =
bob_account.create_group_session_pair_with_defaults(&room_id()).await.unwrap();
bob_machine
.store
.save_inbound_group_sessions(&[inbound_group_session])
.await
.unwrap();
bob_machine.store.save_inbound_group_sessions(&[inbound_group_session]).await.unwrap();
// Alice wants to request the outbound group session from bob.
alice_machine
@ -1326,9 +1188,7 @@ mod test {
group_session.mark_shared_with(&alice_id(), &alice_device_id());
// Put the outbound session into bobs store.
bob_machine
.outbound_group_sessions
.insert(group_session.clone());
bob_machine.outbound_group_sessions.insert(group_session.clone());
// Get the request and convert it into a event.
let requests = alice_machine.outgoing_to_device_requests().await.unwrap();
@ -1346,15 +1206,9 @@ mod test {
let content: RoomKeyRequestToDeviceEventContent =
serde_json::from_str(content.get()).unwrap();
alice_machine
.mark_outgoing_request_as_sent(id)
.await
.unwrap();
alice_machine.mark_outgoing_request_as_sent(id).await.unwrap();
let event = ToDeviceEvent {
sender: alice_id(),
content,
};
let event = ToDeviceEvent { sender: alice_id(), content };
// Bob doesn't have any outgoing requests.
assert!(bob_machine.outgoing_to_device_requests.is_empty());
@ -1383,10 +1237,7 @@ mod test {
bob_machine.mark_outgoing_request_as_sent(id).await.unwrap();
let event = ToDeviceEvent {
sender: bob_id(),
content,
};
let event = ToDeviceEvent { sender: bob_id(), content };
// Check that alice doesn't have the session.
assert!(alice_machine
@ -1407,11 +1258,7 @@ mod test {
.receive_forwarded_room_key(&decrypted.sender_key, &mut e)
.await
.unwrap();
alice_machine
.store
.save_inbound_group_sessions(&[session.unwrap()])
.await
.unwrap();
alice_machine.store.save_inbound_group_sessions(&[session.unwrap()]).await.unwrap();
} else {
panic!("Invalid decrypted event type");
}
@ -1434,10 +1281,7 @@ mod test {
#[async_test]
async fn key_share_cycle_without_session() {
let alice_machine = get_machine().await;
let alice_account = Account {
inner: account(),
store: alice_machine.store.clone(),
};
let alice_account = Account { inner: account(), store: alice_machine.store.clone() };
let bob_machine = bob_machine();
let bob_account = bob_account();
@ -1447,11 +1291,7 @@ mod test {
// We need a trusted device, otherwise we won't request keys
alice_device.set_trust_state(LocalTrust::Verified);
alice_machine
.store
.save_devices(&[alice_device])
.await
.unwrap();
alice_machine.store.save_devices(&[alice_device]).await.unwrap();
// Create Olm sessions for our two accounts.
let (alice_session, bob_session) = alice_account.create_session_for(&bob_account).await;
@ -1461,27 +1301,13 @@ mod test {
// Populate our stores with Olm sessions and a Megolm session.
alice_machine
.store
.save_devices(&[bob_device])
.await
.unwrap();
bob_machine
.store
.save_devices(&[alice_device])
.await
.unwrap();
alice_machine.store.save_devices(&[bob_device]).await.unwrap();
bob_machine.store.save_devices(&[alice_device]).await.unwrap();
let (group_session, inbound_group_session) = bob_account
.create_group_session_pair_with_defaults(&room_id())
.await
.unwrap();
let (group_session, inbound_group_session) =
bob_account.create_group_session_pair_with_defaults(&room_id()).await.unwrap();
bob_machine
.store
.save_inbound_group_sessions(&[inbound_group_session])
.await
.unwrap();
bob_machine.store.save_inbound_group_sessions(&[inbound_group_session]).await.unwrap();
// Alice wants to request the outbound group session from bob.
alice_machine
@ -1495,9 +1321,7 @@ mod test {
group_session.mark_shared_with(&alice_id(), &alice_device_id());
// Put the outbound session into bobs store.
bob_machine
.outbound_group_sessions
.insert(group_session.clone());
bob_machine.outbound_group_sessions.insert(group_session.clone());
// Get the request and convert it into a event.
let requests = alice_machine.outgoing_to_device_requests().await.unwrap();
@ -1515,22 +1339,12 @@ mod test {
let content: RoomKeyRequestToDeviceEventContent =
serde_json::from_str(content.get()).unwrap();
alice_machine
.mark_outgoing_request_as_sent(id)
.await
.unwrap();
alice_machine.mark_outgoing_request_as_sent(id).await.unwrap();
let event = ToDeviceEvent {
sender: alice_id(),
content,
};
let event = ToDeviceEvent { sender: alice_id(), content };
// Bob doesn't have any outgoing requests.
assert!(bob_machine
.outgoing_to_device_requests()
.await
.unwrap()
.is_empty());
assert!(bob_machine.outgoing_to_device_requests().await.unwrap().is_empty());
assert!(bob_machine.users_for_key_claim.is_empty());
assert!(bob_machine.wait_queue.is_empty());
@ -1538,35 +1352,19 @@ mod test {
bob_machine.receive_incoming_key_request(&event);
bob_machine.collect_incoming_key_requests().await.unwrap();
// Bob doens't have an outgoing requests since we're lacking a session.
assert!(bob_machine
.outgoing_to_device_requests()
.await
.unwrap()
.is_empty());
assert!(bob_machine.outgoing_to_device_requests().await.unwrap().is_empty());
assert!(!bob_machine.users_for_key_claim.is_empty());
assert!(!bob_machine.wait_queue.is_empty());
// We create a session now.
alice_machine
.store
.save_sessions(&[alice_session])
.await
.unwrap();
bob_machine
.store
.save_sessions(&[bob_session])
.await
.unwrap();
alice_machine.store.save_sessions(&[alice_session]).await.unwrap();
bob_machine.store.save_sessions(&[bob_session]).await.unwrap();
bob_machine.retry_keyshare(&alice_id(), &alice_device_id());
assert!(bob_machine.users_for_key_claim.is_empty());
bob_machine.collect_incoming_key_requests().await.unwrap();
// Bob now has an outgoing requests.
assert!(!bob_machine
.outgoing_to_device_requests()
.await
.unwrap()
.is_empty());
assert!(!bob_machine.outgoing_to_device_requests().await.unwrap().is_empty());
assert!(bob_machine.wait_queue.is_empty());
// Get the request and convert it to a encrypted to-device event.
@ -1588,10 +1386,7 @@ mod test {
bob_machine.mark_outgoing_request_as_sent(id).await.unwrap();
let event = ToDeviceEvent {
sender: bob_id(),
content,
};
let event = ToDeviceEvent { sender: bob_id(), content };
// Check that alice doesn't have the session.
assert!(alice_machine
@ -1612,11 +1407,7 @@ mod test {
.receive_forwarded_room_key(&decrypted.sender_key, &mut e)
.await
.unwrap();
alice_machine
.store
.save_inbound_group_sessions(&[session.unwrap()])
.await
.unwrap();
alice_machine.store.save_inbound_group_sessions(&[session.unwrap()]).await.unwrap();
} else {
panic!("Invalid decrypted event type");
}

View file

@ -17,8 +17,6 @@ use std::path::Path;
use std::{collections::BTreeMap, mem, sync::Arc};
use dashmap::DashMap;
use tracing::{debug, error, info, trace, warn};
use matrix_sdk_common::{
api::r0::{
keys::{
@ -43,6 +41,7 @@ use matrix_sdk_common::{
uuid::Uuid,
UInt,
};
use tracing::{debug, error, info, trace, warn};
#[cfg(feature = "sled_cryptostore")]
use crate::store::sled::SledStore;
@ -148,19 +147,12 @@ impl OlmMachine {
let store = Arc::new(store);
let verification_machine =
VerificationMachine::new(account.clone(), user_identity.clone(), store.clone());
let store = Store::new(
user_id.clone(),
user_identity.clone(),
store,
verification_machine.clone(),
);
let store =
Store::new(user_id.clone(), user_identity.clone(), store, verification_machine.clone());
let device_id: Arc<DeviceIdBox> = Arc::new(device_id);
let users_for_key_claim = Arc::new(DashMap::new());
let account = Account {
inner: account,
store: store.clone(),
};
let account = Account { inner: account, store: store.clone() };
let group_session_manager = GroupSessionManager::new(account.clone(), store.clone());
@ -244,9 +236,7 @@ impl OlmMachine {
}
};
Ok(OlmMachine::new_helper(
&user_id, device_id, store, account, identity,
))
Ok(OlmMachine::new_helper(&user_id, device_id, store, account, identity))
}
/// Create a new machine with the default crypto store.
@ -296,19 +286,16 @@ impl OlmMachine {
pub async fn outgoing_requests(&self) -> StoreResult<Vec<OutgoingRequest>> {
let mut requests = Vec::new();
if let Some(r) = self.keys_for_upload().await.map(|r| OutgoingRequest {
request_id: Uuid::new_v4(),
request: Arc::new(r.into()),
}) {
if let Some(r) = self
.keys_for_upload()
.await
.map(|r| OutgoingRequest { request_id: Uuid::new_v4(), request: Arc::new(r.into()) })
{
requests.push(r);
}
for request in self
.identity_manager
.users_for_key_query()
.await
.into_iter()
.map(|r| OutgoingRequest {
for request in
self.identity_manager.users_for_key_query().await.into_iter().map(|r| OutgoingRequest {
request_id: Uuid::new_v4(),
request: Arc::new(r.into()),
})
@ -317,12 +304,7 @@ impl OlmMachine {
}
requests.append(&mut self.verification_machine.outgoing_messages());
requests.append(
&mut self
.key_request_machine
.outgoing_to_device_requests()
.await?,
);
requests.append(&mut self.key_request_machine.outgoing_to_device_requests().await?);
Ok(requests)
}
@ -373,10 +355,7 @@ impl OlmMachine {
let identity = self.user_identity.lock().await;
identity.mark_as_shared();
let changes = Changes {
private_identity: Some(identity.clone()),
..Default::default()
};
let changes = Changes { private_identity: Some(identity.clone()), ..Default::default() };
self.store.save_changes(changes).await
}
@ -406,10 +385,7 @@ impl OlmMachine {
);
let changes = Changes {
identities: IdentityChanges {
new: vec![public.into()],
..Default::default()
},
identities: IdentityChanges { new: vec![public.into()], ..Default::default() },
private_identity: Some(identity.clone()),
..Default::default()
};
@ -421,10 +397,8 @@ impl OlmMachine {
info!("Trying to upload the existing cross signing identity");
let request = identity.as_upload_request().await;
// TODO remove this expect.
let signature_request = identity
.sign_account(&self.account)
.await
.expect("Can't sign device keys");
let signature_request =
identity.sign_account(&self.account).await.expect("Can't sign device keys");
Ok((request, signature_request))
}
}
@ -518,9 +492,7 @@ impl OlmMachine {
///
/// * `response` - The response containing the claimed one-time keys.
async fn receive_keys_claim_response(&self, response: &KeysClaimResponse) -> OlmResult<()> {
self.session_manager
.receive_keys_claim_response(response)
.await
self.session_manager.receive_keys_claim_response(response).await
}
/// Receive a successful keys query response.
@ -536,9 +508,7 @@ impl OlmMachine {
&self,
response: &KeysQueryResponse,
) -> OlmResult<(DeviceChanges, IdentityChanges)> {
self.identity_manager
.receive_keys_query_response(response)
.await
self.identity_manager.receive_keys_query_response(response).await
}
/// Get a request to upload E2EE keys to the server.
@ -676,9 +646,7 @@ impl OlmMachine {
/// Returns true if a session was invalidated, false if there was no session
/// to invalidate.
pub async fn invalidate_group_session(&self, room_id: &RoomId) -> StoreResult<bool> {
self.group_session_manager
.invalidate_group_session(room_id)
.await
self.group_session_manager.invalidate_group_session(room_id).await
}
/// Get to-device requests to share a group session with users in a room.
@ -695,9 +663,7 @@ impl OlmMachine {
users: impl Iterator<Item = &UserId>,
encryption_settings: impl Into<EncryptionSettings>,
) -> OlmResult<Vec<Arc<ToDeviceRequest>>> {
self.group_session_manager
.share_group_session(room_id, users, encryption_settings)
.await
self.group_session_manager.share_group_session(room_id, users, encryption_settings).await
}
/// Receive and properly handle a decrypted to-device event.
@ -716,18 +682,15 @@ impl OlmMachine {
let event = match decrypted.event.deserialize() {
Ok(e) => e,
Err(e) => {
warn!(
"Decrypted to-device event failed to be parsed correctly {:?}",
e
);
warn!("Decrypted to-device event failed to be parsed correctly {:?}", e);
return Ok((None, None));
}
};
match event {
AnyToDeviceEvent::RoomKey(mut e) => Ok(self
.add_room_key(&decrypted.sender_key, &decrypted.signing_key, &mut e)
.await?),
AnyToDeviceEvent::RoomKey(mut e) => {
Ok(self.add_room_key(&decrypted.sender_key, &decrypted.signing_key, &mut e).await?)
}
AnyToDeviceEvent::ForwardedRoomKey(mut e) => Ok(self
.key_request_machine
.receive_forwarded_room_key(&decrypted.sender_key, &mut e)
@ -748,14 +711,9 @@ impl OlmMachine {
/// Mark an outgoing to-device requests as sent.
async fn mark_to_device_request_as_sent(&self, request_id: &Uuid) -> StoreResult<()> {
self.verification_machine.mark_request_as_sent(request_id);
self.key_request_machine
.mark_outgoing_request_as_sent(*request_id)
.await?;
self.group_session_manager
.mark_request_as_sent(request_id)
.await?;
self.session_manager
.mark_outgoing_request_as_sent(request_id);
self.key_request_machine.mark_outgoing_request_as_sent(*request_id).await?;
self.group_session_manager.mark_request_as_sent(request_id).await?;
self.session_manager.mark_outgoing_request_as_sent(request_id);
Ok(())
}
@ -830,10 +788,8 @@ impl OlmMachine {
// Always save the account, a new session might get created which also
// touches the account.
let mut changes = Changes {
account: Some(self.account.inner.clone()),
..Default::default()
};
let mut changes =
Changes { account: Some(self.account.inner.clone()), ..Default::default() };
self.update_one_time_key_count(one_time_keys_counts).await;
@ -850,10 +806,7 @@ impl OlmMachine {
Ok(e) => e,
Err(e) => {
// Skip invalid events.
warn!(
"Received an invalid to-device event {:?} {:?}",
e, raw_event
);
warn!("Received an invalid to-device event {:?} {:?}", e, raw_event);
continue;
}
};
@ -865,10 +818,7 @@ impl OlmMachine {
let decrypted = match self.decrypt_to_device_event(&e).await {
Ok(e) => e,
Err(err) => {
warn!(
"Failed to decrypt to-device event from {} {}",
e.sender, err
);
warn!("Failed to decrypt to-device event from {} {}", e.sender, err);
if let OlmError::SessionWedged(sender, curve_key) = err {
if let Err(e) = self
@ -916,10 +866,7 @@ impl OlmMachine {
events.push(raw_event);
}
let changed_sessions = self
.key_request_machine
.collect_incoming_key_requests()
.await?;
let changed_sessions = self.key_request_machine.collect_incoming_key_requests().await?;
changes.sessions.extend(changed_sessions);
@ -1036,25 +983,16 @@ impl OlmMachine {
// TODO check if this is from a verified device.
let (decrypted_event, _) = session.decrypt(event).await?;
trace!(
"Successfully decrypted a Megolm event {:?}",
decrypted_event
);
trace!("Successfully decrypted a Megolm event {:?}", decrypted_event);
if let Ok(e) = decrypted_event.deserialize() {
self.verification_machine
.receive_room_event(room_id, &e)
.await?;
self.verification_machine.receive_room_event(room_id, &e).await?;
}
let encryption_info = self
.get_encryption_info(&session, &event.sender, &content.device_id)
.await?;
let encryption_info =
self.get_encryption_info(&session, &event.sender, &content.device_id).await?;
Ok(SyncRoomEvent {
encryption_info: Some(encryption_info),
event: decrypted_event,
})
Ok(SyncRoomEvent { encryption_info: Some(encryption_info), event: decrypted_event })
}
/// Update the tracked users.
@ -1210,17 +1148,11 @@ impl OlmMachine {
let num_sessions = sessions.len();
let changes = Changes {
inbound_group_sessions: sessions,
..Default::default()
};
let changes = Changes { inbound_group_sessions: sessions, ..Default::default() };
self.store.save_changes(changes).await?;
info!(
"Successfully imported {} inbound group sessions",
num_sessions
);
info!("Successfully imported {} inbound group sessions", num_sessions);
Ok((num_sessions, total_sessions))
}
@ -1288,15 +1220,6 @@ pub(crate) mod test {
};
use http::Response;
use serde_json::json;
use crate::{
machine::OlmMachine,
olm::Utility,
verification::test::{outgoing_request_to_event, request_to_event},
EncryptionSettings, ReadOnlyDevice, ToDeviceRequest,
};
use matrix_sdk_common::{
api::r0::keys::{claim_keys, get_keys, upload_keys, OneTimeKey},
events::{
@ -1313,6 +1236,14 @@ pub(crate) mod test {
IncomingResponse, Raw,
};
use matrix_sdk_test::test_json;
use serde_json::json;
use crate::{
machine::OlmMachine,
olm::Utility,
verification::test::{outgoing_request_to_event, request_to_event},
EncryptionSettings, ReadOnlyDevice, ToDeviceRequest,
};
/// These keys need to be periodically uploaded to the server.
type OneTimeKeys = BTreeMap<DeviceKeyId, OneTimeKey>;
@ -1332,10 +1263,7 @@ pub(crate) mod test {
}
pub fn response_from_file(json: &serde_json::Value) -> Response<Vec<u8>> {
Response::builder()
.status(200)
.body(json.to_string().as_bytes().to_vec())
.unwrap()
Response::builder().status(200).body(json.to_string().as_bytes().to_vec()).unwrap()
}
fn keys_upload_response() -> upload_keys::Response {
@ -1354,15 +1282,7 @@ pub(crate) mod test {
let to_device_request = &requests[0];
let content: Raw<EncryptedEventContent> = serde_json::from_str(
to_device_request
.messages
.values()
.next()
.unwrap()
.values()
.next()
.unwrap()
.get(),
to_device_request.messages.values().next().unwrap().values().next().unwrap().get(),
)
.unwrap();
@ -1372,15 +1292,9 @@ pub(crate) mod test {
pub(crate) async fn get_prepared_machine() -> (OlmMachine, OneTimeKeys) {
let machine = OlmMachine::new(&user_id(), &alice_device_id());
machine.account.inner.update_uploaded_key_count(0);
let request = machine
.keys_for_upload()
.await
.expect("Can't prepare initial key upload");
let request = machine.keys_for_upload().await.expect("Can't prepare initial key upload");
let response = keys_upload_response();
machine
.receive_keys_upload_response(&response)
.await
.unwrap();
machine.receive_keys_upload_response(&response).await.unwrap();
(machine, request.one_time_keys.unwrap())
}
@ -1389,10 +1303,7 @@ pub(crate) mod test {
let (machine, otk) = get_prepared_machine().await;
let response = keys_query_response();
machine
.receive_keys_query_response(&response)
.await
.unwrap();
machine.receive_keys_query_response(&response).await.unwrap();
(machine, otk)
}
@ -1435,28 +1346,15 @@ pub(crate) mod test {
async fn get_machine_pair_with_setup_sessions() -> (OlmMachine, OlmMachine) {
let (alice, bob) = get_machine_pair_with_session().await;
let bob_device = alice
.get_device(&bob.user_id, &bob.device_id)
.await
.unwrap()
.unwrap();
let bob_device = alice.get_device(&bob.user_id, &bob.device_id).await.unwrap().unwrap();
let (session, content) = bob_device
.encrypt(EventType::Dummy, json!({}))
.await
.unwrap();
let (session, content) = bob_device.encrypt(EventType::Dummy, json!({})).await.unwrap();
alice.store.save_sessions(&[session]).await.unwrap();
let event = ToDeviceEvent {
sender: alice.user_id().clone(),
content,
};
let event = ToDeviceEvent { sender: alice.user_id().clone(), content };
let decrypted = bob.decrypt_to_device_event(&event).await.unwrap();
bob.store
.save_sessions(&[decrypted.session.session()])
.await
.unwrap();
bob.store.save_sessions(&[decrypted.session.session()]).await.unwrap();
(alice, bob)
}
@ -1472,34 +1370,18 @@ pub(crate) mod test {
let machine = OlmMachine::new(&user_id(), &alice_device_id());
let mut response = keys_upload_response();
response
.one_time_key_counts
.remove(&DeviceKeyAlgorithm::SignedCurve25519)
.unwrap();
response.one_time_key_counts.remove(&DeviceKeyAlgorithm::SignedCurve25519).unwrap();
assert!(machine.should_upload_keys().await);
machine
.receive_keys_upload_response(&response)
.await
.unwrap();
machine.receive_keys_upload_response(&response).await.unwrap();
assert!(machine.should_upload_keys().await);
response
.one_time_key_counts
.insert(DeviceKeyAlgorithm::SignedCurve25519, uint!(10));
machine
.receive_keys_upload_response(&response)
.await
.unwrap();
response.one_time_key_counts.insert(DeviceKeyAlgorithm::SignedCurve25519, uint!(10));
machine.receive_keys_upload_response(&response).await.unwrap();
assert!(machine.should_upload_keys().await);
response
.one_time_key_counts
.insert(DeviceKeyAlgorithm::SignedCurve25519, uint!(50));
machine
.receive_keys_upload_response(&response)
.await
.unwrap();
response.one_time_key_counts.insert(DeviceKeyAlgorithm::SignedCurve25519, uint!(50));
machine.receive_keys_upload_response(&response).await.unwrap();
assert!(!machine.should_upload_keys().await);
}
@ -1511,20 +1393,12 @@ pub(crate) mod test {
assert!(machine.should_upload_keys().await);
machine
.receive_keys_upload_response(&response)
.await
.unwrap();
machine.receive_keys_upload_response(&response).await.unwrap();
assert!(machine.should_upload_keys().await);
assert!(machine.account.generate_one_time_keys().await.is_ok());
response
.one_time_key_counts
.insert(DeviceKeyAlgorithm::SignedCurve25519, uint!(50));
machine
.receive_keys_upload_response(&response)
.await
.unwrap();
response.one_time_key_counts.insert(DeviceKeyAlgorithm::SignedCurve25519, uint!(50));
machine.receive_keys_upload_response(&response).await.unwrap();
assert!(machine.account.generate_one_time_keys().await.is_err());
}
@ -1551,14 +1425,8 @@ pub(crate) mod test {
let machine = OlmMachine::new(&user_id(), &alice_device_id());
let room_id = room_id!("!test:example.org");
machine
.create_outbound_group_session_with_defaults(&room_id)
.await
.unwrap();
assert!(machine
.group_session_manager
.get_outbound_group_session(&room_id)
.is_some());
machine.create_outbound_group_session_with_defaults(&room_id).await.unwrap();
assert!(machine.group_session_manager.get_outbound_group_session(&room_id).is_some());
machine.invalidate_group_session(&room_id).await.unwrap();
@ -1614,10 +1482,8 @@ pub(crate) mod test {
let identity_keys = machine.account.identity_keys();
let ed25519_key = identity_keys.ed25519();
let mut request = machine
.keys_for_upload()
.await
.expect("Can't prepare initial key upload");
let mut request =
machine.keys_for_upload().await.expect("Can't prepare initial key upload");
let utility = Utility::new();
let ret = utility.verify_json(
@ -1640,15 +1506,10 @@ pub(crate) mod test {
let mut response = keys_upload_response();
response.one_time_key_counts.insert(
DeviceKeyAlgorithm::SignedCurve25519,
(request.one_time_keys.unwrap().len() as u64)
.try_into()
.unwrap(),
(request.one_time_keys.unwrap().len() as u64).try_into().unwrap(),
);
machine
.receive_keys_upload_response(&response)
.await
.unwrap();
machine.receive_keys_upload_response(&response).await.unwrap();
let ret = machine.keys_for_upload().await;
assert!(ret.is_none());
@ -1664,17 +1525,9 @@ pub(crate) mod test {
let alice_devices = machine.store.get_user_devices(&alice_id).await.unwrap();
assert!(alice_devices.devices().peekable().peek().is_none());
machine
.receive_keys_query_response(&response)
.await
.unwrap();
machine.receive_keys_query_response(&response).await.unwrap();
let device = machine
.store
.get_device(&alice_id, alice_device_id)
.await
.unwrap()
.unwrap();
let device = machine.store.get_device(&alice_id, alice_device_id).await.unwrap().unwrap();
assert_eq!(device.user_id(), &alice_id);
assert_eq!(device.device_id(), alice_device_id);
}
@ -1686,11 +1539,8 @@ pub(crate) mod test {
let alice = alice_id();
let alice_device = alice_device_id();
let (_, missing_sessions) = machine
.get_missing_sessions(&mut [alice.clone()].iter())
.await
.unwrap()
.unwrap();
let (_, missing_sessions) =
machine.get_missing_sessions(&mut [alice.clone()].iter()).await.unwrap().unwrap();
assert!(missing_sessions.one_time_keys.contains_key(&alice));
let user_sessions = missing_sessions.one_time_keys.get(&alice).unwrap();
@ -1713,10 +1563,7 @@ pub(crate) mod test {
let response = claim_keys::Response::new(one_time_keys);
alice_machine
.receive_keys_claim_response(&response)
.await
.unwrap();
alice_machine.receive_keys_claim_response(&response).await.unwrap();
let session = alice_machine
.store
@ -1732,28 +1579,14 @@ pub(crate) mod test {
async fn test_olm_encryption() {
let (alice, bob) = get_machine_pair_with_session().await;
let bob_device = alice
.get_device(&bob.user_id, &bob.device_id)
.await
.unwrap()
.unwrap();
let bob_device = alice.get_device(&bob.user_id, &bob.device_id).await.unwrap().unwrap();
let event = ToDeviceEvent {
sender: alice.user_id().clone(),
content: bob_device
.encrypt(EventType::Dummy, json!({}))
.await
.unwrap()
.1,
content: bob_device.encrypt(EventType::Dummy, json!({})).await.unwrap().1,
};
let event = bob
.decrypt_to_device_event(&event)
.await
.unwrap()
.event
.deserialize()
.unwrap();
let event = bob.decrypt_to_device_event(&event).await.unwrap().event.deserialize().unwrap();
if let AnyToDeviceEvent::Dummy(e) = event {
assert_eq!(&e.sender, alice.user_id());
@ -1782,17 +1615,12 @@ pub(crate) mod test {
content: to_device_requests_to_content(to_device_requests),
};
let alice_session = alice
.group_session_manager
.get_outbound_group_session(&room_id)
.unwrap();
let alice_session =
alice.group_session_manager.get_outbound_group_session(&room_id).unwrap();
let decrypted = bob.decrypt_to_device_event(&event).await.unwrap();
bob.store
.save_sessions(&[decrypted.session.session()])
.await
.unwrap();
bob.store.save_sessions(&[decrypted.session.session()]).await.unwrap();
bob.store
.save_inbound_group_sessions(&[decrypted.inbound_group_session.unwrap()])
.await
@ -1837,25 +1665,16 @@ pub(crate) mod test {
content: to_device_requests_to_content(to_device_requests),
};
let group_session = bob
.decrypt_to_device_event(&event)
.await
.unwrap()
.inbound_group_session;
bob.store
.save_inbound_group_sessions(&[group_session.unwrap()])
.await
.unwrap();
let group_session =
bob.decrypt_to_device_event(&event).await.unwrap().inbound_group_session;
bob.store.save_inbound_group_sessions(&[group_session.unwrap()]).await.unwrap();
let plaintext = "It is a secret to everybody";
let content = MessageEventContent::text_plain(plaintext);
let encrypted_content = alice
.encrypt(
&room_id,
AnyMessageEventContent::RoomMessage(content.clone()),
)
.encrypt(&room_id, AnyMessageEventContent::RoomMessage(content.clone()))
.await
.unwrap();
@ -1867,13 +1686,8 @@ pub(crate) mod test {
unsigned: Unsigned::default(),
};
let decrypted_event = bob
.decrypt_room_event(&event, &room_id)
.await
.unwrap()
.event
.deserialize()
.unwrap();
let decrypted_event =
bob.decrypt_room_event(&event, &room_id).await.unwrap().event.deserialize().unwrap();
if let AnySyncRoomEvent::Message(AnySyncMessageEvent::RoomMessage(SyncMessageEvent {
sender,
@ -1912,10 +1726,7 @@ pub(crate) mod test {
let device_id = machine.device_id().to_owned();
let ed25519_key = machine.identity_keys().ed25519().to_owned();
machine
.receive_keys_upload_response(&keys_upload_response())
.await
.unwrap();
machine.receive_keys_upload_response(&keys_upload_response()).await.unwrap();
drop(machine);
@ -1937,11 +1748,7 @@ pub(crate) mod test {
async fn interactive_verification() {
let (alice, bob) = get_machine_pair_with_setup_sessions().await;
let bob_device = alice
.get_device(bob.user_id(), bob.device_id())
.await
.unwrap()
.unwrap();
let bob_device = alice.get_device(bob.user_id(), bob.device_id()).await.unwrap().unwrap();
assert!(!bob_device.is_trusted());
@ -1955,10 +1762,7 @@ pub(crate) mod test {
assert!(alice_sas.emoji().is_none());
assert!(bob_sas.emoji().is_none());
let event = bob_sas
.accept()
.map(|r| request_to_event(bob.user_id(), &r))
.unwrap();
let event = bob_sas.accept().map(|r| request_to_event(bob.user_id(), &r)).unwrap();
alice.handle_verification_event(&event).await;
@ -2007,11 +1811,8 @@ pub(crate) mod test {
assert!(alice_sas.is_done());
assert!(bob_device.is_trusted());
let alice_device = bob
.get_device(alice.user_id(), alice.device_id())
.await
.unwrap()
.unwrap();
let alice_device =
bob.get_device(alice.user_id(), alice.device_id()).await.unwrap().unwrap();
assert!(!alice_device.is_trusted());
bob.handle_verification_event(&event).await;

View file

@ -12,10 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use matrix_sdk_common::events::ToDeviceEvent;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use sha2::{Digest, Sha256};
use std::{
collections::BTreeMap,
convert::{TryFrom, TryInto},
@ -26,7 +22,6 @@ use std::{
Arc,
},
};
use tracing::{debug, trace, warn};
#[cfg(test)]
use matrix_sdk_common::events::EventType;
@ -37,7 +32,7 @@ use matrix_sdk_common::{
encryption::DeviceKeys,
events::{
room::encrypted::{EncryptedEventContent, EncryptedEventScheme},
AnyToDeviceEvent,
AnyToDeviceEvent, ToDeviceEvent,
},
identifiers::{
DeviceId, DeviceIdBox, DeviceKeyAlgorithm, DeviceKeyId, EventEncryptionAlgorithm, RoomId,
@ -53,7 +48,15 @@ use olm_rs::{
session::{OlmMessage, PreKeyMessage},
PicklingMode,
};
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use sha2::{Digest, Sha256};
use tracing::{debug, trace, warn};
use super::{
EncryptionSettings, InboundGroupSession, OutboundGroupSession, PrivateCrossSigningIdentity,
Session,
};
use crate::{
error::{EventError, OlmResult, SessionCreationError},
identities::ReadOnlyDevice,
@ -63,11 +66,6 @@ use crate::{
OlmError,
};
use super::{
EncryptionSettings, InboundGroupSession, OutboundGroupSession, PrivateCrossSigningIdentity,
Session,
};
#[derive(Debug, Clone)]
pub struct Account {
pub(crate) inner: ReadOnlyAccount,
@ -141,10 +139,8 @@ impl Account {
// Try to find a ciphertext that was meant for our device.
if let Some(ciphertext) = own_ciphertext {
let message_type: u8 = ciphertext
.message_type
.try_into()
.map_err(|_| EventError::UnsupportedOlmType)?;
let message_type: u8 =
ciphertext.message_type.try_into().map_err(|_| EventError::UnsupportedOlmType)?;
let sha = Sha256::new()
.chain(&content.sender_key)
@ -162,20 +158,18 @@ impl Account {
.map_err(|_| EventError::UnsupportedOlmType)?;
// Decrypt the OlmMessage and get a Ruma event out of it.
let (session, event, signing_key) = match self
.decrypt_olm_message(&event.sender, &content.sender_key, message)
.await
{
Ok(d) => d,
Err(OlmError::SessionWedged(user_id, sender_key)) => {
if self.store.is_message_known(&message_hash).await? {
return Err(OlmError::ReplayedMessage(user_id, sender_key));
} else {
return Err(OlmError::SessionWedged(user_id, sender_key));
let (session, event, signing_key) =
match self.decrypt_olm_message(&event.sender, &content.sender_key, message).await {
Ok(d) => d,
Err(OlmError::SessionWedged(user_id, sender_key)) => {
if self.store.is_message_known(&message_hash).await? {
return Err(OlmError::ReplayedMessage(user_id, sender_key));
} else {
return Err(OlmError::SessionWedged(user_id, sender_key));
}
}
}
Err(e) => return Err(e),
};
Err(e) => return Err(e),
};
debug!("Decrypted a to-device event {:?}", event);
@ -210,9 +204,8 @@ impl Account {
}
self.inner.mark_as_shared();
let one_time_key_count = response
.one_time_key_counts
.get(&DeviceKeyAlgorithm::SignedCurve25519);
let one_time_key_count =
response.one_time_key_counts.get(&DeviceKeyAlgorithm::SignedCurve25519);
let count: u64 = one_time_key_count.map_or(0, |c| (*c).into());
debug!(
@ -297,9 +290,8 @@ impl Account {
message: OlmMessage,
) -> OlmResult<(SessionType, Raw<AnyToDeviceEvent>, String)> {
// First try to decrypt using an existing session.
let (session, plaintext) = if let Some(d) = self
.try_decrypt_olm_message(sender, sender_key, &message)
.await?
let (session, plaintext) = if let Some(d) =
self.try_decrypt_olm_message(sender, sender_key, &message).await?
{
// Decryption succeeded, de-structure the session/plaintext out of
// the Option.
@ -316,32 +308,26 @@ impl Account {
available sessions {} {}",
sender, sender_key
);
return Err(OlmError::SessionWedged(
sender.to_owned(),
sender_key.to_owned(),
));
return Err(OlmError::SessionWedged(sender.to_owned(), sender_key.to_owned()));
}
OlmMessage::PreKey(m) => {
// Create the new session.
let session = match self
.inner
.create_inbound_session(sender_key, m.clone())
.await
{
Ok(s) => s,
Err(e) => {
warn!(
"Failed to create a new Olm session for {} {}
let session =
match self.inner.create_inbound_session(sender_key, m.clone()).await {
Ok(s) => s,
Err(e) => {
warn!(
"Failed to create a new Olm session for {} {}
from a prekey message: {}",
sender, sender_key, e
);
return Err(OlmError::SessionWedged(
sender.to_owned(),
sender_key.to_owned(),
));
}
};
sender, sender_key, e
);
return Err(OlmError::SessionWedged(
sender.to_owned(),
sender_key.to_owned(),
));
}
};
session
}
@ -428,9 +414,8 @@ impl Account {
return Err(EventError::MissmatchedKeys.into());
}
let signing_key = keys
.get(&DeviceKeyAlgorithm::Ed25519)
.ok_or(EventError::MissingSigningKey)?;
let signing_key =
keys.get(&DeviceKeyAlgorithm::Ed25519).ok_or(EventError::MissingSigningKey)?;
Ok((
Raw::from(serde_json::from_value::<AnyToDeviceEvent>(decrypted_json)?),
@ -547,8 +532,7 @@ impl ReadOnlyAccount {
/// * `new_count` - The new count that was reported by the server.
pub(crate) fn update_uploaded_key_count(&self, new_count: u64) {
let key_count = i64::try_from(new_count).unwrap_or(i64::MAX);
self.uploaded_signed_key_count
.store(key_count, Ordering::Relaxed);
self.uploaded_signed_key_count.store(key_count, Ordering::Relaxed);
}
/// Get the currently known uploaded key count.
@ -631,19 +615,12 @@ impl ReadOnlyAccount {
/// Returns None if no keys need to be uploaded.
pub(crate) async fn keys_for_upload(
&self,
) -> Option<(
Option<DeviceKeys>,
Option<BTreeMap<DeviceKeyId, OneTimeKey>>,
)> {
) -> Option<(Option<DeviceKeys>, Option<BTreeMap<DeviceKeyId, OneTimeKey>>)> {
if !self.should_upload_keys().await {
return None;
}
let device_keys = if !self.shared() {
Some(self.device_keys().await)
} else {
None
};
let device_keys = if !self.shared() { Some(self.device_keys().await) } else { None };
let one_time_keys = self.signed_one_time_keys().await.ok();
@ -666,7 +643,8 @@ impl ReadOnlyAccount {
///
/// # Arguments
///
/// * `pickle_mode` - The mode that was used to pickle the account, either an
/// * `pickle_mode` - The mode that was used to pickle the account, either
/// an
/// unencrypted mode or an encrypted using passphrase.
pub async fn pickle(&self, pickle_mode: PicklingMode) -> PickledAccount {
let pickle = AccountPickle(self.inner.lock().await.pickle(pickle_mode));
@ -686,7 +664,8 @@ impl ReadOnlyAccount {
///
/// * `pickle` - The pickled version of the Account.
///
/// * `pickle_mode` - The mode that was used to pickle the account, either an
/// * `pickle_mode` - The mode that was used to pickle the account, either
/// an
/// unencrypted mode or an encrypted using passphrase.
pub fn from_pickle(
pickle: PickledAccount,
@ -742,25 +721,17 @@ impl ReadOnlyAccount {
"keys": device_keys.keys,
});
device_keys
.signatures
.entry(self.user_id().clone())
.or_insert_with(BTreeMap::new)
.insert(
DeviceKeyId::from_parts(DeviceKeyAlgorithm::Ed25519, &self.device_id),
self.sign_json(json_device_keys).await,
);
device_keys.signatures.entry(self.user_id().clone()).or_insert_with(BTreeMap::new).insert(
DeviceKeyId::from_parts(DeviceKeyAlgorithm::Ed25519, &self.device_id),
self.sign_json(json_device_keys).await,
);
device_keys
}
pub(crate) async fn bootstrap_cross_signing(
&self,
) -> (
PrivateCrossSigningIdentity,
UploadSigningKeysRequest,
SignatureUploadRequest,
) {
) -> (PrivateCrossSigningIdentity, UploadSigningKeysRequest, SignatureUploadRequest) {
PrivateCrossSigningIdentity::new_with_account(self).await
}
@ -873,8 +844,8 @@ impl ReadOnlyAccount {
/// # Arguments
/// * `device` - The other account's device.
///
/// * `key_map` - A map from the algorithm and device id to the one-time
/// key that the other account created and shared with us.
/// * `key_map` - A map from the algorithm and device id to the one-time key
/// that the other account created and shared with us.
pub(crate) async fn create_outbound_session(
&self,
device: ReadOnlyDevice,
@ -911,24 +882,20 @@ impl ReadOnlyAccount {
)
})?;
let curve_key = device
.get_key(DeviceKeyAlgorithm::Curve25519)
.ok_or_else(|| {
SessionCreationError::DeviceMissingCurveKey(
device.user_id().to_owned(),
device.device_id().into(),
)
})?;
let curve_key = device.get_key(DeviceKeyAlgorithm::Curve25519).ok_or_else(|| {
SessionCreationError::DeviceMissingCurveKey(
device.user_id().to_owned(),
device.device_id().into(),
)
})?;
self.create_outbound_session_helper(curve_key, &one_time_key)
.await
.map_err(|e| {
SessionCreationError::OlmError(
device.user_id().to_owned(),
device.device_id().into(),
e,
)
})
self.create_outbound_session_helper(curve_key, &one_time_key).await.map_err(|e| {
SessionCreationError::OlmError(
device.user_id().to_owned(),
device.device_id().into(),
e,
)
})
}
/// Create a new session with another account given a pre-key Olm message.
@ -946,17 +913,10 @@ impl ReadOnlyAccount {
their_identity_key: &str,
message: PreKeyMessage,
) -> Result<Session, OlmSessionError> {
let session = self
.inner
.lock()
.await
.create_inbound_session_from(their_identity_key, message)?;
let session =
self.inner.lock().await.create_inbound_session_from(their_identity_key, message)?;
self.inner
.lock()
.await
.remove_one_time_keys(&session)
.expect(
self.inner.lock().await.remove_one_time_keys(&session).expect(
"Session was successfully created but the account doesn't hold a matching one-time key",
);
@ -1028,8 +988,7 @@ impl ReadOnlyAccount {
&self,
room_id: &RoomId,
) -> Result<(OutboundGroupSession, InboundGroupSession), ()> {
self.create_group_session_pair(room_id, EncryptionSettings::default())
.await
self.create_group_session_pair(room_id, EncryptionSettings::default()).await
}
#[cfg(test)]
@ -1039,27 +998,19 @@ impl ReadOnlyAccount {
let device = ReadOnlyDevice::from_account(other).await;
let mut our_session = self
.create_outbound_session(device.clone(), &one_time)
.await
.unwrap();
let mut our_session =
self.create_outbound_session(device.clone(), &one_time).await.unwrap();
other.mark_keys_as_published().await;
let message = our_session
.encrypt(&device, EventType::Dummy, json!({}))
.await
.unwrap();
let message = our_session.encrypt(&device, EventType::Dummy, json!({})).await.unwrap();
let content = if let EncryptedEventScheme::OlmV1Curve25519AesSha2(c) = message.scheme {
c
} else {
panic!("Invalid encrypted event algorithm");
};
let own_ciphertext = content
.ciphertext
.get(other.identity_keys.curve25519())
.unwrap();
let own_ciphertext = content.ciphertext.get(other.identity_keys.curve25519()).unwrap();
let message_type: u8 = own_ciphertext.message_type.try_into().unwrap();
let message =

View file

@ -19,19 +19,6 @@ use std::{
sync::Arc,
};
use olm_rs::{
errors::OlmGroupSessionError, inbound_group_session::OlmInboundGroupSession, PicklingMode,
};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use zeroize::Zeroizing;
pub use olm_rs::{
account::IdentityKeys,
session::{OlmMessage, PreKeyMessage},
utility::OlmUtility,
};
use matrix_sdk_common::{
events::{
forwarded_room_key::ForwardedRoomKeyToDeviceEventContent,
@ -45,6 +32,17 @@ use matrix_sdk_common::{
locks::Mutex,
Raw,
};
pub use olm_rs::{
account::IdentityKeys,
session::{OlmMessage, PreKeyMessage},
utility::OlmUtility,
};
use olm_rs::{
errors::OlmGroupSessionError, inbound_group_session::OlmInboundGroupSession, PicklingMode,
};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use zeroize::Zeroizing;
use super::{ExportedGroupSessionKey, ExportedRoomKey, GroupSessionKey};
use crate::error::{EventError, MegolmResult};
@ -149,10 +147,8 @@ impl InboundGroupSession {
forwarding_chains.push(sender_key.to_owned());
let mut sender_claimed_key = BTreeMap::new();
sender_claimed_key.insert(
DeviceKeyAlgorithm::Ed25519,
content.sender_claimed_ed25519_key.to_owned(),
);
sender_claimed_key
.insert(DeviceKeyAlgorithm::Ed25519, content.sender_claimed_ed25519_key.to_owned());
Ok(InboundGroupSession {
inner: Mutex::new(session).into(),
@ -219,11 +215,7 @@ impl InboundGroupSession {
let message_index = std::cmp::max(self.first_known_index(), message_index);
let session_key = ExportedGroupSessionKey(
self.inner
.lock()
.await
.export(message_index)
.expect("Can't export session"),
self.inner.lock().await.export(message_index).expect("Can't export session"),
);
ExportedRoomKey {
@ -316,9 +308,7 @@ impl InboundGroupSession {
let (plaintext, message_index) = self.decrypt_helper(content.ciphertext.clone()).await?;
let mut decrypted_value = serde_json::from_str::<Value>(&plaintext)?;
let decrypted_object = decrypted_value
.as_object_mut()
.ok_or(EventError::NotAnObject)?;
let decrypted_object = decrypted_value.as_object_mut().ok_or(EventError::NotAnObject)?;
// TODO better number conversion here.
let server_ts = event
@ -337,10 +327,8 @@ impl InboundGroupSession {
serde_json::to_value(&event.unsigned).unwrap_or_default(),
);
if let Some(decrypted_content) = decrypted_object
.get_mut("content")
.map(|c| c.as_object_mut())
.flatten()
if let Some(decrypted_content) =
decrypted_object.get_mut("content").map(|c| c.as_object_mut()).flatten()
{
if !decrypted_content.contains_key("m.relates_to") {
if let Some(relation) = &event.content.relates_to {
@ -352,19 +340,14 @@ impl InboundGroupSession {
}
}
Ok((
serde_json::from_value::<Raw<AnySyncRoomEvent>>(decrypted_value)?,
message_index,
))
Ok((serde_json::from_value::<Raw<AnySyncRoomEvent>>(decrypted_value)?, message_index))
}
}
#[cfg(not(tarpaulin_include))]
impl fmt::Debug for InboundGroupSession {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("InboundGroupSession")
.field("session_id", &self.session_id())
.finish()
f.debug_struct("InboundGroupSession").field("session_id", &self.session_id()).finish()
}
}
@ -399,7 +382,8 @@ pub struct PickledInboundGroupSession {
pub history_visibility: Option<HistoryVisibility>,
}
/// The typed representation of a base64 encoded string of the GroupSession pickle.
/// The typed representation of a base64 encoded string of the GroupSession
/// pickle.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InboundGroupSessionPickle(String);

View file

@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{collections::BTreeMap, convert::TryInto};
use matrix_sdk_common::{
events::forwarded_room_key::{
ForwardedRoomKeyToDeviceEventContent, ForwardedRoomKeyToDeviceEventContentInit,
@ -19,7 +21,6 @@ use matrix_sdk_common::{
identifiers::{DeviceKeyAlgorithm, EventEncryptionAlgorithm, RoomId},
};
use serde::{Deserialize, Serialize};
use std::{collections::BTreeMap, convert::TryInto};
use zeroize::Zeroize;
mod inbound;
@ -107,10 +108,8 @@ impl From<ForwardedRoomKeyToDeviceEventContent> for ExportedRoomKey {
/// Convert the content of a forwarded room key into a exported room key.
fn from(forwarded_key: ForwardedRoomKeyToDeviceEventContent) -> Self {
let mut sender_claimed_keys: BTreeMap<DeviceKeyAlgorithm, String> = BTreeMap::new();
sender_claimed_keys.insert(
DeviceKeyAlgorithm::Ed25519,
forwarded_key.sender_claimed_ed25519_key,
);
sender_claimed_keys
.insert(DeviceKeyAlgorithm::Ed25519, forwarded_key.sender_claimed_ed25519_key);
Self {
algorithm: forwarded_key.algorithm,
@ -142,10 +141,7 @@ mod test {
#[tokio::test]
#[cfg(target_os = "linux")]
async fn expiration() {
let settings = EncryptionSettings {
rotation_period_msgs: 1,
..Default::default()
};
let settings = EncryptionSettings { rotation_period_msgs: 1, ..Default::default() };
let account = ReadOnlyAccount::new(&user_id!("@alice:example.org"), "DEVICEID".into());
let (session, _) = account
@ -155,9 +151,9 @@ mod test {
assert!(!session.expired());
let _ = session
.encrypt(AnyMessageEventContent::RoomMessage(
MessageEventContent::text_plain("Test message"),
))
.encrypt(AnyMessageEventContent::RoomMessage(MessageEventContent::text_plain(
"Test message",
)))
.await;
assert!(session.expired());

View file

@ -12,15 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use dashmap::DashMap;
use matrix_sdk_common::{
api::r0::to_device::DeviceIdOrAllDevices,
events::room::{
encrypted::MegolmV1AesSha2ContentInit, history_visibility::HistoryVisibility,
message::Relation,
},
uuid::Uuid,
};
use std::{
cmp::max,
collections::BTreeMap,
@ -31,23 +22,24 @@ use std::{
},
time::Duration,
};
use tracing::{debug, error, trace};
use dashmap::DashMap;
use matrix_sdk_common::{
api::r0::to_device::DeviceIdOrAllDevices,
events::{
room::{
encrypted::{EncryptedEventContent, EncryptedEventScheme},
encrypted::{EncryptedEventContent, EncryptedEventScheme, MegolmV1AesSha2ContentInit},
encryption::EncryptionEventContent,
history_visibility::HistoryVisibility,
message::Relation,
},
AnyMessageEventContent, EventContent,
},
identifiers::{DeviceId, DeviceIdBox, EventEncryptionAlgorithm, RoomId, UserId},
instant::Instant,
locks::Mutex,
uuid::Uuid,
};
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
pub use olm_rs::{
account::IdentityKeys,
session::{OlmMessage, PreKeyMessage},
@ -56,13 +48,15 @@ pub use olm_rs::{
use olm_rs::{
errors::OlmGroupSessionError, outbound_group_session::OlmOutboundGroupSession, PicklingMode,
};
use crate::ToDeviceRequest;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use tracing::{debug, error, trace};
use super::{
super::{deserialize_instant, serialize_instant},
GroupSessionKey,
};
use crate::ToDeviceRequest;
const ROTATION_PERIOD: Duration = Duration::from_millis(604800000);
const ROTATION_MESSAGES: u64 = 100;
@ -102,12 +96,10 @@ impl EncryptionSettings {
/// Create new encryption settings using an `EncryptionEventContent` and a
/// history visibility.
pub fn new(content: EncryptionEventContent, history_visibility: HistoryVisibility) -> Self {
let rotation_period: Duration = content
.rotation_period_ms
.map_or(ROTATION_PERIOD, |r| Duration::from_millis(r.into()));
let rotation_period_msgs: u64 = content
.rotation_period_msgs
.map_or(ROTATION_MESSAGES, Into::into);
let rotation_period: Duration =
content.rotation_period_ms.map_or(ROTATION_PERIOD, |r| Duration::from_millis(r.into()));
let rotation_period_msgs: u64 =
content.rotation_period_msgs.map_or(ROTATION_MESSAGES, Into::into);
Self {
algorithm: content.algorithm,
@ -186,8 +178,7 @@ impl OutboundGroupSession {
request: Arc<ToDeviceRequest>,
message_index: u32,
) {
self.to_share_with_set
.insert(request_id, (request, message_index));
self.to_share_with_set.insert(request_id, (request, message_index));
}
/// This should be called if an the user wishes to rotate this session.
@ -225,10 +216,7 @@ impl OutboundGroupSession {
});
user_pairs.for_each(|(u, d)| {
self.shared_with_set
.entry(u)
.or_insert_with(DashMap::new)
.extend(d);
self.shared_with_set.entry(u).or_insert_with(DashMap::new).extend(d);
});
if self.to_share_with_set.is_empty() {
@ -241,11 +229,8 @@ impl OutboundGroupSession {
self.mark_as_shared();
}
} else {
let request_ids: Vec<String> = self
.to_share_with_set
.iter()
.map(|e| e.key().to_string())
.collect();
let request_ids: Vec<String> =
self.to_share_with_set.iter().map(|e| e.key().to_string()).collect();
error!(
all_request_ids = ?request_ids,
@ -296,11 +281,7 @@ impl OutboundGroupSession {
let relates_to: Option<Relation> = json_content
.get("content")
.map(|c| {
c.get("m.relates_to")
.cloned()
.map(|r| serde_json::from_value(r).ok())
})
.map(|c| c.get("m.relates_to").cloned().map(|r| serde_json::from_value(r).ok()))
.flatten()
.flatten();
@ -443,10 +424,7 @@ impl OutboundGroupSession {
/// Get the list of requests that need to be sent out for this session to be
/// marked as shared.
pub(crate) fn pending_requests(&self) -> Vec<Arc<ToDeviceRequest>> {
self.to_share_with_set
.iter()
.map(|i| i.value().0.clone())
.collect()
self.to_share_with_set.iter().map(|i| i.value().0.clone()).collect()
}
/// Get the list of request ids this session is waiting for to be sent out.
@ -462,10 +440,10 @@ impl OutboundGroupSession {
/// # Arguments
///
/// * `device_id` - The device id of the device that created this session.
/// Put differently, our own device id.
/// Put differently, our own device id.
///
/// * `identity_keys` - The identity keys of the device that created this
/// session, our own identity keys.
/// session, our own identity keys.
///
/// * `pickle` - The pickled version of the `OutboundGroupSession`.
///
@ -507,7 +485,8 @@ impl OutboundGroupSession {
///
/// # Arguments
///
/// * `pickle_mode` - The mode that should be used to pickle the group session,
/// * `pickle_mode` - The mode that should be used to pickle the group
/// session,
/// either an unencrypted mode or an encrypted using passphrase.
pub async fn pickle(&self, pickling_mode: PicklingMode) -> PickledOutboundGroupSession {
let pickle: OutboundGroupSessionPickle =
@ -528,10 +507,7 @@ impl OutboundGroupSession {
(
u.key().clone(),
#[allow(clippy::map_clone)]
u.value()
.iter()
.map(|d| (d.key().clone(), *d.value()))
.collect(),
u.value().iter().map(|d| (d.key().clone(), *d.value())).collect(),
)
})
.collect(),
@ -578,10 +554,7 @@ pub struct PickledOutboundGroupSession {
/// The room id this session is used for.
pub room_id: Arc<RoomId>,
/// The timestamp when this session was created.
#[serde(
deserialize_with = "deserialize_instant",
serialize_with = "serialize_instant"
)]
#[serde(deserialize_with = "deserialize_instant", serialize_with = "serialize_instant")]
pub creation_time: Instant,
/// The number of messages this session has already encrypted.
pub message_count: u64,

View file

@ -30,14 +30,13 @@ pub use group_sessions::{
OutboundGroupSession, PickledInboundGroupSession, PickledOutboundGroupSession,
};
pub(crate) use group_sessions::{GroupSessionKey, ShareState};
use matrix_sdk_common::instant::{Duration, Instant};
pub use olm_rs::{account::IdentityKeys, PicklingMode};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
pub use session::{PickledSession, Session, SessionPickle};
pub use signing::{PickledCrossSigningIdentity, PrivateCrossSigningIdentity};
pub(crate) use utility::Utility;
use matrix_sdk_common::instant::{Duration, Instant};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
pub(crate) fn serialize_instant<S>(instant: &Instant, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
@ -60,14 +59,16 @@ where
#[cfg(test)]
pub(crate) mod test {
use crate::olm::{InboundGroupSession, ReadOnlyAccount, Session};
use std::{collections::BTreeMap, convert::TryInto};
use matrix_sdk_common::{
api::r0::keys::SignedKey,
events::forwarded_room_key::ForwardedRoomKeyToDeviceEventContent,
identifiers::{room_id, user_id, DeviceId, UserId},
};
use olm_rs::session::OlmMessage;
use std::{collections::BTreeMap, convert::TryInto};
use crate::olm::{InboundGroupSession, ReadOnlyAccount, Session};
fn alice_id() -> UserId {
user_id!("@alice:example.org")
@ -90,21 +91,12 @@ pub(crate) mod test {
let bob = ReadOnlyAccount::new(&bob_id(), &bob_device_id());
bob.generate_one_time_keys_helper(1).await;
let one_time_key = bob
.one_time_keys()
.await
.curve25519()
.iter()
.next()
.unwrap()
.1
.to_owned();
let one_time_key =
bob.one_time_keys().await.curve25519().iter().next().unwrap().1.to_owned();
let one_time_key = SignedKey::new(one_time_key, BTreeMap::new());
let sender_key = bob.identity_keys().curve25519().to_owned();
let session = alice
.create_outbound_session_helper(&sender_key, &one_time_key)
.await
.unwrap();
let session =
alice.create_outbound_session_helper(&sender_key, &one_time_key).await.unwrap();
(alice, session)
}
@ -120,10 +112,7 @@ pub(crate) mod test {
assert_ne!(identity_keys.keys().len(), 0);
assert_ne!(identity_keys.iter().len(), 0);
assert!(identity_keys.contains_key("ed25519"));
assert_eq!(
identity_keys.ed25519(),
identity_keys.get("ed25519").unwrap()
);
assert_eq!(identity_keys.ed25519(), identity_keys.get("ed25519").unwrap());
assert!(!identity_keys.curve25519().is_empty());
account.mark_as_shared();
@ -147,10 +136,7 @@ pub(crate) mod test {
assert_ne!(one_time_keys.iter().len(), 0);
assert!(one_time_keys.contains_key("curve25519"));
assert_eq!(one_time_keys.curve25519().keys().len(), 10);
assert_eq!(
one_time_keys.curve25519(),
one_time_keys.get("curve25519").unwrap()
);
assert_eq!(one_time_keys.curve25519(), one_time_keys.get("curve25519").unwrap());
account.mark_keys_as_published().await;
let one_time_keys = account.one_time_keys().await;
@ -166,13 +152,7 @@ pub(crate) mod test {
let one_time_keys = alice.one_time_keys().await;
alice.mark_keys_as_published().await;
let one_time_key = one_time_keys
.curve25519()
.iter()
.next()
.unwrap()
.1
.to_owned();
let one_time_key = one_time_keys.curve25519().iter().next().unwrap().1.to_owned();
let one_time_key = SignedKey::new(one_time_key, BTreeMap::new());
@ -196,10 +176,7 @@ pub(crate) mod test {
.await
.unwrap();
assert!(alice_session
.matches(bob_keys.curve25519(), prekey_message)
.await
.unwrap());
assert!(alice_session.matches(bob_keys.curve25519(), prekey_message).await.unwrap());
assert_eq!(bob_session.session_id(), alice_session.session_id());
@ -212,10 +189,7 @@ pub(crate) mod test {
let alice = ReadOnlyAccount::new(&alice_id(), &alice_device_id());
let room_id = room_id!("!test:localhost");
let (outbound, _) = alice
.create_group_session_pair_with_defaults(&room_id)
.await
.unwrap();
let (outbound, _) = alice.create_group_session_pair_with_defaults(&room_id).await.unwrap();
assert_eq!(0, outbound.message_index().await);
assert!(!outbound.shared());
@ -238,10 +212,7 @@ pub(crate) mod test {
let plaintext = "This is a secret to everybody".to_owned();
let ciphertext = outbound.encrypt_helper(plaintext.clone()).await;
assert_eq!(
plaintext,
inbound.decrypt_helper(ciphertext).await.unwrap().0
);
assert_eq!(plaintext, inbound.decrypt_helper(ciphertext).await.unwrap().0);
}
#[tokio::test]
@ -249,10 +220,7 @@ pub(crate) mod test {
let alice = ReadOnlyAccount::new(&alice_id(), &alice_device_id());
let room_id = room_id!("!test:localhost");
let (_, inbound) = alice
.create_group_session_pair_with_defaults(&room_id)
.await
.unwrap();
let (_, inbound) = alice.create_group_session_pair_with_defaults(&room_id).await.unwrap();
let export = inbound.export().await;
let export: ForwardedRoomKeyToDeviceEventContent = export.try_into().unwrap();

View file

@ -27,20 +27,18 @@ use matrix_sdk_common::{
locks::Mutex,
};
use olm_rs::{errors::OlmSessionError, session::OlmSession, PicklingMode};
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use crate::{
error::{EventError, OlmResult, SessionUnpicklingError},
ReadOnlyDevice,
};
pub use olm_rs::{
session::{OlmMessage, PreKeyMessage},
utility::OlmUtility,
};
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use super::{deserialize_instant, serialize_instant, IdentityKeys};
use crate::{
error::{EventError, OlmResult, SessionUnpicklingError},
ReadOnlyDevice,
};
/// Cryptographic session that enables secure communication between two
/// `Account`s
@ -105,8 +103,8 @@ impl Session {
/// # Arguments
///
/// * `recipient_device` - The device for which this message is going to be
/// encrypted, this needs to be the device that was used to create this
/// session with.
/// encrypted, this needs to be the device that was used to create this
/// session with.
///
/// * `event_type` - The type of the event.
///
@ -121,10 +119,8 @@ impl Session {
.get_key(DeviceKeyAlgorithm::Ed25519)
.ok_or(EventError::MissingSigningKey)?;
let relates_to = content
.get("m.relates_to")
.cloned()
.and_then(|v| serde_json::from_value(v).ok());
let relates_to =
content.get("m.relates_to").cloned().and_then(|v| serde_json::from_value(v).ok());
let payload = json!({
"sender": self.user_id.as_str(),
@ -174,10 +170,7 @@ impl Session {
their_identity_key: &str,
message: PreKeyMessage,
) -> Result<bool, OlmSessionError> {
self.inner
.lock()
.await
.matches_inbound_session_from(their_identity_key, message)
self.inner.lock().await.matches_inbound_session_from(their_identity_key, message)
}
/// Returns the unique identifier for this session.
@ -259,20 +252,15 @@ pub struct PickledSession {
/// The curve25519 key of the other user that we share this session with.
pub sender_key: String,
/// The relative time elapsed since the session was created.
#[serde(
deserialize_with = "deserialize_instant",
serialize_with = "serialize_instant"
)]
#[serde(deserialize_with = "deserialize_instant", serialize_with = "serialize_instant")]
pub creation_time: Instant,
/// The relative time elapsed since the session was last used.
#[serde(
deserialize_with = "deserialize_instant",
serialize_with = "serialize_instant"
)]
#[serde(deserialize_with = "deserialize_instant", serialize_with = "serialize_instant")]
pub last_use_time: Instant,
}
/// The typed representation of a base64 encoded string of the Olm Session pickle.
/// The typed representation of a base64 encoded string of the Olm Session
/// pickle.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionPickle(String);

View file

@ -14,8 +14,6 @@
mod pk_signing;
use serde::{Deserialize, Serialize};
use serde_json::Error as JsonError;
use std::{
collections::BTreeMap,
sync::{
@ -30,14 +28,15 @@ use matrix_sdk_common::{
identifiers::{DeviceKeyAlgorithm, DeviceKeyId, UserId},
locks::Mutex,
};
use pk_signing::{MasterSigning, PickledSignings, SelfSigning, Signing, SigningError, UserSigning};
use serde::{Deserialize, Serialize};
use serde_json::Error as JsonError;
use crate::{
error::SignatureError, requests::UploadSigningKeysRequest, OwnUserIdentity, ReadOnlyAccount,
ReadOnlyDevice, UserIdentity,
};
use pk_signing::{MasterSigning, PickledSignings, SelfSigning, Signing, SigningError, UserSigning};
/// Private cross signing identity.
///
/// This object holds the private and public ed25519 key triplet that is used
@ -186,10 +185,7 @@ impl PrivateCrossSigningIdentity {
signed_keys
.entry((&*self.user_id).to_owned())
.or_insert_with(BTreeMap::new)
.insert(
device_keys.device_id.to_string(),
serde_json::to_value(device_keys)?,
);
.insert(device_keys.device_id.to_string(), serde_json::to_value(device_keys)?);
Ok(SignatureUploadRequest::new(signed_keys))
}
@ -229,10 +225,7 @@ impl PrivateCrossSigningIdentity {
signature,
);
let master = MasterSigning {
inner: master,
public_key: public_key.into(),
};
let master = MasterSigning { inner: master, public_key: public_key.into() };
let identity = Self::new_helper(account.user_id(), master).await;
let signature_request = identity
@ -250,20 +243,14 @@ impl PrivateCrossSigningIdentity {
let mut public_key = user.cross_signing_key(user_id.to_owned(), KeyUsage::UserSigning);
master.sign_subkey(&mut public_key).await;
let user = UserSigning {
inner: user,
public_key: public_key.into(),
};
let user = UserSigning { inner: user, public_key: public_key.into() };
let self_signing = Signing::new();
let mut public_key =
self_signing.cross_signing_key(user_id.to_owned(), KeyUsage::SelfSigning);
master.sign_subkey(&mut public_key).await;
let self_signing = SelfSigning {
inner: self_signing,
public_key: public_key.into(),
};
let self_signing = SelfSigning { inner: self_signing, public_key: public_key.into() };
Self {
user_id: Arc::new(user_id.to_owned()),
@ -281,10 +268,7 @@ impl PrivateCrossSigningIdentity {
let master = Signing::new();
let public_key = master.cross_signing_key(user_id.clone(), KeyUsage::Master);
let master = MasterSigning {
inner: master,
public_key: public_key.into(),
};
let master = MasterSigning { inner: master, public_key: public_key.into() };
Self::new_helper(&user_id, master).await
}
@ -334,11 +318,7 @@ impl PrivateCrossSigningIdentity {
None
};
let pickle = PickledSignings {
master_key,
user_signing_key,
self_signing_key,
};
let pickle = PickledSignings { master_key, user_signing_key, self_signing_key };
let pickle = serde_json::to_string(&pickle)?;
@ -390,54 +370,35 @@ impl PrivateCrossSigningIdentity {
/// Get the upload request that is needed to share the public keys of this
/// identity.
pub(crate) async fn as_upload_request(&self) -> UploadSigningKeysRequest {
let master_key = self
.master_key
.lock()
.await
.as_ref()
.cloned()
.map(|k| k.public_key.into());
let master_key =
self.master_key.lock().await.as_ref().cloned().map(|k| k.public_key.into());
let user_signing_key = self
.user_signing_key
.lock()
.await
.as_ref()
.cloned()
.map(|k| k.public_key.into());
let user_signing_key =
self.user_signing_key.lock().await.as_ref().cloned().map(|k| k.public_key.into());
let self_signing_key = self
.self_signing_key
.lock()
.await
.as_ref()
.cloned()
.map(|k| k.public_key.into());
let self_signing_key =
self.self_signing_key.lock().await.as_ref().cloned().map(|k| k.public_key.into());
UploadSigningKeysRequest {
master_key,
self_signing_key,
user_signing_key,
}
UploadSigningKeysRequest { master_key, self_signing_key, user_signing_key }
}
}
#[cfg(test)]
mod test {
use crate::{
identities::{ReadOnlyDevice, UserIdentity},
olm::ReadOnlyAccount,
};
use std::{collections::BTreeMap, sync::Arc};
use super::{PrivateCrossSigningIdentity, Signing};
use matrix_sdk_common::{
api::r0::keys::CrossSigningKey,
identifiers::{user_id, UserId},
};
use matrix_sdk_test::async_test;
use super::{PrivateCrossSigningIdentity, Signing};
use crate::{
identities::{ReadOnlyDevice, UserIdentity},
olm::ReadOnlyAccount,
};
fn user_id() -> UserId {
user_id!("@example:localhost")
}
@ -481,28 +442,12 @@ mod test {
assert!(master_key
.public_key
.verify_subkey(
&identity
.self_signing_key
.lock()
.await
.as_ref()
.unwrap()
.public_key,
)
.verify_subkey(&identity.self_signing_key.lock().await.as_ref().unwrap().public_key,)
.is_ok());
assert!(master_key
.public_key
.verify_subkey(
&identity
.user_signing_key
.lock()
.await
.as_ref()
.unwrap()
.public_key,
)
.verify_subkey(&identity.user_signing_key.lock().await.as_ref().unwrap().public_key,)
.is_ok());
}
@ -512,15 +457,11 @@ mod test {
let pickled = identity.pickle(pickle_key()).await.unwrap();
let unpickled = PrivateCrossSigningIdentity::from_pickle(pickled, pickle_key())
.await
.unwrap();
let unpickled =
PrivateCrossSigningIdentity::from_pickle(pickled, pickle_key()).await.unwrap();
assert_eq!(identity.user_id, unpickled.user_id);
assert_eq!(
&*identity.master_key.lock().await,
&*unpickled.master_key.lock().await
);
assert_eq!(&*identity.master_key.lock().await, &*unpickled.master_key.lock().await);
assert_eq!(
&*identity.user_signing_key.lock().await,
&*unpickled.user_signing_key.lock().await
@ -591,9 +532,6 @@ mod test {
bob_public.master_key = master.into();
user_signing
.public_key
.verify_master_key(bob_public.master_key())
.unwrap();
user_signing.public_key.verify_master_key(bob_public.master_key()).unwrap();
}
}

View file

@ -12,32 +12,27 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{collections::BTreeMap, convert::TryInto, sync::Arc};
use aes_gcm::{
aead::{generic_array::GenericArray, Aead, NewAead},
Aes256Gcm,
};
use getrandom::getrandom;
use matrix_sdk_common::{
encryption::DeviceKeys,
identifiers::{DeviceKeyAlgorithm, DeviceKeyId},
};
use serde::{Deserialize, Serialize};
use serde_json::{json, Error as JsonError, Value};
use std::{collections::BTreeMap, convert::TryInto, sync::Arc};
use thiserror::Error;
use zeroize::Zeroizing;
use olm_rs::pk::OlmPkSigning;
#[cfg(test)]
use olm_rs::{errors::OlmUtilityError, utility::OlmUtility};
use matrix_sdk_common::{
api::r0::keys::{CrossSigningKey, KeyUsage},
identifiers::UserId,
encryption::DeviceKeys,
identifiers::{DeviceKeyAlgorithm, DeviceKeyId, UserId},
locks::Mutex,
CanonicalJsonValue,
};
use olm_rs::pk::OlmPkSigning;
#[cfg(test)]
use olm_rs::{errors::OlmUtilityError, utility::OlmUtility};
use serde::{Deserialize, Serialize};
use serde_json::{json, Error as JsonError, Value};
use thiserror::Error;
use zeroize::Zeroizing;
use crate::{
error::SignatureError,
@ -73,9 +68,7 @@ pub struct Signing {
impl std::fmt::Debug for Signing {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Signing")
.field("public_key", &self.public_key.as_str())
.finish()
f.debug_struct("Signing").field("public_key", &self.public_key.as_str()).finish()
}
}
@ -156,10 +149,7 @@ impl MasterSigning {
) -> Result<Self, SigningError> {
let inner = Signing::from_pickle(pickle.pickle, pickle_key)?;
Ok(Self {
inner,
public_key: pickle.public_key.into(),
})
Ok(Self { inner, public_key: pickle.public_key.into() })
}
pub async fn sign_subkey<'a>(&self, subkey: &mut CrossSigningKey) {
@ -200,10 +190,7 @@ impl UserSigning {
user: &UserIdentity,
) -> Result<BTreeMap<UserId, BTreeMap<String, Value>>, SignatureError> {
let user_master: &CrossSigningKey = user.master_key().as_ref();
let signature = self
.inner
.sign_json(serde_json::to_value(user_master)?)
.await?;
let signature = self.inner.sign_json(serde_json::to_value(user_master)?).await?;
let mut signatures = BTreeMap::new();
@ -228,10 +215,7 @@ impl UserSigning {
) -> Result<Self, SigningError> {
let inner = Signing::from_pickle(pickle.pickle, pickle_key)?;
Ok(Self {
inner,
public_key: pickle.public_key.into(),
})
Ok(Self { inner, public_key: pickle.public_key.into() })
}
}
@ -279,10 +263,7 @@ impl SelfSigning {
) -> Result<Self, SigningError> {
let inner = Signing::from_pickle(pickle.pickle, pickle_key)?;
Ok(Self {
inner,
public_key: pickle.public_key.into(),
})
Ok(Self { inner, public_key: pickle.public_key.into() })
}
}
@ -353,17 +334,12 @@ impl Signing {
getrandom(&mut nonce).expect("Can't generate nonce to pickle the signing object");
let nonce = GenericArray::from_slice(nonce.as_slice());
let ciphertext = cipher
.encrypt(nonce, self.seed.as_slice())
.expect("Can't encrypt signing pickle");
let ciphertext =
cipher.encrypt(nonce, self.seed.as_slice()).expect("Can't encrypt signing pickle");
let ciphertext = encode(ciphertext);
let pickle = InnerPickle {
version: 1,
nonce: encode(nonce.as_slice()),
ciphertext,
};
let pickle = InnerPickle { version: 1, nonce: encode(nonce.as_slice()), ciphertext };
PickledSigning(serde_json::to_string(&pickle).expect("Can't encode pickled signing"))
}
@ -376,11 +352,8 @@ impl Signing {
let mut keys = BTreeMap::new();
keys.insert(
DeviceKeyId::from_parts(
DeviceKeyAlgorithm::Ed25519,
self.public_key().as_str().into(),
)
.to_string(),
DeviceKeyId::from_parts(DeviceKeyAlgorithm::Ed25519, self.public_key().as_str().into())
.to_string(),
self.public_key().to_string(),
);

View file

@ -12,14 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use olm_rs::utility::OlmUtility;
use serde_json::Value;
use std::convert::TryInto;
use matrix_sdk_common::{
identifiers::{DeviceKeyAlgorithm, DeviceKeyId, UserId},
CanonicalJsonValue,
};
use olm_rs::utility::OlmUtility;
use serde_json::Value;
use crate::error::SignatureError;
@ -29,9 +29,7 @@ pub(crate) struct Utility {
impl Utility {
pub fn new() -> Self {
Self {
inner: OlmUtility::new(),
}
Self { inner: OlmUtility::new() }
}
/// Verify a signed JSON object.
@ -49,7 +47,7 @@ impl Utility {
/// * `key_id` - The id of the key that signed the JSON object.
///
/// * `signing_key` - The public ed25519 key which was used to sign the JSON
/// object.
/// object.
///
/// * `json` - The JSON object that should be verified.
pub(crate) fn verify_json(
@ -67,29 +65,20 @@ impl Utility {
let unsigned = json_object.remove("unsigned");
let signatures = json_object.remove("signatures");
let canonical_json: CanonicalJsonValue = json
.clone()
.try_into()
.map_err(|_| SignatureError::NotAnObject)?;
let canonical_json: CanonicalJsonValue =
json.clone().try_into().map_err(|_| SignatureError::NotAnObject)?;
let canonical_json: String = canonical_json.to_string();
let signatures = signatures.ok_or(SignatureError::NoSignatureFound)?;
let signature_object = signatures
.as_object()
.ok_or(SignatureError::NoSignatureFound)?;
let signature = signature_object
.get(user_id.as_str())
.ok_or(SignatureError::NoSignatureFound)?;
let signature = signature
.get(key_id.to_string())
.ok_or(SignatureError::NoSignatureFound)?;
let signature_object = signatures.as_object().ok_or(SignatureError::NoSignatureFound)?;
let signature =
signature_object.get(user_id.as_str()).ok_or(SignatureError::NoSignatureFound)?;
let signature =
signature.get(key_id.to_string()).ok_or(SignatureError::NoSignatureFound)?;
let signature = signature.as_str().ok_or(SignatureError::NoSignatureFound)?;
let ret = match self
.inner
.ed25519_verify(signing_key, &canonical_json, signature)
{
let ret = match self.inner.ed25519_verify(signing_key, &canonical_json, signature) {
Ok(_) => Ok(()),
Err(_) => Err(SignatureError::VerificationError),
};
@ -108,10 +97,11 @@ impl Utility {
#[cfg(test)]
mod test {
use super::Utility;
use matrix_sdk_common::identifiers::{user_id, DeviceKeyAlgorithm, DeviceKeyId};
use serde_json::json;
use super::Utility;
#[test]
fn signature_test() {
let mut device_keys = json!({

View file

@ -35,18 +35,19 @@ use matrix_sdk_common::{
identifiers::{DeviceIdBox, RoomId, UserId},
uuid::Uuid,
};
use serde::{Deserialize, Serialize};
use serde_json::value::RawValue as RawJsonValue;
/// Customized version of `ruma_client_api::r0::to_device::send_event_to_device::Request`, using a
/// Customized version of
/// `ruma_client_api::r0::to_device::send_event_to_device::Request`, using a
/// UUID for the transaction ID.
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ToDeviceRequest {
/// Type of event being sent to each device.
pub event_type: EventType,
/// A request identifier unique to the access token used to send the request.
/// A request identifier unique to the access token used to send the
/// request.
pub txn_id: Uuid,
/// A map of users to devices to a content for a message event to be
@ -80,15 +81,18 @@ impl ToDeviceRequest {
pub struct UploadSigningKeysRequest {
/// The user's master key.
pub master_key: Option<CrossSigningKey>,
/// The user's self-signing key. Must be signed with the accompanied master, or by the
/// user's most recently uploaded master key if no master key is included in the request.
/// The user's self-signing key. Must be signed with the accompanied master,
/// or by the user's most recently uploaded master key if no master key
/// is included in the request.
pub self_signing_key: Option<CrossSigningKey>,
/// The user's user-signing key. Must be signed with the accompanied master, or by the
/// user's most recently uploaded master key if no master key is included in the request.
/// The user's user-signing key. Must be signed with the accompanied master,
/// or by the user's most recently uploaded master key if no master key
/// is included in the request.
pub user_signing_key: Option<CrossSigningKey>,
}
/// Customized version of `ruma_client_api::r0::keys::get_keys::Request`, without any references.
/// Customized version of `ruma_client_api::r0::keys::get_keys::Request`,
/// without any references.
#[derive(Clone, Debug)]
pub struct KeysQueryRequest {
/// The time (in milliseconds) to wait when downloading keys from remote
@ -109,11 +113,7 @@ pub struct KeysQueryRequest {
impl KeysQueryRequest {
pub(crate) fn new(device_keys: BTreeMap<UserId, Vec<DeviceIdBox>>) -> Self {
Self {
timeout: None,
device_keys,
token: None,
}
Self { timeout: None, device_keys, token: None }
}
}
@ -177,19 +177,13 @@ impl From<SignatureUploadRequest> for OutgoingRequests {
impl From<OutgoingVerificationRequest> for OutgoingRequest {
fn from(r: OutgoingVerificationRequest) -> Self {
Self {
request_id: r.request_id(),
request: Arc::new(r.into()),
}
Self { request_id: r.request_id(), request: Arc::new(r.into()) }
}
}
impl From<SignatureUploadRequest> for OutgoingRequest {
fn from(r: SignatureUploadRequest) -> Self {
Self {
request_id: Uuid::new_v4(),
request: Arc::new(r.into()),
}
Self { request_id: Uuid::new_v4(), request: Arc::new(r.into()) }
}
}

View file

@ -17,9 +17,8 @@ use std::{
sync::Arc,
};
use futures::future::join_all;
use dashmap::DashMap;
use futures::future::join_all;
use matrix_sdk_common::{
api::r0::to_device::DeviceIdOrAllDevices,
events::{
@ -105,10 +104,7 @@ impl GroupSessionCache {
room_id: &RoomId,
session_id: &str,
) -> StoreResult<Option<OutboundGroupSession>> {
Ok(self
.get_or_load(room_id)
.await?
.filter(|o| session_id == o.session_id()))
Ok(self.get_or_load(room_id).await?.filter(|o| session_id == o.session_id()))
}
}
@ -127,11 +123,7 @@ impl GroupSessionManager {
const MAX_TO_DEVICE_MESSAGES: usize = 250;
pub(crate) fn new(account: Account, store: Store) -> Self {
Self {
account,
store: store.clone(),
sessions: GroupSessionCache::new(store),
}
Self { account, store: store.clone(), sessions: GroupSessionCache::new(store) }
}
pub async fn invalidate_group_session(&self, room_id: &RoomId) -> StoreResult<bool> {
@ -231,9 +223,7 @@ impl GroupSessionManager {
Ok((s, None))
}
} else {
self.create_outbound_group_session(room_id, settings)
.await
.map(|(o, i)| (o, i.into()))
self.create_outbound_group_session(room_id, settings).await.map(|(o, i)| (o, i.into()))
}
}
@ -253,13 +243,10 @@ impl GroupSessionManager {
let used_session = match encrypted {
Ok((session, encrypted)) => {
message
.entry(device.user_id().clone())
.or_insert_with(BTreeMap::new)
.insert(
DeviceIdOrAllDevices::DeviceId(device.device_id().into()),
serde_json::value::to_raw_value(&encrypted)?,
);
message.entry(device.user_id().clone()).or_insert_with(BTreeMap::new).insert(
DeviceIdOrAllDevices::DeviceId(device.device_id().into()),
serde_json::value::to_raw_value(&encrypted)?,
);
Some(session)
}
// TODO we'll want to create m.room_key.withheld here.
@ -271,10 +258,8 @@ impl GroupSessionManager {
Ok((used_session, message))
};
let tasks: Vec<_> = devices
.iter()
.map(|d| spawn(encrypt(d.clone(), content.clone())))
.collect();
let tasks: Vec<_> =
devices.iter().map(|d| spawn(encrypt(d.clone(), content.clone()))).collect();
let results = join_all(tasks).await;
@ -286,20 +271,14 @@ impl GroupSessionManager {
}
for (user, device_messages) in message.into_iter() {
messages
.entry(user)
.or_insert_with(BTreeMap::new)
.extend(device_messages);
messages.entry(user).or_insert_with(BTreeMap::new).extend(device_messages);
}
}
let id = Uuid::new_v4();
let request = ToDeviceRequest {
event_type: EventType::RoomEncrypted,
txn_id: id,
messages,
};
let request =
ToDeviceRequest { event_type: EventType::RoomEncrypted, txn_id: id, messages };
trace!(
recipient_count = request.message_count(),
@ -331,20 +310,14 @@ impl GroupSessionManager {
"Calculating group session recipients"
);
let users_shared_with: HashSet<UserId> = outbound
.shared_with_set
.iter()
.map(|k| k.key().clone())
.collect();
let users_shared_with: HashSet<UserId> =
outbound.shared_with_set.iter().map(|k| k.key().clone()).collect();
let users_shared_with: HashSet<&UserId> = users_shared_with.iter().collect();
// A user left if a user is missing from the set of users that should
// get the session but is in the set of users that received the session.
let user_left = !users_shared_with
.difference(&users)
.collect::<HashSet<_>>()
.is_empty();
let user_left = !users_shared_with.difference(&users).collect::<HashSet<_>>().is_empty();
let visibility_changed = outbound.settings().history_visibility != history_visibility;
@ -359,10 +332,8 @@ impl GroupSessionManager {
for user_id in users {
let user_devices = self.store.get_user_devices(&user_id).await?;
let non_blacklisted_devices: Vec<Device> = user_devices
.devices()
.filter(|d| !d.is_blacklisted())
.collect();
let non_blacklisted_devices: Vec<Device> =
user_devices.devices().filter(|d| !d.is_blacklisted()).collect();
// If we haven't already concluded that the session should be
// rotated for other reasons, we also need to check whether any
@ -370,10 +341,8 @@ impl GroupSessionManager {
// meantime. If so, we should also rotate the session.
if !should_rotate {
// Device IDs that should receive this session
let non_blacklisted_device_ids: HashSet<&DeviceId> = non_blacklisted_devices
.iter()
.map(|d| d.device_id())
.collect();
let non_blacklisted_device_ids: HashSet<&DeviceId> =
non_blacklisted_devices.iter().map(|d| d.device_id()).collect();
if let Some(shared) = outbound.shared_with_set.get(user_id) {
#[allow(clippy::map_clone)]
@ -389,9 +358,8 @@ impl GroupSessionManager {
//
// represents newly deleted or blacklisted devices. If this
// set is non-empty, we must rotate.
let newly_deleted_or_blacklisted = shared
.difference(&non_blacklisted_device_ids)
.collect::<HashSet<_>>();
let newly_deleted_or_blacklisted =
shared.difference(&non_blacklisted_device_ids).collect::<HashSet<_>>();
if !newly_deleted_or_blacklisted.is_empty() {
should_rotate = true;
@ -399,10 +367,7 @@ impl GroupSessionManager {
};
}
devices
.entry(user_id.clone())
.or_insert_with(Vec::new)
.extend(non_blacklisted_devices);
devices.entry(user_id.clone()).or_insert_with(Vec::new).extend(non_blacklisted_devices);
}
debug!(
@ -462,25 +427,22 @@ impl GroupSessionManager {
let history_visibility = encryption_settings.history_visibility.clone();
let mut changes = Changes::default();
let (outbound, inbound) = self
.get_or_create_outbound_session(room_id, encryption_settings.clone())
.await?;
let (outbound, inbound) =
self.get_or_create_outbound_session(room_id, encryption_settings.clone()).await?;
if let Some(inbound) = inbound {
changes.outbound_group_sessions.push(outbound.clone());
changes.inbound_group_sessions.push(inbound);
}
let (should_rotate, devices) = self
.collect_session_recipients(users, history_visibility, &outbound)
.await?;
let (should_rotate, devices) =
self.collect_session_recipients(users, history_visibility, &outbound).await?;
let outbound = if should_rotate {
let old_session_id = outbound.session_id();
let (outbound, inbound) = self
.create_outbound_group_session(room_id, encryption_settings)
.await?;
let (outbound, inbound) =
self.create_outbound_group_session(room_id, encryption_settings).await?;
changes.outbound_group_sessions.push(outbound.clone());
changes.inbound_group_sessions.push(inbound);
@ -515,9 +477,7 @@ impl GroupSessionManager {
if !devices.is_empty() {
let users = devices.iter().fold(BTreeMap::new(), |mut acc, d| {
acc.entry(d.user_id())
.or_insert_with(BTreeSet::new)
.insert(d.device_id());
acc.entry(d.user_id()).or_insert_with(BTreeSet::new).insert(d.device_id());
acc
});
@ -626,14 +586,8 @@ mod test {
let machine = OlmMachine::new(&alice_id(), &alice_device_id());
machine
.mark_request_as_sent(&uuid, &keys_query)
.await
.unwrap();
machine
.mark_request_as_sent(&uuid, &keys_claim)
.await
.unwrap();
machine.mark_request_as_sent(&uuid, &keys_query).await.unwrap();
machine.mark_request_as_sent(&uuid, &keys_claim).await.unwrap();
machine
}
@ -647,11 +601,7 @@ mod test {
let users: Vec<_> = keys_claim.one_time_keys.keys().collect();
let requests = machine
.share_group_session(
&room_id,
users.clone().into_iter(),
EncryptionSettings::default(),
)
.share_group_session(&room_id, users.clone().into_iter(), EncryptionSettings::default())
.await
.unwrap();

View file

@ -77,11 +77,7 @@ impl SessionManager {
}
pub async fn mark_device_as_wedged(&self, sender: &UserId, curve_key: &str) -> StoreResult<()> {
if let Some(device) = self
.store
.get_device_from_curve_key(sender, curve_key)
.await?
{
if let Some(device) = self.store.get_device_from_curve_key(sender, curve_key).await? {
let sessions = device.get_sessions().await?;
if let Some(sessions) = sessions {
@ -120,25 +116,16 @@ impl SessionManager {
///
/// If the device was wedged this will queue up a dummy to-device message.
async fn check_if_unwedged(&self, user_id: &UserId, device_id: &DeviceId) -> OlmResult<()> {
if self
.wedged_devices
.get(user_id)
.map(|d| d.remove(device_id))
.flatten()
.is_some()
{
if self.wedged_devices.get(user_id).map(|d| d.remove(device_id)).flatten().is_some() {
if let Some(device) = self.store.get_device(user_id, device_id).await? {
let (_, content) = device.encrypt(EventType::Dummy, json!({})).await?;
let id = Uuid::new_v4();
let mut messages = BTreeMap::new();
messages
.entry(device.user_id().to_owned())
.or_insert_with(BTreeMap::new)
.insert(
DeviceIdOrAllDevices::DeviceId(device.device_id().into()),
to_raw_value(&content)?,
);
messages.entry(device.user_id().to_owned()).or_insert_with(BTreeMap::new).insert(
DeviceIdOrAllDevices::DeviceId(device.device_id().into()),
to_raw_value(&content)?,
);
let request = OutgoingRequest {
request_id: id,
@ -307,13 +294,13 @@ impl SessionManager {
#[cfg(test)]
mod test {
use dashmap::DashMap;
use matrix_sdk_common::locks::Mutex;
use std::{collections::BTreeMap, sync::Arc};
use dashmap::DashMap;
use matrix_sdk_common::{
api::r0::keys::claim_keys::Response as KeyClaimResponse,
identifiers::{user_id, DeviceIdBox, UserId},
locks::Mutex,
};
use matrix_sdk_test::async_test;
@ -347,9 +334,7 @@ mod test {
let account = ReadOnlyAccount::new(&user_id, &device_id);
let store: Arc<Box<dyn CryptoStore>> = Arc::new(Box::new(MemoryStore::new()));
store.save_account(account.clone()).await.unwrap();
let identity = Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty(
user_id.clone(),
)));
let identity = Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty(user_id.clone())));
let verification =
VerificationMachine::new(account.clone(), identity.clone(), store.clone());
@ -358,10 +343,7 @@ mod test {
let store = Store::new(user_id.clone(), identity, store, verification);
let account = Account {
inner: account,
store: store.clone(),
};
let account = Account { inner: account, store: store.clone() };
let session_cache = GroupSessionCache::new(store.clone());
@ -405,10 +387,7 @@ mod test {
let response = KeyClaimResponse::new(one_time_keys);
manager
.receive_keys_claim_response(&response)
.await
.unwrap();
manager.receive_keys_claim_response(&response).await.unwrap();
assert!(manager
.get_missing_sessions(&mut [bob.user_id().clone()].iter())
@ -434,11 +413,7 @@ mod test {
let bob_device = ReadOnlyDevice::from_account(&bob).await;
session.creation_time = Arc::new(Instant::now() - Duration::from_secs(3601));
manager
.store
.save_devices(&[bob_device.clone()])
.await
.unwrap();
manager.store.save_devices(&[bob_device.clone()]).await.unwrap();
manager.store.save_sessions(&[session]).await.unwrap();
assert!(manager
@ -451,10 +426,7 @@ mod test {
assert!(!manager.users_for_key_claim.contains_key(bob.user_id()));
assert!(!manager.is_device_wedged(&bob_device));
manager
.mark_device_as_wedged(bob_device.user_id(), &curve_key)
.await
.unwrap();
manager.mark_device_as_wedged(bob_device.user_id(), &curve_key).await.unwrap();
assert!(manager.is_device_wedged(&bob_device));
assert!(manager.users_for_key_claim.contains_key(bob.user_id()));
@ -480,10 +452,7 @@ mod test {
assert!(manager.outgoing_to_device_requests.is_empty());
manager
.receive_keys_claim_response(&response)
.await
.unwrap();
manager.receive_keys_claim_response(&response).await.unwrap();
assert!(!manager.is_device_wedged(&bob_device));
assert!(manager

View file

@ -39,9 +39,7 @@ pub struct SessionStore {
impl SessionStore {
/// Create a new empty Session store.
pub fn new() -> Self {
SessionStore {
entries: Arc::new(DashMap::new()),
}
SessionStore { entries: Arc::new(DashMap::new()) }
}
/// Add a session to the store.
@ -72,8 +70,7 @@ impl SessionStore {
/// Add a list of sessions belonging to the sender key.
pub fn set_for_sender(&self, sender_key: &str, sessions: Vec<Session>) {
self.entries
.insert(sender_key.to_owned(), Arc::new(Mutex::new(sessions)));
self.entries.insert(sender_key.to_owned(), Arc::new(Mutex::new(sessions)));
}
}
@ -87,9 +84,7 @@ pub struct GroupSessionStore {
impl GroupSessionStore {
/// Create a new empty store.
pub fn new() -> Self {
GroupSessionStore {
entries: Arc::new(DashMap::new()),
}
GroupSessionStore { entries: Arc::new(DashMap::new()) }
}
/// Add an inbound group session to the store.
@ -148,9 +143,7 @@ pub struct DeviceStore {
impl DeviceStore {
/// Create a new empty device store.
pub fn new() -> Self {
DeviceStore {
entries: Arc::new(DashMap::new()),
}
DeviceStore { entries: Arc::new(DashMap::new()) }
}
/// Add a device to the store.
@ -167,19 +160,15 @@ impl DeviceStore {
/// Get the device with the given device_id and belonging to the given user.
pub fn get(&self, user_id: &UserId, device_id: &DeviceId) -> Option<ReadOnlyDevice> {
self.entries
.get(user_id)
.and_then(|m| m.get(device_id).map(|d| d.value().clone()))
self.entries.get(user_id).and_then(|m| m.get(device_id).map(|d| d.value().clone()))
}
/// Remove the device with the given device_id and belonging to the given user.
/// Remove the device with the given device_id and belonging to the given
/// user.
///
/// Returns the device if it was removed, None if it wasn't in the store.
pub fn remove(&self, user_id: &UserId, device_id: &DeviceId) -> Option<ReadOnlyDevice> {
self.entries
.get(user_id)
.and_then(|m| m.remove(device_id))
.map(|(_, d)| d)
self.entries.get(user_id).and_then(|m| m.remove(device_id)).map(|(_, d)| d)
}
/// Get a read-only view over all devices of the given user.
@ -195,12 +184,13 @@ impl DeviceStore {
#[cfg(test)]
mod test {
use matrix_sdk_common::identifiers::room_id;
use crate::{
identities::device::test::get_device,
olm::{test::get_account_and_session, InboundGroupSession},
store::caches::{DeviceStore, GroupSessionStore, SessionStore},
};
use matrix_sdk_common::identifiers::room_id;
#[tokio::test]
async fn test_session_store() {
@ -239,10 +229,8 @@ mod test {
let (account, _) = get_account_and_session().await;
let room_id = room_id!("!test:localhost");
let (outbound, _) = account
.create_group_session_pair_with_defaults(&room_id)
.await
.unwrap();
let (outbound, _) =
account.create_group_session_pair_with_defaults(&room_id).await.unwrap();
assert_eq!(0, outbound.message_index().await);
assert!(!outbound.shared());
@ -261,9 +249,7 @@ mod test {
let store = GroupSessionStore::new();
store.add(inbound.clone());
let loaded_session = store
.get(&room_id, "test_key", outbound.session_id())
.unwrap();
let loaded_session = store.get(&room_id, "test_key", outbound.session_id()).unwrap();
assert_eq!(inbound, loaded_session);
}

View file

@ -37,10 +37,7 @@ use crate::{
};
fn encode_key_info(info: &RequestedKeyInfo) -> String {
format!(
"{}{}{}{}",
info.room_id, info.sender_key, info.algorithm, info.session_id
)
format!("{}{}{}{}", info.room_id, info.sender_key, info.algorithm, info.session_id)
}
/// An in-memory only store that will forget all the E2EE key once it's dropped.
@ -121,22 +118,14 @@ impl CryptoStore for MemoryStore {
async fn save_changes(&self, mut changes: Changes) -> Result<()> {
self.save_sessions(changes.sessions).await;
self.save_inbound_group_sessions(changes.inbound_group_sessions)
.await;
self.save_inbound_group_sessions(changes.inbound_group_sessions).await;
self.save_devices(changes.devices.new).await;
self.save_devices(changes.devices.changed).await;
self.delete_devices(changes.devices.deleted).await;
for identity in changes
.identities
.new
.drain(..)
.chain(changes.identities.changed)
{
let _ = self
.identities
.insert(identity.user_id().to_owned(), identity.clone());
for identity in changes.identities.new.drain(..).chain(changes.identities.changed) {
let _ = self.identities.insert(identity.user_id().to_owned(), identity.clone());
}
for hash in changes.message_hashes {
@ -167,9 +156,7 @@ impl CryptoStore for MemoryStore {
sender_key: &str,
session_id: &str,
) -> Result<Option<InboundGroupSession>> {
Ok(self
.inbound_group_sessions
.get(room_id, sender_key, session_id))
Ok(self.inbound_group_sessions.get(room_id, sender_key, session_id))
}
async fn get_inbound_group_sessions(&self) -> Result<Vec<InboundGroupSession>> {
@ -250,10 +237,7 @@ impl CryptoStore for MemoryStore {
&self,
request_id: Uuid,
) -> Result<Option<OutgoingKeyRequest>> {
Ok(self
.outgoing_key_requests
.get(&request_id)
.map(|r| r.clone()))
Ok(self.outgoing_key_requests.get(&request_id).map(|r| r.clone()))
}
async fn get_key_request_by_info(
@ -278,12 +262,10 @@ impl CryptoStore for MemoryStore {
}
async fn delete_outgoing_key_request(&self, request_id: Uuid) -> Result<()> {
self.outgoing_key_requests
.remove(&request_id)
.and_then(|(_, i)| {
let key_info_string = encode_key_info(&i.info);
self.key_requests_by_info.remove(&key_info_string)
});
self.outgoing_key_requests.remove(&request_id).and_then(|(_, i)| {
let key_info_string = encode_key_info(&i.info);
self.key_requests_by_info.remove(&key_info_string)
});
Ok(())
}
@ -291,12 +273,13 @@ impl CryptoStore for MemoryStore {
#[cfg(test)]
mod test {
use matrix_sdk_common::identifiers::room_id;
use crate::{
identities::device::test::get_device,
olm::{test::get_account_and_session, InboundGroupSession, OlmMessageHash},
store::{memorystore::MemoryStore, Changes, CryptoStore},
};
use matrix_sdk_common::identifiers::room_id;
#[tokio::test]
async fn test_session_store() {
@ -308,11 +291,7 @@ mod test {
store.save_sessions(vec![session.clone()]).await;
let sessions = store
.get_sessions(&session.sender_key)
.await
.unwrap()
.unwrap();
let sessions = store.get_sessions(&session.sender_key).await.unwrap().unwrap();
let sessions = sessions.lock().await;
let loaded_session = &sessions[0];
@ -325,10 +304,8 @@ mod test {
let (account, _) = get_account_and_session().await;
let room_id = room_id!("!test:localhost");
let (outbound, _) = account
.create_group_session_pair_with_defaults(&room_id)
.await
.unwrap();
let (outbound, _) =
account.create_group_session_pair_with_defaults(&room_id).await.unwrap();
let inbound = InboundGroupSession::new(
"test_key",
"test_key",
@ -339,9 +316,7 @@ mod test {
.unwrap();
let store = MemoryStore::new();
let _ = store
.save_inbound_group_sessions(vec![inbound.clone()])
.await;
let _ = store.save_inbound_group_sessions(vec![inbound.clone()]).await;
let loaded_session = store
.get_inbound_group_session(&room_id, "test_key", outbound.session_id())
@ -358,11 +333,8 @@ mod test {
store.save_devices(vec![device.clone()]).await;
let loaded_device = store
.get_device(device.user_id(), device.device_id())
.await
.unwrap()
.unwrap();
let loaded_device =
store.get_device(device.user_id(), device.device_id()).await.unwrap().unwrap();
assert_eq!(device, loaded_device);
@ -376,11 +348,7 @@ mod test {
assert_eq!(&device, loaded_device);
store.delete_devices(vec![device.clone()]).await;
assert!(store
.get_device(device.user_id(), device.device_id())
.await
.unwrap()
.is_none());
assert!(store.get_device(device.user_id(), device.device_id()).await.unwrap().is_none());
}
#[tokio::test]
@ -388,14 +356,8 @@ mod test {
let device = get_device();
let store = MemoryStore::new();
assert!(store
.update_tracked_user(device.user_id(), false)
.await
.unwrap());
assert!(!store
.update_tracked_user(device.user_id(), false)
.await
.unwrap());
assert!(store.update_tracked_user(device.user_id(), false).await.unwrap());
assert!(!store.update_tracked_user(device.user_id(), false).await.unwrap());
assert!(store.is_user_tracked(device.user_id()));
}
@ -404,10 +366,8 @@ mod test {
async fn test_message_hash() {
let store = MemoryStore::new();
let hash = OlmMessageHash {
sender_key: "test_sender".to_owned(),
hash: "test_hash".to_owned(),
};
let hash =
OlmMessageHash { sender_key: "test_sender".to_owned(), hash: "test_hash".to_owned() };
let mut changes = Changes::default();
changes.message_hashes.push(hash.clone());

View file

@ -43,11 +43,6 @@ mod pickle_key;
#[cfg(feature = "sled_cryptostore")]
pub(crate) mod sled;
#[cfg(feature = "sled_cryptostore")]
pub use self::sled::SledStore;
pub use memorystore::MemoryStore;
pub use pickle_key::{EncryptedPickleKey, PickleKey};
use std::{
collections::{HashMap, HashSet},
fmt::Debug,
@ -56,10 +51,6 @@ use std::{
sync::Arc,
};
use olm_rs::errors::{OlmAccountError, OlmGroupSessionError, OlmSessionError};
use serde_json::Error as SerdeError;
use thiserror::Error;
use matrix_sdk_common::{
async_trait,
events::room_key_request::RequestedKeyInfo,
@ -71,7 +62,14 @@ use matrix_sdk_common::{
uuid::Uuid,
AsyncTraitDeps,
};
pub use memorystore::MemoryStore;
use olm_rs::errors::{OlmAccountError, OlmGroupSessionError, OlmSessionError};
pub use pickle_key::{EncryptedPickleKey, PickleKey};
use serde_json::Error as SerdeError;
use thiserror::Error;
#[cfg(feature = "sled_cryptostore")]
pub use self::sled::SledStore;
use crate::{
error::SessionUnpicklingError,
identities::{Device, ReadOnlyDevice, UserDevices, UserIdentities},
@ -145,12 +143,7 @@ impl Store {
store: Arc<Box<dyn CryptoStore>>,
verification_machine: VerificationMachine,
) -> Self {
Self {
user_id,
identity,
inner: store,
verification_machine,
}
Self { user_id, identity, inner: store, verification_machine }
}
pub async fn get_readonly_device(
@ -162,10 +155,7 @@ impl Store {
}
pub async fn save_sessions(&self, sessions: &[Session]) -> Result<()> {
let changes = Changes {
sessions: sessions.to_vec(),
..Default::default()
};
let changes = Changes { sessions: sessions.to_vec(), ..Default::default() };
self.save_changes(changes).await
}
@ -173,10 +163,7 @@ impl Store {
#[cfg(test)]
pub async fn save_devices(&self, devices: &[ReadOnlyDevice]) -> Result<()> {
let changes = Changes {
devices: DeviceChanges {
changed: devices.to_vec(),
..Default::default()
},
devices: DeviceChanges { changed: devices.to_vec(), ..Default::default() },
..Default::default()
};
@ -188,10 +175,7 @@ impl Store {
&self,
sessions: &[InboundGroupSession],
) -> Result<()> {
let changes = Changes {
inbound_group_sessions: sessions.to_vec(),
..Default::default()
};
let changes = Changes { inbound_group_sessions: sessions.to_vec(), ..Default::default() };
self.save_changes(changes).await
}
@ -210,8 +194,7 @@ impl Store {
) -> Result<Option<Device>> {
self.get_user_devices(user_id).await.map(|d| {
d.devices().find(|d| {
d.get_key(DeviceKeyAlgorithm::Curve25519)
.map_or(false, |k| k == curve_key)
d.get_key(DeviceKeyAlgorithm::Curve25519).map_or(false, |k| k == curve_key)
})
})
}
@ -219,12 +202,8 @@ impl Store {
pub async fn get_user_devices(&self, user_id: &UserId) -> Result<UserDevices> {
let devices = self.inner.get_user_devices(user_id).await?;
let own_identity = self
.inner
.get_user_identity(&self.user_id)
.await?
.map(|i| i.own().cloned())
.flatten();
let own_identity =
self.inner.get_user_identity(&self.user_id).await?.map(|i| i.own().cloned()).flatten();
let device_owner_identity = self.inner.get_user_identity(user_id).await.ok().flatten();
Ok(UserDevices {
@ -241,24 +220,17 @@ impl Store {
user_id: &UserId,
device_id: &DeviceId,
) -> Result<Option<Device>> {
let own_identity = self
.get_user_identity(&self.user_id)
.await?
.map(|i| i.own().cloned())
.flatten();
let own_identity =
self.get_user_identity(&self.user_id).await?.map(|i| i.own().cloned()).flatten();
let device_owner_identity = self.get_user_identity(user_id).await?;
Ok(self
.inner
.get_device(user_id, device_id)
.await?
.map(|d| Device {
inner: d,
private_identity: self.identity.clone(),
verification_machine: self.verification_machine.clone(),
own_identity,
device_owner_identity,
}))
Ok(self.inner.get_device(user_id, device_id).await?.map(|d| Device {
inner: d,
private_identity: self.identity.clone(),
verification_machine: self.verification_machine.clone(),
own_identity,
device_owner_identity,
}))
}
}
@ -366,7 +338,8 @@ pub trait CryptoStore: AsyncTraitDeps {
/// Get all the inbound group sessions we have stored.
async fn get_inbound_group_sessions(&self) -> Result<Vec<InboundGroupSession>>;
/// Get the outobund group sessions we have stored that is used for the given room.
/// Get the outobund group sessions we have stored that is used for the
/// given room.
async fn get_outbound_group_sessions(
&self,
room_id: &RoomId,

View file

@ -22,11 +22,10 @@ use getrandom::getrandom;
use hmac::Hmac;
use olm_rs::PicklingMode;
use pbkdf2::pbkdf2;
use serde::{Deserialize, Serialize};
use sha2::Sha256;
use zeroize::{Zeroize, Zeroizing};
use serde::{Deserialize, Serialize};
const KEY_SIZE: usize = 32;
const NONCE_SIZE: usize = 12;
const KDF_SALT_SIZE: usize = 32;
@ -114,9 +113,7 @@ impl PickleKey {
/// Get a `PicklingMode` version of this pickle key.
pub fn pickle_mode(&self) -> PicklingMode {
PicklingMode::Encrypted {
key: self.aes256_key.clone(),
}
PicklingMode::Encrypted { key: self.aes256_key.clone() }
}
/// Get the raw AES256 key.
@ -142,10 +139,7 @@ impl PickleKey {
getrandom(&mut nonce).expect("Can't generate new random nonce for the pickle key");
let ciphertext = cipher
.encrypt(
&GenericArray::from_slice(nonce.as_ref()),
self.aes256_key.as_slice(),
)
.encrypt(&GenericArray::from_slice(nonce.as_ref()), self.aes256_key.as_slice())
.expect("Can't encrypt pickle key");
EncryptedPickleKey {
@ -181,9 +175,7 @@ impl PickleKey {
}
};
Ok(Self {
aes256_key: decrypted,
})
Ok(Self { aes256_key: decrypted })
}
}

View file

@ -20,13 +20,6 @@ use std::{
};
use dashmap::DashSet;
use olm_rs::{account::IdentityKeys, PicklingMode};
pub use sled::Error;
use sled::{
transaction::{ConflictableTransactionError, TransactionError},
Config, Db, Transactional, Tree,
};
use matrix_sdk_common::{
async_trait,
events::room_key_request::RequestedKeyInfo,
@ -34,6 +27,12 @@ use matrix_sdk_common::{
locks::Mutex,
uuid,
};
use olm_rs::{account::IdentityKeys, PicklingMode};
pub use sled::Error;
use sled::{
transaction::{ConflictableTransactionError, TransactionError},
Config, Db, Transactional, Tree,
};
use uuid::Uuid;
use super::{
@ -97,13 +96,7 @@ impl EncodeKey for &str {
impl EncodeKey for (&str, &str) {
fn encode(&self) -> Vec<u8> {
[
self.0.as_bytes(),
&[Self::SEPARATOR],
self.1.as_bytes(),
&[Self::SEPARATOR],
]
.concat()
[self.0.as_bytes(), &[Self::SEPARATOR], self.1.as_bytes(), &[Self::SEPARATOR]].concat()
}
}
@ -164,9 +157,7 @@ impl std::fmt::Debug for SledStore {
if let Some(path) = &self.path {
f.debug_struct("SledStore").field("path", &path).finish()
} else {
f.debug_struct("SledStore")
.field("path", &"memory store")
.finish()
f.debug_struct("SledStore").field("path", &"memory store").finish()
}
}
}
@ -253,9 +244,8 @@ impl SledStore {
}
fn get_or_create_pickle_key(passphrase: &str, database: &Db) -> Result<PickleKey> {
let key = if let Some(key) = database
.get("pickle_key".encode())?
.map(|v| serde_json::from_slice(&v))
let key = if let Some(key) =
database.get("pickle_key".encode())?.map(|v| serde_json::from_slice(&v))
{
PickleKey::from_encrypted(passphrase, key?)
.map_err(|_| CryptoStoreError::UnpicklingError)?
@ -297,9 +287,7 @@ impl SledStore {
&self,
room_id: &RoomId,
) -> Result<Option<OutboundGroupSession>> {
let account_info = self
.get_account_info()
.ok_or(CryptoStoreError::AccountUnset)?;
let account_info = self.get_account_info().ok_or(CryptoStoreError::AccountUnset)?;
self.outbound_group_sessions
.get(room_id.encode())?
@ -501,17 +489,11 @@ impl SledStore {
&self,
id: &[u8],
) -> Result<Option<OutgoingKeyRequest>> {
let request = self
.outgoing_key_requests
.get(id)?
.map(|r| serde_json::from_slice(&r))
.transpose()?;
let request =
self.outgoing_key_requests.get(id)?.map(|r| serde_json::from_slice(&r)).transpose()?;
let request = if request.is_none() {
self.unsent_key_requests
.get(id)?
.map(|r| serde_json::from_slice(&r))
.transpose()?
self.unsent_key_requests.get(id)?.map(|r| serde_json::from_slice(&r)).transpose()?
} else {
request
};
@ -553,10 +535,7 @@ impl CryptoStore for SledStore {
*self.account_info.write().unwrap() = Some(account_info);
let changes = Changes {
account: Some(account),
..Default::default()
};
let changes = Changes { account: Some(account), ..Default::default() };
self.save_changes(changes).await
}
@ -579,9 +558,7 @@ impl CryptoStore for SledStore {
}
async fn get_sessions(&self, sender_key: &str) -> Result<Option<Arc<Mutex<Vec<Session>>>>> {
let account_info = self
.get_account_info()
.ok_or(CryptoStoreError::AccountUnset)?;
let account_info = self.get_account_info().ok_or(CryptoStoreError::AccountUnset)?;
if self.session_cache.get(sender_key).is_none() {
let sessions: Result<Vec<Session>> = self
@ -613,16 +590,10 @@ impl CryptoStore for SledStore {
session_id: &str,
) -> Result<Option<InboundGroupSession>> {
let key = (room_id.as_str(), sender_key, session_id).encode();
let pickle = self
.inbound_group_sessions
.get(&key)?
.map(|p| serde_json::from_slice(&p));
let pickle = self.inbound_group_sessions.get(&key)?.map(|p| serde_json::from_slice(&p));
if let Some(pickle) = pickle {
Ok(Some(InboundGroupSession::from_pickle(
pickle?,
self.get_pickle_mode(),
)?))
Ok(Some(InboundGroupSession::from_pickle(pickle?, self.get_pickle_mode())?))
} else {
Ok(None)
}
@ -658,10 +629,7 @@ impl CryptoStore for SledStore {
fn users_for_key_query(&self) -> HashSet<UserId> {
#[allow(clippy::map_clone)]
self.users_for_key_query_cache
.iter()
.map(|u| u.clone())
.collect()
self.users_for_key_query_cache.iter().map(|u| u.clone()).collect()
}
async fn update_tracked_user(&self, user: &UserId, dirty: bool) -> Result<bool> {
@ -715,9 +683,7 @@ impl CryptoStore for SledStore {
}
async fn is_message_known(&self, message_hash: &crate::olm::OlmMessageHash) -> Result<bool> {
Ok(self
.olm_hashes
.contains_key(serde_json::to_vec(message_hash)?)?)
Ok(self.olm_hashes.contains_key(serde_json::to_vec(message_hash)?)?)
}
async fn get_outgoing_key_request(
@ -753,36 +719,33 @@ impl CryptoStore for SledStore {
}
async fn delete_outgoing_key_request(&self, request_id: Uuid) -> Result<()> {
let ret: Result<(), TransactionError<serde_json::Error>> = (
&self.outgoing_key_requests,
&self.unsent_key_requests,
&self.key_requests_by_info,
)
.transaction(
|(outgoing_key_requests, unsent_key_requests, key_requests_by_info)| {
let sent_request: Option<OutgoingKeyRequest> = outgoing_key_requests
.remove(request_id.encode())?
.map(|r| serde_json::from_slice(&r))
.transpose()
.map_err(ConflictableTransactionError::Abort)?;
let ret: Result<(), TransactionError<serde_json::Error>> =
(&self.outgoing_key_requests, &self.unsent_key_requests, &self.key_requests_by_info)
.transaction(
|(outgoing_key_requests, unsent_key_requests, key_requests_by_info)| {
let sent_request: Option<OutgoingKeyRequest> = outgoing_key_requests
.remove(request_id.encode())?
.map(|r| serde_json::from_slice(&r))
.transpose()
.map_err(ConflictableTransactionError::Abort)?;
let unsent_request: Option<OutgoingKeyRequest> = unsent_key_requests
.remove(request_id.encode())?
.map(|r| serde_json::from_slice(&r))
.transpose()
.map_err(ConflictableTransactionError::Abort)?;
let unsent_request: Option<OutgoingKeyRequest> = unsent_key_requests
.remove(request_id.encode())?
.map(|r| serde_json::from_slice(&r))
.transpose()
.map_err(ConflictableTransactionError::Abort)?;
if let Some(request) = sent_request {
key_requests_by_info.remove((&request.info).encode())?;
}
if let Some(request) = sent_request {
key_requests_by_info.remove((&request.info).encode())?;
}
if let Some(request) = unsent_request {
key_requests_by_info.remove((&request.info).encode())?;
}
if let Some(request) = unsent_request {
key_requests_by_info.remove((&request.info).encode())?;
}
Ok(())
},
);
Ok(())
},
);
ret?;
self.inner.flush_async().await?;
@ -793,6 +756,19 @@ impl CryptoStore for SledStore {
#[cfg(test)]
mod test {
use std::collections::BTreeMap;
use matrix_sdk_common::{
api::r0::keys::SignedKey,
events::room_key_request::RequestedKeyInfo,
identifiers::{room_id, user_id, DeviceId, EventEncryptionAlgorithm, UserId},
uuid::Uuid,
};
use matrix_sdk_test::async_test;
use olm_rs::outbound_group_session::OlmOutboundGroupSession;
use tempfile::tempdir;
use super::{CryptoStore, OutgoingKeyRequest, SledStore};
use crate::{
identities::{
device::test::get_device,
@ -804,18 +780,6 @@ mod test {
},
store::{Changes, DeviceChanges, IdentityChanges},
};
use matrix_sdk_common::{
api::r0::keys::SignedKey,
events::room_key_request::RequestedKeyInfo,
identifiers::{room_id, user_id, DeviceId, EventEncryptionAlgorithm, UserId},
uuid::Uuid,
};
use matrix_sdk_test::async_test;
use olm_rs::outbound_group_session::OlmOutboundGroupSession;
use std::collections::BTreeMap;
use tempfile::tempdir;
use super::{CryptoStore, OutgoingKeyRequest, SledStore};
fn alice_id() -> UserId {
user_id!("@alice:example.org")
@ -846,10 +810,7 @@ mod test {
async fn get_loaded_store() -> (ReadOnlyAccount, SledStore, tempfile::TempDir) {
let (store, dir) = get_store(None).await;
let account = get_account();
store
.save_account(account.clone())
.await
.expect("Can't save account");
store.save_account(account.clone()).await.expect("Can't save account");
(account, store, dir)
}
@ -863,21 +824,12 @@ mod test {
let bob = ReadOnlyAccount::new(&bob_id(), &bob_device_id());
bob.generate_one_time_keys_helper(1).await;
let one_time_key = bob
.one_time_keys()
.await
.curve25519()
.iter()
.next()
.unwrap()
.1
.to_owned();
let one_time_key =
bob.one_time_keys().await.curve25519().iter().next().unwrap().1.to_owned();
let one_time_key = SignedKey::new(one_time_key, BTreeMap::new());
let sender_key = bob.identity_keys().curve25519().to_owned();
let session = alice
.create_outbound_session_helper(&sender_key, &one_time_key)
.await
.unwrap();
let session =
alice.create_outbound_session_helper(&sender_key, &one_time_key).await.unwrap();
(alice, session)
}
@ -895,10 +847,7 @@ mod test {
assert!(store.load_account().await.unwrap().is_none());
let account = get_account();
store
.save_account(account)
.await
.expect("Can't save account");
store.save_account(account).await.expect("Can't save account");
}
#[async_test]
@ -906,10 +855,7 @@ mod test {
let (store, _dir) = get_store(None).await;
let account = get_account();
store
.save_account(account.clone())
.await
.expect("Can't save account");
store.save_account(account.clone()).await.expect("Can't save account");
let loaded_account = store.load_account().await.expect("Can't load account");
let loaded_account = loaded_account.unwrap();
@ -922,10 +868,7 @@ mod test {
let (store, _dir) = get_store(Some("secret_passphrase")).await;
let account = get_account();
store
.save_account(account.clone())
.await
.expect("Can't save account");
store.save_account(account.clone()).await.expect("Can't save account");
let loaded_account = store.load_account().await.expect("Can't load account");
let loaded_account = loaded_account.unwrap();
@ -938,50 +881,32 @@ mod test {
let (store, _dir) = get_store(None).await;
let account = get_account();
store
.save_account(account.clone())
.await
.expect("Can't save account");
store.save_account(account.clone()).await.expect("Can't save account");
account.mark_as_shared();
account.update_uploaded_key_count(50);
store
.save_account(account.clone())
.await
.expect("Can't save account");
store.save_account(account.clone()).await.expect("Can't save account");
let loaded_account = store.load_account().await.expect("Can't load account");
let loaded_account = loaded_account.unwrap();
assert_eq!(account, loaded_account);
assert_eq!(
account.uploaded_key_count(),
loaded_account.uploaded_key_count()
);
assert_eq!(account.uploaded_key_count(), loaded_account.uploaded_key_count());
}
#[async_test]
async fn load_sessions() {
let (store, _dir) = get_store(None).await;
let (account, session) = get_account_and_session().await;
store
.save_account(account.clone())
.await
.expect("Can't save account");
store.save_account(account.clone()).await.expect("Can't save account");
let changes = Changes {
sessions: vec![session.clone()],
..Default::default()
};
let changes = Changes { sessions: vec![session.clone()], ..Default::default() };
store.save_changes(changes).await.unwrap();
let sessions = store
.get_sessions(&session.sender_key)
.await
.expect("Can't load sessions")
.unwrap();
let sessions =
store.get_sessions(&session.sender_key).await.expect("Can't load sessions").unwrap();
let loaded_session = sessions.lock().await.get(0).cloned().unwrap();
assert_eq!(&session, &loaded_session);
@ -994,15 +919,9 @@ mod test {
let sender_key = session.sender_key.to_owned();
let session_id = session.session_id().to_owned();
store
.save_account(account.clone())
.await
.expect("Can't save account");
store.save_account(account.clone()).await.expect("Can't save account");
let changes = Changes {
sessions: vec![session.clone()],
..Default::default()
};
let changes = Changes { sessions: vec![session.clone()], ..Default::default() };
store.save_changes(changes).await.unwrap();
let sessions = store.get_sessions(&sender_key).await.unwrap().unwrap();
@ -1040,15 +959,9 @@ mod test {
)
.expect("Can't create session");
let changes = Changes {
inbound_group_sessions: vec![session],
..Default::default()
};
let changes = Changes { inbound_group_sessions: vec![session], ..Default::default() };
store
.save_changes(changes)
.await
.expect("Can't save group session");
store.save_changes(changes).await.expect("Can't save group session");
}
#[async_test]
@ -1072,15 +985,10 @@ mod test {
let session = InboundGroupSession::from_export(export).unwrap();
let changes = Changes {
inbound_group_sessions: vec![session.clone()],
..Default::default()
};
let changes =
Changes { inbound_group_sessions: vec![session.clone()], ..Default::default() };
store
.save_changes(changes)
.await
.expect("Can't save group session");
store.save_changes(changes).await.expect("Can't save group session");
drop(store);
@ -1103,21 +1011,12 @@ mod test {
let (_account, store, dir) = get_loaded_store().await;
let device = get_device();
assert!(store
.update_tracked_user(device.user_id(), false)
.await
.unwrap());
assert!(!store
.update_tracked_user(device.user_id(), false)
.await
.unwrap());
assert!(store.update_tracked_user(device.user_id(), false).await.unwrap());
assert!(!store.update_tracked_user(device.user_id(), false).await.unwrap());
assert!(store.is_user_tracked(device.user_id()));
assert!(!store.users_for_key_query().contains(device.user_id()));
assert!(!store
.update_tracked_user(device.user_id(), true)
.await
.unwrap());
assert!(!store.update_tracked_user(device.user_id(), true).await.unwrap());
assert!(store.users_for_key_query().contains(device.user_id()));
drop(store);
@ -1128,10 +1027,7 @@ mod test {
assert!(store.is_user_tracked(device.user_id()));
assert!(store.users_for_key_query().contains(device.user_id()));
store
.update_tracked_user(device.user_id(), false)
.await
.unwrap();
store.update_tracked_user(device.user_id(), false).await.unwrap();
assert!(!store.users_for_key_query().contains(device.user_id()));
drop(store);
@ -1148,10 +1044,7 @@ mod test {
let device = get_device();
let changes = Changes {
devices: DeviceChanges {
changed: vec![device.clone()],
..Default::default()
},
devices: DeviceChanges { changed: vec![device.clone()], ..Default::default() },
..Default::default()
};
@ -1163,11 +1056,8 @@ mod test {
store.load_account().await.unwrap();
let loaded_device = store
.get_device(device.user_id(), device.device_id())
.await
.unwrap()
.unwrap();
let loaded_device =
store.get_device(device.user_id(), device.device_id()).await.unwrap().unwrap();
assert_eq!(device, loaded_device);
@ -1188,20 +1078,14 @@ mod test {
let device = get_device();
let changes = Changes {
devices: DeviceChanges {
changed: vec![device.clone()],
..Default::default()
},
devices: DeviceChanges { changed: vec![device.clone()], ..Default::default() },
..Default::default()
};
store.save_changes(changes).await.unwrap();
let changes = Changes {
devices: DeviceChanges {
deleted: vec![device.clone()],
..Default::default()
},
devices: DeviceChanges { deleted: vec![device.clone()], ..Default::default() },
..Default::default()
};
@ -1212,10 +1096,7 @@ mod test {
store.load_account().await.unwrap();
let loaded_device = store
.get_device(device.user_id(), device.device_id())
.await
.unwrap();
let loaded_device = store.get_device(device.user_id(), device.device_id()).await.unwrap();
assert!(loaded_device.is_none());
}
@ -1232,10 +1113,7 @@ mod test {
let account = ReadOnlyAccount::new(&user_id, &device_id);
store
.save_account(account.clone())
.await
.expect("Can't save account");
store.save_account(account.clone()).await.expect("Can't save account");
let own_identity = get_own_identity();
@ -1247,10 +1125,7 @@ mod test {
..Default::default()
};
store
.save_changes(changes)
.await
.expect("Can't save identity");
store.save_changes(changes).await.expect("Can't save identity");
drop(store);
@ -1258,17 +1133,10 @@ mod test {
store.load_account().await.unwrap();
let loaded_user = store
.get_user_identity(own_identity.user_id())
.await
.unwrap()
.unwrap();
let loaded_user = store.get_user_identity(own_identity.user_id()).await.unwrap().unwrap();
assert_eq!(loaded_user.master_key(), own_identity.master_key());
assert_eq!(
loaded_user.self_signing_key(),
own_identity.self_signing_key()
);
assert_eq!(loaded_user.self_signing_key(), own_identity.self_signing_key());
assert_eq!(loaded_user, own_identity.clone().into());
let other_identity = get_other_identity();
@ -1283,17 +1151,10 @@ mod test {
store.save_changes(changes).await.unwrap();
let loaded_user = store
.get_user_identity(other_identity.user_id())
.await
.unwrap()
.unwrap();
let loaded_user = store.get_user_identity(other_identity.user_id()).await.unwrap().unwrap();
assert_eq!(loaded_user.master_key(), other_identity.master_key());
assert_eq!(
loaded_user.self_signing_key(),
other_identity.self_signing_key()
);
assert_eq!(loaded_user.self_signing_key(), other_identity.self_signing_key());
assert_eq!(loaded_user, other_identity.into());
own_identity.mark_as_verified();
@ -1317,10 +1178,7 @@ mod test {
assert!(store.load_identity().await.unwrap().is_none());
let identity = PrivateCrossSigningIdentity::new(alice_id()).await;
let changes = Changes {
private_identity: Some(identity.clone()),
..Default::default()
};
let changes = Changes { private_identity: Some(identity.clone()), ..Default::default() };
store.save_changes(changes).await.unwrap();
let loaded_identity = store.load_identity().await.unwrap().unwrap();
@ -1331,10 +1189,8 @@ mod test {
async fn olm_hash_saving() {
let (_, store, _dir) = get_loaded_store().await;
let hash = OlmMessageHash {
sender_key: "test_sender".to_owned(),
hash: "test_hash".to_owned(),
};
let hash =
OlmMessageHash { sender_key: "test_sender".to_owned(), hash: "test_hash".to_owned() };
let mut changes = Changes::default();
changes.message_hashes.push(hash.clone());

View file

@ -15,9 +15,6 @@
use std::{convert::TryFrom, sync::Arc};
use dashmap::DashMap;
use tracing::{info, trace, warn};
use matrix_sdk_common::{
events::{
room::message::MessageType, AnyMessageEvent, AnySyncMessageEvent, AnySyncRoomEvent,
@ -27,12 +24,12 @@ use matrix_sdk_common::{
locks::Mutex,
uuid::Uuid,
};
use tracing::{info, trace, warn};
use super::{
requests::VerificationRequest,
sas::{content_to_request, OutgoingContent, Sas, VerificationResult},
};
use crate::{
olm::PrivateCrossSigningIdentity,
requests::OutgoingRequest,
@ -85,18 +82,14 @@ impl VerificationMachine {
);
let request = match content.into() {
OutgoingContent::Room(r, c) => RoomMessageRequest {
room_id: r,
txn_id: Uuid::new_v4(),
content: c,
OutgoingContent::Room(r, c) => {
RoomMessageRequest { room_id: r, txn_id: Uuid::new_v4(), content: c }.into()
}
.into(),
OutgoingContent::ToDevice(c) => {
let request =
content_to_request(device.user_id(), device.device_id().to_owned(), c);
self.verifications
.insert(sas.flow_id().as_str().to_owned(), sas.clone());
self.verifications.insert(sas.flow_id().as_str().to_owned(), sas.clone());
request.into()
}
@ -136,10 +129,7 @@ impl VerificationMachine {
let request = content_to_request(recipient, recipient_device.to_owned(), c);
let request_id = request.txn_id;
let request = OutgoingRequest {
request_id,
request: Arc::new(request.into()),
};
let request = OutgoingRequest { request_id, request: Arc::new(request.into()) };
self.outgoing_messages.insert(request_id, request);
}
@ -149,12 +139,7 @@ impl VerificationMachine {
let request = OutgoingRequest {
request: Arc::new(
RoomMessageRequest {
room_id: r,
txn_id: request_id,
content: c,
}
.into(),
RoomMessageRequest { room_id: r, txn_id: request_id, content: c }.into(),
),
request_id,
};
@ -181,24 +166,17 @@ impl VerificationMachine {
}
pub fn outgoing_messages(&self) -> Vec<OutgoingRequest> {
self.outgoing_messages
.iter()
.map(|r| (*r).clone())
.collect()
self.outgoing_messages.iter().map(|r| (*r).clone()).collect()
}
pub fn garbage_collect(&self) {
self.verifications
.retain(|_, s| !(s.is_done() || s.is_canceled()));
self.verifications.retain(|_, s| !(s.is_done() || s.is_canceled()));
for sas in self.verifications.iter() {
if let Some(r) = sas.cancel_if_timed_out() {
self.outgoing_messages.insert(
r.request_id(),
OutgoingRequest {
request_id: r.request_id(),
request: Arc::new(r.into()),
},
OutgoingRequest { request_id: r.request_id(), request: Arc::new(r.into()) },
);
}
}
@ -239,8 +217,7 @@ impl VerificationMachine {
r,
);
self.requests
.insert(request.flow_id().as_str().to_owned(), request);
self.requests.insert(request.flow_id().as_str().to_owned(), request);
}
}
}
@ -261,10 +238,8 @@ impl VerificationMachine {
if let Some((_, request)) =
self.requests.remove(e.content.relation.event_id.as_str())
{
if let Some(d) = self
.store
.get_device(&e.sender, &e.content.from_device)
.await?
if let Some(d) =
self.store.get_device(&e.sender, &e.content.from_device).await?
{
match request.into_started_sas(
e,
@ -370,8 +345,7 @@ impl VerificationMachine {
&e.content,
);
self.requests
.insert(request.flow_id().as_str().to_string(), request);
self.requests.insert(request.flow_id().as_str().to_string(), request);
}
AnyToDeviceEvent::KeyVerificationReady(e) => {
if let Some(request) = self.requests.get(&e.content.transaction_id) {
@ -388,11 +362,7 @@ impl VerificationMachine {
e.content.from_device
);
if let Some(d) = self
.store
.get_device(&e.sender, &e.content.from_device)
.await?
{
if let Some(d) = self.store.get_device(&e.sender, &e.content.from_device).await? {
let private_identity = self.private_identity.lock().await.clone();
match Sas::from_start_event(
self.account.clone(),
@ -403,8 +373,7 @@ impl VerificationMachine {
self.store.get_user_identity(&e.sender).await?,
) {
Ok(s) => {
self.verifications
.insert(e.content.transaction_id.clone(), s);
self.verifications.insert(e.content.transaction_id.clone(), s);
}
Err(c) => {
warn!(
@ -455,10 +424,7 @@ impl VerificationMachine {
self.outgoing_messages.insert(
request_id,
OutgoingRequest {
request_id,
request: Arc::new(r.into()),
},
OutgoingRequest { request_id, request: Arc::new(r.into()) },
);
}
}
@ -535,10 +501,7 @@ mod test {
);
machine
.receive_event(&wrap_any_to_device_content(
bob_sas.user_id(),
start_content.into(),
))
.receive_event(&wrap_any_to_device_content(bob_sas.user_id(), start_content.into()))
.await
.unwrap();

View file

@ -17,11 +17,10 @@ mod requests;
mod sas;
pub use machine::VerificationMachine;
use matrix_sdk_common::identifiers::{EventId, RoomId};
pub use requests::VerificationRequest;
pub use sas::{AcceptSettings, Sas, VerificationResult};
use matrix_sdk_common::identifiers::{EventId, RoomId};
#[derive(Clone, Debug, Hash, PartialEq, PartialOrd)]
pub enum FlowId {
ToDevice(String),
@ -59,18 +58,17 @@ impl From<(RoomId, EventId)> for FlowId {
#[cfg(test)]
pub(crate) mod test {
use crate::{
requests::{OutgoingRequest, OutgoingRequests},
OutgoingVerificationRequest,
};
use serde_json::Value;
use matrix_sdk_common::{
events::{AnyToDeviceEvent, AnyToDeviceEventContent, EventType, ToDeviceEvent},
identifiers::UserId,
};
use serde_json::Value;
use super::sas::OutgoingContent;
use crate::{
requests::{OutgoingRequest, OutgoingRequests},
OutgoingVerificationRequest,
};
pub(crate) fn request_to_event(
sender: &UserId,
@ -94,11 +92,7 @@ pub(crate) mod test {
sender: &UserId,
content: OutgoingContent,
) -> AnyToDeviceEvent {
let content = if let OutgoingContent::ToDevice(c) = content {
c
} else {
unreachable!()
};
let content = if let OutgoingContent::ToDevice(c) = content { c } else { unreachable!() };
match content {
AnyToDeviceEventContent::KeyVerificationKey(c) => {
@ -133,22 +127,11 @@ pub(crate) mod test {
pub(crate) fn get_content_from_request(
request: &OutgoingVerificationRequest,
) -> OutgoingContent {
let request = if let OutgoingVerificationRequest::ToDevice(r) = request {
r
} else {
unreachable!()
};
let request =
if let OutgoingVerificationRequest::ToDevice(r) = request { r } else { unreachable!() };
let json: Value = serde_json::from_str(
request
.messages
.values()
.next()
.unwrap()
.values()
.next()
.unwrap()
.get(),
request.messages.values().next().unwrap().values().next().unwrap().get(),
)
.unwrap();

View file

@ -35,6 +35,10 @@ use matrix_sdk_common::{
uuid::Uuid,
};
use super::{
sas::{content_to_request, OutgoingContent, StartContent},
FlowId,
};
use crate::{
olm::{PrivateCrossSigningIdentity, ReadOnlyAccount},
store::CryptoStore,
@ -42,11 +46,6 @@ use crate::{
UserIdentities,
};
use super::{
sas::{content_to_request, OutgoingContent, StartContent},
FlowId,
};
const SUPPORTED_METHODS: &[VerificationMethod] = &[VerificationMethod::MSasV1];
pub enum RequestContent<'a> {
@ -256,15 +255,13 @@ impl VerificationRequest {
content: RequestContent,
) -> Self {
Self {
inner: Arc::new(Mutex::new(InnerRequest::Requested(
RequestState::from_request_event(
account.user_id(),
account.device_id(),
sender,
&flow_id,
content,
),
))),
inner: Arc::new(Mutex::new(InnerRequest::Requested(RequestState::from_request_event(
account.user_id(),
account.device_id(),
sender,
&flow_id,
content,
)))),
account,
other_user_id: sender.clone().into(),
private_cross_signing_identity,
@ -278,15 +275,12 @@ impl VerificationRequest {
let mut inner = self.inner.lock().unwrap();
inner.accept().map(|c| match c {
OutgoingContent::ToDevice(content) => self
.content_to_request(inner.other_device_id(), content)
.into(),
OutgoingContent::Room(room_id, content) => RoomMessageRequest {
room_id,
txn_id: Uuid::new_v4(),
content,
OutgoingContent::ToDevice(content) => {
self.content_to_request(inner.other_device_id(), content).into()
}
OutgoingContent::Room(room_id, content) => {
RoomMessageRequest { room_id, txn_id: Uuid::new_v4(), content }.into()
}
.into(),
})
}
@ -445,10 +439,7 @@ impl RequestState<Created> {
own_user_id: own_user_id.to_owned(),
own_device_id: own_device_id.to_owned(),
other_user_id: other_user_id.to_owned(),
state: Created {
methods: SUPPORTED_METHODS.to_vec(),
flow_id: flow_id.to_owned(),
},
state: Created { methods: SUPPORTED_METHODS.to_vec(), flow_id: flow_id.to_owned() },
}
}
@ -589,14 +580,9 @@ impl RequestState<Ready> {
other_identity: Option<UserIdentities>,
) -> (Sas, StartContent) {
match self.state.flow_id {
FlowId::ToDevice(t) => Sas::start(
account,
private_identity,
other_device,
store,
other_identity,
Some(t),
),
FlowId::ToDevice(t) => {
Sas::start(account, private_identity, other_device, store, other_identity, Some(t))
}
FlowId::InRoom(r, e) => Sas::start_in_room(
e,
r,
@ -630,6 +616,7 @@ mod test {
};
use matrix_sdk_test::async_test;
use super::VerificationRequest;
use crate::{
olm::{PrivateCrossSigningIdentity, ReadOnlyAccount},
store::{CryptoStore, MemoryStore},
@ -640,8 +627,6 @@ mod test {
ReadOnlyDevice,
};
use super::VerificationRequest;
fn alice_id() -> UserId {
UserId::try_from("@alice:example.org").unwrap()
}
@ -760,9 +745,7 @@ mod test {
panic!("Invalid start event content type");
};
let alice_sas = alice_request
.into_started_sas(&event, bob_device, None)
.unwrap();
let alice_sas = alice_request.into_started_sas(&event, bob_device, None).unwrap();
assert!(!bob_sas.is_canceled());
assert!(!alice_sas.is_canceled());

View file

@ -59,10 +59,7 @@ impl StartContent {
StartContent::Room(_, c) => serde_json::to_value(c),
};
content
.expect("Can't serialize content")
.try_into()
.expect("Can't canonicalize content")
content.expect("Can't serialize content").try_into().expect("Can't canonicalize content")
}
}
@ -287,14 +284,7 @@ impl From<OutgoingVerificationRequest> for OutgoingContent {
match request {
OutgoingVerificationRequest::ToDevice(r) => {
let json: Value = serde_json::from_str(
r.messages
.values()
.next()
.unwrap()
.values()
.next()
.unwrap()
.get(),
r.messages.values().next().unwrap().values().next().unwrap().get(),
)
.unwrap();

View file

@ -14,11 +14,6 @@
use std::{collections::BTreeMap, convert::TryInto};
use sha2::{Digest, Sha256};
use tracing::{trace, warn};
use olm_rs::sas::OlmSas;
use matrix_sdk_common::{
api::r0::to_device::DeviceIdOrAllDevices,
events::{
@ -32,17 +27,19 @@ use matrix_sdk_common::{
identifiers::{DeviceKeyAlgorithm, DeviceKeyId, UserId},
uuid::Uuid,
};
use crate::{
identities::{ReadOnlyDevice, UserIdentities},
utilities::encode,
ReadOnlyAccount, ToDeviceRequest,
};
use olm_rs::sas::OlmSas;
use sha2::{Digest, Sha256};
use tracing::{trace, warn};
use super::{
event_enums::{MacContent, StartContent},
FlowId,
};
use crate::{
identities::{ReadOnlyDevice, UserIdentities},
utilities::encode,
ReadOnlyAccount, ToDeviceRequest,
};
#[derive(Clone, Debug)]
pub struct SasIds {
@ -65,12 +62,7 @@ pub fn calculate_commitment(public_key: &str, content: impl Into<StartContent>)
let content = content.into().canonical_json();
let content_string = content.to_string();
encode(
Sha256::new()
.chain(&public_key)
.chain(&content_string)
.finalize(),
)
encode(Sha256::new().chain(&public_key).chain(&content_string).finalize())
}
/// Get a tuple of an emoji and a description of the emoji using a number.
@ -234,11 +226,7 @@ pub fn receive_mac_event(
.calculate_mac(key, &format!("{}{}", info, key_id))
.expect("Can't calculate SAS MAC")
{
trace!(
"Successfully verified the device key {} from {}",
key_id,
sender
);
trace!("Successfully verified the device key {} from {}", key_id, sender);
verified_devices.push(ids.other_device.clone());
} else {
@ -253,11 +241,7 @@ pub fn receive_mac_event(
.calculate_mac(key, &format!("{}{}", info, key_id))
.expect("Can't calculate SAS MAC")
{
trace!(
"Successfully verified the master key {} from {}",
key_id,
sender
);
trace!("Successfully verified the master key {} from {}", key_id, sender);
verified_identities.push(identity.clone())
} else {
return Err(CancelCode::KeyMismatch);
@ -319,8 +303,7 @@ pub fn get_mac_content(sas: &OlmSas, ids: &SasIds, flow_id: &FlowId) -> MacConte
mac.insert(
key_id.to_string(),
sas.calculate_mac(key, &format!("{}{}", info, key_id))
.expect("Can't calculate SAS MAC"),
sas.calculate_mac(key, &format!("{}{}", info, key_id)).expect("Can't calculate SAS MAC"),
);
// TODO Add the cross signing master key here if we trust/have it.
@ -332,23 +315,13 @@ pub fn get_mac_content(sas: &OlmSas, ids: &SasIds, flow_id: &FlowId) -> MacConte
.expect("Can't calculate SAS MAC");
match flow_id {
FlowId::ToDevice(s) => MacToDeviceEventContent {
transaction_id: s.to_string(),
keys,
mac,
FlowId::ToDevice(s) => {
MacToDeviceEventContent { transaction_id: s.to_string(), keys, mac }.into()
}
FlowId::InRoom(r, e) => {
(r.clone(), MacEventContent { mac, keys, relation: Relation { event_id: e.clone() } })
.into()
}
.into(),
FlowId::InRoom(r, e) => (
r.clone(),
MacEventContent {
mac,
keys,
relation: Relation {
event_id: e.clone(),
},
},
)
.into(),
}
}
@ -369,24 +342,12 @@ fn extra_info_sas(
flow_id: &str,
we_started: bool,
) -> String {
let our_info = format!(
"{}|{}|{}",
ids.account.user_id(),
ids.account.device_id(),
own_pubkey
);
let their_info = format!(
"{}|{}|{}",
ids.other_device.user_id(),
ids.other_device.device_id(),
their_pubkey
);
let our_info = format!("{}|{}|{}", ids.account.user_id(), ids.account.device_id(), own_pubkey);
let their_info =
format!("{}|{}|{}", ids.other_device.user_id(), ids.other_device.device_id(), their_pubkey);
let (first_info, second_info) = if we_started {
(our_info, their_info)
} else {
(their_info, our_info)
};
let (first_info, second_info) =
if we_started { (our_info, their_info) } else { (their_info, our_info) };
let info = format!(
"MATRIX_KEY_VERIFICATION_SAS|{first_info}|{second_info}|{flow_id}",
@ -585,11 +546,7 @@ pub fn content_to_request(
_ => unreachable!(),
};
ToDeviceRequest {
txn_id: Uuid::new_v4(),
event_type,
messages,
}
ToDeviceRequest { txn_id: Uuid::new_v4(), event_type, messages }
}
#[cfg(test)]
@ -627,18 +584,14 @@ mod test {
#[test]
fn emoji_generation() {
let bytes = vec![0, 0, 0, 0, 0, 0];
let index: Vec<(&'static str, &'static str)> = vec![0, 0, 0, 0, 0, 0, 0]
.into_iter()
.map(emoji_from_index)
.collect();
let index: Vec<(&'static str, &'static str)> =
vec![0, 0, 0, 0, 0, 0, 0].into_iter().map(emoji_from_index).collect();
assert_eq!(bytes_to_emoji(bytes), index.as_ref());
let bytes = vec![0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF];
let index: Vec<(&'static str, &'static str)> = vec![63, 63, 63, 63, 63, 63, 63]
.into_iter()
.map(emoji_from_index)
.collect();
let index: Vec<(&'static str, &'static str)> =
vec![63, 63, 63, 63, 63, 63, 63].into_iter().map(emoji_from_index).collect();
assert_eq!(bytes_to_emoji(bytes), index.as_ref());
}

View file

@ -12,11 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
#[cfg(test)]
use std::time::Instant;
use std::sync::Arc;
use matrix_sdk_common::{
events::{
key::verification::{cancel::CancelCode, ShortAuthenticationString},
@ -25,11 +24,6 @@ use matrix_sdk_common::{
identifiers::{EventId, RoomId},
};
use crate::{
identities::{ReadOnlyDevice, UserIdentities},
ReadOnlyAccount,
};
use super::{
event_enums::{AcceptContent, CancelContent, MacContent, OutgoingContent},
sas_state::{
@ -38,6 +32,10 @@ use super::{
},
FlowId, StartContent,
};
use crate::{
identities::{ReadOnlyDevice, UserIdentities},
ReadOnlyAccount,
};
#[derive(Clone, Debug)]
pub enum InnerSas {
@ -315,14 +313,15 @@ impl InnerSas {
_ => (self, None),
},
AnyToDeviceEvent::KeyVerificationMac(e) => match self {
InnerSas::KeyRecieved(s) => match s.into_mac_received(&e.sender, e.content.clone())
{
Ok(s) => (InnerSas::MacReceived(s), None),
Err(s) => {
let content = s.as_content();
(InnerSas::Canceled(s), Some(content.into()))
InnerSas::KeyRecieved(s) => {
match s.into_mac_received(&e.sender, e.content.clone()) {
Ok(s) => (InnerSas::MacReceived(s), None),
Err(s) => {
let content = s.as_content();
(InnerSas::Canceled(s), Some(content.into()))
}
}
},
}
InnerSas::Confirmed(s) => match s.into_done(&e.sender, e.content.clone()) {
Ok(s) => (InnerSas::Done(s), None),
Err(s) => {

View file

@ -17,13 +17,14 @@ mod helpers;
mod inner_sas;
mod sas_state;
use std::sync::{Arc, Mutex};
#[cfg(test)]
use std::time::Instant;
use event_enums::AcceptContent;
use std::sync::{Arc, Mutex};
use tracing::{error, info, trace, warn};
pub use event_enums::{CancelContent, OutgoingContent, StartContent};
pub use helpers::content_to_request;
use inner_sas::InnerSas;
use matrix_sdk_common::{
api::r0::keys::upload_signatures::Request as SignatureUploadRequest,
events::{
@ -37,7 +38,9 @@ use matrix_sdk_common::{
identifiers::{DeviceId, EventId, RoomId, UserId},
uuid::Uuid,
};
use tracing::{error, info, trace, warn};
use super::FlowId;
use crate::{
error::SignatureError,
identities::{LocalTrust, ReadOnlyDevice, UserIdentities},
@ -47,12 +50,6 @@ use crate::{
ReadOnlyAccount, ToDeviceRequest,
};
use super::FlowId;
pub use event_enums::{CancelContent, OutgoingContent, StartContent};
pub use helpers::content_to_request;
use inner_sas::InnerSas;
#[derive(Debug)]
/// A result of a verification flow.
#[allow(clippy::large_enum_variant)]
@ -275,22 +272,18 @@ impl Sas {
&self,
settings: AcceptSettings,
) -> Option<OutgoingVerificationRequest> {
self.inner
.lock()
.unwrap()
.accept()
.map(|c| match settings.apply(c) {
AcceptContent::ToDevice(c) => {
let content = AnyToDeviceEventContent::KeyVerificationAccept(c);
self.content_to_request(content).into()
}
AcceptContent::Room(room_id, content) => RoomMessageRequest {
room_id,
txn_id: Uuid::new_v4(),
content: AnyMessageEventContent::KeyVerificationAccept(content),
}
.into(),
})
self.inner.lock().unwrap().accept().map(|c| match settings.apply(c) {
AcceptContent::ToDevice(c) => {
let content = AnyToDeviceEventContent::KeyVerificationAccept(c);
self.content_to_request(content).into()
}
AcceptContent::Room(room_id, content) => RoomMessageRequest {
room_id,
txn_id: Uuid::new_v4(),
content: AnyMessageEventContent::KeyVerificationAccept(content),
}
.into(),
})
}
/// Confirm the Sas verification.
@ -303,10 +296,7 @@ impl Sas {
pub async fn confirm(
&self,
) -> Result<
(
Option<OutgoingVerificationRequest>,
Option<SignatureUploadRequest>,
),
(Option<OutgoingVerificationRequest>, Option<SignatureUploadRequest>),
CryptoStoreError,
> {
let (content, done) = {
@ -319,9 +309,9 @@ impl Sas {
};
let mac_request = content.map(|c| match c {
event_enums::MacContent::ToDevice(c) => self
.content_to_request(AnyToDeviceEventContent::KeyVerificationMac(c))
.into(),
event_enums::MacContent::ToDevice(c) => {
self.content_to_request(AnyToDeviceEventContent::KeyVerificationMac(c)).into()
}
event_enums::MacContent::Room(r, c) => RoomMessageRequest {
room_id: r,
txn_id: Uuid::new_v4(),
@ -374,10 +364,7 @@ impl Sas {
};
let mut changes = Changes {
devices: DeviceChanges {
changed: vec![device],
..Default::default()
},
devices: DeviceChanges { changed: vec![device], ..Default::default() },
..Default::default()
};
@ -437,10 +424,7 @@ impl Sas {
.map(VerificationResult::SignatureUpload)
.unwrap_or(VerificationResult::Ok))
} else {
Ok(self
.cancel()
.map(VerificationResult::Cancel)
.unwrap_or(VerificationResult::Ok))
Ok(self.cancel().map(VerificationResult::Cancel).unwrap_or(VerificationResult::Ok))
}
}
@ -463,14 +447,8 @@ impl Sas {
.as_ref()
.map_or(false, |i| i.master_key() == identity.master_key())
{
if self
.verified_identities()
.map_or(false, |i| i.contains(&identity))
{
trace!(
"Marking user identity of {} as verified.",
identity.user_id(),
);
if self.verified_identities().map_or(false, |i| i.contains(&identity)) {
trace!("Marking user identity of {} as verified.", identity.user_id(),);
if let UserIdentities::Own(i) = &identity {
i.mark_as_verified();
@ -509,17 +487,11 @@ impl Sas {
pub(crate) async fn mark_device_as_verified(
&self,
) -> Result<Option<ReadOnlyDevice>, CryptoStoreError> {
let device = self
.store
.get_device(self.other_user_id(), self.other_device_id())
.await?;
let device = self.store.get_device(self.other_user_id(), self.other_device_id()).await?;
if let Some(device) = device {
if device.keys() == self.other_device.keys() {
if self
.verified_devices()
.map_or(false, |v| v.contains(&device))
{
if self.verified_devices().map_or(false, |v| v.contains(&device)) {
trace!(
"Marking device {} {} as verified.",
device.user_id(),
@ -580,9 +552,9 @@ impl Sas {
content: AnyMessageEventContent::KeyVerificationCancel(content),
}
.into(),
CancelContent::ToDevice(c) => self
.content_to_request(AnyToDeviceEventContent::KeyVerificationCancel(c))
.into(),
CancelContent::ToDevice(c) => {
self.content_to_request(AnyToDeviceEventContent::KeyVerificationCancel(c)).into()
}
})
}
@ -684,11 +656,7 @@ impl Sas {
}
pub(crate) fn content_to_request(&self, content: AnyToDeviceEventContent) -> ToDeviceRequest {
content_to_request(
self.other_user_id(),
self.other_device_id().to_owned(),
content,
)
content_to_request(self.other_user_id(), self.other_device_id().to_owned(), content)
}
}
@ -717,9 +685,7 @@ impl AcceptSettings {
///
/// * `methods` - The methods this client allows at most
pub fn with_allowed_methods(methods: Vec<ShortAuthenticationString>) -> Self {
Self {
allowed_methods: methods,
}
Self { allowed_methods: methods }
}
fn apply(self, mut content: AcceptContent) -> AcceptContent {
@ -728,15 +694,8 @@ impl AcceptSettings {
method: AcceptMethod::MSasV1(c),
..
})
| AcceptContent::Room(
_,
AcceptEventContent {
method: AcceptMethod::MSasV1(c),
..
},
) => {
c.short_authentication_string
.retain(|sas| self.allowed_methods.contains(sas));
| AcceptContent::Room(_, AcceptEventContent { method: AcceptMethod::MSasV1(c), .. }) => {
c.short_authentication_string.retain(|sas| self.allowed_methods.contains(sas));
content
}
_ => content,
@ -750,6 +709,7 @@ mod test {
use matrix_sdk_common::identifiers::{DeviceId, UserId};
use super::Sas;
use crate::{
olm::PrivateCrossSigningIdentity,
store::{CryptoStore, MemoryStore},
@ -757,8 +717,6 @@ mod test {
ReadOnlyAccount, ReadOnlyDevice,
};
use super::Sas;
fn alice_id() -> UserId {
UserId::try_from("@alice:example.org").unwrap()
}
@ -841,13 +799,7 @@ mod test {
);
alice.receive_event(&event);
assert!(alice
.verified_devices()
.unwrap()
.contains(&alice.other_device()));
assert!(bob
.verified_devices()
.unwrap()
.contains(&bob.other_device()));
assert!(alice.verified_devices().unwrap().contains(&alice.other_device()));
assert!(bob.verified_devices().unwrap().contains(&bob.other_device()));
}
}

View file

@ -19,8 +19,6 @@ use std::{
time::{Duration, Instant},
};
use olm_rs::sas::OlmSas;
use matrix_sdk_common::{
events::key::verification::{
accept::{
@ -40,6 +38,7 @@ use matrix_sdk_common::{
identifiers::{DeviceId, EventId, RoomId, UserId},
uuid::Uuid,
};
use olm_rs::sas::OlmSas;
use tracing::info;
use super::{
@ -51,7 +50,6 @@ use super::{
receive_mac_event, SasIds,
},
};
use crate::{
identities::{ReadOnlyDevice, UserIdentities},
verification::FlowId,
@ -62,10 +60,8 @@ const KEY_AGREEMENT_PROTOCOLS: &[KeyAgreementProtocol] =
&[KeyAgreementProtocol::Curve25519HkdfSha256];
const HASHES: &[HashAlgorithm] = &[HashAlgorithm::Sha256];
const MACS: &[MessageAuthenticationCode] = &[MessageAuthenticationCode::HkdfHmacSha256];
const STRINGS: &[ShortAuthenticationString] = &[
ShortAuthenticationString::Decimal,
ShortAuthenticationString::Emoji,
];
const STRINGS: &[ShortAuthenticationString] =
&[ShortAuthenticationString::Decimal, ShortAuthenticationString::Emoji];
// The max time a SAS flow can take from start to done.
const MAX_AGE: Duration = Duration::from_secs(60 * 5);
@ -91,9 +87,7 @@ impl TryFrom<AcceptV1Content> for AcceptedProtocols {
if !KEY_AGREEMENT_PROTOCOLS.contains(&content.key_agreement_protocol)
|| !HASHES.contains(&content.hash)
|| !MACS.contains(&content.message_authentication_code)
|| (!content
.short_authentication_string
.contains(&ShortAuthenticationString::Emoji)
|| (!content.short_authentication_string.contains(&ShortAuthenticationString::Emoji)
&& !content
.short_authentication_string
.contains(&ShortAuthenticationString::Decimal))
@ -402,11 +396,7 @@ impl SasState<Created> {
) -> SasState<Created> {
SasState {
inner: Arc::new(Mutex::new(OlmSas::new())),
ids: SasIds {
account,
other_device,
other_identity,
},
ids: SasIds { account, other_device, other_identity },
verification_flow_id: flow_id.into(),
creation_time: Arc::new(Instant::now()),
@ -441,9 +431,7 @@ impl SasState<Created> {
MSasV1Content::new(self.state.protocol_definitions.clone())
.expect("Invalid initial protocol definitions."),
),
relation: Relation {
event_id: e.clone(),
},
relation: Relation { event_id: e.clone() },
},
),
}
@ -490,8 +478,8 @@ impl SasState<Created> {
}
impl SasState<Started> {
/// Create a new SAS verification flow from an in-room m.key.verification.start
/// event.
/// Create a new SAS verification flow from an in-room
/// m.key.verification.start event.
///
/// This will put us in the `started` state.
///
@ -549,11 +537,7 @@ impl SasState<Started> {
Ok(SasState {
inner: Arc::new(Mutex::new(sas)),
ids: SasIds {
account,
other_device,
other_identity,
},
ids: SasIds { account, other_device, other_identity },
creation_time: Arc::new(Instant::now()),
last_event_time: Arc::new(Instant::now()),
@ -605,19 +589,12 @@ impl SasState<Started> {
);
match self.verification_flow_id.as_ref() {
FlowId::ToDevice(s) => AcceptToDeviceEventContent {
transaction_id: s.to_string(),
method,
FlowId::ToDevice(s) => {
AcceptToDeviceEventContent { transaction_id: s.to_string(), method }.into()
}
.into(),
FlowId::InRoom(r, e) => (
r.clone(),
AcceptEventContent {
method,
relation: Relation {
event_id: e.clone(),
},
},
AcceptEventContent { method, relation: Relation { event_id: e.clone() } },
)
.into(),
}
@ -683,10 +660,8 @@ impl SasState<Accepted> {
self.check_event(&sender, content.flow_id().as_str())
.map_err(|c| self.clone().cancel(c))?;
let commitment = calculate_commitment(
content.public_key(),
self.state.start_content.as_ref().clone(),
);
let commitment =
calculate_commitment(content.public_key(), self.state.start_content.as_ref().clone());
if self.state.commitment != commitment {
Err(self.cancel(CancelCode::InvalidMessage))
@ -728,9 +703,7 @@ impl SasState<Accepted> {
r.clone(),
KeyEventContent {
key: self.inner.lock().unwrap().public_key(),
relation: Relation {
event_id: e.clone(),
},
relation: Relation { event_id: e.clone() },
},
)
.into(),
@ -754,9 +727,7 @@ impl SasState<KeyReceived> {
r.clone(),
KeyEventContent {
key: self.inner.lock().unwrap().public_key(),
relation: Relation {
event_id: e.clone(),
},
relation: Relation { event_id: e.clone() },
},
)
.into(),
@ -779,8 +750,8 @@ impl SasState<KeyReceived> {
/// Get the index of the emoji of the short authentication string.
///
/// Returns seven u8 numbers in the range from 0 to 63 inclusive, those numbers
/// can be converted to a unique emoji defined by the spec.
/// Returns seven u8 numbers in the range from 0 to 63 inclusive, those
/// numbers can be converted to a unique emoji defined by the spec.
pub fn get_emoji_index(&self) -> [u8; 7] {
get_emoji_index(
&self.inner.lock().unwrap(),
@ -952,11 +923,7 @@ impl SasState<Confirmed> {
///
/// The content needs to be automatically sent to the other side.
pub fn as_content(&self) -> MacContent {
get_mac_content(
&self.inner.lock().unwrap(),
&self.ids,
&self.verification_flow_id,
)
get_mac_content(&self.inner.lock().unwrap(), &self.ids, &self.verification_flow_id)
}
}
@ -1015,8 +982,8 @@ impl SasState<MacReceived> {
/// Get the index of the emoji of the short authentication string.
///
/// Returns seven u8 numbers in the range from 0 to 63 inclusive, those numbers
/// can be converted to a unique emoji defined by the spec.
/// Returns seven u8 numbers in the range from 0 to 63 inclusive, those
/// numbers can be converted to a unique emoji defined by the spec.
pub fn get_emoji_index(&self) -> [u8; 7] {
get_emoji_index(
&self.inner.lock().unwrap(),
@ -1048,11 +1015,7 @@ impl SasState<WaitingForDone> {
/// The content needs to be automatically sent to the other side if it
/// wasn't already sent.
pub fn as_content(&self) -> MacContent {
get_mac_content(
&self.inner.lock().unwrap(),
&self.ids,
&self.verification_flow_id,
)
get_mac_content(&self.inner.lock().unwrap(), &self.ids, &self.verification_flow_id)
}
pub fn done_content(&self) -> DoneContent {
@ -1060,15 +1023,9 @@ impl SasState<WaitingForDone> {
FlowId::ToDevice(_) => {
unreachable!("The done content isn't supported yet for to-device verifications")
}
FlowId::InRoom(r, e) => (
r.clone(),
DoneEventContent {
relation: Relation {
event_id: e.clone(),
},
},
)
.into(),
FlowId::InRoom(r, e) => {
(r.clone(), DoneEventContent { relation: Relation { event_id: e.clone() } }).into()
}
}
}
@ -1110,11 +1067,7 @@ impl SasState<Done> {
/// The content needs to be automatically sent to the other side if it
/// wasn't already sent.
pub fn as_content(&self) -> MacContent {
get_mac_content(
&self.inner.lock().unwrap(),
&self.ids,
&self.verification_flow_id,
)
get_mac_content(&self.inner.lock().unwrap(), &self.ids, &self.verification_flow_id)
}
pub fn done_content(&self) -> DoneContent {
@ -1122,15 +1075,9 @@ impl SasState<Done> {
FlowId::ToDevice(_) => {
unreachable!("The done content isn't supported yet for to-device verifications")
}
FlowId::InRoom(r, e) => (
r.clone(),
DoneEventContent {
relation: Relation {
event_id: e.clone(),
},
},
)
.into(),
FlowId::InRoom(r, e) => {
(r.clone(), DoneEventContent { relation: Relation { event_id: e.clone() } }).into()
}
}
}
@ -1166,10 +1113,7 @@ impl Canceled {
_ => unimplemented!(),
};
Canceled {
cancel_code: code,
reason,
}
Canceled { cancel_code: code, reason }
}
}
@ -1188,9 +1132,7 @@ impl SasState<Canceled> {
CancelEventContent {
reason: self.state.reason.to_string(),
code: self.state.cancel_code.clone(),
relation: Relation {
event_id: e.clone(),
},
relation: Relation { event_id: e.clone() },
},
)
.into(),
@ -1202,10 +1144,6 @@ impl SasState<Canceled> {
mod test {
use std::convert::TryFrom;
use crate::{
verification::sas::{event_enums::AcceptContent, StartContent},
ReadOnlyAccount, ReadOnlyDevice,
};
use matrix_sdk_common::{
events::key::verification::{
accept::{AcceptMethod, CustomContent},
@ -1215,6 +1153,10 @@ mod test {
};
use super::{Accepted, Created, SasState, Started};
use crate::{
verification::sas::{event_enums::AcceptContent, StartContent},
ReadOnlyAccount, ReadOnlyDevice,
};
fn alice_id() -> UserId {
UserId::try_from("@alice:example.org").unwrap()
@ -1353,9 +1295,7 @@ mod test {
let content = bob.as_content();
let sender = UserId::try_from("@malory:example.org").unwrap();
alice
.into_accepted(&sender, content)
.expect_err("Didn't cancel on a invalid sender");
alice.into_accepted(&sender, content).expect_err("Didn't cancel on a invalid sender");
}
#[tokio::test]

View file

@ -1,7 +1,6 @@
use std::{collections::HashMap, panic};
use http::Response;
use matrix_sdk_common::{
api::r0::sync::sync_events::Response as SyncResponse,
events::{
@ -11,9 +10,8 @@ use matrix_sdk_common::{
identifiers::{room_id, RoomId},
IncomingResponse,
};
use serde_json::Value as JsonValue;
pub use matrix_sdk_test_macros::async_test;
use serde_json::Value as JsonValue;
pub mod test_json;
@ -44,16 +42,17 @@ pub enum EventsJson {
Typing,
}
/// The `EventBuilder` struct can be used to easily generate valid sync responses for testing.
/// These can be then fed into either `Client` or `Room`.
/// The `EventBuilder` struct can be used to easily generate valid sync
/// responses for testing. These can be then fed into either `Client` or `Room`.
///
/// It supports generated a number of canned events, such as a member entering a room, his power
/// level and display name changing and similar. It also supports insertion of custom events in the
/// form of `EventsJson` values.
/// It supports generated a number of canned events, such as a member entering a
/// room, his power level and display name changing and similar. It also
/// supports insertion of custom events in the form of `EventsJson` values.
///
/// **Important** You *must* use the *same* builder when sending multiple sync responses to
/// a single client. Otherwise, the subsequent responses will be *ignored* by the client because
/// the `next_batch` sync token will not be rotated properly.
/// **Important** You *must* use the *same* builder when sending multiple sync
/// responses to a single client. Otherwise, the subsequent responses will be
/// *ignored* by the client because the `next_batch` sync token will not be
/// rotated properly.
///
/// # Example usage
///
@ -94,7 +93,8 @@ pub struct EventBuilder {
ephemeral: Vec<AnySyncEphemeralRoomEvent>,
/// The account data events that determine the state of a `Room`.
account_data: Vec<AnyGlobalAccountDataEvent>,
/// Internal counter to enable the `prev_batch` and `next_batch` of each sync response to vary.
/// Internal counter to enable the `prev_batch` and `next_batch` of each
/// sync response to vary.
batch_counter: i64,
}
@ -154,10 +154,7 @@ impl EventBuilder {
}
fn add_joined_event(&mut self, room_id: &RoomId, event: AnySyncRoomEvent) {
self.joined_room_events
.entry(room_id.clone())
.or_insert_with(Vec::new)
.push(event);
self.joined_room_events.entry(room_id.clone()).or_insert_with(Vec::new).push(event);
}
pub fn add_custom_invited_event(
@ -166,10 +163,7 @@ impl EventBuilder {
event: serde_json::Value,
) -> &mut Self {
let event = serde_json::from_value::<AnySyncStateEvent>(event).unwrap();
self.invited_room_events
.entry(room_id.clone())
.or_insert_with(Vec::new)
.push(event);
self.invited_room_events.entry(room_id.clone()).or_insert_with(Vec::new).push(event);
self
}
@ -179,10 +173,7 @@ impl EventBuilder {
event: serde_json::Value,
) -> &mut Self {
let event = serde_json::from_value::<AnySyncRoomEvent>(event).unwrap();
self.left_room_events
.entry(room_id.clone())
.or_insert_with(Vec::new)
.push(event);
self.left_room_events.entry(room_id.clone()).or_insert_with(Vec::new).push(event);
self
}
@ -227,7 +218,8 @@ impl EventBuilder {
pub fn build_json_sync_response(&mut self) -> JsonValue {
let main_room_id = room_id!("!SVkFJHzfwvuaIEawgC:localhost");
// First time building a sync response, so initialize the `prev_batch` to a default one.
// First time building a sync response, so initialize the `prev_batch` to a
// default one.
let prev_batch = self.generate_sync_token();
self.batch_counter += 1;
let next_batch = self.generate_sync_token();
@ -352,9 +344,7 @@ impl EventBuilder {
pub fn build_sync_response(&mut self) -> SyncResponse {
let body = self.build_json_sync_response();
let response = Response::builder()
.body(serde_json::to_vec(&body).unwrap())
.unwrap();
let response = Response::builder().body(serde_json::to_vec(&body).unwrap()).unwrap();
SyncResponse::try_from_http_response(response).unwrap()
}
@ -395,15 +385,10 @@ pub fn sync_response(kind: SyncResponseFile) -> SyncResponse {
SyncResponseFile::Voip => &test_json::VOIP_SYNC,
};
let response = Response::builder()
.body(data.to_string().as_bytes().to_vec())
.unwrap();
let response = Response::builder().body(data.to_string().as_bytes().to_vec()).unwrap();
SyncResponse::try_from_http_response(response).unwrap()
}
pub fn response_from_file(json: &serde_json::Value) -> Response<Vec<u8>> {
Response::builder()
.status(200)
.body(json.to_string().as_bytes().to_vec())
.unwrap()
Response::builder().status(200).body(json.to_string().as_bytes().to_vec()).unwrap()
}

View file

@ -1,8 +1,8 @@
//! Test data for the matrix-sdk crates.
//!
//! Exporting each const allows all the test data to have a single source of truth.
//! When running `cargo publish` no external folders are allowed so all the
//! test data needs to be contained within this crate.
//! Exporting each const allows all the test data to have a single source of
//! truth. When running `cargo publish` no external folders are allowed so all
//! the test data needs to be contained within this crate.
use lazy_static::lazy_static;
use serde_json::{json, Value as JsonValue};
@ -17,12 +17,11 @@ pub use events::{
PUBLIC_ROOMS, REACTION, REDACTED, REDACTED_INVALID, REDACTED_STATE, REDACTION,
REGISTRATION_RESPONSE_ERR, ROOM_ID, ROOM_MESSAGES, TYPING,
};
pub use members::MEMBERS;
pub use sync::{
DEFAULT_SYNC_SUMMARY, INVITE_SYNC, LEAVE_SYNC, LEAVE_SYNC_EVENT, MORE_SYNC, SYNC, VOIP_SYNC,
};
pub use members::MEMBERS;
lazy_static! {
pub static ref DEVICES: JsonValue = json!({
"devices": [

View file

@ -1,6 +1,7 @@
use proc_macro::TokenStream;
/// Attribute to use `wasm_bindgen_test` for wasm32 targets and `tokio::test` for everything else
/// Attribute to use `wasm_bindgen_test` for wasm32 targets and `tokio::test`
/// for everything else
#[proc_macro_attribute]
pub fn async_test(_attr: TokenStream, item: TokenStream) -> TokenStream {
let attrs = r#"