Merge remote-tracking branch 'origin/master' into feature/qrcode-feature

master
Alexander Sieg 2021-09-09 15:10:18 +02:00
commit c47ac8d6b1
42 changed files with 2124 additions and 1578 deletions

View File

@ -32,9 +32,12 @@ appservice = ["ruma/appservice-api-s", "ruma/appservice-api-helper", "ruma/rand"
docs = ["encryption", "sled_cryptostore", "sled_state_store", "sso_login"] docs = ["encryption", "sled_cryptostore", "sled_state_store", "sso_login"]
[dependencies] [dependencies]
anyhow = { version = "1.0.42", optional = true }
dashmap = "4.0.2" dashmap = "4.0.2"
event-listener = "2.5.1"
futures = "0.3.15" futures = "0.3.15"
http = "0.2.4" http = "0.2.4"
serde = "1.0.126"
serde_json = "1.0.64" serde_json = "1.0.64"
thiserror = "1.0.25" thiserror = "1.0.25"
tracing = "0.1.26" tracing = "0.1.26"
@ -91,6 +94,7 @@ version = "3.0.2"
features = ["wasm-bindgen"] features = ["wasm-bindgen"]
[dev-dependencies] [dev-dependencies]
anyhow = "1.0"
dirs = "3.0.2" dirs = "3.0.2"
matches = "0.1.8" matches = "0.1.8"
matrix-sdk-test = { version = "0.3.0", path = "../matrix_sdk_test" } matrix-sdk-test = { version = "0.3.0", path = "../matrix_sdk_test" }

View File

@ -1,33 +1,19 @@
use std::{env, process::exit}; use std::{env, process::exit};
use matrix_sdk::{ use matrix_sdk::{
async_trait,
room::Room, room::Room,
ruma::events::{room::member::MemberEventContent, StrippedStateEvent}, ruma::events::{room::member::MemberEventContent, StrippedStateEvent},
Client, ClientConfig, EventHandler, SyncSettings, Client, ClientConfig, SyncSettings,
}; };
use tokio::time::{sleep, Duration}; use tokio::time::{sleep, Duration};
use url::Url; use url::Url;
struct AutoJoinBot {
client: Client,
}
impl AutoJoinBot {
pub fn new(client: Client) -> Self {
Self { client }
}
}
#[async_trait]
impl EventHandler for AutoJoinBot {
async fn on_stripped_state_member( async fn on_stripped_state_member(
&self, room_member: StrippedStateEvent<MemberEventContent>,
client: Client,
room: Room, room: Room,
room_member: &StrippedStateEvent<MemberEventContent>,
_: Option<MemberEventContent>,
) { ) {
if room_member.state_key != self.client.user_id().await.unwrap() { if room_member.state_key != client.user_id().await.unwrap() {
return; return;
} }
@ -39,12 +25,7 @@ impl EventHandler for AutoJoinBot {
// retry autojoin due to synapse sending invites, before the // retry autojoin due to synapse sending invites, before the
// invited user can join for more information see // invited user can join for more information see
// https://github.com/matrix-org/synapse/issues/4345 // https://github.com/matrix-org/synapse/issues/4345
eprintln!( eprintln!("Failed to join room {} ({:?}), retrying in {}s", room.room_id(), err, delay);
"Failed to join room {} ({:?}), retrying in {}s",
room.room_id(),
err,
delay
);
sleep(Duration::from_secs(delay)).await; sleep(Duration::from_secs(delay)).await;
delay *= 2; delay *= 2;
@ -57,7 +38,6 @@ impl EventHandler for AutoJoinBot {
println!("Successfully joined room {}", room.room_id()); println!("Successfully joined room {}", room.room_id());
} }
} }
}
async fn login_and_sync( async fn login_and_sync(
homeserver_url: String, homeserver_url: String,
@ -76,7 +56,7 @@ async fn login_and_sync(
println!("logged in as {}", username); println!("logged in as {}", username);
client.set_event_handler(Box::new(AutoJoinBot::new(client.clone()))).await; client.register_event_handler(on_stripped_state_member).await;
client.sync(SyncSettings::default()).await; client.sync(SyncSettings::default()).await;

View File

@ -1,27 +1,16 @@
use std::{env, process::exit}; use std::{env, process::exit};
use matrix_sdk::{ use matrix_sdk::{
async_trait,
room::Room, room::Room,
ruma::events::{ ruma::events::{
room::message::{MessageEventContent, MessageType, TextMessageEventContent}, room::message::{MessageEventContent, MessageType, TextMessageEventContent},
AnyMessageEventContent, SyncMessageEvent, AnyMessageEventContent, SyncMessageEvent,
}, },
Client, ClientConfig, EventHandler, SyncSettings, Client, ClientConfig, SyncSettings,
}; };
use url::Url; use url::Url;
struct CommandBot; async fn on_room_message(event: SyncMessageEvent<MessageEventContent>, room: Room) {
impl CommandBot {
pub fn new() -> Self {
Self {}
}
}
#[async_trait]
impl EventHandler for CommandBot {
async fn on_room_message(&self, room: Room, event: &SyncMessageEvent<MessageEventContent>) {
if let Room::Joined(room) = room { if let Room::Joined(room) = room {
let msg_body = if let SyncMessageEvent { let msg_body = if let SyncMessageEvent {
content: content:
@ -52,7 +41,6 @@ impl EventHandler for CommandBot {
} }
} }
} }
}
async fn login_and_sync( async fn login_and_sync(
homeserver_url: String, homeserver_url: String,
@ -79,7 +67,7 @@ async fn login_and_sync(
client.sync_once(SyncSettings::default()).await.unwrap(); client.sync_once(SyncSettings::default()).await.unwrap();
// add our CommandBot to be notified of incoming messages, we do this after the // add our CommandBot to be notified of incoming messages, we do this after the
// initial sync to avoid responding to messages before the bot was running. // initial sync to avoid responding to messages before the bot was running.
client.set_event_handler(Box::new(CommandBot::new())).await; client.register_event_handler(on_room_message).await;
// since we called `sync_once` before we entered our sync loop we must pass // since we called `sync_once` before we entered our sync loop we must pass
// that sync token to `sync` // that sync token to `sync`

View File

@ -8,31 +8,22 @@ use std::{
}; };
use matrix_sdk::{ use matrix_sdk::{
self, async_trait, self,
room::Room, room::Room,
ruma::events::{ ruma::events::{
room::message::{MessageEventContent, MessageType, TextMessageEventContent}, room::message::{MessageEventContent, MessageType, TextMessageEventContent},
SyncMessageEvent, SyncMessageEvent,
}, },
Client, EventHandler, SyncSettings, Client, SyncSettings,
}; };
use tokio::sync::Mutex; use tokio::sync::Mutex;
use url::Url; use url::Url;
struct ImageBot { async fn on_room_message(
event: SyncMessageEvent<MessageEventContent>,
room: Room,
image: Arc<Mutex<File>>, image: Arc<Mutex<File>>,
} ) {
impl ImageBot {
pub fn new(image: File) -> Self {
let image = Arc::new(Mutex::new(image));
Self { image }
}
}
#[async_trait]
impl EventHandler for ImageBot {
async fn on_room_message(&self, room: Room, event: &SyncMessageEvent<MessageEventContent>) {
if let Room::Joined(room) = room { if let Room::Joined(room) = room {
let msg_body = if let SyncMessageEvent { let msg_body = if let SyncMessageEvent {
content: content:
@ -50,7 +41,7 @@ impl EventHandler for ImageBot {
if msg_body.contains("!image") { if msg_body.contains("!image") {
println!("sending image"); println!("sending image");
let mut image = self.image.lock().await; let mut image = image.lock().await;
room.send_attachment("cat", &mime::IMAGE_JPEG, &mut *image, None).await.unwrap(); room.send_attachment("cat", &mime::IMAGE_JPEG, &mut *image, None).await.unwrap();
@ -60,7 +51,6 @@ impl EventHandler for ImageBot {
} }
} }
} }
}
async fn login_and_sync( async fn login_and_sync(
homeserver_url: String, homeserver_url: String,
@ -74,7 +64,9 @@ async fn login_and_sync(
client.login(&username, &password, None, Some("command bot")).await?; client.login(&username, &password, None, Some("command bot")).await?;
client.sync_once(SyncSettings::default()).await.unwrap(); client.sync_once(SyncSettings::default()).await.unwrap();
client.set_event_handler(Box::new(ImageBot::new(image))).await;
let image = Arc::new(Mutex::new(image));
client.register_event_handler(move |ev, room| on_room_message(ev, room, image.clone())).await;
let settings = SyncSettings::default().token(client.sync_token().await.unwrap()); let settings = SyncSettings::default().token(client.sync_token().await.unwrap());
client.sync(settings).await; client.sync(settings).await;

View File

@ -1,21 +1,17 @@
use std::{env, process::exit}; use std::{env, process::exit};
use matrix_sdk::{ use matrix_sdk::{
self, async_trait, self,
room::Room, room::Room,
ruma::events::{ ruma::events::{
room::message::{MessageEventContent, MessageType, TextMessageEventContent}, room::message::{MessageEventContent, MessageType, TextMessageEventContent},
SyncMessageEvent, SyncMessageEvent,
}, },
Client, EventHandler, SyncSettings, Client, SyncSettings,
}; };
use url::Url; use url::Url;
struct EventCallback; async fn on_room_message(event: SyncMessageEvent<MessageEventContent>, room: Room) {
#[async_trait]
impl EventHandler for EventCallback {
async fn on_room_message(&self, room: Room, event: &SyncMessageEvent<MessageEventContent>) {
if let Room::Joined(room) = room { if let Room::Joined(room) = room {
if let SyncMessageEvent { if let SyncMessageEvent {
content: content:
@ -27,13 +23,12 @@ impl EventHandler for EventCallback {
.. ..
} = event } = event
{ {
let member = room.get_member(sender).await.unwrap().unwrap(); let member = room.get_member(&sender).await.unwrap().unwrap();
let name = member.display_name().unwrap_or_else(|| member.user_id().as_str()); let name = member.display_name().unwrap_or_else(|| member.user_id().as_str());
println!("{}: {}", name, msg_body); println!("{}: {}", name, msg_body);
} }
} }
} }
}
async fn login( async fn login(
homeserver_url: String, homeserver_url: String,
@ -43,7 +38,7 @@ async fn login(
let homeserver_url = Url::parse(&homeserver_url).expect("Couldn't parse the homeserver URL"); let homeserver_url = Url::parse(&homeserver_url).expect("Couldn't parse the homeserver URL");
let client = Client::new(homeserver_url).unwrap(); let client = Client::new(homeserver_url).unwrap();
client.set_event_handler(Box::new(EventCallback)).await; client.register_event_handler(on_room_message).await;
client.login(username, password, None, Some("rust-sdk")).await?; client.login(username, password, None, Some("rust-sdk")).await?;
client.sync(SyncSettings::new()).await; client.sync(SyncSettings::new()).await;

View File

@ -15,10 +15,15 @@
#[cfg(all(feature = "encryption", not(target_arch = "wasm32")))] #[cfg(all(feature = "encryption", not(target_arch = "wasm32")))]
use std::path::PathBuf; use std::path::PathBuf;
#[cfg(feature = "encryption")]
use std::{ use std::{
collections::BTreeMap, collections::BTreeMap,
io::{Cursor, Write}, fmt::{self, Debug},
future::Future,
io::Read,
path::Path,
pin::Pin,
result::Result as StdResult,
sync::Arc,
}; };
#[cfg(feature = "sso_login")] #[cfg(feature = "sso_login")]
use std::{ use std::{
@ -26,16 +31,14 @@ use std::{
io::{Error as IoError, ErrorKind as IoErrorKind}, io::{Error as IoError, ErrorKind as IoErrorKind},
ops::Range, ops::Range,
}; };
#[cfg(feature = "encryption")]
use std::{ use std::{
fmt::{self, Debug}, collections::HashSet,
future::Future, io::{Cursor, Write},
io::Read,
path::Path,
result::Result as StdResult,
sync::Arc,
}; };
use dashmap::DashMap; use dashmap::DashMap;
use futures::FutureExt;
use futures_timer::Delay as sleep; use futures_timer::Delay as sleep;
use http::HeaderValue; use http::HeaderValue;
#[cfg(feature = "sso_login")] #[cfg(feature = "sso_login")]
@ -48,9 +51,9 @@ use matrix_sdk_base::crypto::{
ToDeviceRequest, ToDeviceRequest,
}; };
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
use matrix_sdk_base::deserialized_responses::RoomEvent; use matrix_sdk_base::{crypto::CrossSigningStatus, deserialized_responses::RoomEvent};
use matrix_sdk_base::{ use matrix_sdk_base::{
deserialized_responses::SyncResponse, deserialized_responses::{JoinedRoom, LeftRoom, SyncResponse},
media::{MediaEventContent, MediaFormat, MediaRequest, MediaThumbnailSize, MediaType}, media::{MediaEventContent, MediaFormat, MediaRequest, MediaThumbnailSize, MediaType},
BaseClient, BaseClientConfig, Session, Store, BaseClient, BaseClientConfig, Session, Store,
}; };
@ -59,15 +62,20 @@ use mime::{self, Mime};
use rand::{thread_rng, Rng}; use rand::{thread_rng, Rng};
use reqwest::header::InvalidHeaderValue; use reqwest::header::InvalidHeaderValue;
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
use ruma::events::{AnyMessageEvent, AnyRoomEvent, AnySyncMessageEvent}; use ruma::events::{AnyMessageEvent, AnyRoomEvent, AnySyncMessageEvent, EventType};
use ruma::{api::SendAccessToken, events::AnyMessageEventContent, MxcUri}; use ruma::{
api::{client::r0::push::get_notifications::Notification, SendAccessToken},
events::AnyMessageEventContent,
MxcUri,
};
use serde::de::DeserializeOwned;
#[cfg(feature = "sso_login")] #[cfg(feature = "sso_login")]
use tokio::{net::TcpListener, sync::oneshot}; use tokio::{net::TcpListener, sync::oneshot};
#[cfg(feature = "sso_login")] #[cfg(feature = "sso_login")]
use tokio_stream::wrappers::TcpListenerStream; use tokio_stream::wrappers::TcpListenerStream;
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
use tracing::{debug, warn}; use tracing::{debug, trace};
use tracing::{error, info, instrument}; use tracing::{error, info, instrument, warn};
use url::Url; use url::Url;
#[cfg(feature = "sso_login")] #[cfg(feature = "sso_login")]
use warp::Filter; use warp::Filter;
@ -134,15 +142,15 @@ use ruma::{
use crate::verification::QrVerification; use crate::verification::QrVerification;
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
use crate::{ use crate::{
device::{Device, UserDevices},
error::RoomKeyImportError, error::RoomKeyImportError,
identities::{Device, UserDevices},
verification::{SasVerification, Verification, VerificationRequest}, verification::{SasVerification, Verification, VerificationRequest},
}; };
use crate::{ use crate::{
error::HttpError, error::{HttpError, HttpResult},
event_handler::Handler, event_handler::{EventHandler, EventHandlerData, EventHandlerResult, EventKind, SyncEvent},
http_client::{client_with_config, HttpClient, HttpSend}, http_client::{client_with_config, HttpClient, HttpSend},
room, Error, EventHandler, Result, room, Error, Result,
}; };
const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(10); const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(10);
@ -158,6 +166,14 @@ const SSO_SERVER_BIND_RANGE: Range<u16> = 20000..30000;
#[cfg(feature = "sso_login")] #[cfg(feature = "sso_login")]
const SSO_SERVER_BIND_TRIES: u8 = 10; const SSO_SERVER_BIND_TRIES: u8 = 10;
type EventHandlerFut = Pin<Box<dyn Future<Output = ()> + Send>>;
type EventHandlerFn = Box<dyn Fn(EventHandlerData<'_>) -> EventHandlerFut + Send + Sync>;
type EventHandlerMap = BTreeMap<(EventKind, &'static str), Vec<EventHandlerFn>>;
type NotificationHandlerFut = EventHandlerFut;
type NotificationHandlerFn =
Box<dyn Fn(Notification, room::Room, Client) -> NotificationHandlerFut + Send + Sync>;
/// An async/await enabled Matrix client. /// An async/await enabled Matrix client.
/// ///
/// All of the state is held in an `Arc` so the `Client` can be cloned freely. /// All of the state is held in an `Arc` so the `Client` can be cloned freely.
@ -178,13 +194,20 @@ pub struct Client {
key_claim_lock: Arc<Mutex<()>>, key_claim_lock: Arc<Mutex<()>>,
pub(crate) members_request_locks: Arc<DashMap<RoomId, Arc<Mutex<()>>>>, pub(crate) members_request_locks: Arc<DashMap<RoomId, Arc<Mutex<()>>>>,
pub(crate) typing_notice_times: Arc<DashMap<RoomId, Instant>>, pub(crate) typing_notice_times: Arc<DashMap<RoomId, Instant>>,
/// Any implementor of EventHandler will act as the callbacks for various /// Event handlers. See `register_event_handler`.
/// events. pub(crate) event_handlers: Arc<RwLock<EventHandlerMap>>,
event_handler: Arc<RwLock<Option<Handler>>>, /// Notification handlers. See `register_notification_handler`.
notification_handlers: Arc<RwLock<Vec<NotificationHandlerFn>>>,
/// Whether the client should operate in application service style mode. /// Whether the client should operate in application service style mode.
/// This is low-level functionality. For an high-level API check the /// This is low-level functionality. For an high-level API check the
/// `matrix_sdk_appservice` crate. /// `matrix_sdk_appservice` crate.
appservice_mode: bool, appservice_mode: bool,
/// An event that can be listened on to wait for a successful sync. The
/// event will only be fired if a sync loop is running. Can be used for
/// synchronization, e.g. if we send out a request to create a room, we can
/// wait for the sync to get the data to fetch a room object from the state
/// store.
sync_beat: Arc<event_listener::Event>,
} }
#[cfg(not(tarpaulin_include))] #[cfg(not(tarpaulin_include))]
@ -559,13 +582,15 @@ impl Client {
http_client, http_client,
base_client, base_client,
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
group_session_locks: Arc::new(DashMap::new()), group_session_locks: Default::default(),
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
key_claim_lock: Arc::new(Mutex::new(())), key_claim_lock: Default::default(),
members_request_locks: Arc::new(DashMap::new()), members_request_locks: Default::default(),
typing_notice_times: Arc::new(DashMap::new()), typing_notice_times: Default::default(),
event_handler: Arc::new(RwLock::new(None)), event_handlers: Default::default(),
notification_handlers: Default::default(),
appservice_mode: config.appservice_mode, appservice_mode: config.appservice_mode,
sync_beat: event_listener::Event::new().into(),
}) })
} }
@ -629,7 +654,7 @@ impl Client {
Ok(result) Ok(result)
} }
async fn discover_homeserver(&self) -> Result<discover_homeserver::Response> { async fn discover_homeserver(&self) -> HttpResult<discover_homeserver::Response> {
self.send(discover_homeserver::Request::new(), Some(RequestConfig::new().disable_retry())) self.send(discover_homeserver::Request::new(), Some(RequestConfig::new().disable_retry()))
.await .await
} }
@ -644,7 +669,7 @@ impl Client {
*homeserver = homeserver_url; *homeserver = homeserver_url;
} }
async fn get_supported_versions(&self) -> Result<get_supported_versions::Response> { async fn get_supported_versions(&self) -> HttpResult<get_supported_versions::Response> {
self.send( self.send(
get_supported_versions::Request::new(), get_supported_versions::Request::new(),
Some(RequestConfig::new().disable_retry()), Some(RequestConfig::new().disable_retry()),
@ -668,12 +693,7 @@ impl Client {
) -> Result<()> { ) -> Result<()> {
let txn_id = incoming_transaction.txn_id.clone(); let txn_id = incoming_transaction.txn_id.clone();
let response = incoming_transaction.try_into_sync_response(txn_id)?; let response = incoming_transaction.try_into_sync_response(txn_id)?;
let base_client = self.base_client.clone(); self.process_sync(response).await?;
let sync_response = base_client.receive_sync_response(response).await?;
if let Some(handler) = self.event_handler.read().await.as_ref() {
handler.handle_sync(&sync_response).await;
}
Ok(()) Ok(())
} }
@ -708,6 +728,16 @@ impl Client {
self.base_client.olm_machine().await.map(|o| o.identity_keys().ed25519().to_owned()) self.base_client.olm_machine().await.map(|o| o.identity_keys().ed25519().to_owned())
} }
/// Get all the tracked users we know about
///
/// Tracked users are users for which we keep the device list of E2EE
/// capable devices up to date.
#[cfg(feature = "encryption")]
#[cfg_attr(feature = "docs", doc(cfg(encryption)))]
pub async fn tracked_users(&self) -> HashSet<UserId> {
self.base_client.olm_machine().await.map(|o| o.tracked_users()).unwrap_or_default()
}
/// Fetches the display name of the owner of the client. /// Fetches the display name of the owner of the client.
/// ///
/// # Example /// # Example
@ -869,13 +899,125 @@ impl Client {
Ok(()) Ok(())
} }
/// Add `EventHandler` to `Client`. /// Register a handler for a specific event type.
/// ///
/// The methods of `EventHandler` are called when the respective /// The handler is a function or closure with one or more arguments. The
/// `RoomEvents` occur. /// first argument is the event itself. All additional arguments are
pub async fn set_event_handler(&self, handler: Box<dyn EventHandler>) { /// "context" arguments: They have to implement [`EventHandlerContext`].
let handler = Handler { inner: handler, client: self.clone() }; /// This trait is named that way because most of the types implementing it
*self.event_handler.write().await = Some(handler); /// give additional context about an event: The room it was in, its raw form
/// and other similar things. As an exception to this,
/// [`Client`] also implements the `EventHandlerContext` trait
/// so you don't have to clone your client into the event handler manually.
///
/// Some context arguments are not universally applicable. A context
/// argument that isn't available for the given event type will result in
/// the event handler being skipped and an error being logged. The following
/// context argument types are only available for a subset of event types:
///
/// * [`Room`][room::Room] is only available for room-specific events, i.e.
/// not for events like global account data events or presence events
///
/// [`EventHandlerContext`]: crate::event_handler::EventHandlerContext
///
/// # Examples
///
/// ```no_run
/// # let client: matrix_sdk::Client = unimplemented!();
/// use matrix_sdk::{
/// room::Room,
/// ruma::{
/// events::{
/// macros::EventContent,
/// push_rules::PushRulesEvent,
/// room::{message::MessageEventContent, topic::TopicEventContent},
/// SyncMessageEvent, SyncStateEvent,
/// },
/// Int, MilliSecondsSinceUnixEpoch,
/// },
/// Client,
/// };
/// use serde::{Deserialize, Serialize};
///
/// # let _ = async {
/// client
/// .register_event_handler(
/// |ev: SyncMessageEvent<MessageEventContent>, room: Room, client: Client| async move {
/// // Common usage: Room event plus room and client.
/// },
/// )
/// .await
/// .register_event_handler(|ev: SyncStateEvent<TopicEventContent>| async move {
/// // Also possible: Omit any or all arguments after the first.
/// })
/// .await;
///
/// // Custom events work exactly the same way, you just need to declare the content struct and
/// // use the EventContent derive macro on it.
/// #[derive(Clone, Debug, Deserialize, Serialize, EventContent)]
/// #[ruma_event(type = "org.shiny_new_2fa.token", kind = Message)]
/// struct TokenEventContent {
/// token: String,
/// #[serde(rename = "exp")]
/// expires_at: MilliSecondsSinceUnixEpoch,
/// }
///
/// client.register_event_handler(
/// |ev: SyncMessageEvent<TokenEventContent>, room: Room| async move {
/// todo!("Display the token");
/// },
/// ).await;
/// # };
/// ```
pub async fn register_event_handler<Ev, Ctx, H>(&self, handler: H) -> &Self
where
Ev: SyncEvent + DeserializeOwned + Send + 'static,
H: EventHandler<Ev, Ctx>,
<H::Future as Future>::Output: EventHandlerResult,
{
let event_type = H::ID.1;
self.event_handlers.write().await.entry(H::ID).or_default().push(Box::new(move |data| {
let maybe_fut = serde_json::from_str(data.raw.get())
.map(|ev| handler.clone().handle_event(ev, data));
async move {
match maybe_fut {
Ok(Some(fut)) => {
fut.await.print_error(event_type);
}
Ok(None) => {
error!("Event handler for {} has an invalid context argument", event_type);
}
Err(e) => {
warn!(
"Failed to deserialize `{}` event, skipping event handler.\n\
Deserialization error: {}",
event_type, e,
);
}
}
}
.boxed()
}));
self
}
/// Register a handler for a notification.
///
/// Similar to `.register_event_handler`, but only allows functions or
/// closures with exactly the three arguments `Notification`, `room::Room`,
/// `Client` for now.
pub async fn register_notification_handler<H, Fut>(&self, handler: H) -> &Self
where
H: Fn(Notification, room::Room, Client) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
self.notification_handlers.write().await.push(Box::new(
move |notification, room, client| (handler)(notification, room, client).boxed(),
));
self
} }
/// Get all the rooms the client knows about. /// Get all the rooms the client knows about.
@ -990,7 +1132,7 @@ impl Client {
/// ///
/// This should be the first step when trying to login so you can call the /// This should be the first step when trying to login so you can call the
/// appropriate method for the next step. /// appropriate method for the next step.
pub async fn get_login_types(&self) -> Result<get_login_types::Response> { pub async fn get_login_types(&self) -> HttpResult<get_login_types::Response> {
let request = get_login_types::Request::new(); let request = get_login_types::Request::new();
self.send(request, None).await self.send(request, None).await
} }
@ -1407,7 +1549,7 @@ impl Client {
pub async fn register( pub async fn register(
&self, &self,
registration: impl Into<register::Request<'_>>, registration: impl Into<register::Request<'_>>,
) -> Result<register::Response> { ) -> HttpResult<register::Response> {
info!("Registering to {}", self.homeserver().await); info!("Registering to {}", self.homeserver().await);
let config = if self.appservice_mode { let config = if self.appservice_mode {
@ -1497,7 +1639,7 @@ impl Client {
/// # Arguments /// # Arguments
/// ///
/// * `room_id` - The `RoomId` of the room to be joined. /// * `room_id` - The `RoomId` of the room to be joined.
pub async fn join_room_by_id(&self, room_id: &RoomId) -> Result<join_room_by_id::Response> { pub async fn join_room_by_id(&self, room_id: &RoomId) -> HttpResult<join_room_by_id::Response> {
let request = join_room_by_id::Request::new(room_id); let request = join_room_by_id::Request::new(room_id);
self.send(request, None).await self.send(request, None).await
} }
@ -1515,7 +1657,7 @@ impl Client {
&self, &self,
alias: &RoomIdOrAliasId, alias: &RoomIdOrAliasId,
server_names: &[Box<ServerName>], server_names: &[Box<ServerName>],
) -> Result<join_room_by_id_or_alias::Response> { ) -> HttpResult<join_room_by_id_or_alias::Response> {
let request = assign!(join_room_by_id_or_alias::Request::new(alias), { let request = assign!(join_room_by_id_or_alias::Request::new(alias), {
server_name: server_names, server_name: server_names,
}); });
@ -1558,7 +1700,7 @@ impl Client {
limit: Option<u32>, limit: Option<u32>,
since: Option<&str>, since: Option<&str>,
server: Option<&ServerName>, server: Option<&ServerName>,
) -> Result<get_public_rooms::Response> { ) -> HttpResult<get_public_rooms::Response> {
let limit = limit.map(UInt::from); let limit = limit.map(UInt::from);
let request = assign!(get_public_rooms::Request::new(), { let request = assign!(get_public_rooms::Request::new(), {
@ -1599,7 +1741,7 @@ impl Client {
pub async fn create_room( pub async fn create_room(
&self, &self,
room: impl Into<create_room::Request<'_>>, room: impl Into<create_room::Request<'_>>,
) -> Result<create_room::Response> { ) -> HttpResult<create_room::Response> {
let request = room.into(); let request = room.into();
self.send(request, None).await self.send(request, None).await
} }
@ -1639,7 +1781,7 @@ impl Client {
pub async fn public_rooms_filtered( pub async fn public_rooms_filtered(
&self, &self,
room_search: impl Into<get_public_rooms_filtered::Request<'_>>, room_search: impl Into<get_public_rooms_filtered::Request<'_>>,
) -> Result<get_public_rooms_filtered::Response> { ) -> HttpResult<get_public_rooms_filtered::Response> {
let request = room_search.into(); let request = room_search.into();
self.send(request, None).await self.send(request, None).await
} }
@ -1773,7 +1915,7 @@ impl Client {
let txn_id = txn_id.unwrap_or_else(Uuid::new_v4).to_string(); let txn_id = txn_id.unwrap_or_else(Uuid::new_v4).to_string();
let request = send_message_event::Request::new(room_id, &txn_id, &content); let request = send_message_event::Request::new(room_id, &txn_id, &content);
self.send(request, None).await Ok(self.send(request, None).await?)
} }
} }
@ -1821,7 +1963,7 @@ impl Client {
&self, &self,
request: Request, request: Request,
config: Option<RequestConfig>, config: Option<RequestConfig>,
) -> Result<Request::IncomingResponse> ) -> HttpResult<Request::IncomingResponse>
where where
Request: OutgoingRequest + Debug, Request: OutgoingRequest + Debug,
HttpError: From<FromHttpResponseError<Request::EndpointError>>, HttpError: From<FromHttpResponseError<Request::EndpointError>>,
@ -1833,8 +1975,9 @@ impl Client {
pub(crate) async fn send_to_device( pub(crate) async fn send_to_device(
&self, &self,
request: &ToDeviceRequest, request: &ToDeviceRequest,
) -> Result<ToDeviceResponse> { ) -> HttpResult<ToDeviceResponse> {
let txn_id_string = request.txn_id_string(); let txn_id_string = request.txn_id_string();
let request = RumaToDeviceRequest::new_raw( let request = RumaToDeviceRequest::new_raw(
request.event_type.as_str(), request.event_type.as_str(),
&txn_id_string, &txn_id_string,
@ -1867,7 +2010,7 @@ impl Client {
/// } /// }
/// # }); /// # });
/// ``` /// ```
pub async fn devices(&self) -> Result<get_devices::Response> { pub async fn devices(&self) -> HttpResult<get_devices::Response> {
let request = get_devices::Request::new(); let request = get_devices::Request::new();
self.send(request, None).await self.send(request, None).await
@ -1924,7 +2067,7 @@ impl Client {
&self, &self,
devices: &[DeviceIdBox], devices: &[DeviceIdBox],
auth_data: Option<AuthData<'_>>, auth_data: Option<AuthData<'_>>,
) -> Result<delete_devices::Response> { ) -> HttpResult<delete_devices::Response> {
let mut request = delete_devices::Request::new(devices); let mut request = delete_devices::Request::new(devices);
request.auth = auth_data; request.auth = auth_data;
@ -1959,13 +2102,91 @@ impl Client {
); );
let response = self.send(request, Some(request_config)).await?; let response = self.send(request, Some(request_config)).await?;
let sync_response = self.base_client.receive_sync_response(response).await?; self.process_sync(response).await
if let Some(handler) = self.event_handler.read().await.as_ref() {
handler.handle_sync(&sync_response).await;
} }
Ok(sync_response) async fn process_sync(&self, response: sync_events::Response) -> Result<SyncResponse> {
let response = self.base_client.receive_sync_response(response).await?;
let SyncResponse {
next_batch: _,
rooms,
presence,
account_data,
to_device: _,
device_lists: _,
device_one_time_keys_count: _,
ambiguity_changes: _,
notifications,
} = &response;
self.handle_sync_events(EventKind::GlobalAccountData, &None, &account_data.events).await?;
self.handle_sync_events(EventKind::Presence, &None, &presence.events).await?;
for (room_id, room_info) in &rooms.join {
let room = self.get_room(room_id);
if room.is_none() {
error!("Can't call event handler, room {} not found", room_id);
continue;
}
let JoinedRoom { unread_notifications: _, timeline, state, account_data, ephemeral } =
room_info;
self.handle_sync_events(EventKind::EphemeralRoomData, &room, &ephemeral.events).await?;
self.handle_sync_events(EventKind::RoomAccountData, &room, &account_data.events)
.await?;
self.handle_sync_state_events(&room, &state.events).await?;
self.handle_sync_timeline_events(&room, &timeline.events).await?;
}
for (room_id, room_info) in &rooms.leave {
let room = self.get_room(room_id);
if room.is_none() {
error!("Can't call event handler, room {} not found", room_id);
continue;
}
let LeftRoom { timeline, state, account_data } = room_info;
self.handle_sync_events(EventKind::RoomAccountData, &room, &account_data.events)
.await?;
self.handle_sync_state_events(&room, &state.events).await?;
self.handle_sync_timeline_events(&room, &timeline.events).await?;
}
for (room_id, room_info) in &rooms.invite {
let room = self.get_room(room_id);
if room.is_none() {
error!("Can't call event handler, room {} not found", room_id);
continue;
}
// FIXME: Destructure room_info
self.handle_sync_events(EventKind::InitialState, &room, &room_info.invite_state.events)
.await?;
}
for handler in &*self.notification_handlers.read().await {
for (room_id, room_notifications) in notifications {
let room = match self.get_room(room_id) {
Some(room) => room,
None => {
warn!("Can't call notification handler, room {} not found", room_id);
continue;
}
};
for notification in room_notifications {
matrix_sdk_common::executor::spawn((handler)(
notification.clone(),
room.clone(),
self.clone(),
));
}
}
}
Ok(response)
} }
/// Repeatedly call sync to synchronize the client state with the server. /// Repeatedly call sync to synchronize the client state with the server.
@ -2153,6 +2374,87 @@ impl Client {
sync_settings.token = sync_settings.token =
Some(self.sync_token().await.expect("No sync token found after initial sync")); Some(self.sync_token().await.expect("No sync token found after initial sync"));
self.sync_beat.notify(usize::MAX);
}
}
#[cfg(feature = "encryption")]
#[cfg_attr(feature = "docs", doc(cfg(encryption)))]
async fn send_account_data(
&self,
content: ruma::events::AnyGlobalAccountDataEventContent,
) -> Result<ruma::api::client::r0::config::set_global_account_data::Response> {
let own_user =
self.user_id().await.ok_or_else(|| Error::from(HttpError::AuthenticationRequired))?;
let data = serde_json::value::to_raw_value(&content)?;
let request = ruma::api::client::r0::config::set_global_account_data::Request::new(
&data,
ruma::events::EventContent::event_type(&content),
&own_user,
);
Ok(self.send(request, None).await?)
}
#[cfg(feature = "encryption")]
#[cfg_attr(feature = "docs", doc(cfg(encryption)))]
pub(crate) async fn create_dm_room(&self, user_id: UserId) -> Result<Option<room::Joined>> {
use ruma::{
api::client::r0::room::create_room::RoomPreset,
events::AnyGlobalAccountDataEventContent,
};
const SYNC_WAIT_TIME: Duration = Duration::from_secs(3);
// First we create the DM room, where we invite the user and tell the
// invitee that the room should be a DM.
let invite = &[user_id.clone()];
let request = assign!(
ruma::api::client::r0::room::create_room::Request::new(),
{
invite,
is_direct: true,
preset: Some(RoomPreset::TrustedPrivateChat),
}
);
let response = self.send(request, None).await?;
// Now we need to mark the room as a DM for ourselves, we fetch the
// existing `m.direct` event and append the room to the list of DMs we
// have with this user.
let mut content = self
.store()
.get_account_data_event(EventType::Direct)
.await?
.map(|e| e.deserialize())
.transpose()?
.and_then(|e| {
if let AnyGlobalAccountDataEventContent::Direct(c) = e.content() {
Some(c)
} else {
None
}
})
.unwrap_or_else(|| ruma::events::direct::DirectEventContent(BTreeMap::new()));
content.entry(user_id.to_owned()).or_default().push(response.room_id.to_owned());
// TODO We should probably save the fact that we need to send this out
// because otherwise we might end up in a state where we have a DM that
// isn't marked as one.
self.send_account_data(AnyGlobalAccountDataEventContent::Direct(content)).await?;
// If the room is already in our store, fetch it, otherwise wait for a
// sync to be done which should put the room into our store.
if let Some(room) = self.get_joined_room(&response.room_id) {
Ok(Some(room))
} else {
self.sync_beat.listen().wait_timeout(SYNC_WAIT_TIME);
Ok(self.get_joined_room(&response.room_id))
} }
} }
@ -2296,7 +2598,7 @@ impl Client {
/// ///
/// println!("{:?}", device.verified()); /// println!("{:?}", device.verified());
/// ///
/// let verification = device.start_verification().await.unwrap(); /// let verification = device.request_verification().await.unwrap();
/// # }); /// # });
/// ``` /// ```
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
@ -2311,6 +2613,89 @@ impl Client {
Ok(device.map(|d| Device { inner: d, client: self.clone() })) Ok(device.map(|d| Device { inner: d, client: self.clone() }))
} }
/// Get a E2EE identity of an user.
///
/// # Arguments
///
/// * `user_id` - The unique id of the user that the identity belongs to.
///
/// Returns a `UserIdentity` if one is found and the crypto store
/// didn't throw an error.
///
/// This will always return None if the client hasn't been logged in.
///
/// # Example
///
/// ```no_run
/// # use std::convert::TryFrom;
/// # use matrix_sdk::{Client, ruma::UserId};
/// # use url::Url;
/// # use futures::executor::block_on;
/// # let alice = UserId::try_from("@alice:example.org").unwrap();
/// # let homeserver = Url::parse("http://example.com").unwrap();
/// # let client = Client::new(homeserver).unwrap();
/// # block_on(async {
/// let user = client.get_user_identity(&alice).await?;
///
/// if let Some(user) = user {
/// println!("{:?}", user.verified());
///
/// let verification = user.request_verification().await?;
/// }
/// # anyhow::Result::<()>::Ok(()) });
/// ```
#[cfg(feature = "encryption")]
#[cfg_attr(feature = "docs", doc(cfg(encryption)))]
pub async fn get_user_identity(
&self,
user_id: &UserId,
) -> StdResult<Option<crate::identities::UserIdentity>, CryptoStoreError> {
use crate::identities::UserIdentity;
if let Some(olm) = self.base_client.olm_machine().await {
let identity = olm.get_identity(user_id).await?;
Ok(identity.map(|i| match i {
matrix_sdk_base::crypto::UserIdentities::Own(i) => {
UserIdentity::new_own(self.clone(), i)
}
matrix_sdk_base::crypto::UserIdentities::Other(i) => {
UserIdentity::new(self.clone(), i, self.get_dm_room(user_id))
}
}))
} else {
Ok(None)
}
}
/// Get the status of the private cross signing keys.
///
/// This can be used to check which private cross signing keys we have
/// stored locally.
#[cfg(feature = "encryption")]
#[cfg_attr(feature = "docs", doc(cfg(encryption)))]
pub async fn cross_signing_status(&self) -> Option<CrossSigningStatus> {
if let Some(machine) = self.base_client.olm_machine().await {
Some(machine.cross_signing_status().await)
} else {
None
}
}
#[cfg(feature = "encryption")]
#[cfg_attr(feature = "docs", doc(cfg(encryption)))]
fn get_dm_room(&self, user_id: &UserId) -> Option<room::Joined> {
let rooms = self.joined_rooms();
let room_pairs: Vec<_> =
rooms.iter().map(|r| (r.room_id().to_owned(), r.direct_target())).collect();
trace!(rooms =? room_pairs, "Finding direct room");
let room = rooms.into_iter().find(|r| r.direct_target().as_ref() == Some(user_id));
trace!(room =? room, "Found room");
room
}
/// Create and upload a new cross signing identity. /// Create and upload a new cross signing identity.
/// ///
/// # Arguments /// # Arguments
@ -2741,7 +3126,7 @@ impl Client {
} }
/// Gets information about the owner of a given access token. /// Gets information about the owner of a given access token.
pub async fn whoami(&self) -> Result<whoami::Response> { pub async fn whoami(&self) -> HttpResult<whoami::Response> {
let request = whoami::Request::new(); let request = whoami::Request::new();
self.send(request, None).await self.send(request, None).await
} }
@ -2769,8 +3154,10 @@ mod test {
use std::{ use std::{
collections::BTreeMap, collections::BTreeMap,
convert::{TryFrom, TryInto}, convert::{TryFrom, TryInto},
future,
io::Cursor, io::Cursor,
str::FromStr, str::FromStr,
sync::Arc,
time::Duration, time::Duration,
}; };
@ -2800,17 +3187,18 @@ mod test {
event_id, event_id,
events::{ events::{
room::{ room::{
member::MemberEventContent,
message::{ImageMessageEventContent, MessageEventContent}, message::{ImageMessageEventContent, MessageEventContent},
ImageInfo, ImageInfo,
}, },
AnyMessageEventContent, AnySyncStateEvent, EventType, AnyMessageEventContent, AnySyncStateEvent, EventType, SyncStateEvent,
}, },
mxc_uri, room_id, thirdparty, uint, user_id, UserId, mxc_uri, room_id, thirdparty, uint, user_id, UserId,
}; };
use serde_json::json; use serde_json::json;
use super::{Client, Session, SyncSettings, Url}; use super::{Client, Session, SyncSettings, Url};
use crate::{ClientConfig, HttpError, RequestConfig, RoomMember}; use crate::{room, ClientConfig, HttpError, RequestConfig, RoomMember};
async fn logged_in_client() -> Client { async fn logged_in_client() -> Client {
let session = Session { let session = Session {
@ -3070,6 +3458,56 @@ mod test {
// assert_eq!(1, ignored_users.len()) // assert_eq!(1, ignored_users.len())
} }
#[tokio::test]
async fn event_handler() {
use std::sync::atomic::{AtomicU8, Ordering::SeqCst};
let client = logged_in_client().await;
let member_count = Arc::new(AtomicU8::new(0));
let typing_count = Arc::new(AtomicU8::new(0));
let power_levels_count = Arc::new(AtomicU8::new(0));
client
.register_event_handler({
let member_count = member_count.clone();
move |_ev: SyncStateEvent<MemberEventContent>, _room: room::Room| {
member_count.fetch_add(1, SeqCst);
future::ready(())
}
})
.await
.register_event_handler({
let typing_count = typing_count.clone();
move |_ev: SyncStateEvent<MemberEventContent>| {
typing_count.fetch_add(1, SeqCst);
future::ready(())
}
})
.await
.register_event_handler({
let power_levels_count = power_levels_count.clone();
move |_ev: SyncStateEvent<MemberEventContent>,
_client: Client,
_room: room::Room| {
power_levels_count.fetch_add(1, SeqCst);
future::ready(())
}
})
.await;
let response = EventBuilder::default()
.add_room_event(EventsJson::Member)
.add_ephemeral(EventsJson::Typing)
.add_state_event(EventsJson::PowerLevels)
.build_sync_response();
client.process_sync(response).await.unwrap();
assert_eq!(member_count.load(SeqCst), 1);
assert_eq!(typing_count.load(SeqCst), 1);
assert_eq!(power_levels_count.load(SeqCst), 1);
}
#[tokio::test] #[tokio::test]
async fn room_creation() { async fn room_creation() {
let client = logged_in_client().await; let client = logged_in_client().await;
@ -3137,12 +3575,8 @@ mod test {
}); });
if let Err(err) = client.register(user).await { if let Err(err) = client.register(user).await {
if let crate::Error::Http(HttpError::UiaaError(FromHttpResponseError::Http( if let HttpError::UiaaError(FromHttpResponseError::Http(ServerError::Known(
ServerError::Known(UiaaResponse::MatrixError(client_api::Error { UiaaResponse::MatrixError(client_api::Error { kind, message, status_code }),
kind,
message,
status_code,
})),
))) = err ))) = err
{ {
if let client_api::error::ErrorKind::Forbidden = kind { if let client_api::error::ErrorKind::Forbidden = kind {

View File

@ -40,6 +40,9 @@ use url::ParseError as UrlParseError;
/// Result type of the rust-sdk. /// Result type of the rust-sdk.
pub type Result<T> = std::result::Result<T, Error>; pub type Result<T> = std::result::Result<T, Error>;
/// Result type of a pure HTTP request.
pub type HttpResult<T> = std::result::Result<T, HttpError>;
/// An HTTP error, representing either a connection error or an error while /// An HTTP error, representing either a connection error or an error while
/// converting the raw HTTP response into a Matrix response. /// converting the raw HTTP response into a Matrix response.
#[derive(Error, Debug)] #[derive(Error, Debug)]
@ -182,6 +185,30 @@ pub enum RoomKeyImportError {
Export(#[from] KeyExportError), Export(#[from] KeyExportError),
} }
impl HttpError {
/// Try to destructure the error into an universal interactive auth info.
///
/// Some requests require universal interactive auth, doing such a request
/// will always fail the first time with a 401 status code, the response
/// body will contain info how the client can authenticate.
///
/// The request will need to be retried, this time containing additional
/// authentication data.
///
/// This method is an convenience method to get to the info the server
/// returned on the first, failed request.
pub fn uiaa_response(&self) -> Option<&UiaaInfo> {
if let HttpError::UiaaError(FromHttpResponseError::Http(ServerError::Known(
UiaaError::AuthResponse(i),
))) = self
{
Some(i)
} else {
None
}
}
}
impl Error { impl Error {
/// Try to destructure the error into an universal interactive auth info. /// Try to destructure the error into an universal interactive auth info.
/// ///

View File

@ -0,0 +1,439 @@
// Copyright 2021 Jonas Platte
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Types and traits related for event handlers. For usage, see
//! [`Client::register_event_handler`].
//!
//! ### How it works
//!
//! The `register_event_handler` method registers event handlers of different
//! signatures by actually storing boxed closures that all have the same
//! signature of `async (EventHandlerData) -> ()` where `EventHandlerData` is a
//! private type that contains all of the data an event handler *might* need.
//!
//! The stored closure takes care of deserializing the event which the
//! `EventHandlerData` contains as a (borrowed) [`serde_json::value::RawValue`],
//! extracing the context arguments from other fields of `EventHandlerData` and
//! calling / `.await`ing the event handler if the previous steps succeeded.
//! It also logs any errors from the above chain of function calls.
//!
//! For more details, see the [`EventHandler`] trait.
use std::{borrow::Cow, future::Future, ops::Deref};
use matrix_sdk_base::deserialized_responses::SyncRoomEvent;
use ruma::{events::AnySyncStateEvent, serde::Raw};
use serde::Deserialize;
use serde_json::value::RawValue as RawJsonValue;
use crate::{room, Client};
#[doc(hidden)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
pub enum EventKind {
GlobalAccountData,
RoomAccountData,
EphemeralRoomData,
Message { redacted: bool },
State { redacted: bool },
StrippedState { redacted: bool },
InitialState,
ToDevice,
Presence,
}
/// A statically-known event kind/type that can be retrieved from an event sync.
pub trait SyncEvent {
#[doc(hidden)]
const ID: (EventKind, &'static str);
}
/// Interface for event handlers.
///
/// This trait is an abstraction for a certain kind of functions / closures,
/// specifically:
///
/// * They must have at least one argument, which is the event itself, a type
/// that implements [`SyncEvent`]. Any additional arguments need to implement
/// the [`EventHandlerContext`] trait.
/// * Their return type has to be one of: `()`, `Result<(), impl
/// std::error::Error>` or `anyhow::Result<()>` (requires the `anyhow` Cargo
/// feature to be enabled)
///
/// ### How it works
///
/// This trait is basically a very constrained version of `Fn`: It requires at
/// least one argument, which is represented as its own generic parameter `Ev`
/// with the remaining parameter types being represented by the second generic
/// parameter `Ctx`; they have to be stuffed into one generic parameter as a
/// tuple because Rust doesn't have variadic generics.
///
/// `Ev` and `Ctx` are generic parameters rather than associated types because
/// the argument list is a generic parameter for the `Fn` traits too, so a
/// single type could implement `Fn` multiple times with different argument
/// lists¹. Luckily, when calling [`Client::register_event_handler`] with a
/// closure argument the trait solver takes into account that only a single one
/// of the implementations applies (even though this could theoretically change
/// through a dependency upgrade) and uses that rather than raising an ambiguity
/// error. This is the same trick used by web frameworks like actix-web and
/// axum.
///
/// ¹ the only thing stopping such types from existing in stable Rust is that
/// all manual implementations of the `Fn` traits require a Nightly feature
pub trait EventHandler<Ev, Ctx>: Clone + Send + Sync + 'static {
/// The future returned by `handle_event`.
#[doc(hidden)]
type Future: Future + Send + 'static;
/// The event type being handled, for example a message event of type
/// `m.room.message`.
#[doc(hidden)]
const ID: (EventKind, &'static str);
/// Create a future for handling the given event.
///
/// `data` provides additional data about the event, for example the room it
/// appeared in.
///
/// Returns `None` if one of the context extractors failed.
#[doc(hidden)]
fn handle_event(&self, ev: Ev, data: EventHandlerData<'_>) -> Option<Self::Future>;
}
#[doc(hidden)]
#[derive(Debug)]
pub struct EventHandlerData<'a> {
pub client: Client,
pub room: Option<room::Room>,
pub raw: &'a RawJsonValue,
}
/// Context for an event handler.
///
/// This trait defines the set of types that may be used as additional arguments
/// in event handler functions after the event itself.
pub trait EventHandlerContext: Sized {
#[doc(hidden)]
fn from_data(_: &EventHandlerData<'_>) -> Option<Self>;
}
impl EventHandlerContext for Client {
fn from_data(data: &EventHandlerData<'_>) -> Option<Self> {
Some(data.client.clone())
}
}
/// This event handler context argument is only applicable to room-specific
/// events.
///
/// Trying to use it in the event handler for another event, for example a
/// global account data or presence event, will result in the event handler
/// being skipped and an error getting logged.
impl EventHandlerContext for room::Room {
fn from_data(data: &EventHandlerData<'_>) -> Option<Self> {
data.room.clone()
}
}
/// The raw JSON form of an event.
///
/// Used as a context argument for event handlers (see
/// [`Client::register_event_handler`]).
// FIXME: This could be made to not own the raw JSON value with some changes to
// the traits above, but only with GATs.
#[derive(Clone, Debug)]
pub struct RawEvent(pub Box<RawJsonValue>);
impl Deref for RawEvent {
type Target = RawJsonValue;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl EventHandlerContext for RawEvent {
fn from_data(data: &EventHandlerData<'_>) -> Option<Self> {
Some(Self(data.raw.to_owned()))
}
}
/// Return types supported for event handlers implement this trait.
///
/// It is not meant to be implemented outside of matrix-sdk.
pub trait EventHandlerResult: Sized {
#[doc(hidden)]
fn print_error(&self, event_type: &str);
}
impl EventHandlerResult for () {
fn print_error(&self, _event_type: &str) {}
}
impl<E: std::error::Error> EventHandlerResult for Result<(), E> {
fn print_error(&self, event_type: &str) {
if let Err(e) = self {
tracing::error!("Event handler for `{}` failed: {}", event_type, e);
}
}
}
#[cfg(feature = "anyhow")]
impl EventHandlerResult for anyhow::Result<()> {
fn print_error(&self, event_type: &str) {
if let Err(e) = self {
tracing::error!("Event handler for `{}` failed: {:?}", event_type, e);
}
}
}
#[derive(Deserialize)]
struct UnsignedDetails {
redacted_because: Option<serde::de::IgnoredAny>,
}
/// Event handling internals.
impl Client {
pub(crate) async fn handle_sync_events<T>(
&self,
kind: EventKind,
room: &Option<room::Room>,
events: &[Raw<T>],
) -> serde_json::Result<()> {
self.handle_sync_events_wrapped(kind, room, events, |x| x).await
}
pub(crate) async fn handle_sync_state_events(
&self,
room: &Option<room::Room>,
state_events: &[Raw<AnySyncStateEvent>],
) -> serde_json::Result<()> {
#[derive(Deserialize)]
struct StateEventDetails<'a> {
#[serde(borrow, rename = "type")]
event_type: Cow<'a, str>,
unsigned: Option<UnsignedDetails>,
}
self.handle_sync_events_wrapped_with(room, state_events, std::convert::identity, |raw| {
let StateEventDetails { event_type, unsigned } = raw.deserialize_as()?;
let redacted = unsigned.and_then(|u| u.redacted_because).is_some();
Ok((EventKind::State { redacted }, event_type))
})
.await
}
pub(crate) async fn handle_sync_timeline_events(
&self,
room: &Option<room::Room>,
timeline_events: &[SyncRoomEvent],
) -> serde_json::Result<()> {
// FIXME: add EncryptionInfo to context
#[derive(Deserialize)]
struct TimelineEventDetails<'a> {
#[serde(borrow, rename = "type")]
event_type: Cow<'a, str>,
state_key: Option<serde::de::IgnoredAny>,
unsigned: Option<UnsignedDetails>,
}
self.handle_sync_events_wrapped_with(
room,
timeline_events,
|e| &e.event,
|raw| {
let TimelineEventDetails { event_type, state_key, unsigned } =
raw.deserialize_as()?;
let redacted = unsigned.and_then(|u| u.redacted_because).is_some();
let kind = match state_key {
Some(_) => EventKind::State { redacted },
None => EventKind::Message { redacted },
};
Ok((kind, event_type))
},
)
.await
}
async fn handle_sync_events_wrapped<'a, T: 'a, U: 'a>(
&self,
kind: EventKind,
room: &Option<room::Room>,
events: &'a [U],
get_event: impl Fn(&'a U) -> &'a Raw<T>,
) -> Result<(), serde_json::Error> {
#[derive(Deserialize)]
struct ExtractType<'a> {
#[serde(borrow, rename = "type")]
event_type: Cow<'a, str>,
}
self.handle_sync_events_wrapped_with(room, events, get_event, |raw| {
Ok((kind, raw.deserialize_as::<ExtractType>()?.event_type))
})
.await
}
async fn handle_sync_events_wrapped_with<'a, T: 'a, U: 'a>(
&self,
room: &Option<room::Room>,
list: &'a [U],
get_event: impl Fn(&'a U) -> &'a Raw<T>,
get_id: impl Fn(&Raw<T>) -> serde_json::Result<(EventKind, Cow<'_, str>)>,
) -> serde_json::Result<()> {
for x in list {
let event = get_event(x);
let (ev_kind, ev_type) = get_id(event)?;
let event_handler_id = (ev_kind, &*ev_type);
if let Some(handlers) = self.event_handlers.read().await.get(&event_handler_id) {
for handler in &*handlers {
let data = EventHandlerData {
client: self.clone(),
room: room.clone(),
raw: event.json(),
};
matrix_sdk_common::executor::spawn((handler)(data));
}
}
}
Ok(())
}
}
macro_rules! impl_event_handler {
($($ty:ident),* $(,)?) => {
impl<Ev, Fun, Fut, $($ty),*> EventHandler<Ev, ($($ty,)*)> for Fun
where
Ev: SyncEvent,
Fun: Fn(Ev, $($ty),*) -> Fut + Clone + Send + Sync + 'static,
Fut: Future + Send + 'static,
Fut::Output: EventHandlerResult,
$($ty: EventHandlerContext),*
{
type Future = Fut;
const ID: (EventKind, &'static str) = Ev::ID;
fn handle_event(&self, ev: Ev, _d: EventHandlerData<'_>) -> Option<Self::Future> {
Some((self)(ev, $($ty::from_data(&_d)?),*))
}
}
};
}
impl_event_handler!();
impl_event_handler!(A);
impl_event_handler!(A, B);
impl_event_handler!(A, B, C);
impl_event_handler!(A, B, C, D);
impl_event_handler!(A, B, C, D, E);
impl_event_handler!(A, B, C, D, E, F);
impl_event_handler!(A, B, C, D, E, F, G);
impl_event_handler!(A, B, C, D, E, F, G, H);
mod static_events {
use ruma::events::{
self,
presence::{PresenceEvent, PresenceEventContent},
StaticEventContent,
};
use super::{EventKind, SyncEvent};
impl<C> SyncEvent for events::GlobalAccountDataEvent<C>
where
C: StaticEventContent + events::GlobalAccountDataEventContent,
{
const ID: (EventKind, &'static str) = (EventKind::GlobalAccountData, C::TYPE);
}
impl<C> SyncEvent for events::RoomAccountDataEvent<C>
where
C: StaticEventContent + events::RoomAccountDataEventContent,
{
const ID: (EventKind, &'static str) = (EventKind::RoomAccountData, C::TYPE);
}
impl<C> SyncEvent for events::SyncEphemeralRoomEvent<C>
where
C: StaticEventContent + events::EphemeralRoomEventContent,
{
const ID: (EventKind, &'static str) = (EventKind::EphemeralRoomData, C::TYPE);
}
impl<C> SyncEvent for events::SyncMessageEvent<C>
where
C: StaticEventContent + events::MessageEventContent,
{
const ID: (EventKind, &'static str) = (EventKind::Message { redacted: false }, C::TYPE);
}
impl<C> SyncEvent for events::SyncStateEvent<C>
where
C: StaticEventContent + events::StateEventContent,
{
const ID: (EventKind, &'static str) = (EventKind::State { redacted: false }, C::TYPE);
}
impl<C> SyncEvent for events::StrippedStateEvent<C>
where
C: StaticEventContent + events::StateEventContent,
{
const ID: (EventKind, &'static str) =
(EventKind::StrippedState { redacted: false }, C::TYPE);
}
impl<C> SyncEvent for events::InitialStateEvent<C>
where
C: StaticEventContent + events::StateEventContent,
{
const ID: (EventKind, &'static str) = (EventKind::InitialState, C::TYPE);
}
impl<C> SyncEvent for events::ToDeviceEvent<C>
where
C: StaticEventContent + events::ToDeviceEventContent,
{
const ID: (EventKind, &'static str) = (EventKind::ToDevice, C::TYPE);
}
impl SyncEvent for PresenceEvent {
const ID: (EventKind, &'static str) = (EventKind::Presence, PresenceEventContent::TYPE);
}
impl<C> SyncEvent for events::RedactedSyncMessageEvent<C>
where
C: StaticEventContent + events::RedactedMessageEventContent,
{
const ID: (EventKind, &'static str) = (EventKind::Message { redacted: true }, C::TYPE);
}
impl<C> SyncEvent for events::RedactedSyncStateEvent<C>
where
C: StaticEventContent + events::RedactedStateEventContent,
{
const ID: (EventKind, &'static str) = (EventKind::State { redacted: true }, C::TYPE);
}
impl<C> SyncEvent for events::RedactedStrippedStateEvent<C>
where
C: StaticEventContent + events::RedactedStateEventContent,
{
const ID: (EventKind, &'static str) =
(EventKind::StrippedState { redacted: true }, C::TYPE);
}
}

View File

@ -1,964 +0,0 @@
// Copyright 2020 Damir Jelić
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::ops::Deref;
use matrix_sdk_base::{hoist_and_deserialize_state_event, hoist_room_event_prev_content};
use matrix_sdk_common::async_trait;
use ruma::{
api::client::r0::push::get_notifications::Notification,
events::{
call::{
answer::AnswerEventContent, candidates::CandidatesEventContent,
hangup::HangupEventContent, invite::InviteEventContent,
},
custom::CustomEventContent,
fully_read::FullyReadEventContent,
ignored_user_list::IgnoredUserListEventContent,
presence::PresenceEvent,
push_rules::PushRulesEventContent,
reaction::ReactionEventContent,
receipt::ReceiptEventContent,
room::{
aliases::AliasesEventContent,
avatar::AvatarEventContent,
canonical_alias::CanonicalAliasEventContent,
join_rules::JoinRulesEventContent,
member::MemberEventContent,
message::{feedback::FeedbackEventContent, MessageEventContent as MsgEventContent},
name::NameEventContent,
power_levels::PowerLevelsEventContent,
redaction::SyncRedactionEvent,
tombstone::TombstoneEventContent,
},
typing::TypingEventContent,
AnyGlobalAccountDataEvent, AnyRoomAccountDataEvent, AnyStrippedStateEvent,
AnySyncEphemeralRoomEvent, AnySyncMessageEvent, AnySyncRoomEvent, AnySyncStateEvent,
GlobalAccountDataEvent, RoomAccountDataEvent, StrippedStateEvent, SyncEphemeralRoomEvent,
SyncMessageEvent, SyncStateEvent,
},
serde::Raw,
RoomId,
};
use serde_json::value::RawValue as RawJsonValue;
use crate::{deserialized_responses::SyncResponse, room::Room, Client};
pub(crate) struct Handler {
pub(crate) inner: Box<dyn EventHandler>,
pub(crate) client: Client,
}
impl Deref for Handler {
type Target = dyn EventHandler;
fn deref(&self) -> &Self::Target {
&*self.inner
}
}
impl Handler {
fn get_room(&self, room_id: &RoomId) -> Option<Room> {
self.client.get_room(room_id)
}
pub(crate) async fn handle_sync(&self, response: &SyncResponse) {
for event in response.account_data.events.iter().filter_map(|e| e.deserialize().ok()) {
self.handle_account_data_event(&event).await;
}
for (room_id, room_info) in &response.rooms.join {
if let Some(room) = self.get_room(room_id) {
for event in room_info.ephemeral.events.iter().filter_map(|e| e.deserialize().ok())
{
self.handle_ephemeral_event(room.clone(), &event).await;
}
for event in
room_info.account_data.events.iter().filter_map(|e| e.deserialize().ok())
{
self.handle_room_account_data_event(room.clone(), &event).await;
}
for (raw_event, event) in room_info.state.events.iter().filter_map(|e| {
if let Ok(d) = hoist_and_deserialize_state_event(e) {
Some((e, d))
} else {
None
}
}) {
self.handle_state_event(room.clone(), &event, raw_event).await;
}
for (raw_event, event) in room_info.timeline.events.iter().filter_map(|e| {
if let Ok(d) = hoist_room_event_prev_content(&e.event) {
Some((&e.event, d))
} else {
None
}
}) {
self.handle_timeline_event(room.clone(), &event, raw_event).await;
}
}
}
for (room_id, room_info) in &response.rooms.leave {
if let Some(room) = self.get_room(room_id) {
for event in
room_info.account_data.events.iter().filter_map(|e| e.deserialize().ok())
{
self.handle_room_account_data_event(room.clone(), &event).await;
}
for (raw_event, event) in room_info.state.events.iter().filter_map(|e| {
if let Ok(d) = hoist_and_deserialize_state_event(e) {
Some((e, d))
} else {
None
}
}) {
self.handle_state_event(room.clone(), &event, raw_event).await;
}
for (raw_event, event) in room_info.timeline.events.iter().filter_map(|e| {
if let Ok(d) = hoist_room_event_prev_content(&e.event) {
Some((&e.event, d))
} else {
None
}
}) {
self.handle_timeline_event(room.clone(), &event, raw_event).await;
}
}
}
for (room_id, room_info) in &response.rooms.invite {
if let Some(room) = self.get_room(room_id) {
for event in
room_info.invite_state.events.iter().filter_map(|e| e.deserialize().ok())
{
self.handle_stripped_state_event(room.clone(), &event).await;
}
}
}
for event in response.presence.events.iter().filter_map(|e| e.deserialize().ok()) {
self.on_presence_event(&event).await;
}
for (room_id, notifications) in &response.notifications {
if let Some(room) = self.get_room(room_id) {
for notification in notifications {
self.on_room_notification(room.clone(), notification.clone()).await;
}
}
}
}
async fn handle_timeline_event(
&self,
room: Room,
event: &AnySyncRoomEvent,
raw_event: &Raw<AnySyncRoomEvent>,
) {
match event {
AnySyncRoomEvent::State(event) => match event {
AnySyncStateEvent::RoomMember(e) => self.on_room_member(room, e).await,
AnySyncStateEvent::RoomName(e) => self.on_room_name(room, e).await,
AnySyncStateEvent::RoomCanonicalAlias(e) => {
self.on_room_canonical_alias(room, e).await
}
AnySyncStateEvent::RoomAliases(e) => self.on_room_aliases(room, e).await,
AnySyncStateEvent::RoomAvatar(e) => self.on_room_avatar(room, e).await,
AnySyncStateEvent::RoomPowerLevels(e) => self.on_room_power_levels(room, e).await,
AnySyncStateEvent::RoomTombstone(e) => self.on_room_tombstone(room, e).await,
AnySyncStateEvent::RoomJoinRules(e) => self.on_room_join_rules(room, e).await,
AnySyncStateEvent::PolicyRuleRoom(_)
| AnySyncStateEvent::PolicyRuleServer(_)
| AnySyncStateEvent::PolicyRuleUser(_)
| AnySyncStateEvent::RoomCreate(_)
| AnySyncStateEvent::RoomEncryption(_)
| AnySyncStateEvent::RoomGuestAccess(_)
| AnySyncStateEvent::RoomHistoryVisibility(_)
| AnySyncStateEvent::RoomPinnedEvents(_)
| AnySyncStateEvent::RoomServerAcl(_)
| AnySyncStateEvent::RoomThirdPartyInvite(_)
| AnySyncStateEvent::RoomTopic(_)
| AnySyncStateEvent::SpaceChild(_)
| AnySyncStateEvent::SpaceParent(_) => {}
_ => {
if let Ok(e) = raw_event.deserialize_as::<SyncStateEvent<CustomEventContent>>()
{
self.on_custom_event(room, &CustomEvent::State(&e)).await;
}
}
},
AnySyncRoomEvent::Message(event) => match event {
AnySyncMessageEvent::RoomMessage(e) => self.on_room_message(room, e).await,
AnySyncMessageEvent::RoomMessageFeedback(e) => {
self.on_room_message_feedback(room, e).await
}
AnySyncMessageEvent::RoomRedaction(e) => self.on_room_redaction(room, e).await,
AnySyncMessageEvent::Reaction(e) => self.on_room_reaction(room, e).await,
AnySyncMessageEvent::CallInvite(e) => self.on_room_call_invite(room, e).await,
AnySyncMessageEvent::CallAnswer(e) => self.on_room_call_answer(room, e).await,
AnySyncMessageEvent::CallCandidates(e) => {
self.on_room_call_candidates(room, e).await
}
AnySyncMessageEvent::CallHangup(e) => self.on_room_call_hangup(room, e).await,
AnySyncMessageEvent::KeyVerificationReady(_)
| AnySyncMessageEvent::KeyVerificationStart(_)
| AnySyncMessageEvent::KeyVerificationCancel(_)
| AnySyncMessageEvent::KeyVerificationAccept(_)
| AnySyncMessageEvent::KeyVerificationKey(_)
| AnySyncMessageEvent::KeyVerificationMac(_)
| AnySyncMessageEvent::KeyVerificationDone(_)
| AnySyncMessageEvent::RoomEncrypted(_)
| AnySyncMessageEvent::Sticker(_) => {}
_ => {
if let Ok(e) =
raw_event.deserialize_as::<SyncMessageEvent<CustomEventContent>>()
{
self.on_custom_event(room, &CustomEvent::Message(&e)).await;
}
}
},
AnySyncRoomEvent::RedactedState(_event) => {}
AnySyncRoomEvent::RedactedMessage(_event) => {}
}
}
async fn handle_state_event(
&self,
room: Room,
event: &AnySyncStateEvent,
raw_event: &Raw<AnySyncStateEvent>,
) {
match event {
AnySyncStateEvent::RoomMember(member) => self.on_state_member(room, member).await,
AnySyncStateEvent::RoomName(name) => self.on_state_name(room, name).await,
AnySyncStateEvent::RoomCanonicalAlias(canonical) => {
self.on_state_canonical_alias(room, canonical).await
}
AnySyncStateEvent::RoomAliases(aliases) => self.on_state_aliases(room, aliases).await,
AnySyncStateEvent::RoomAvatar(avatar) => self.on_state_avatar(room, avatar).await,
AnySyncStateEvent::RoomPowerLevels(power) => {
self.on_state_power_levels(room, power).await
}
AnySyncStateEvent::RoomJoinRules(rules) => self.on_state_join_rules(room, rules).await,
AnySyncStateEvent::RoomTombstone(tomb) => {
// TODO make `on_state_tombstone` method
self.on_room_tombstone(room, tomb).await
}
AnySyncStateEvent::PolicyRuleRoom(_)
| AnySyncStateEvent::PolicyRuleServer(_)
| AnySyncStateEvent::PolicyRuleUser(_)
| AnySyncStateEvent::RoomCreate(_)
| AnySyncStateEvent::RoomEncryption(_)
| AnySyncStateEvent::RoomGuestAccess(_)
| AnySyncStateEvent::RoomHistoryVisibility(_)
| AnySyncStateEvent::RoomPinnedEvents(_)
| AnySyncStateEvent::RoomServerAcl(_)
| AnySyncStateEvent::RoomThirdPartyInvite(_)
| AnySyncStateEvent::RoomTopic(_)
| AnySyncStateEvent::SpaceChild(_)
| AnySyncStateEvent::SpaceParent(_) => {}
_ => {
if let Ok(e) = raw_event.deserialize_as::<SyncStateEvent<CustomEventContent>>() {
self.on_custom_event(room, &CustomEvent::State(&e)).await;
}
}
}
}
pub(crate) async fn handle_stripped_state_event(
&self,
// TODO these events are only handled in invited rooms.
room: Room,
event: &AnyStrippedStateEvent,
) {
match event {
AnyStrippedStateEvent::RoomMember(member) => {
self.on_stripped_state_member(room, member, None).await
}
AnyStrippedStateEvent::RoomName(name) => self.on_stripped_state_name(room, name).await,
AnyStrippedStateEvent::RoomCanonicalAlias(canonical) => {
self.on_stripped_state_canonical_alias(room, canonical).await
}
AnyStrippedStateEvent::RoomAliases(aliases) => {
self.on_stripped_state_aliases(room, aliases).await
}
AnyStrippedStateEvent::RoomAvatar(avatar) => {
self.on_stripped_state_avatar(room, avatar).await
}
AnyStrippedStateEvent::RoomPowerLevels(power) => {
self.on_stripped_state_power_levels(room, power).await
}
AnyStrippedStateEvent::RoomJoinRules(rules) => {
self.on_stripped_state_join_rules(room, rules).await
}
_ => {}
}
}
pub(crate) async fn handle_room_account_data_event(
&self,
room: Room,
event: &AnyRoomAccountDataEvent,
) {
if let AnyRoomAccountDataEvent::FullyRead(event) = event {
self.on_non_room_fully_read(room, event).await
}
}
pub(crate) async fn handle_account_data_event(&self, event: &AnyGlobalAccountDataEvent) {
match event {
AnyGlobalAccountDataEvent::IgnoredUserList(ignored) => {
self.on_non_room_ignored_users(ignored).await
}
AnyGlobalAccountDataEvent::PushRules(rules) => self.on_non_room_push_rules(rules).await,
_ => {}
}
}
pub(crate) async fn handle_ephemeral_event(
&self,
room: Room,
event: &AnySyncEphemeralRoomEvent,
) {
match event {
AnySyncEphemeralRoomEvent::Typing(typing) => {
self.on_non_room_typing(room, typing).await
}
AnySyncEphemeralRoomEvent::Receipt(receipt) => {
self.on_non_room_receipt(room, receipt).await
}
_ => {}
}
}
}
/// This represents the various "unrecognized" events.
#[derive(Clone, Copy, Debug)]
pub enum CustomEvent<'c> {
/// A custom basic event.
Basic(&'c GlobalAccountDataEvent<CustomEventContent>),
/// A custom basic event.
EphemeralRoom(&'c SyncEphemeralRoomEvent<CustomEventContent>),
/// A custom room event.
Message(&'c SyncMessageEvent<CustomEventContent>),
/// A custom state event.
State(&'c SyncStateEvent<CustomEventContent>),
/// A custom stripped state event.
StrippedState(&'c StrippedStateEvent<CustomEventContent>),
}
/// This trait allows any type implementing `EventHandler` to specify event
/// callbacks for each event. The `Client` calls each method when the
/// corresponding event is received.
///
/// # Examples
/// ```
/// # use std::ops::Deref;
/// # use std::sync::Arc;
/// # use std::{env, process::exit};
/// # use matrix_sdk::{
/// # async_trait,
/// # EventHandler,
/// # ruma::events::{
/// # room::message::{MessageEventContent, MessageType, TextMessageEventContent},
/// # SyncMessageEvent
/// # },
/// # locks::RwLock,
/// # room::Room,
/// # };
///
/// struct EventCallback;
///
/// #[async_trait]
/// impl EventHandler for EventCallback {
/// async fn on_room_message(&self, room: Room, event: &SyncMessageEvent<MessageEventContent>) {
/// if let Room::Joined(room) = room {
/// if let SyncMessageEvent {
/// content:
/// MessageEventContent {
/// msgtype: MessageType::Text(TextMessageEventContent { body: msg_body, .. }),
/// ..
/// },
/// sender,
/// ..
/// } = event
/// {
/// let member = room.get_member(&sender).await.unwrap().unwrap();
/// let name = member
/// .display_name()
/// .unwrap_or_else(|| member.user_id().as_str());
/// println!("{}: {}", name, msg_body);
/// }
/// }
/// }
/// }
/// ```
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
pub trait EventHandler: Send + Sync {
// ROOM EVENTS from `IncomingTimeline`
/// Fires when `Client` receives a `RoomEvent::RoomMember` event.
async fn on_room_member(&self, _: Room, _: &SyncStateEvent<MemberEventContent>) {}
/// Fires when `Client` receives a `RoomEvent::RoomName` event.
async fn on_room_name(&self, _: Room, _: &SyncStateEvent<NameEventContent>) {}
/// Fires when `Client` receives a `RoomEvent::RoomCanonicalAlias` event.
async fn on_room_canonical_alias(
&self,
_: Room,
_: &SyncStateEvent<CanonicalAliasEventContent>,
) {
}
/// Fires when `Client` receives a `RoomEvent::RoomAliases` event.
async fn on_room_aliases(&self, _: Room, _: &SyncStateEvent<AliasesEventContent>) {}
/// Fires when `Client` receives a `RoomEvent::RoomAvatar` event.
async fn on_room_avatar(&self, _: Room, _: &SyncStateEvent<AvatarEventContent>) {}
/// Fires when `Client` receives a `RoomEvent::RoomMessage` event.
async fn on_room_message(&self, _: Room, _: &SyncMessageEvent<MsgEventContent>) {}
/// Fires when `Client` receives a `RoomEvent::RoomMessageFeedback` event.
async fn on_room_message_feedback(&self, _: Room, _: &SyncMessageEvent<FeedbackEventContent>) {}
/// Fires when `Client` receives a `RoomEvent::Reaction` event.
async fn on_room_reaction(&self, _: Room, _: &SyncMessageEvent<ReactionEventContent>) {}
/// Fires when `Client` receives a `RoomEvent::CallInvite` event
async fn on_room_call_invite(&self, _: Room, _: &SyncMessageEvent<InviteEventContent>) {}
/// Fires when `Client` receives a `RoomEvent::CallAnswer` event
async fn on_room_call_answer(&self, _: Room, _: &SyncMessageEvent<AnswerEventContent>) {}
/// Fires when `Client` receives a `RoomEvent::CallCandidates` event
async fn on_room_call_candidates(&self, _: Room, _: &SyncMessageEvent<CandidatesEventContent>) {
}
/// Fires when `Client` receives a `RoomEvent::CallHangup` event
async fn on_room_call_hangup(&self, _: Room, _: &SyncMessageEvent<HangupEventContent>) {}
/// Fires when `Client` receives a `RoomEvent::RoomRedaction` event.
async fn on_room_redaction(&self, _: Room, _: &SyncRedactionEvent) {}
/// Fires when `Client` receives a `RoomEvent::RoomPowerLevels` event.
async fn on_room_power_levels(&self, _: Room, _: &SyncStateEvent<PowerLevelsEventContent>) {}
/// Fires when `Client` receives a `RoomEvent::RoomJoinRules` event.
async fn on_room_join_rules(&self, _: Room, _: &SyncStateEvent<JoinRulesEventContent>) {}
/// Fires when `Client` receives a `RoomEvent::Tombstone` event.
async fn on_room_tombstone(&self, _: Room, _: &SyncStateEvent<TombstoneEventContent>) {}
/// Fires when `Client` receives room events that trigger notifications
/// according to the push rules of the user.
async fn on_room_notification(&self, _: Room, _: Notification) {}
// `RoomEvent`s from `IncomingState`
/// Fires when `Client` receives a `StateEvent::RoomMember` event.
async fn on_state_member(&self, _: Room, _: &SyncStateEvent<MemberEventContent>) {}
/// Fires when `Client` receives a `StateEvent::RoomName` event.
async fn on_state_name(&self, _: Room, _: &SyncStateEvent<NameEventContent>) {}
/// Fires when `Client` receives a `StateEvent::RoomCanonicalAlias` event.
async fn on_state_canonical_alias(
&self,
_: Room,
_: &SyncStateEvent<CanonicalAliasEventContent>,
) {
}
/// Fires when `Client` receives a `StateEvent::RoomAliases` event.
async fn on_state_aliases(&self, _: Room, _: &SyncStateEvent<AliasesEventContent>) {}
/// Fires when `Client` receives a `StateEvent::RoomAvatar` event.
async fn on_state_avatar(&self, _: Room, _: &SyncStateEvent<AvatarEventContent>) {}
/// Fires when `Client` receives a `StateEvent::RoomPowerLevels` event.
async fn on_state_power_levels(&self, _: Room, _: &SyncStateEvent<PowerLevelsEventContent>) {}
/// Fires when `Client` receives a `StateEvent::RoomJoinRules` event.
async fn on_state_join_rules(&self, _: Room, _: &SyncStateEvent<JoinRulesEventContent>) {}
// `AnyStrippedStateEvent`s
/// Fires when `Client` receives a
/// `AnyStrippedStateEvent::StrippedRoomMember` event.
async fn on_stripped_state_member(
&self,
_: Room,
_: &StrippedStateEvent<MemberEventContent>,
_: Option<MemberEventContent>,
) {
}
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomName`
/// event.
async fn on_stripped_state_name(&self, _: Room, _: &StrippedStateEvent<NameEventContent>) {}
/// Fires when `Client` receives a
/// `AnyStrippedStateEvent::StrippedRoomCanonicalAlias` event.
async fn on_stripped_state_canonical_alias(
&self,
_: Room,
_: &StrippedStateEvent<CanonicalAliasEventContent>,
) {
}
/// Fires when `Client` receives a
/// `AnyStrippedStateEvent::StrippedRoomAliases` event.
async fn on_stripped_state_aliases(
&self,
_: Room,
_: &StrippedStateEvent<AliasesEventContent>,
) {
}
/// Fires when `Client` receives a
/// `AnyStrippedStateEvent::StrippedRoomAvatar` event.
async fn on_stripped_state_avatar(&self, _: Room, _: &StrippedStateEvent<AvatarEventContent>) {}
/// Fires when `Client` receives a
/// `AnyStrippedStateEvent::StrippedRoomPowerLevels` event.
async fn on_stripped_state_power_levels(
&self,
_: Room,
_: &StrippedStateEvent<PowerLevelsEventContent>,
) {
}
/// Fires when `Client` receives a
/// `AnyStrippedStateEvent::StrippedRoomJoinRules` event.
async fn on_stripped_state_join_rules(
&self,
_: Room,
_: &StrippedStateEvent<JoinRulesEventContent>,
) {
}
// `NonRoomEvent` (this is a type alias from ruma_events)
/// Fires when `Client` receives a `NonRoomEvent::RoomPresence` event.
async fn on_non_room_presence(&self, _: Room, _: &PresenceEvent) {}
/// Fires when `Client` receives a `NonRoomEvent::RoomName` event.
async fn on_non_room_ignored_users(
&self,
_: &GlobalAccountDataEvent<IgnoredUserListEventContent>,
) {
}
/// Fires when `Client` receives a `NonRoomEvent::RoomCanonicalAlias` event.
async fn on_non_room_push_rules(&self, _: &GlobalAccountDataEvent<PushRulesEventContent>) {}
/// Fires when `Client` receives a `NonRoomEvent::RoomAliases` event.
async fn on_non_room_fully_read(
&self,
_: Room,
_: &RoomAccountDataEvent<FullyReadEventContent>,
) {
}
/// Fires when `Client` receives a `NonRoomEvent::Typing` event.
async fn on_non_room_typing(&self, _: Room, _: &SyncEphemeralRoomEvent<TypingEventContent>) {}
/// Fires when `Client` receives a `NonRoomEvent::Receipt` event.
///
/// This is always a read receipt.
async fn on_non_room_receipt(&self, _: Room, _: &SyncEphemeralRoomEvent<ReceiptEventContent>) {}
// `PresenceEvent` is a struct so there is only the one method
/// Fires when `Client` receives a `NonRoomEvent::RoomAliases` event.
async fn on_presence_event(&self, _: &PresenceEvent) {}
/// Fires when `Client` receives a `Event::Custom` event or if
/// deserialization fails because the event was unknown to ruma.
///
/// The only guarantee this method can give about the event is that it is
/// valid JSON.
async fn on_unrecognized_event(&self, _: Room, _: &RawJsonValue) {}
/// Fires when `Client` receives a `Event::Custom` event or if
/// deserialization fails because the event was unknown to ruma.
///
/// The only guarantee this method can give about the event is that it is in
/// the shape of a valid matrix event.
async fn on_custom_event(&self, _: Room, _: &CustomEvent<'_>) {}
}
#[cfg(test)]
mod test {
use std::{sync::Arc, time::Duration};
use matrix_sdk_common::{async_trait, locks::Mutex};
use matrix_sdk_test::{async_test, test_json};
use mockito::{mock, Matcher};
use ruma::user_id;
#[cfg(target_arch = "wasm32")]
pub use wasm_bindgen_test::*;
use super::*;
#[derive(Clone)]
pub struct EvHandlerTest(Arc<Mutex<Vec<String>>>);
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
impl EventHandler for EvHandlerTest {
async fn on_room_member(&self, _: Room, _: &SyncStateEvent<MemberEventContent>) {
self.0.lock().await.push("member".to_string())
}
async fn on_room_name(&self, _: Room, _: &SyncStateEvent<NameEventContent>) {
self.0.lock().await.push("name".to_string())
}
async fn on_room_canonical_alias(
&self,
_: Room,
_: &SyncStateEvent<CanonicalAliasEventContent>,
) {
self.0.lock().await.push("canonical".to_string())
}
async fn on_room_aliases(&self, _: Room, _: &SyncStateEvent<AliasesEventContent>) {
self.0.lock().await.push("aliases".to_string())
}
async fn on_room_avatar(&self, _: Room, _: &SyncStateEvent<AvatarEventContent>) {
self.0.lock().await.push("avatar".to_string())
}
async fn on_room_message(&self, _: Room, _: &SyncMessageEvent<MsgEventContent>) {
self.0.lock().await.push("message".to_string())
}
async fn on_room_message_feedback(
&self,
_: Room,
_: &SyncMessageEvent<FeedbackEventContent>,
) {
self.0.lock().await.push("feedback".to_string())
}
async fn on_room_call_invite(&self, _: Room, _: &SyncMessageEvent<InviteEventContent>) {
self.0.lock().await.push("call invite".to_string())
}
async fn on_room_call_answer(&self, _: Room, _: &SyncMessageEvent<AnswerEventContent>) {
self.0.lock().await.push("call answer".to_string())
}
async fn on_room_call_candidates(
&self,
_: Room,
_: &SyncMessageEvent<CandidatesEventContent>,
) {
self.0.lock().await.push("call candidates".to_string())
}
async fn on_room_call_hangup(&self, _: Room, _: &SyncMessageEvent<HangupEventContent>) {
self.0.lock().await.push("call hangup".to_string())
}
async fn on_room_redaction(&self, _: Room, _: &SyncRedactionEvent) {
self.0.lock().await.push("redaction".to_string())
}
async fn on_room_power_levels(&self, _: Room, _: &SyncStateEvent<PowerLevelsEventContent>) {
self.0.lock().await.push("power".to_string())
}
async fn on_room_tombstone(&self, _: Room, _: &SyncStateEvent<TombstoneEventContent>) {
self.0.lock().await.push("tombstone".to_string())
}
async fn on_state_member(&self, _: Room, _: &SyncStateEvent<MemberEventContent>) {
self.0.lock().await.push("state member".to_string())
}
async fn on_state_name(&self, _: Room, _: &SyncStateEvent<NameEventContent>) {
self.0.lock().await.push("state name".to_string())
}
async fn on_state_canonical_alias(
&self,
_: Room,
_: &SyncStateEvent<CanonicalAliasEventContent>,
) {
self.0.lock().await.push("state canonical".to_string())
}
async fn on_state_aliases(&self, _: Room, _: &SyncStateEvent<AliasesEventContent>) {
self.0.lock().await.push("state aliases".to_string())
}
async fn on_state_avatar(&self, _: Room, _: &SyncStateEvent<AvatarEventContent>) {
self.0.lock().await.push("state avatar".to_string())
}
async fn on_state_power_levels(
&self,
_: Room,
_: &SyncStateEvent<PowerLevelsEventContent>,
) {
self.0.lock().await.push("state power".to_string())
}
async fn on_state_join_rules(&self, _: Room, _: &SyncStateEvent<JoinRulesEventContent>) {
self.0.lock().await.push("state rules".to_string())
}
// `AnyStrippedStateEvent`s
/// Fires when `Client` receives a
/// `AnyStrippedStateEvent::StrippedRoomMember` event.
async fn on_stripped_state_member(
&self,
_: Room,
_: &StrippedStateEvent<MemberEventContent>,
_: Option<MemberEventContent>,
) {
self.0.lock().await.push("stripped state member".to_string())
}
/// Fires when `Client` receives a
/// `AnyStrippedStateEvent::StrippedRoomName` event.
async fn on_stripped_state_name(&self, _: Room, _: &StrippedStateEvent<NameEventContent>) {
self.0.lock().await.push("stripped state name".to_string())
}
/// Fires when `Client` receives a
/// `AnyStrippedStateEvent::StrippedRoomCanonicalAlias` event.
async fn on_stripped_state_canonical_alias(
&self,
_: Room,
_: &StrippedStateEvent<CanonicalAliasEventContent>,
) {
self.0.lock().await.push("stripped state canonical".to_string())
}
/// Fires when `Client` receives a
/// `AnyStrippedStateEvent::StrippedRoomAliases` event.
async fn on_stripped_state_aliases(
&self,
_: Room,
_: &StrippedStateEvent<AliasesEventContent>,
) {
self.0.lock().await.push("stripped state aliases".to_string())
}
/// Fires when `Client` receives a
/// `AnyStrippedStateEvent::StrippedRoomAvatar` event.
async fn on_stripped_state_avatar(
&self,
_: Room,
_: &StrippedStateEvent<AvatarEventContent>,
) {
self.0.lock().await.push("stripped state avatar".to_string())
}
/// Fires when `Client` receives a
/// `AnyStrippedStateEvent::StrippedRoomPowerLevels` event.
async fn on_stripped_state_power_levels(
&self,
_: Room,
_: &StrippedStateEvent<PowerLevelsEventContent>,
) {
self.0.lock().await.push("stripped state power".to_string())
}
/// Fires when `Client` receives a
/// `AnyStrippedStateEvent::StrippedRoomJoinRules` event.
async fn on_stripped_state_join_rules(
&self,
_: Room,
_: &StrippedStateEvent<JoinRulesEventContent>,
) {
self.0.lock().await.push("stripped state rules".to_string())
}
async fn on_non_room_presence(&self, _: Room, _: &PresenceEvent) {
self.0.lock().await.push("presence".to_string())
}
async fn on_non_room_ignored_users(
&self,
_: &GlobalAccountDataEvent<IgnoredUserListEventContent>,
) {
self.0.lock().await.push("account ignore".to_string())
}
async fn on_non_room_push_rules(&self, _: &GlobalAccountDataEvent<PushRulesEventContent>) {
self.0.lock().await.push("account push rules".to_string())
}
async fn on_non_room_fully_read(
&self,
_: Room,
_: &RoomAccountDataEvent<FullyReadEventContent>,
) {
self.0.lock().await.push("account read".to_string())
}
async fn on_non_room_typing(
&self,
_: Room,
_: &SyncEphemeralRoomEvent<TypingEventContent>,
) {
self.0.lock().await.push("typing event".to_string())
}
async fn on_non_room_receipt(
&self,
_: Room,
_: &SyncEphemeralRoomEvent<ReceiptEventContent>,
) {
self.0.lock().await.push("receipt event".to_string())
}
async fn on_presence_event(&self, _: &PresenceEvent) {
self.0.lock().await.push("presence event".to_string())
}
async fn on_unrecognized_event(&self, _: Room, _: &RawJsonValue) {
self.0.lock().await.push("unrecognized event".to_string())
}
async fn on_custom_event(&self, _: Room, _: &CustomEvent<'_>) {
self.0.lock().await.push("custom event".to_string())
}
async fn on_room_notification(&self, _: Room, _: Notification) {
self.0.lock().await.push("notification".to_string())
}
}
use crate::{Client, Session, SyncSettings};
async fn get_client() -> Client {
let session = Session {
access_token: "1234".to_owned(),
user_id: user_id!("@example:localhost"),
device_id: "DEVICEID".into(),
};
let homeserver = url::Url::parse(&mockito::server_url()).unwrap();
let client = Client::new(homeserver).unwrap();
client.restore_login(session).await.unwrap();
client
}
async fn mock_sync(client: &Client, response: String) {
let _m = mock("GET", Matcher::Regex(r"^/_matrix/client/r0/sync\?.*$".to_string()))
.with_status(200)
.match_header("authorization", "Bearer 1234")
.with_body(response)
.create();
let sync_settings = SyncSettings::new().timeout(Duration::from_millis(3000));
let _response = client.sync_once(sync_settings).await.unwrap();
}
#[async_test]
async fn event_handler_joined() {
let vec = Arc::new(Mutex::new(Vec::new()));
let test_vec = Arc::clone(&vec);
let handler = Box::new(EvHandlerTest(vec));
let client = get_client().await;
client.set_event_handler(handler).await;
mock_sync(&client, test_json::SYNC.to_string()).await;
let v = test_vec.lock().await;
assert_eq!(
v.as_slice(),
[
"account ignore",
"receipt event",
"account read",
"state rules",
"state member",
"state aliases",
"state power",
"state canonical",
"state member",
"state member",
"message",
"presence event",
"notification",
],
)
}
#[async_test]
async fn event_handler_invite() {
let vec = Arc::new(Mutex::new(Vec::new()));
let test_vec = Arc::clone(&vec);
let handler = Box::new(EvHandlerTest(vec));
let client = get_client().await;
client.set_event_handler(handler).await;
mock_sync(&client, test_json::INVITE_SYNC.to_string()).await;
let v = test_vec.lock().await;
assert_eq!(v.as_slice(), ["stripped state name", "stripped state member", "presence event"],)
}
#[async_test]
async fn event_handler_leave() {
let vec = Arc::new(Mutex::new(Vec::new()));
let test_vec = Arc::clone(&vec);
let handler = Box::new(EvHandlerTest(vec));
let client = get_client().await;
client.set_event_handler(handler).await;
mock_sync(&client, test_json::LEAVE_SYNC.to_string()).await;
let v = test_vec.lock().await;
assert_eq!(
v.as_slice(),
[
"account ignore",
"state rules",
"state member",
"state aliases",
"state power",
"state canonical",
"state member",
"state member",
"message",
"presence event",
"notification",
],
)
}
#[async_test]
async fn event_handler_more_events() {
let vec = Arc::new(Mutex::new(Vec::new()));
let test_vec = Arc::clone(&vec);
let handler = Box::new(EvHandlerTest(vec));
let client = get_client().await;
client.set_event_handler(handler).await;
mock_sync(&client, test_json::MORE_SYNC.to_string()).await;
let v = test_vec.lock().await;
assert_eq!(
v.as_slice(),
[
"receipt event",
"typing event",
"message",
"message", // this is a message edit event
"redaction",
"message", // this is a notice event
],
)
}
#[async_test]
async fn event_handler_voip() {
let vec = Arc::new(Mutex::new(Vec::new()));
let test_vec = Arc::clone(&vec);
let handler = Box::new(EvHandlerTest(vec));
let client = get_client().await;
client.set_event_handler(handler).await;
mock_sync(&client, test_json::VOIP_SYNC.to_string()).await;
let v = test_vec.lock().await;
assert_eq!(v.as_slice(), ["call invite", "call answer", "call candidates", "call hangup",],)
}
#[async_test]
async fn event_handler_two_syncs() {
let vec = Arc::new(Mutex::new(Vec::new()));
let test_vec = Arc::clone(&vec);
let handler = Box::new(EvHandlerTest(vec));
let client = get_client().await;
client.set_event_handler(handler).await;
mock_sync(&client, test_json::SYNC.to_string()).await;
mock_sync(&client, test_json::MORE_SYNC.to_string()).await;
let v = test_vec.lock().await;
assert_eq!(
v.as_slice(),
[
"account ignore",
"receipt event",
"account read",
"state rules",
"state member",
"state aliases",
"state power",
"state canonical",
"state member",
"state member",
"message",
"presence event",
"notification",
"receipt event",
"typing event",
"message",
"message", // this is a message edit event
"redaction",
"message", // this is a notice event
"notification",
"notification",
"notification",
],
)
}
}

View File

@ -20,14 +20,19 @@ use matrix_sdk_base::crypto::{
}; };
use ruma::{events::key::verification::VerificationMethod, DeviceId, DeviceIdBox}; use ruma::{events::key::verification::VerificationMethod, DeviceId, DeviceIdBox};
use super::ManualVerifyError;
use crate::{ use crate::{
error::Result, error::Result,
verification::{SasVerification, VerificationRequest}, verification::{SasVerification, VerificationRequest},
Client, Client,
}; };
/// A device represents a E2EE capable client or device of an user.
///
/// A device is backed by [device keys] that are uploaded to the server.
///
/// [device keys]: https://spec.matrix.org/unstable/client-server-api/#device-keys
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
/// A device represents a E2EE capable client of an user.
pub struct Device { pub struct Device {
pub(crate) inner: BaseDevice, pub(crate) inner: BaseDevice,
pub(crate) client: Client, pub(crate) client: Client,
@ -42,49 +47,13 @@ impl Deref for Device {
} }
impl Device { impl Device {
/// Start a interactive verification with this `Device` /// Request an interactive verification with this `Device`.
/// ///
/// Returns a `Sas` object that represents the interactive verification /// Returns a [`VerificationRequest`] object that can be used to control the
/// flow. /// verification flow.
///
/// This method has been deprecated in the spec and the
/// [`request_verification()`] method should be used instead.
///
/// # Examples
///
/// ```no_run
/// # use std::convert::TryFrom;
/// # use matrix_sdk::{Client, ruma::UserId};
/// # use url::Url;
/// # use futures::executor::block_on;
/// # let alice = UserId::try_from("@alice:example.org").unwrap();
/// # let homeserver = Url::parse("http://example.com").unwrap();
/// # let client = Client::new(homeserver).unwrap();
/// # block_on(async {
/// let device = client.get_device(&alice, "DEVICEID".into())
/// .await
/// .unwrap()
/// .unwrap();
///
/// let verification = device.start_verification().await.unwrap();
/// # });
/// ```
///
/// [`request_verification()`]: #method.request_verification
pub async fn start_verification(&self) -> Result<SasVerification> {
let (sas, request) = self.inner.start_verification().await?;
self.client.send_to_device(&request).await?;
Ok(SasVerification { inner: sas, client: self.client.clone() })
}
/// Request an interacitve verification with this `Device`
///
/// Returns a `VerificationRequest` object and a to-device request that
/// needs to be sent out.
/// ///
/// The default methods that are supported are `m.sas.v1` and /// The default methods that are supported are `m.sas.v1` and
/// `m.qr_code.show.v1`, if this isn't desireable the /// `m.qr_code.show.v1`, if this isn't desirable the
/// [`request_verification_with_methods()`] method can be used to override /// [`request_verification_with_methods()`] method can be used to override
/// this. /// this.
/// ///
@ -99,13 +68,12 @@ impl Device {
/// # let homeserver = Url::parse("http://example.com").unwrap(); /// # let homeserver = Url::parse("http://example.com").unwrap();
/// # let client = Client::new(homeserver).unwrap(); /// # let client = Client::new(homeserver).unwrap();
/// # block_on(async { /// # block_on(async {
/// let device = client.get_device(&alice, "DEVICEID".into()) /// let device = client.get_device(&alice, "DEVICEID".into()).await?;
/// .await
/// .unwrap()
/// .unwrap();
/// ///
/// let verification = device.request_verification().await.unwrap(); /// if let Some(device) = device {
/// # }); /// let verification = device.request_verification().await?;
/// }
/// # anyhow::Result::<()>::Ok(()) });
/// ``` /// ```
/// ///
/// [`request_verification_with_methods()`]: /// [`request_verification_with_methods()`]:
@ -117,14 +85,19 @@ impl Device {
Ok(VerificationRequest { inner: verification, client: self.client.clone() }) Ok(VerificationRequest { inner: verification, client: self.client.clone() })
} }
/// Request an interacitve verification with this `Device` /// Request an interactive verification with this `Device`.
/// ///
/// Returns a `VerificationRequest` object and a to-device request that /// Returns a [`VerificationRequest`] object that can be used to control the
/// needs to be sent out. /// verification flow.
/// ///
/// # Arguments /// # Arguments
/// ///
/// * `methods` - The verification methods that we want to support. /// * `methods` - The verification methods that we want to support. Must be
/// non-empty.
///
/// # Panics
///
/// This method will panic if `methods` is empty.
/// ///
/// # Examples /// # Examples
/// ///
@ -143,30 +116,157 @@ impl Device {
/// # let homeserver = Url::parse("http://example.com").unwrap(); /// # let homeserver = Url::parse("http://example.com").unwrap();
/// # let client = Client::new(homeserver).unwrap(); /// # let client = Client::new(homeserver).unwrap();
/// # block_on(async { /// # block_on(async {
/// let device = client.get_device(&alice, "DEVICEID".into()) /// let device = client.get_device(&alice, "DEVICEID".into()).await?;
/// .await
/// .unwrap()
/// .unwrap();
/// ///
/// // We don't want to support showing a QR code, we only support SAS /// // We don't want to support showing a QR code, we only support SAS
/// // verification /// // verification
/// let methods = vec![VerificationMethod::SasV1]; /// let methods = vec![VerificationMethod::SasV1];
/// ///
/// let verification = device.request_verification_with_methods(methods).await.unwrap(); /// if let Some(device) = device {
/// # }); /// let verification = device.request_verification_with_methods(methods).await?;
/// }
/// # anyhow::Result::<()>::Ok(()) });
/// ``` /// ```
pub async fn request_verification_with_methods( pub async fn request_verification_with_methods(
&self, &self,
methods: Vec<VerificationMethod>, methods: Vec<VerificationMethod>,
) -> Result<VerificationRequest> { ) -> Result<VerificationRequest> {
if methods.is_empty() {
panic!("The list of verification methods can't be non-empty");
}
let (verification, request) = self.inner.request_verification_with_methods(methods).await; let (verification, request) = self.inner.request_verification_with_methods(methods).await;
self.client.send_verification_request(request).await?; self.client.send_verification_request(request).await?;
Ok(VerificationRequest { inner: verification, client: self.client.clone() }) Ok(VerificationRequest { inner: verification, client: self.client.clone() })
} }
/// Is the device considered to be verified, either by locally trusting it /// Start an interactive verification with this [`Device`]
/// or using cross signing. ///
/// Returns a [`SasVerification`] object that represents the interactive
/// verification flow.
///
/// This method has been deprecated in the spec and the
/// [`request_verification()`] method should be used instead.
///
/// # Examples
///
/// ```no_run
/// # use std::convert::TryFrom;
/// # use matrix_sdk::{Client, ruma::UserId};
/// # use url::Url;
/// # use futures::executor::block_on;
/// # let alice = UserId::try_from("@alice:example.org").unwrap();
/// # let homeserver = Url::parse("http://example.com").unwrap();
/// # let client = Client::new(homeserver).unwrap();
/// # block_on(async {
/// let device = client.get_device(&alice, "DEVICEID".into()).await?;
///
/// if let Some(device) = device {
/// let verification = device.start_verification().await?;
/// }
/// # anyhow::Result::<()>::Ok(()) });
/// ```
///
/// [`request_verification()`]: #method.request_verification
#[deprecated(
since = "0.4.0",
note = "directly starting a verification is deprecated in the spec. \
Users should instead use request_verification()"
)]
pub async fn start_verification(&self) -> Result<SasVerification> {
let (sas, request) = self.inner.start_verification().await?;
self.client.send_to_device(&request).await?;
Ok(SasVerification { inner: sas, client: self.client.clone() })
}
/// Manually verify this device.
///
/// This method will attempt to sign the device using our private cross
/// signing key.
///
/// This method will always fail if the device belongs to someone else, we
/// can only sign our own devices.
///
/// It can also fail if we don't have the private part of our self-signing
/// key.
///
/// The state of our private cross signing keys can be inspected using the
/// [`Client::cross_signing_status()`] method.
///
/// [`Client::cross_signing_status()`]: crate::Client::cross_signing_status
///
/// # Examples
///
/// ```no_run
/// # use std::convert::TryFrom;
/// # use matrix_sdk::{
/// # Client,
/// # ruma::{
/// # UserId,
/// # events::key::verification::VerificationMethod,
/// # }
/// # };
/// # use url::Url;
/// # use futures::executor::block_on;
/// # let alice = UserId::try_from("@alice:example.org").unwrap();
/// # let homeserver = Url::parse("http://example.com").unwrap();
/// # let client = Client::new(homeserver).unwrap();
/// # block_on(async {
/// let device = client.get_device(&alice, "DEVICEID".into()).await?;
///
/// if let Some(device) = device {
/// device.verify().await?;
/// }
/// # anyhow::Result::<()>::Ok(()) });
/// ```
pub async fn verify(&self) -> std::result::Result<(), ManualVerifyError> {
let request = self.inner.verify().await?;
self.client.send(request, None).await?;
Ok(())
}
/// Is the device considered to be verified.
///
/// A device is considered to be verified, either if it's locally trusted,
/// or if it's signed by the appropriate cross signing key.
///
/// If the device belongs to our own userk, the device needs to be signed by
/// our self-signing key and our own user identity needs to be verified.
///
/// If the device belongs to some other user, the device needs to be signed
/// by the users signing key and the user identity of the user needs to be
/// verified.
//// # Examples
///
/// ```no_run
/// # use std::convert::TryFrom;
/// # use matrix_sdk::{
/// # Client,
/// # ruma::{
/// # UserId,
/// # events::key::verification::VerificationMethod,
/// # }
/// # };
/// # use url::Url;
/// # use futures::executor::block_on;
/// # let alice = UserId::try_from("@alice:example.org").unwrap();
/// # let homeserver = Url::parse("http://example.com").unwrap();
/// # let client = Client::new(homeserver).unwrap();
/// # block_on(async {
/// let user = client.get_user_identity(&alice).await?;
///
/// if let Some(user) = user {
/// if user.verified() {
/// println!("User {} is verified", user.user_id().as_str());
/// } else {
/// println!("User {} is not verified", user.user_id().as_str());
/// }
/// }
/// # anyhow::Result::<()>::Ok(()) });
/// ```
pub fn verified(&self) -> bool { pub fn verified(&self) -> bool {
self.inner.verified() self.inner.verified()
} }
@ -187,7 +287,7 @@ impl Device {
} }
} }
/// A read only view over all devices belonging to a user. /// The collection of all the [`Device`]s a user has.
#[derive(Debug)] #[derive(Debug)]
pub struct UserDevices { pub struct UserDevices {
pub(crate) inner: BaseUserDevices, pub(crate) inner: BaseUserDevices,

View File

@ -0,0 +1,120 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Cryptographic identities used in Matrix.
//!
//! There are two types of cryptographic identities in Matrix.
//!
//! 1. Devices, which are backed by [device keys], they represent each
//! individual log in by an E2EE capable Matrix client. We represent devices
//! using the [`Device`] struct.
//!
//! 2. User identities, which are backed by [cross signing keys]. The user
//! identity represent a unique E2EE capable identity of any given user. This
//! identity is generally created and uploaded to the server by the first E2EE
//! capable client the user logs in with. We represent user identities using the
//! [`UserIdentity`] struct.
//!
//! A [`Device`] or an [`UserIdentity`] can be used to inspect the public keys
//! of the device or identity, or it can be used to initiate a interactive
//! verification flow. They can also be manually marked as verified.
//!
//! # Examples
//!
//! Verifying a device is pretty straightforward:
//!
//! ```no_run
//! # use std::convert::TryFrom;
//! # use matrix_sdk::{Client, ruma::UserId};
//! # use url::Url;
//! # use futures::executor::block_on;
//! # let alice = UserId::try_from("@alice:example.org").unwrap();
//! # let homeserver = Url::parse("http://example.com").unwrap();
//! # let client = Client::new(homeserver).unwrap();
//! # block_on(async {
//! let device = client.get_device(&alice, "DEVICEID".into()).await?;
//!
//! if let Some(device) = device {
//! // Let's request the device to be verified.
//! let verification = device.request_verification().await?;
//!
//! // Actually this is taking too long.
//! verification.cancel().await?;
//!
//! // Let's just mark it as verified.
//! device.verify().await?;
//! }
//! # anyhow::Result::<()>::Ok(()) });
//! ```
//!
//! Verifying a user identity works largely the same:
//!
//! ```no_run
//! # use std::convert::TryFrom;
//! # use matrix_sdk::{Client, ruma::UserId};
//! # use url::Url;
//! # use futures::executor::block_on;
//! # let alice = UserId::try_from("@alice:example.org").unwrap();
//! # let homeserver = Url::parse("http://example.com").unwrap();
//! # let client = Client::new(homeserver).unwrap();
//! # block_on(async {
//! let user = client.get_user_identity(&alice).await?;
//!
//! if let Some(user) = user {
//! // Let's request the user to be verified.
//! let verification = user.request_verification().await?;
//!
//! // Actually this is taking too long.
//! verification.cancel().await?;
//!
//! // Let's just mark it as verified.
//! user.verify().await?;
//! }
//! # anyhow::Result::<()>::Ok(()) });
//! ```
//!
//! [cross signing keys]: https://spec.matrix.org/unstable/client-server-api/#cross-signing
//! [device keys]: https://spec.matrix.org/unstable/client-server-api/#device-keys
mod devices;
mod users;
pub use devices::{Device, UserDevices};
pub use matrix_sdk_base::crypto::MasterPubkey;
pub use users::UserIdentity;
/// Error for the manual verification step, when we manually sign users or
/// devices.
#[derive(thiserror::Error, Debug)]
pub enum ManualVerifyError {
/// Error that happens when we try to upload the user or device signature.
#[error(transparent)]
Http(#[from] crate::HttpError),
/// Error that happens when we try to sign the user or device.
#[error(transparent)]
Signature(#[from] matrix_sdk_base::crypto::SignatureError),
}
/// Error when requesting a verification.
#[derive(thiserror::Error, Debug)]
pub enum RequestVerificationError {
/// An ordinary error coming from the SDK, i.e. when we fail to send out a
/// HTTP request or if there's an error with the storage layer.
#[error(transparent)]
Sdk(#[from] crate::Error),
/// Verifying other users requires having a DM open with them, this error
/// signals that we didn't have a DM and that we failed to create one.
#[error("Couldn't create a DM with user {0} where the verification should take place")]
RoomCreation(ruma::UserId),
}

View File

@ -0,0 +1,406 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{result::Result, sync::Arc};
use matrix_sdk_base::{
crypto::{
MasterPubkey, OwnUserIdentity as InnerOwnUserIdentity, UserIdentity as InnerUserIdentity,
},
locks::RwLock,
};
use ruma::{
events::{
key::verification::VerificationMethod,
room::message::{MessageEventContent, MessageType},
AnyMessageEventContent,
},
UserId,
};
use super::{ManualVerifyError, RequestVerificationError};
use crate::{room::Joined, verification::VerificationRequest, Client};
/// A struct representing a E2EE capable identity of a user.
///
/// The identity is backed by public [cross signing] keys that users upload.
///
/// [cross signing]: https://spec.matrix.org/unstable/client-server-api/#cross-signing
#[derive(Debug, Clone)]
pub struct UserIdentity {
inner: UserIdentities,
}
impl UserIdentity {
pub(crate) fn new_own(client: Client, identity: InnerOwnUserIdentity) -> Self {
let identity = OwnUserIdentity { inner: identity, client };
Self { inner: identity.into() }
}
pub(crate) fn new(client: Client, identity: InnerUserIdentity, room: Option<Joined>) -> Self {
let identity = OtherUserIdentity {
inner: identity,
client,
direct_message_room: RwLock::new(room).into(),
};
Self { inner: identity.into() }
}
/// The ID of the user this E2EE identity belongs to.
pub fn user_id(&self) -> &UserId {
match &self.inner {
UserIdentities::Own(i) => i.inner.user_id(),
UserIdentities::Other(i) => i.inner.user_id(),
}
}
/// Request an interacitve verification with this `UserIdentity`.
///
/// Returns a [`VerificationRequest`] object that can be used to control the
/// verification flow.
///
/// This will send out a `m.key.verification.request` event to all the E2EE
/// capable devices we have if we're requesting verification with our own
/// user identity or will send out the event to a DM we share with the user.
///
/// If we don't share a DM with this user one will be created before the
/// event gets sent out.
///
/// The default methods that are supported are `m.sas.v1` and
/// `m.qr_code.show.v1`, if this isn't desirable the
/// [`request_verification_with_methods()`] method can be used to override
/// this.
///
/// # Examples
///
/// ```no_run
/// # use std::convert::TryFrom;
/// # use matrix_sdk::{Client, ruma::UserId};
/// # use url::Url;
/// # let alice = UserId::try_from("@alice:example.org").unwrap();
/// # let homeserver = Url::parse("http://example.com").unwrap();
/// # let client = Client::new(homeserver).unwrap();
/// # futures::executor::block_on(async {
/// let user = client.get_user_identity(&alice).await?;
///
/// if let Some(user) = user {
/// let verification = user.request_verification().await?;
/// }
///
/// # anyhow::Result::<()>::Ok(()) });
/// ```
///
/// [`request_verification_with_methods()`]:
/// #method.request_verification_with_methods
pub async fn request_verification(
&self,
) -> Result<VerificationRequest, RequestVerificationError> {
match &self.inner {
UserIdentities::Own(i) => i.request_verification(None).await,
UserIdentities::Other(i) => i.request_verification(None).await,
}
}
/// Request an interacitve verification with this `UserIdentity`.
///
/// Returns a [`VerificationRequest`] object that can be used to control the
/// verification flow.
///
/// This methods behaves the same way as [`request_verification()`],
/// but the advertised verification methods can be manually selected.
///
/// # Arguments
///
/// * `methods` - The verification methods that we want to support. Must be
/// non-empty.
///
/// # Panics
///
/// This method will panic if `methods` is empty.
///
/// # Examples
///
/// ```no_run
/// # use std::convert::TryFrom;
/// # use matrix_sdk::{
/// # Client,
/// # ruma::{
/// # UserId,
/// # events::key::verification::VerificationMethod,
/// # }
/// # };
/// # use url::Url;
/// # use futures::executor::block_on;
/// # let alice = UserId::try_from("@alice:example.org").unwrap();
/// # let homeserver = Url::parse("http://example.com").unwrap();
/// # let client = Client::new(homeserver).unwrap();
/// # block_on(async {
/// let user = client.get_user_identity(&alice).await?;
///
/// // We don't want to support showing a QR code, we only support SAS
/// // verification
/// let methods = vec![VerificationMethod::SasV1];
///
/// if let Some(user) = user {
/// let verification = user.request_verification_with_methods(methods).await?;
/// }
/// # anyhow::Result::<()>::Ok(()) });
/// ```
///
/// [`request_verification()`]: #method.request_verification
pub async fn request_verification_with_methods(
&self,
methods: Vec<VerificationMethod>,
) -> Result<VerificationRequest, RequestVerificationError> {
if methods.is_empty() {
panic!("The list of verification methods can't be non-empty");
}
match &self.inner {
UserIdentities::Own(i) => i.request_verification(Some(methods)).await,
UserIdentities::Other(i) => i.request_verification(Some(methods)).await,
}
}
/// Manually verify this [`UserIdentity`].
///
/// This method will attempt to sign the user identity using our private
/// cross signing key. Verifying can fail if we don't have the private
/// part of our user-signing key.
///
/// The state of our private cross signing keys can be inspected using the
/// [`Client::cross_signing_status()`] method.
///
/// [`Client::cross_signing_status()`]: crate::Client::cross_signing_status
///
/// # Examples
///
/// ```no_run
/// # use std::convert::TryFrom;
/// # use matrix_sdk::{
/// # Client,
/// # ruma::{
/// # UserId,
/// # events::key::verification::VerificationMethod,
/// # }
/// # };
/// # use url::Url;
/// # use futures::executor::block_on;
/// # let alice = UserId::try_from("@alice:example.org").unwrap();
/// # let homeserver = Url::parse("http://example.com").unwrap();
/// # let client = Client::new(homeserver).unwrap();
/// # block_on(async {
/// let user = client.get_user_identity(&alice).await?;
///
/// if let Some(user) = user {
/// user.verify().await?;
/// }
/// # anyhow::Result::<()>::Ok(()) });
/// ```
pub async fn verify(&self) -> Result<(), ManualVerifyError> {
match &self.inner {
UserIdentities::Own(i) => i.verify().await,
UserIdentities::Other(i) => i.verify().await,
}
}
/// Is the user identity considered to be verified.
///
/// A user identity is considered to be verified if it has been signed by
/// our user-signing key, if the identity belongs to another user, or if we
/// locally marked it as verified, if the user identity belongs to us.
///
/// If the identity belongs to another user, our own user identity needs to
/// be verified as well for the identity to be considered to be verified.
///
/// # Examples
///
/// ```no_run
/// # use std::convert::TryFrom;
/// # use matrix_sdk::{
/// # Client,
/// # ruma::{
/// # UserId,
/// # events::key::verification::VerificationMethod,
/// # }
/// # };
/// # use url::Url;
/// # use futures::executor::block_on;
/// # let alice = UserId::try_from("@alice:example.org").unwrap();
/// # let homeserver = Url::parse("http://example.com").unwrap();
/// # let client = Client::new(homeserver).unwrap();
/// # block_on(async {
/// let user = client.get_user_identity(&alice).await?;
///
/// if let Some(user) = user {
/// if user.verified() {
/// println!("User {} is verified", user.user_id().as_str());
/// } else {
/// println!("User {} is not verified", user.user_id().as_str());
/// }
/// }
/// # anyhow::Result::<()>::Ok(()) });
/// ```
pub fn verified(&self) -> bool {
match &self.inner {
UserIdentities::Own(i) => i.inner.is_verified(),
UserIdentities::Other(i) => i.inner.verified(),
}
}
/// Get the public part of the master key of this user identity.
///
/// # Examples
///
/// ```no_run
/// # use std::convert::TryFrom;
/// # use matrix_sdk::{
/// # Client,
/// # ruma::{
/// # UserId,
/// # events::key::verification::VerificationMethod,
/// # }
/// # };
/// # use url::Url;
/// # use futures::executor::block_on;
/// # let alice = UserId::try_from("@alice:example.org").unwrap();
/// # let homeserver = Url::parse("http://example.com").unwrap();
/// # let client = Client::new(homeserver).unwrap();
/// # block_on(async {
/// let user = client.get_user_identity(&alice).await?;
///
/// if let Some(user) = user {
/// // Let's verify the user after we confirm that the master key
/// // matches what we expect, for this we fetch the first public key we
/// // can find, there's currently only a single key allowed so this is
/// // fine.
/// if user.master_key().get_first_key() == Some("MyMasterKey") {
/// println!(
/// "Master keys match for user {}, marking the user as verified",
/// user.user_id().as_str(),
/// );
/// user.verify().await?;
/// } else {
/// println!("Master keys don't match for user {}", user.user_id().as_str());
/// }
/// }
/// # anyhow::Result::<()>::Ok(()) });
/// ```
pub fn master_key(&self) -> &MasterPubkey {
match &self.inner {
UserIdentities::Own(i) => i.inner.master_key(),
UserIdentities::Other(i) => i.inner.master_key(),
}
}
}
#[derive(Debug, Clone)]
enum UserIdentities {
Own(OwnUserIdentity),
Other(OtherUserIdentity),
}
impl From<OwnUserIdentity> for UserIdentities {
fn from(i: OwnUserIdentity) -> Self {
Self::Own(i)
}
}
impl From<OtherUserIdentity> for UserIdentities {
fn from(i: OtherUserIdentity) -> Self {
Self::Other(i)
}
}
#[derive(Debug, Clone)]
struct OwnUserIdentity {
pub(crate) inner: InnerOwnUserIdentity,
pub(crate) client: Client,
}
#[derive(Debug, Clone)]
struct OtherUserIdentity {
pub(crate) inner: InnerUserIdentity,
pub(crate) client: Client,
pub(crate) direct_message_room: Arc<RwLock<Option<Joined>>>,
}
impl OwnUserIdentity {
async fn request_verification(
&self,
methods: Option<Vec<VerificationMethod>>,
) -> Result<VerificationRequest, RequestVerificationError> {
let (verification, request) = if let Some(methods) = methods {
self.inner
.request_verification_with_methods(methods)
.await
.map_err(crate::Error::from)?
} else {
self.inner.request_verification().await.map_err(crate::Error::from)?
};
self.client.send_verification_request(request).await?;
Ok(VerificationRequest { inner: verification, client: self.client.clone() })
}
async fn verify(&self) -> Result<(), ManualVerifyError> {
let request = self.inner.verify().await?;
self.client.send(request, None).await?;
Ok(())
}
}
impl OtherUserIdentity {
async fn request_verification(
&self,
methods: Option<Vec<VerificationMethod>>,
) -> Result<VerificationRequest, RequestVerificationError> {
let content = self.inner.verification_request_content(methods.clone()).await;
let room = if let Some(room) = self.direct_message_room.read().await.as_ref() {
room.clone()
} else if let Some(room) =
self.client.create_dm_room(self.inner.user_id().to_owned()).await?
{
room
} else {
return Err(RequestVerificationError::RoomCreation(self.inner.user_id().to_owned()));
};
let response = room
.send(
AnyMessageEventContent::RoomMessage(MessageEventContent::new(
MessageType::VerificationRequest(content),
)),
None,
)
.await?;
let verification =
self.inner.request_verification(room.room_id(), &response.event_id, methods).await;
Ok(VerificationRequest { inner: verification, client: self.client.clone() })
}
async fn verify(&self) -> Result<(), ManualVerifyError> {
let request = self.inner.verify().await?;
self.client.send(request, None).await?;
Ok(())
}
}

View File

@ -39,23 +39,24 @@
//! The following crate feature flags are available: //! The following crate feature flags are available:
//! //!
//! * `encryption`: Enables end-to-end encryption support in the library. //! * `encryption`: Enables end-to-end encryption support in the library.
//! * `sled_cryptostore`: Enables a Sled based store for the encryption //! * `sled_cryptostore`: Enables a Sled based store for the encryption keys. If
//! keys. If this is disabled and `encryption` support is enabled the keys will //! this is disabled and `encryption` support is enabled the keys will by
//! by default be stored only in memory and thus lost after the client is //! default be stored only in memory and thus lost after the client is
//! destroyed. //! destroyed.
//! * `markdown`: Support for sending markdown formatted messages. //! * `markdown`: Support for sending markdown formatted messages.
//! * `socks`: Enables SOCKS support in reqwest, the default HTTP client. //! * `socks`: Enables SOCKS support in reqwest, the default HTTP client.
//! * `sso_login`: Enables SSO login with a local http server. //! * `sso_login`: Enables SSO login with a local http server.
//! * `require_auth_for_profile_requests`: Whether to send the access token in //! * `require_auth_for_profile_requests`: Whether to send the access token in
//! the authentication //! the authentication header when calling endpoints that retrieve profile
//! header when calling endpoints that retrieve profile data. This matches the //! data. This matches the synapse configuration
//! synapse configuration `require_auth_for_profile_requests`. Enabled by //! `require_auth_for_profile_requests`. Enabled by default.
//! default.
//! * `appservice`: Enables low-level appservice functionality. For an //! * `appservice`: Enables low-level appservice functionality. For an
//! high-level API there's the `matrix-sdk-appservice` crate //! high-level API there's the `matrix-sdk-appservice` crate
//! * `anyhow`: Support for returning `anyhow::Result<()>` from event handlers.
#![deny( #![deny(
missing_debug_implementations, missing_debug_implementations,
missing_docs,
dead_code, dead_code,
missing_docs, missing_docs,
trivial_casts, trivial_casts,
@ -90,7 +91,7 @@ pub use ruma;
mod client; mod client;
mod error; mod error;
mod event_handler; pub mod event_handler;
mod http_client; mod http_client;
/// High-level room API /// High-level room API
pub mod room; pub mod room;
@ -98,16 +99,14 @@ pub mod room;
mod room_member; mod room_member;
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
mod device; #[cfg_attr(feature = "docs", doc(cfg(encryption)))]
pub mod identities;
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
#[cfg_attr(feature = "docs", doc(cfg(encryption)))]
pub mod verification; pub mod verification;
pub use client::{Client, ClientConfig, LoopCtrl, RequestConfig, SyncSettings}; pub use client::{Client, ClientConfig, LoopCtrl, RequestConfig, SyncSettings};
#[cfg(feature = "encryption")] pub use error::{Error, HttpError, HttpResult, Result};
#[cfg_attr(feature = "docs", doc(cfg(encryption)))]
pub use device::Device;
pub use error::{Error, HttpError, Result};
pub use event_handler::{CustomEvent, EventHandler};
pub use http_client::HttpSend; pub use http_client::HttpSend;
pub use room_member::RoomMember; pub use room_member::RoomMember;
#[cfg(not(target_arch = "wasm32"))] #[cfg(not(target_arch = "wasm32"))]

View File

@ -14,6 +14,7 @@ use ruma::{
}; };
use crate::{ use crate::{
error::HttpResult,
media::{MediaFormat, MediaRequest, MediaType}, media::{MediaFormat, MediaRequest, MediaType},
room::RoomType, room::RoomType,
BaseRoom, Client, Result, RoomMember, BaseRoom, Client, Result, RoomMember,
@ -143,7 +144,7 @@ impl Common {
pub async fn messages( pub async fn messages(
&self, &self,
request: impl Into<get_message_events::Request<'_>>, request: impl Into<get_message_events::Request<'_>>,
) -> Result<get_message_events::Response> { ) -> HttpResult<get_message_events::Response> {
let request = request.into(); let request = request.into();
self.client.send(request, None).await self.client.send(request, None).await
} }
@ -376,4 +377,25 @@ impl Common {
.await .await
.map_err(Into::into) .map_err(Into::into)
} }
/// Check if all members of this room are verified and all their devices are
/// verified.
///
/// Returns true if all devices in the room are verified, otherwise false.
#[cfg(feature = "encryption")]
#[cfg_attr(feature = "docs", doc(cfg(encryption)))]
pub async fn contains_only_verified_devices(&self) -> Result<bool> {
let user_ids = self.client.store().get_user_ids(self.room_id()).await?;
for user_id in user_ids {
let devices = self.client.get_user_devices(&user_id).await?;
let any_unverified = devices.devices().any(|d| !d.verified());
if any_unverified {
return Ok(false);
}
}
Ok(true)
}
} }

View File

@ -46,7 +46,7 @@ use ruma::{
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
use tracing::instrument; use tracing::instrument;
use crate::{room::Common, BaseRoom, Client, Error, Result, RoomType}; use crate::{error::HttpResult, room::Common, BaseRoom, Client, HttpError, Result, RoomType};
const TYPING_NOTICE_TIMEOUT: Duration = Duration::from_secs(4); const TYPING_NOTICE_TIMEOUT: Duration = Duration::from_secs(4);
const TYPING_NOTICE_RESEND_TIMEOUT: Duration = Duration::from_secs(3); const TYPING_NOTICE_RESEND_TIMEOUT: Duration = Duration::from_secs(3);
@ -562,7 +562,7 @@ impl Joined {
&self, &self,
content: impl Into<AnyStateEventContent>, content: impl Into<AnyStateEventContent>,
state_key: &str, state_key: &str,
) -> Result<send_state_event::Response> { ) -> HttpResult<send_state_event::Response> {
let content = content.into(); let content = content.into();
let request = send_state_event::Request::new(self.inner.room_id(), state_key, &content); let request = send_state_event::Request::new(self.inner.room_id(), state_key, &content);
@ -606,7 +606,7 @@ impl Joined {
event_id: &EventId, event_id: &EventId,
reason: Option<&str>, reason: Option<&str>,
txn_id: Option<Uuid>, txn_id: Option<Uuid>,
) -> Result<redact_event::Response> { ) -> HttpResult<redact_event::Response> {
let txn_id = txn_id.unwrap_or_else(Uuid::new_v4).to_string(); let txn_id = txn_id.unwrap_or_else(Uuid::new_v4).to_string();
let request = let request =
assign!(redact_event::Request::new(self.inner.room_id(), event_id, &txn_id), { assign!(redact_event::Request::new(self.inner.room_id(), event_id, &txn_id), {
@ -642,8 +642,8 @@ impl Joined {
/// room.set_tag("u.work", tag_info ); /// room.set_tag("u.work", tag_info );
/// # }) /// # })
/// ``` /// ```
pub async fn set_tag(&self, tag: &str, tag_info: TagInfo) -> Result<create_tag::Response> { pub async fn set_tag(&self, tag: &str, tag_info: TagInfo) -> HttpResult<create_tag::Response> {
let user_id = self.client.user_id().await.ok_or(Error::AuthenticationRequired)?; let user_id = self.client.user_id().await.ok_or(HttpError::AuthenticationRequired)?;
let request = create_tag::Request::new(&user_id, self.inner.room_id(), tag, tag_info); let request = create_tag::Request::new(&user_id, self.inner.room_id(), tag, tag_info);
self.client.send(request, None).await self.client.send(request, None).await
} }
@ -654,8 +654,8 @@ impl Joined {
/// ///
/// # Arguments /// # Arguments
/// * `tag` - The tag to remove. /// * `tag` - The tag to remove.
pub async fn remove_tag(&self, tag: &str) -> Result<delete_tag::Response> { pub async fn remove_tag(&self, tag: &str) -> HttpResult<delete_tag::Response> {
let user_id = self.client.user_id().await.ok_or(Error::AuthenticationRequired)?; let user_id = self.client.user_id().await.ok_or(HttpError::AuthenticationRequired)?;
let request = delete_tag::Request::new(&user_id, self.inner.room_id(), tag); let request = delete_tag::Request::new(&user_id, self.inner.room_id(), tag);
self.client.send(request, None).await self.client.send(request, None).await
} }

View File

@ -1 +0,0 @@

View File

@ -25,9 +25,10 @@
//! [VerificationRequest::is_ready()] method returns true, the verification can //! [VerificationRequest::is_ready()] method returns true, the verification can
//! transition into one of the supported verification flows: //! transition into one of the supported verification flows:
//! //!
//! * [SasVerification] - Interactive verification using a short authentication //! * [`SasVerification`] - Interactive verification using a short
//! authentication
//! string. //! string.
//! * [QrVerification] - Interactive verification using QR codes. //! * [`QrVerification`] - Interactive verification using QR codes.
#[cfg(feature = "qrcode")] #[cfg(feature = "qrcode")]
mod qrcode; mod qrcode;

View File

@ -8,7 +8,17 @@ name = "matrix-sdk-appservice"
version = "0.1.0" version = "0.1.0"
[features] [features]
default = ["warp"] default = ["warp", "native-tls"]
encryption = ["matrix-sdk/encryption"]
sled_state_store = ["matrix-sdk/sled_state_store"]
sled_cryptostore = ["matrix-sdk/sled_cryptostore"]
markdown = ["matrix-sdk/markdown"]
native-tls = ["matrix-sdk/native-tls"]
rustls-tls = ["matrix-sdk/rustls-tls"]
socks = ["matrix-sdk/socks"]
sso_login = ["matrix-sdk/sso_login"]
require_auth_for_profile_requests = ["matrix-sdk/require_auth_for_profile_requests"]
docs = ["warp"] docs = ["warp"]
@ -26,7 +36,7 @@ tracing = "0.1"
url = "2" url = "2"
warp = { git = "https://github.com/seanmonstar/warp.git", rev = "629405", optional = true, default-features = false } warp = { git = "https://github.com/seanmonstar/warp.git", rev = "629405", optional = true, default-features = false }
matrix-sdk = { version = "0.3", path = "../matrix_sdk", default-features = false, features = ["appservice", "native-tls"] } matrix-sdk = { version = "0.3", path = "../matrix_sdk", default-features = false, features = ["appservice"] }
[dependencies.ruma] [dependencies.ruma]
version = "0.3.0" version = "0.3.0"

View File

@ -2,7 +2,6 @@ use std::{convert::TryFrom, env};
use matrix_sdk_appservice::{ use matrix_sdk_appservice::{
matrix_sdk::{ matrix_sdk::{
async_trait,
room::Room, room::Room,
ruma::{ ruma::{
events::{ events::{
@ -11,52 +10,28 @@ use matrix_sdk_appservice::{
}, },
UserId, UserId,
}, },
EventHandler,
}, },
AppService, AppServiceRegistration, AppService, AppServiceRegistration, Result,
}; };
use tracing::{error, trace}; use tracing::trace;
struct AppServiceEventHandler {
appservice: AppService,
}
impl AppServiceEventHandler {
pub fn new(appservice: AppService) -> Self {
Self { appservice }
}
pub async fn handle_room_member( pub async fn handle_room_member(
&self, appservice: AppService,
room: Room, room: Room,
event: &SyncStateEvent<MemberEventContent>, event: SyncStateEvent<MemberEventContent>,
) -> Result<(), Box<dyn std::error::Error>> { ) -> Result<()> {
if !self.appservice.user_id_is_in_namespace(&event.state_key)? { if !appservice.user_id_is_in_namespace(&event.state_key)? {
trace!("not an appservice user: {}", event.state_key); trace!("not an appservice user: {}", event.state_key);
} else if let MembershipState::Invite = event.content.membership { } else if let MembershipState::Invite = event.content.membership {
let user_id = UserId::try_from(event.state_key.clone())?; let user_id = UserId::try_from(event.state_key.as_str())?;
let appservice = self.appservice.clone();
appservice.register_virtual_user(user_id.localpart()).await?; appservice.register_virtual_user(user_id.localpart()).await?;
let client = appservice.virtual_user_client(user_id.localpart()).await?; let client = appservice.virtual_user_client(user_id.localpart()).await?;
client.join_room_by_id(room.room_id()).await?; client.join_room_by_id(room.room_id()).await?;
} }
Ok(()) Ok(())
} }
}
#[async_trait]
impl EventHandler for AppServiceEventHandler {
async fn on_room_member(&self, room: Room, event: &SyncStateEvent<MemberEventContent>) {
match self.handle_room_member(room, event).await {
Ok(_) => (),
Err(error) => error!("{:?}", error),
}
}
}
#[tokio::main] #[tokio::main]
pub async fn main() -> Result<(), Box<dyn std::error::Error>> { pub async fn main() -> Result<(), Box<dyn std::error::Error>> {
@ -68,7 +43,14 @@ pub async fn main() -> Result<(), Box<dyn std::error::Error>> {
let registration = AppServiceRegistration::try_from_yaml_file("./tests/registration.yaml")?; let registration = AppServiceRegistration::try_from_yaml_file("./tests/registration.yaml")?;
let mut appservice = AppService::new(homeserver_url, server_name, registration).await?; let mut appservice = AppService::new(homeserver_url, server_name, registration).await?;
appservice.set_event_handler(Box::new(AppServiceEventHandler::new(appservice.clone()))).await?; appservice
.register_event_handler({
let appservice = appservice.clone();
move |event: SyncStateEvent<MemberEventContent>, room: Room| {
handle_room_member(appservice.clone(), room, event)
}
})
.await?;
let (host, port) = appservice.registration().get_host_and_port()?; let (host, port) = appservice.registration().get_host_and_port()?;
appservice.run(host, port).await?; appservice.run(host, port).await?;

View File

@ -87,3 +87,9 @@ impl From<warp::Rejection> for Error {
Self::WarpRejection(format!("{:?}", rejection)) Self::WarpRejection(format!("{:?}", rejection))
} }
} }
impl From<matrix_sdk::HttpError> for Error {
fn from(e: matrix_sdk::HttpError) -> Self {
matrix_sdk::Error::from(e).into()
}
}

View File

@ -34,14 +34,10 @@
//! ```no_run //! ```no_run
//! # async { //! # async {
//! # //! #
//! # use matrix_sdk::{async_trait, EventHandler}; //! use matrix_sdk_appservice::{
//! # //! ruma::events::{SyncStateEvent, room::member::MemberEventContent},
//! # struct MyEventHandler; //! AppService, AppServiceRegistration
//! # //! };
//! # #[async_trait]
//! # impl EventHandler for MyEventHandler {}
//! #
//! use matrix_sdk_appservice::{AppService, AppServiceRegistration};
//! //!
//! let homeserver_url = "http://127.0.0.1:8008"; //! let homeserver_url = "http://127.0.0.1:8008";
//! let server_name = "localhost"; //! let server_name = "localhost";
@ -59,7 +55,9 @@
//! ")?; //! ")?;
//! //!
//! let mut appservice = AppService::new(homeserver_url, server_name, registration).await?; //! let mut appservice = AppService::new(homeserver_url, server_name, registration).await?;
//! appservice.set_event_handler(Box::new(MyEventHandler)).await?; //! appservice.register_event_handler(|_ev: SyncStateEvent<MemberEventContent>| async {
//! // do stuff
//! });
//! //!
//! let (host, port) = appservice.registration().get_host_and_port()?; //! let (host, port) = appservice.registration().get_host_and_port()?;
//! appservice.run(host, port).await?; //! appservice.run(host, port).await?;
@ -80,6 +78,7 @@ compile_error!("one webserver feature must be enabled. available ones: `warp`");
use std::{ use std::{
convert::{TryFrom, TryInto}, convert::{TryFrom, TryInto},
fs::File, fs::File,
future::Future,
ops::Deref, ops::Deref,
path::PathBuf, path::PathBuf,
sync::Arc, sync::Arc,
@ -92,7 +91,10 @@ pub use matrix_sdk;
#[doc(no_inline)] #[doc(no_inline)]
pub use matrix_sdk::ruma; pub use matrix_sdk::ruma;
use matrix_sdk::{ use matrix_sdk::{
bytes::Bytes, reqwest::Url, Client, ClientConfig, EventHandler, HttpError, Session, bytes::Bytes,
event_handler::{EventHandler, EventHandlerResult, SyncEvent},
reqwest::Url,
Client, ClientConfig, Session,
}; };
use regex::Regex; use regex::Regex;
use ruma::{ use ruma::{
@ -106,12 +108,13 @@ use ruma::{
}, },
assign, identifiers, DeviceId, ServerNameBox, UserId, assign, identifiers, DeviceId, ServerNameBox, UserId,
}; };
use serde::de::DeserializeOwned;
use tracing::{info, warn}; use tracing::{info, warn};
mod error; mod error;
mod webserver; mod webserver;
pub type Result<T> = std::result::Result<T, Error>; pub type Result<T, E = Error> = std::result::Result<T, E>;
pub type Host = String; pub type Host = String;
pub type Port = u16; pub type Port = u16;
@ -354,8 +357,8 @@ impl AppService {
Ok(entry.value().clone()) Ok(entry.value().clone())
} }
/// Convenience wrapper around [`Client::set_event_handler()`] that attaches /// Convenience wrapper around [`Client::register_event_handler()`] that
/// the event handler to the [`MainUser`]'s [`Client`] /// attaches the event handler to the [`MainUser`]'s [`Client`]
/// ///
/// Note that the event handler in the [`AppService`] context only triggers /// Note that the event handler in the [`AppService`] context only triggers
/// [`join` room `timeline` events], so no state events or events from the /// [`join` room `timeline` events], so no state events or events from the
@ -370,10 +373,14 @@ impl AppService {
/// ///
/// [`join` room `timeline` events]: https://spec.matrix.org/unstable/client-server-api/#get_matrixclientr0sync /// [`join` room `timeline` events]: https://spec.matrix.org/unstable/client-server-api/#get_matrixclientr0sync
/// [MSC2409]: https://github.com/matrix-org/matrix-doc/pull/2409 /// [MSC2409]: https://github.com/matrix-org/matrix-doc/pull/2409
pub async fn set_event_handler(&mut self, handler: Box<dyn EventHandler>) -> Result<()> { pub async fn register_event_handler<Ev, Ctx, H>(&mut self, handler: H) -> Result<()>
where
Ev: SyncEvent + DeserializeOwned + Send + 'static,
H: EventHandler<Ev, Ctx>,
<H::Future as Future>::Output: EventHandlerResult,
{
let client = self.get_cached_client(None)?; let client = self.get_cached_client(None)?;
client.register_event_handler(handler).await;
client.set_event_handler(handler).await;
Ok(()) Ok(())
} }
@ -395,9 +402,9 @@ impl AppService {
match client.register(request).await { match client.register(request).await {
Ok(_) => (), Ok(_) => (),
Err(error) => match error { Err(error) => match error {
matrix_sdk::Error::Http(HttpError::UiaaError(FromHttpResponseError::Http( matrix_sdk::HttpError::UiaaError(FromHttpResponseError::Http(
ServerError::Known(UiaaResponse::MatrixError(ref matrix_error)), ServerError::Known(UiaaResponse::MatrixError(ref matrix_error)),
))) => { )) => {
match matrix_error.kind { match matrix_error.kind {
ErrorKind::UserInUse => { ErrorKind::UserInUse => {
// TODO: persist the fact that we registered that user // TODO: persist the fact that we registered that user

View File

@ -1,13 +1,14 @@
use std::sync::{Arc, Mutex}; use std::{
future,
sync::{Arc, Mutex},
};
use matrix_sdk::{ use matrix_sdk::{
async_trait,
room::Room,
ruma::{ ruma::{
api::appservice::Registration, api::appservice::Registration,
events::{room::member::MemberEventContent, SyncStateEvent}, events::{room::member::MemberEventContent, SyncStateEvent},
}, },
ClientConfig, EventHandler, RequestConfig, ClientConfig, RequestConfig,
}; };
use matrix_sdk_appservice::*; use matrix_sdk_appservice::*;
use matrix_sdk_test::{appservice::TransactionBuilder, async_test, EventsJson}; use matrix_sdk_test::{appservice::TransactionBuilder, async_test, EventsJson};
@ -203,28 +204,17 @@ async fn test_no_access_token() -> Result<()> {
async fn test_event_handler() -> Result<()> { async fn test_event_handler() -> Result<()> {
let mut appservice = appservice(None).await?; let mut appservice = appservice(None).await?;
#[derive(Clone)]
struct Example {
pub on_state_member: Arc<Mutex<bool>>,
}
impl Example {
pub fn new() -> Self {
#[allow(clippy::mutex_atomic)] #[allow(clippy::mutex_atomic)]
Self { on_state_member: Arc::new(Mutex::new(false)) } let on_state_member = Arc::new(Mutex::new(false));
} appservice
} .register_event_handler({
let on_state_member = on_state_member.clone();
#[async_trait] move |_ev: SyncStateEvent<MemberEventContent>| {
impl EventHandler for Example {
async fn on_room_member(&self, _: Room, _: &SyncStateEvent<MemberEventContent>) {
let on_state_member = self.on_state_member.clone();
*on_state_member.lock().unwrap() = true; *on_state_member.lock().unwrap() = true;
future::ready(())
} }
} })
.await?;
let example = Example::new();
appservice.set_event_handler(Box::new(example.clone())).await?;
let uri = "/_matrix/app/v1/transactions/1?access_token=hs_token"; let uri = "/_matrix/app/v1/transactions/1?access_token=hs_token";
@ -241,7 +231,7 @@ async fn test_event_handler() -> Result<()> {
.await .await
.unwrap(); .unwrap();
let on_room_member_called = *example.on_state_member.lock().unwrap(); let on_room_member_called = *on_state_member.lock().unwrap();
assert!(on_room_member_called); assert!(on_room_member_called);
Ok(()) Ok(())

View File

@ -50,16 +50,15 @@ use ruma::{
use ruma::{ use ruma::{
api::client::r0::{self as api, push::get_notifications::Notification}, api::client::r0::{self as api, push::get_notifications::Notification},
events::{ events::{
room::member::{MemberEventContent, MembershipState}, room::member::MembershipState, AnyGlobalAccountDataEvent, AnyRoomAccountDataEvent,
AnyGlobalAccountDataEvent, AnyRoomAccountDataEvent, AnyStrippedStateEvent, AnyStrippedStateEvent, AnySyncEphemeralRoomEvent, AnySyncRoomEvent, AnySyncStateEvent,
AnySyncEphemeralRoomEvent, AnySyncRoomEvent, AnySyncStateEvent, EventContent, EventType, EventContent, EventType,
StateEvent,
}, },
push::{Action, PushConditionRoomCtx, Ruleset}, push::{Action, PushConditionRoomCtx, Ruleset},
serde::Raw, serde::Raw,
MilliSecondsSinceUnixEpoch, RoomId, UInt, UserId, MilliSecondsSinceUnixEpoch, RoomId, UInt, UserId,
}; };
use tracing::{info, warn}; use tracing::{info, trace, warn};
use zeroize::Zeroizing; use zeroize::Zeroizing;
use crate::{ use crate::{
@ -71,97 +70,6 @@ use crate::{
pub type Token = String; pub type Token = String;
/// A deserialization wrapper for extracting the prev_content field when
/// found in an `unsigned` field.
///
/// Represents the outer `unsigned` field
#[derive(serde::Deserialize)]
pub struct AdditionalEventData {
unsigned: AdditionalUnsignedData,
}
/// A deserialization wrapper for extracting the prev_content field when
/// found in an `unsigned` field.
///
/// Represents the inner `prev_content` field
#[derive(serde::Deserialize)]
pub struct AdditionalUnsignedData {
pub prev_content: Option<Raw<MemberEventContent>>,
}
/// Transform an `AnySyncStateEvent` by hoisting `prev_content` field from
/// `unsigned` to the top level.
///
/// Due to a [bug in synapse][synapse-bug], `prev_content` often ends up in
/// `unsigned` contrary to the C2S spec. Some more discussion can be found
/// [here][discussion]. Until this is fixed in synapse or handled in Ruma, we
/// use this to hoist up `prev_content` to the top level.
///
/// [synapse-bug]: <https://github.com/matrix-org/matrix-doc/issues/684#issuecomment-641182668>
/// [discussion]: <https://github.com/matrix-org/matrix-doc/issues/684#issuecomment-641182668>
pub fn hoist_and_deserialize_state_event(
event: &Raw<AnySyncStateEvent>,
) -> StdResult<AnySyncStateEvent, serde_json::Error> {
let prev_content = event.deserialize_as::<AdditionalEventData>()?.unsigned.prev_content;
let mut ev = event.deserialize()?;
if let AnySyncStateEvent::RoomMember(ref mut member) = ev {
if member.prev_content.is_none() {
member.prev_content = prev_content.and_then(|e| e.deserialize().ok());
}
}
Ok(ev)
}
fn hoist_member_event(
event: &Raw<StateEvent<MemberEventContent>>,
) -> StdResult<StateEvent<MemberEventContent>, serde_json::Error> {
let prev_content = event.deserialize_as::<AdditionalEventData>()?.unsigned.prev_content;
let mut e = event.deserialize()?;
if e.prev_content.is_none() {
e.prev_content = prev_content.and_then(|e| e.deserialize().ok());
}
Ok(e)
}
/// Transform an `AnySyncRoomEvent` by hoisting `prev_content` field from
/// `unsigned` to the top level.
///
/// Due to a [bug in synapse][synapse-bug], `prev_content` often ends up in
/// `unsigned` contrary to the C2S spec. Some more discussion can be found
/// [here][discussion]. Until this is fixed in synapse or handled in Ruma, we
/// use this to hoist up `prev_content` to the top level.
///
/// [synapse-bug]: <https://github.com/matrix-org/matrix-doc/issues/684#issuecomment-641182668>
/// [discussion]: <https://github.com/matrix-org/matrix-doc/issues/684#issuecomment-641182668>
pub fn hoist_room_event_prev_content(
event: &Raw<AnySyncRoomEvent>,
) -> StdResult<AnySyncRoomEvent, serde_json::Error> {
let prev_content = event
.deserialize_as::<AdditionalEventData>()
.map(|more_unsigned| more_unsigned.unsigned)
.map(|additional| additional.prev_content)?
.and_then(|p| p.deserialize().ok());
let mut ev = event.deserialize()?;
match &mut ev {
AnySyncRoomEvent::State(AnySyncStateEvent::RoomMember(ref mut member))
if member.prev_content.is_none() =>
{
member.prev_content = prev_content;
}
_ => (),
}
Ok(ev)
}
/// A no IO Client implementation. /// A no IO Client implementation.
/// ///
/// This Client is a state machine that receives responses and events and /// This Client is a state machine that receives responses and events and
@ -445,7 +353,7 @@ impl BaseClient {
#[allow(unused_mut)] #[allow(unused_mut)]
let mut event: SyncRoomEvent = event.into(); let mut event: SyncRoomEvent = event.into();
match hoist_room_event_prev_content(&event.event) { match event.event.deserialize() {
Ok(e) => { Ok(e) => {
#[allow(clippy::single_match)] #[allow(clippy::single_match)]
match &e { match &e {
@ -611,7 +519,7 @@ impl BaseClient {
let room_id = room_info.room_id.clone(); let room_id = room_info.room_id.clone();
for raw_event in events { for raw_event in events {
let event = match hoist_and_deserialize_state_event(raw_event) { let event = match raw_event.deserialize() {
Ok(e) => e, Ok(e) => e,
Err(e) => { Err(e) => {
warn!( warn!(
@ -687,15 +595,23 @@ impl BaseClient {
let mut account_data = BTreeMap::new(); let mut account_data = BTreeMap::new();
for raw_event in events { for raw_event in events {
let event = if let Ok(e) = raw_event.deserialize() { let event = match raw_event.deserialize() {
e Ok(e) => e,
} else { Err(e) => {
warn!(error =? e, "Failed to deserialize a global account data event");
continue; continue;
}
}; };
if let AnyGlobalAccountDataEvent::Direct(e) = &event { if let AnyGlobalAccountDataEvent::Direct(e) = &event {
for (user_id, rooms) in e.content.iter() { for (user_id, rooms) in e.content.iter() {
for room_id in rooms { for room_id in rooms {
trace!(
room_id = room_id.as_str(),
target = user_id.as_str(),
"Marking room as direct room"
);
if let Some(room) = changes.room_infos.get_mut(room_id) { if let Some(room) = changes.room_infos.get_mut(room_id) {
room.base_info.dm_target = Some(user_id.clone()); room.base_info.dm_target = Some(user_id.clone());
} else if let Some(room) = self.store.get_room(room_id) { } else if let Some(room) = self.store.get_room(room_id) {
@ -916,6 +832,12 @@ impl BaseClient {
new_rooms.invite.insert(room_id, new_info); new_rooms.invite.insert(room_id, new_info);
} }
// TODO remove this, we're processing account data events here again
// because we want to have the push rules in place before we process
// rooms and their events, but we want to create the rooms before we
// process the `m.direct` account data event.
self.handle_account_data(&account_data.events, &mut changes).await;
changes.presence = presence changes.presence = presence
.events .events
.iter() .iter()
@ -976,7 +898,7 @@ impl BaseClient {
let members: Vec<MemberEvent> = response let members: Vec<MemberEvent> = response
.chunk .chunk
.iter() .iter()
.filter_map(|e| hoist_member_event(e).ok().and_then(|e| MemberEvent::try_from(e).ok())) .filter_map(|e| e.deserialize().ok().and_then(|e| MemberEvent::try_from(e).ok()))
.collect(); .collect();
let mut ambiguity_cache = AmbiguityCache::new(self.store.clone()); let mut ambiguity_cache = AmbiguityCache::new(self.store.clone());

View File

@ -50,9 +50,7 @@ mod rooms;
mod session; mod session;
mod store; mod store;
pub use client::{ pub use client::{BaseClient, BaseClientConfig};
hoist_and_deserialize_state_event, hoist_room_event_prev_content, BaseClient, BaseClientConfig,
};
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
#[cfg_attr(feature = "docs", doc(cfg(encryption)))] #[cfg_attr(feature = "docs", doc(cfg(encryption)))]
pub use matrix_sdk_crypto as crypto; pub use matrix_sdk_crypto as crypto;

View File

@ -66,30 +66,32 @@ pub struct MemoryStore {
room_event_receipts: room_event_receipts:
Arc<DashMap<RoomId, DashMap<String, DashMap<EventId, DashMap<UserId, Receipt>>>>>, Arc<DashMap<RoomId, DashMap<String, DashMap<EventId, DashMap<UserId, Receipt>>>>>,
media: Arc<Mutex<LruCache<String, Vec<u8>>>>, media: Arc<Mutex<LruCache<String, Vec<u8>>>>,
custom: Arc<DashMap<Vec<u8>, Vec<u8>>>,
} }
impl MemoryStore { impl MemoryStore {
#[allow(dead_code)] #[allow(dead_code)]
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
sync_token: Arc::new(RwLock::new(None)), sync_token: Default::default(),
filters: DashMap::new().into(), filters: Default::default(),
account_data: DashMap::new().into(), account_data: Default::default(),
members: DashMap::new().into(), members: Default::default(),
profiles: DashMap::new().into(), profiles: Default::default(),
display_names: DashMap::new().into(), display_names: Default::default(),
joined_user_ids: DashMap::new().into(), joined_user_ids: Default::default(),
invited_user_ids: DashMap::new().into(), invited_user_ids: Default::default(),
room_info: DashMap::new().into(), room_info: Default::default(),
room_state: DashMap::new().into(), room_state: Default::default(),
room_account_data: DashMap::new().into(), room_account_data: Default::default(),
stripped_room_info: DashMap::new().into(), stripped_room_info: Default::default(),
stripped_room_state: DashMap::new().into(), stripped_room_state: Default::default(),
stripped_members: DashMap::new().into(), stripped_members: Default::default(),
presence: DashMap::new().into(), presence: Default::default(),
room_user_receipts: DashMap::new().into(), room_user_receipts: Default::default(),
room_event_receipts: DashMap::new().into(), room_event_receipts: Default::default(),
media: Arc::new(Mutex::new(LruCache::new(100))), media: Arc::new(Mutex::new(LruCache::new(100))),
custom: DashMap::new().into(),
} }
} }
@ -407,6 +409,14 @@ impl MemoryStore {
.unwrap_or_else(Vec::new)) .unwrap_or_else(Vec::new))
} }
async fn get_custom_value(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
Ok(self.custom.get(key).map(|e| e.value().clone()))
}
async fn set_custom_value(&self, key: &[u8], value: Vec<u8>) -> Result<Option<Vec<u8>>> {
Ok(self.custom.insert(key.to_vec(), value))
}
async fn add_media_content(&self, request: &MediaRequest, data: Vec<u8>) -> Result<()> { async fn add_media_content(&self, request: &MediaRequest, data: Vec<u8>) -> Result<()> {
self.media.lock().await.put(request.unique_key(), data); self.media.lock().await.put(request.unique_key(), data);
@ -563,6 +573,14 @@ impl StateStore for MemoryStore {
self.get_event_room_receipt_events(room_id, receipt_type, event_id).await self.get_event_room_receipt_events(room_id, receipt_type, event_id).await
} }
async fn get_custom_value(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
self.get_custom_value(key).await
}
async fn set_custom_value(&self, key: &[u8], value: Vec<u8>) -> Result<Option<Vec<u8>>> {
self.set_custom_value(key, value).await
}
async fn add_media_content(&self, request: &MediaRequest, data: Vec<u8>) -> Result<()> { async fn add_media_content(&self, request: &MediaRequest, data: Vec<u8>) -> Result<()> {
self.add_media_content(request, data).await self.add_media_content(request, data).await
} }

View File

@ -263,6 +263,22 @@ pub trait StateStore: AsyncTraitDeps {
event_id: &EventId, event_id: &EventId,
) -> Result<Vec<(UserId, Receipt)>>; ) -> Result<Vec<(UserId, Receipt)>>;
/// Get arbitrary data from the custom store
///
/// # Arguments
///
/// * `key` - The key to fetch data for
async fn get_custom_value(&self, key: &[u8]) -> Result<Option<Vec<u8>>>;
/// Put arbitrary data into the custom store
///
/// # Arguments
///
/// * `key` - The key to insert data into
///
/// * `value` - The value to insert
async fn set_custom_value(&self, key: &[u8], value: Vec<u8>) -> Result<Option<Vec<u8>>>;
/// Add a media file's content in the media store. /// Add a media file's content in the media store.
/// ///
/// # Arguments /// # Arguments
@ -310,15 +326,12 @@ pub struct Store {
impl Store { impl Store {
fn new(inner: Box<dyn StateStore>) -> Self { fn new(inner: Box<dyn StateStore>) -> Self {
let session = Arc::new(RwLock::new(None));
let sync_token = Arc::new(RwLock::new(None));
Self { Self {
inner: inner.into(), inner: inner.into(),
session, session: Default::default(),
sync_token, sync_token: Default::default(),
rooms: DashMap::new().into(), rooms: Default::default(),
stripped_rooms: DashMap::new().into(), stripped_rooms: Default::default(),
} }
} }

View File

@ -189,6 +189,7 @@ pub struct SledStore {
room_user_receipts: Tree, room_user_receipts: Tree,
room_event_receipts: Tree, room_event_receipts: Tree,
media: Tree, media: Tree,
custom: Tree,
} }
impl std::fmt::Debug for SledStore { impl std::fmt::Debug for SledStore {
@ -226,6 +227,8 @@ impl SledStore {
let media = db.open_tree("media")?; let media = db.open_tree("media")?;
let custom = db.open_tree("custom")?;
Ok(Self { Ok(Self {
path, path,
inner: db, inner: db,
@ -247,6 +250,7 @@ impl SledStore {
room_user_receipts, room_user_receipts,
room_event_receipts, room_event_receipts,
media, media,
custom,
}) })
} }
@ -762,6 +766,17 @@ impl SledStore {
.map(|m| m.to_vec())) .map(|m| m.to_vec()))
} }
async fn get_custom_value(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
Ok(self.custom.get(key)?.map(|v| v.to_vec()))
}
async fn set_custom_value(&self, key: &[u8], value: Vec<u8>) -> Result<Option<Vec<u8>>> {
let ret = self.custom.insert(key, value)?.map(|v| v.to_vec());
self.inner.flush_async().await?;
Ok(ret)
}
async fn remove_media_content(&self, request: &MediaRequest) -> Result<()> { async fn remove_media_content(&self, request: &MediaRequest) -> Result<()> {
self.media.remove( self.media.remove(
(request.media_type.unique_key().as_str(), request.format.unique_key().as_str()) (request.media_type.unique_key().as_str(), request.format.unique_key().as_str())
@ -899,6 +914,14 @@ impl StateStore for SledStore {
self.get_event_room_receipt_events(room_id, receipt_type, event_id).await self.get_event_room_receipt_events(room_id, receipt_type, event_id).await
} }
async fn get_custom_value(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
self.get_custom_value(key).await
}
async fn set_custom_value(&self, key: &[u8], value: Vec<u8>) -> Result<Option<Vec<u8>>> {
self.set_custom_value(key, value).await
}
async fn add_media_content(&self, request: &MediaRequest, data: Vec<u8>) -> Result<()> { async fn add_media_content(&self, request: &MediaRequest, data: Vec<u8>) -> Result<()> {
self.add_media_content(request, data).await self.add_media_content(request, data).await
} }
@ -939,7 +962,7 @@ mod test {
}; };
use serde_json::json; use serde_json::json;
use super::{SledStore, StateChanges}; use super::{Result, SledStore, StateChanges};
use crate::{ use crate::{
deserialized_responses::MemberEvent, deserialized_responses::MemberEvent,
media::{MediaFormat, MediaRequest, MediaThumbnailSize, MediaType}, media::{MediaFormat, MediaRequest, MediaThumbnailSize, MediaType},
@ -1155,4 +1178,19 @@ mod test {
assert!(store.get_media_content(&request_file).await.unwrap().is_none()); assert!(store.get_media_content(&request_file).await.unwrap().is_none());
assert!(store.get_media_content(&request_thumbnail).await.unwrap().is_none()); assert!(store.get_media_content(&request_thumbnail).await.unwrap().is_none());
} }
#[async_test]
async fn test_custom_storage() -> Result<()> {
let key = "my_key";
let value = &[0, 1, 2, 3];
let store = SledStore::open()?;
store.set_custom_value(key.as_bytes(), value.to_vec()).await?;
let read = store.get_custom_value(key.as_bytes()).await?;
assert_eq!(Some(value.as_ref()), read.as_deref());
Ok(())
}
} }

View File

@ -58,7 +58,7 @@ indoc = "1.0.3"
criterion = { version = "0.3.4", features = ["async", "async_tokio", "html_reports"] } criterion = { version = "0.3.4", features = ["async", "async_tokio", "html_reports"] }
[target.'cfg(target_os = "linux")'.dev-dependencies] [target.'cfg(target_os = "linux")'.dev-dependencies]
pprof = { version = "0.4.3", features = ["flamegraph"] } pprof = { version = "0.5.0", features = ["flamegraph", "criterion"] }
[[bench]] [[bench]]
name = "crypto_bench" name = "crypto_bench"

View File

@ -0,0 +1,80 @@
# Benchmarks for the rust-sdk crypto layer
This directory contains various benchmarks that test critical functionality in
the crypto layer in the rust-sdk.
We're using [Criterion] for the benchmarks, the full documentation for Criterion
can be found [here](https://bheisler.github.io/criterion.rs/book/criterion_rs.html).
## Running the benchmarks
The benchmark can be simply run by using the `bench` command of `cargo`:
```bash
$ cargo bench
```
This will work from the workspace directory of the rust-sdk.
If you want to pass options to the benchmark [you'll need to specify the name of
the benchmark](https://bheisler.github.io/criterion.rs/book/faq.html#cargo-bench-gives-unrecognized-option-errors-for-valid-command-line-options):
```bash
$ cargo bench --bench crypto_bench -- # Your options go here
```
If you want to run only a specific benchmark, simply pass the name of the
benchmark as an argument:
```bash
$ cargo bench --bench crypto_bench "Room key sharing/"
```
After the benchmarks are done, a HTML report can be found in `target/criterion/report/index.html`.
### Using a baseline for the benchmark
The benchmarks will by default compare the results to the previous run of the
benchmark. If you are improving the performance of a specific feature and run
the benchmark many times, it may be useful to store a baseline to compare
against instead.
The `--save-baseline` switch can be used to create a baseline for the benchmark.
```bash
$ cargo bench --bench crypto_bench -- --save-baseline libolm
```
After you make your changes you can use the baseline to compare the results like
so:
```bash
$ cargo bench --bench crypto_bench -- --baseline libolm
```
### Generating Flame Graphs for the benchmarks
The benchmarks support profiling and generating [Flame Graphs] while they run in
profiling mode using [pprof].
Profiling usually requieres root permissions, to avoid the need for root
permissions you can adjust the value of `perf_event_paranoid`, e.g. the most
permisive value is `-1`:
```bash
$ echo -1 | sudo tee /proc/sys/kernel/perf_event_paranoid
```
To generate flame graphs feature simply enable the profiling mode using the
`--profile-time` command line flag:
```bash
$ cargo bench --bench crypto_bench -- --profile-time=5
```
After the benchmarks are done, a flame graph for each individual benchmark can be
found in `target/criterion/<name-of-benchmark>/profile/flamegraph.svg`.
[pprof]: https://docs.rs/pprof/0.5.0/pprof/index.html#
[Criterion]: https://docs.rs/criterion/0.3.5/criterion/
[Flame Graphs]: https://www.brendangregg.com/flamegraphs.html

View File

@ -1,6 +1,3 @@
#[cfg(target_os = "linux")]
mod perf;
use std::sync::Arc; use std::sync::Arc;
use criterion::*; use criterion::*;
@ -262,7 +259,10 @@ pub fn devices_missing_sessions_collecting(c: &mut Criterion) {
fn criterion() -> Criterion { fn criterion() -> Criterion {
#[cfg(target_os = "linux")] #[cfg(target_os = "linux")]
let criterion = Criterion::default().with_profiler(perf::FlamegraphProfiler::new(100)); let criterion = Criterion::default().with_profiler(pprof::criterion::PProfProfiler::new(
100,
pprof::criterion::Output::Flamegraph(None),
));
#[cfg(not(target_os = "linux"))] #[cfg(not(target_os = "linux"))]
let criterion = Criterion::default(); let criterion = Criterion::default();

View File

@ -1,76 +0,0 @@
//! This is a simple Criterion Profiler implementation using pprof.
//!
//! It's mostly a direct copy from here: https://www.jibbow.com/posts/criterion-flamegraphs/
use std::{fs::File, os::raw::c_int, path::Path};
use criterion::profiler::Profiler;
use pprof::ProfilerGuard;
/// Small custom profiler that can be used with Criterion to create a flamegraph
/// for benchmarks. Also see [the Criterion documentation on
/// this][custom-profiler].
///
/// ## Example on how to enable the custom profiler:
///
/// ```
/// mod perf;
/// use perf::FlamegraphProfiler;
///
/// fn fibonacci_profiled(criterion: &mut Criterion) {
/// // Use the criterion struct as normal here.
/// }
///
/// fn custom() -> Criterion {
/// Criterion::default().with_profiler(FlamegraphProfiler::new())
/// }
///
/// criterion_group! {
/// name = benches;
/// config = custom();
/// targets = fibonacci_profiled
/// }
/// ```
///
/// The neat thing about this is that it will sample _only_ the benchmark, and
/// not other stuff like the setup process.
///
/// Further, it will only kick in if `--profile-time <time>` is passed to the
/// benchmark binary. A flamegraph will be created for each individual benchmark
/// in its report directory under `profile/flamegraph.svg`.
///
/// [custom-profiler]: https://bheisler.github.io/criterion.rs/book/user_guide/profiling.html#implementing-in-process-profiling-hooks
pub struct FlamegraphProfiler<'a> {
frequency: c_int,
active_profiler: Option<ProfilerGuard<'a>>,
}
impl<'a> FlamegraphProfiler<'a> {
pub fn new(frequency: c_int) -> Self {
FlamegraphProfiler { frequency, active_profiler: None }
}
}
impl<'a> Profiler for FlamegraphProfiler<'a> {
fn start_profiling(&mut self, _benchmark_id: &str, _benchmark_dir: &Path) {
self.active_profiler = Some(ProfilerGuard::new(self.frequency).unwrap());
}
fn stop_profiling(&mut self, _benchmark_id: &str, benchmark_dir: &Path) {
std::fs::create_dir_all(benchmark_dir)
.expect("Can't create a directory to store the benchmarking report");
let flamegraph_path = benchmark_dir.join("flamegraph.svg");
let flamegraph_file = File::create(&flamegraph_path)
.expect("File system error while creating flamegraph.svg");
if let Some(profiler) = self.active_profiler.take() {
profiler
.report()
.build()
.expect("Can't build profiling report")
.flamegraph(flamegraph_file)
.expect("Error writing flamegraph");
}
}
}

View File

@ -76,8 +76,8 @@ impl GossipMachine {
device_id, device_id,
store, store,
outbound_group_sessions, outbound_group_sessions,
outgoing_requests: DashMap::new().into(), outgoing_requests: Default::default(),
incoming_key_requests: DashMap::new().into(), incoming_key_requests: Default::default(),
wait_queue: WaitQueue::new(), wait_queue: WaitQueue::new(),
users_for_key_claim, users_for_key_claim,
} }

View File

@ -45,7 +45,7 @@ pub use file_encryption::{
DecryptorError, EncryptionInfo, KeyExportError, DecryptorError, EncryptionInfo, KeyExportError,
}; };
pub use identities::{ pub use identities::{
Device, LocalTrust, OwnUserIdentity, ReadOnlyDevice, ReadOnlyOwnUserIdentity, Device, LocalTrust, MasterPubkey, OwnUserIdentity, ReadOnlyDevice, ReadOnlyOwnUserIdentity,
ReadOnlyUserIdentities, ReadOnlyUserIdentity, UserDevices, UserIdentities, UserIdentity, ReadOnlyUserIdentities, ReadOnlyUserIdentity, UserDevices, UserIdentities, UserIdentity,
}; };
pub use machine::OlmMachine; pub use machine::OlmMachine;

View File

@ -14,7 +14,11 @@
#[cfg(feature = "sled_cryptostore")] #[cfg(feature = "sled_cryptostore")]
use std::path::Path; use std::path::Path;
use std::{collections::BTreeMap, mem, sync::Arc}; use std::{
collections::{BTreeMap, HashSet},
mem,
sync::Arc,
};
use dashmap::DashMap; use dashmap::DashMap;
use matrix_sdk_common::{ use matrix_sdk_common::{
@ -293,6 +297,11 @@ impl OlmMachine {
self.store.device_display_name().await self.store.device_display_name().await
} }
/// Get all the tracked users we know about
pub fn tracked_users(&self) -> HashSet<UserId> {
self.store.tracked_users()
}
/// Get the outgoing requests that need to be sent out. /// Get the outgoing requests that need to be sent out.
/// ///
/// This returns a list of `OutGoingRequest`, those requests need to be sent /// This returns a list of `OutGoingRequest`, those requests need to be sent

View File

@ -49,11 +49,7 @@ pub(crate) struct GroupSessionCache {
impl GroupSessionCache { impl GroupSessionCache {
pub(crate) fn new(store: Store) -> Self { pub(crate) fn new(store: Store) -> Self {
Self { Self { store, sessions: Default::default(), sessions_being_shared: Default::default() }
store,
sessions: DashMap::new().into(),
sessions_being_shared: Arc::new(DashMap::new()),
}
} }
pub(crate) fn insert(&self, session: OutboundGroupSession) { pub(crate) fn insert(&self, session: OutboundGroupSession) {

View File

@ -64,8 +64,8 @@ impl SessionManager {
store, store,
key_request_machine, key_request_machine,
users_for_key_claim, users_for_key_claim,
wedged_devices: Arc::new(DashMap::new()), wedged_devices: Default::default(),
outgoing_to_device_requests: Arc::new(DashMap::new()), outgoing_to_device_requests: Default::default(),
} }
} }

View File

@ -37,7 +37,7 @@ pub struct SessionStore {
impl SessionStore { impl SessionStore {
/// Create a new empty Session store. /// Create a new empty Session store.
pub fn new() -> Self { pub fn new() -> Self {
SessionStore { entries: Arc::new(DashMap::new()) } Self::default()
} }
/// Add a session to the store. /// Add a session to the store.
@ -82,7 +82,7 @@ pub struct GroupSessionStore {
impl GroupSessionStore { impl GroupSessionStore {
/// Create a new empty store. /// Create a new empty store.
pub fn new() -> Self { pub fn new() -> Self {
GroupSessionStore { entries: Arc::new(DashMap::new()) } Self::default()
} }
/// Add an inbound group session to the store. /// Add an inbound group session to the store.
@ -141,7 +141,7 @@ pub struct DeviceStore {
impl DeviceStore { impl DeviceStore {
/// Create a new empty device store. /// Create a new empty device store.
pub fn new() -> Self { pub fn new() -> Self {
DeviceStore { entries: Arc::new(DashMap::new()) } Self::default()
} }
/// Add a device to the store. /// Add a device to the store.

View File

@ -59,13 +59,13 @@ impl Default for MemoryStore {
MemoryStore { MemoryStore {
sessions: SessionStore::new(), sessions: SessionStore::new(),
inbound_group_sessions: GroupSessionStore::new(), inbound_group_sessions: GroupSessionStore::new(),
tracked_users: Arc::new(DashSet::new()), tracked_users: Default::default(),
users_for_key_query: Arc::new(DashSet::new()), users_for_key_query: Default::default(),
olm_hashes: Arc::new(DashMap::new()), olm_hashes: Default::default(),
devices: DeviceStore::new(), devices: DeviceStore::new(),
identities: Arc::new(DashMap::new()), identities: Default::default(),
outgoing_key_requests: Arc::new(DashMap::new()), outgoing_key_requests: Default::default(),
key_requests_by_info: Arc::new(DashMap::new()), key_requests_by_info: Default::default(),
} }
} }
} }
@ -183,6 +183,10 @@ impl CryptoStore for MemoryStore {
self.users_for_key_query.iter().map(|u| u.clone()).collect() self.users_for_key_query.iter().map(|u| u.clone()).collect()
} }
fn tracked_users(&self) -> HashSet<UserId> {
self.tracked_users.iter().map(|u| u.to_owned()).collect()
}
async fn update_tracked_user(&self, user: &UserId, dirty: bool) -> Result<bool> { async fn update_tracked_user(&self, user: &UserId, dirty: bool) -> Result<bool> {
// TODO to prevent a race between the sync and a key query in flight we // TODO to prevent a race between the sync and a key query in flight we
// need to have an additional state to mention that the user changed. // need to have an additional state to mention that the user changed.

View File

@ -584,6 +584,9 @@ pub trait CryptoStore: AsyncTraitDeps {
/// the tracked users. /// the tracked users.
fn users_for_key_query(&self) -> HashSet<UserId>; fn users_for_key_query(&self) -> HashSet<UserId>;
/// Get all tracked users we know about.
fn tracked_users(&self) -> HashSet<UserId>;
/// Add an user for tracking. /// Add an user for tracking.
/// ///
/// Returns true if the user wasn't already tracked, false otherwise. /// Returns true if the user wasn't already tracked, false otherwise.

View File

@ -673,6 +673,10 @@ impl CryptoStore for SledStore {
!self.users_for_key_query_cache.is_empty() !self.users_for_key_query_cache.is_empty()
} }
fn tracked_users(&self) -> HashSet<UserId> {
self.tracked_users_cache.to_owned().iter().map(|u| u.clone()).collect()
}
fn users_for_key_query(&self) -> HashSet<UserId> { fn users_for_key_query(&self) -> HashSet<UserId> {
#[allow(clippy::map_clone)] #[allow(clippy::map_clone)]
self.users_for_key_query_cache.iter().map(|u| u.clone()).collect() self.users_for_key_query_cache.iter().map(|u| u.clone()).collect()

View File

@ -32,7 +32,7 @@ pub struct VerificationCache {
impl VerificationCache { impl VerificationCache {
pub fn new() -> Self { pub fn new() -> Self {
Self { verification: DashMap::new().into(), outgoing_requests: DashMap::new().into() } Self { verification: Default::default(), outgoing_requests: Default::default() }
} }
#[cfg(test)] #[cfg(test)]

View File

@ -62,7 +62,7 @@ impl VerificationMachine {
private_identity: identity, private_identity: identity,
store: VerificationStore { account, inner: store }, store: VerificationStore { account, inner: store },
verifications: VerificationCache::new(), verifications: VerificationCache::new(),
requests: DashMap::new().into(), requests: Default::default(),
} }
} }