diff --git a/matrix_sdk/Cargo.toml b/matrix_sdk/Cargo.toml index 9cf61d52..847b9444 100644 --- a/matrix_sdk/Cargo.toml +++ b/matrix_sdk/Cargo.toml @@ -32,9 +32,12 @@ appservice = ["ruma/appservice-api-s", "ruma/appservice-api-helper", "ruma/rand" docs = ["encryption", "sled_cryptostore", "sled_state_store", "sso_login"] [dependencies] +anyhow = { version = "1.0.42", optional = true } dashmap = "4.0.2" +event-listener = "2.5.1" futures = "0.3.15" http = "0.2.4" +serde = "1.0.126" serde_json = "1.0.64" thiserror = "1.0.25" tracing = "0.1.26" @@ -91,6 +94,7 @@ version = "3.0.2" features = ["wasm-bindgen"] [dev-dependencies] +anyhow = "1.0" dirs = "3.0.2" matches = "0.1.8" matrix-sdk-test = { version = "0.3.0", path = "../matrix_sdk_test" } diff --git a/matrix_sdk/examples/autojoin.rs b/matrix_sdk/examples/autojoin.rs index 6838b7c1..5db51444 100644 --- a/matrix_sdk/examples/autojoin.rs +++ b/matrix_sdk/examples/autojoin.rs @@ -1,61 +1,41 @@ use std::{env, process::exit}; use matrix_sdk::{ - async_trait, room::Room, ruma::events::{room::member::MemberEventContent, StrippedStateEvent}, - Client, ClientConfig, EventHandler, SyncSettings, + Client, ClientConfig, SyncSettings, }; use tokio::time::{sleep, Duration}; use url::Url; -struct AutoJoinBot { +async fn on_stripped_state_member( + room_member: StrippedStateEvent, client: Client, -} - -impl AutoJoinBot { - pub fn new(client: Client) -> Self { - Self { client } + room: Room, +) { + if room_member.state_key != client.user_id().await.unwrap() { + return; } -} -#[async_trait] -impl EventHandler for AutoJoinBot { - async fn on_stripped_state_member( - &self, - room: Room, - room_member: &StrippedStateEvent, - _: Option, - ) { - if room_member.state_key != self.client.user_id().await.unwrap() { - return; - } + if let Room::Invited(room) = room { + println!("Autojoining room {}", room.room_id()); + let mut delay = 2; - if let Room::Invited(room) = room { - println!("Autojoining room {}", room.room_id()); - let mut delay = 2; + while let Err(err) = room.accept_invitation().await { + // retry autojoin due to synapse sending invites, before the + // invited user can join for more information see + // https://github.com/matrix-org/synapse/issues/4345 + eprintln!("Failed to join room {} ({:?}), retrying in {}s", room.room_id(), err, delay); - while let Err(err) = room.accept_invitation().await { - // retry autojoin due to synapse sending invites, before the - // invited user can join for more information see - // https://github.com/matrix-org/synapse/issues/4345 - eprintln!( - "Failed to join room {} ({:?}), retrying in {}s", - room.room_id(), - err, - delay - ); + sleep(Duration::from_secs(delay)).await; + delay *= 2; - sleep(Duration::from_secs(delay)).await; - delay *= 2; - - if delay > 3600 { - eprintln!("Can't join room {} ({:?})", room.room_id(), err); - break; - } + if delay > 3600 { + eprintln!("Can't join room {} ({:?})", room.room_id(), err); + break; } - println!("Successfully joined room {}", room.room_id()); } + println!("Successfully joined room {}", room.room_id()); } } @@ -76,7 +56,7 @@ async fn login_and_sync( println!("logged in as {}", username); - client.set_event_handler(Box::new(AutoJoinBot::new(client.clone()))).await; + client.register_event_handler(on_stripped_state_member).await; client.sync(SyncSettings::default()).await; diff --git a/matrix_sdk/examples/command_bot.rs b/matrix_sdk/examples/command_bot.rs index 90f00e38..77fbc34e 100644 --- a/matrix_sdk/examples/command_bot.rs +++ b/matrix_sdk/examples/command_bot.rs @@ -1,55 +1,43 @@ use std::{env, process::exit}; use matrix_sdk::{ - async_trait, room::Room, ruma::events::{ room::message::{MessageEventContent, MessageType, TextMessageEventContent}, AnyMessageEventContent, SyncMessageEvent, }, - Client, ClientConfig, EventHandler, SyncSettings, + Client, ClientConfig, SyncSettings, }; use url::Url; -struct CommandBot; +async fn on_room_message(event: SyncMessageEvent, room: Room) { + if let Room::Joined(room) = room { + let msg_body = if let SyncMessageEvent { + content: + MessageEventContent { + msgtype: MessageType::Text(TextMessageEventContent { body: msg_body, .. }), + .. + }, + .. + } = event + { + msg_body + } else { + return; + }; -impl CommandBot { - pub fn new() -> Self { - Self {} - } -} + if msg_body.contains("!party") { + let content = AnyMessageEventContent::RoomMessage(MessageEventContent::text_plain( + "πŸŽ‰πŸŽŠπŸ₯³ let's PARTY!! πŸ₯³πŸŽŠπŸŽ‰", + )); -#[async_trait] -impl EventHandler for CommandBot { - async fn on_room_message(&self, room: Room, event: &SyncMessageEvent) { - if let Room::Joined(room) = room { - let msg_body = if let SyncMessageEvent { - content: - MessageEventContent { - msgtype: MessageType::Text(TextMessageEventContent { body: msg_body, .. }), - .. - }, - .. - } = event - { - msg_body - } else { - return; - }; + println!("sending"); - if msg_body.contains("!party") { - let content = AnyMessageEventContent::RoomMessage(MessageEventContent::text_plain( - "πŸŽ‰πŸŽŠπŸ₯³ let's PARTY!! πŸ₯³πŸŽŠπŸŽ‰", - )); + // send our message to the room we found the "!party" command in + // the last parameter is an optional Uuid which we don't care about. + room.send(content, None).await.unwrap(); - println!("sending"); - - // send our message to the room we found the "!party" command in - // the last parameter is an optional Uuid which we don't care about. - room.send(content, None).await.unwrap(); - - println!("message sent"); - } + println!("message sent"); } } } @@ -79,7 +67,7 @@ async fn login_and_sync( client.sync_once(SyncSettings::default()).await.unwrap(); // add our CommandBot to be notified of incoming messages, we do this after the // initial sync to avoid responding to messages before the bot was running. - client.set_event_handler(Box::new(CommandBot::new())).await; + client.register_event_handler(on_room_message).await; // since we called `sync_once` before we entered our sync loop we must pass // that sync token to `sync` diff --git a/matrix_sdk/examples/image_bot.rs b/matrix_sdk/examples/image_bot.rs index 7df95ffd..a895c549 100644 --- a/matrix_sdk/examples/image_bot.rs +++ b/matrix_sdk/examples/image_bot.rs @@ -8,56 +8,46 @@ use std::{ }; use matrix_sdk::{ - self, async_trait, + self, room::Room, ruma::events::{ room::message::{MessageEventContent, MessageType, TextMessageEventContent}, SyncMessageEvent, }, - Client, EventHandler, SyncSettings, + Client, SyncSettings, }; use tokio::sync::Mutex; use url::Url; -struct ImageBot { +async fn on_room_message( + event: SyncMessageEvent, + room: Room, image: Arc>, -} +) { + if let Room::Joined(room) = room { + let msg_body = if let SyncMessageEvent { + content: + MessageEventContent { + msgtype: MessageType::Text(TextMessageEventContent { body: msg_body, .. }), + .. + }, + .. + } = event + { + msg_body + } else { + return; + }; -impl ImageBot { - pub fn new(image: File) -> Self { - let image = Arc::new(Mutex::new(image)); - Self { image } - } -} + if msg_body.contains("!image") { + println!("sending image"); + let mut image = image.lock().await; -#[async_trait] -impl EventHandler for ImageBot { - async fn on_room_message(&self, room: Room, event: &SyncMessageEvent) { - if let Room::Joined(room) = room { - let msg_body = if let SyncMessageEvent { - content: - MessageEventContent { - msgtype: MessageType::Text(TextMessageEventContent { body: msg_body, .. }), - .. - }, - .. - } = event - { - msg_body - } else { - return; - }; + room.send_attachment("cat", &mime::IMAGE_JPEG, &mut *image, None).await.unwrap(); - if msg_body.contains("!image") { - println!("sending image"); - let mut image = self.image.lock().await; + image.seek(SeekFrom::Start(0)).unwrap(); - room.send_attachment("cat", &mime::IMAGE_JPEG, &mut *image, None).await.unwrap(); - - image.seek(SeekFrom::Start(0)).unwrap(); - - println!("message sent"); - } + println!("message sent"); } } } @@ -74,7 +64,9 @@ async fn login_and_sync( client.login(&username, &password, None, Some("command bot")).await?; client.sync_once(SyncSettings::default()).await.unwrap(); - client.set_event_handler(Box::new(ImageBot::new(image))).await; + + let image = Arc::new(Mutex::new(image)); + client.register_event_handler(move |ev, room| on_room_message(ev, room, image.clone())).await; let settings = SyncSettings::default().token(client.sync_token().await.unwrap()); client.sync(settings).await; diff --git a/matrix_sdk/examples/login.rs b/matrix_sdk/examples/login.rs index fa6ab054..b34e6e6f 100644 --- a/matrix_sdk/examples/login.rs +++ b/matrix_sdk/examples/login.rs @@ -1,36 +1,31 @@ use std::{env, process::exit}; use matrix_sdk::{ - self, async_trait, + self, room::Room, ruma::events::{ room::message::{MessageEventContent, MessageType, TextMessageEventContent}, SyncMessageEvent, }, - Client, EventHandler, SyncSettings, + Client, SyncSettings, }; use url::Url; -struct EventCallback; - -#[async_trait] -impl EventHandler for EventCallback { - async fn on_room_message(&self, room: Room, event: &SyncMessageEvent) { - if let Room::Joined(room) = room { - if let SyncMessageEvent { - content: - MessageEventContent { - msgtype: MessageType::Text(TextMessageEventContent { body: msg_body, .. }), - .. - }, - sender, - .. - } = event - { - let member = room.get_member(sender).await.unwrap().unwrap(); - let name = member.display_name().unwrap_or_else(|| member.user_id().as_str()); - println!("{}: {}", name, msg_body); - } +async fn on_room_message(event: SyncMessageEvent, room: Room) { + if let Room::Joined(room) = room { + if let SyncMessageEvent { + content: + MessageEventContent { + msgtype: MessageType::Text(TextMessageEventContent { body: msg_body, .. }), + .. + }, + sender, + .. + } = event + { + let member = room.get_member(&sender).await.unwrap().unwrap(); + let name = member.display_name().unwrap_or_else(|| member.user_id().as_str()); + println!("{}: {}", name, msg_body); } } } @@ -43,7 +38,7 @@ async fn login( let homeserver_url = Url::parse(&homeserver_url).expect("Couldn't parse the homeserver URL"); let client = Client::new(homeserver_url).unwrap(); - client.set_event_handler(Box::new(EventCallback)).await; + client.register_event_handler(on_room_message).await; client.login(username, password, None, Some("rust-sdk")).await?; client.sync(SyncSettings::new()).await; diff --git a/matrix_sdk/src/client.rs b/matrix_sdk/src/client.rs index b963fca5..fc2d976a 100644 --- a/matrix_sdk/src/client.rs +++ b/matrix_sdk/src/client.rs @@ -15,10 +15,15 @@ #[cfg(all(feature = "encryption", not(target_arch = "wasm32")))] use std::path::PathBuf; -#[cfg(feature = "encryption")] use std::{ collections::BTreeMap, - io::{Cursor, Write}, + fmt::{self, Debug}, + future::Future, + io::Read, + path::Path, + pin::Pin, + result::Result as StdResult, + sync::Arc, }; #[cfg(feature = "sso_login")] use std::{ @@ -26,16 +31,14 @@ use std::{ io::{Error as IoError, ErrorKind as IoErrorKind}, ops::Range, }; +#[cfg(feature = "encryption")] use std::{ - fmt::{self, Debug}, - future::Future, - io::Read, - path::Path, - result::Result as StdResult, - sync::Arc, + collections::HashSet, + io::{Cursor, Write}, }; use dashmap::DashMap; +use futures::FutureExt; use futures_timer::Delay as sleep; use http::HeaderValue; #[cfg(feature = "sso_login")] @@ -48,9 +51,9 @@ use matrix_sdk_base::crypto::{ ToDeviceRequest, }; #[cfg(feature = "encryption")] -use matrix_sdk_base::deserialized_responses::RoomEvent; +use matrix_sdk_base::{crypto::CrossSigningStatus, deserialized_responses::RoomEvent}; use matrix_sdk_base::{ - deserialized_responses::SyncResponse, + deserialized_responses::{JoinedRoom, LeftRoom, SyncResponse}, media::{MediaEventContent, MediaFormat, MediaRequest, MediaThumbnailSize, MediaType}, BaseClient, BaseClientConfig, Session, Store, }; @@ -59,15 +62,20 @@ use mime::{self, Mime}; use rand::{thread_rng, Rng}; use reqwest::header::InvalidHeaderValue; #[cfg(feature = "encryption")] -use ruma::events::{AnyMessageEvent, AnyRoomEvent, AnySyncMessageEvent}; -use ruma::{api::SendAccessToken, events::AnyMessageEventContent, MxcUri}; +use ruma::events::{AnyMessageEvent, AnyRoomEvent, AnySyncMessageEvent, EventType}; +use ruma::{ + api::{client::r0::push::get_notifications::Notification, SendAccessToken}, + events::AnyMessageEventContent, + MxcUri, +}; +use serde::de::DeserializeOwned; #[cfg(feature = "sso_login")] use tokio::{net::TcpListener, sync::oneshot}; #[cfg(feature = "sso_login")] use tokio_stream::wrappers::TcpListenerStream; #[cfg(feature = "encryption")] -use tracing::{debug, warn}; -use tracing::{error, info, instrument}; +use tracing::{debug, trace}; +use tracing::{error, info, instrument, warn}; use url::Url; #[cfg(feature = "sso_login")] use warp::Filter; @@ -134,15 +142,15 @@ use ruma::{ use crate::verification::QrVerification; #[cfg(feature = "encryption")] use crate::{ - device::{Device, UserDevices}, error::RoomKeyImportError, + identities::{Device, UserDevices}, verification::{SasVerification, Verification, VerificationRequest}, }; use crate::{ - error::HttpError, - event_handler::Handler, + error::{HttpError, HttpResult}, + event_handler::{EventHandler, EventHandlerData, EventHandlerResult, EventKind, SyncEvent}, http_client::{client_with_config, HttpClient, HttpSend}, - room, Error, EventHandler, Result, + room, Error, Result, }; const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(10); @@ -158,6 +166,14 @@ const SSO_SERVER_BIND_RANGE: Range = 20000..30000; #[cfg(feature = "sso_login")] const SSO_SERVER_BIND_TRIES: u8 = 10; +type EventHandlerFut = Pin + Send>>; +type EventHandlerFn = Box) -> EventHandlerFut + Send + Sync>; +type EventHandlerMap = BTreeMap<(EventKind, &'static str), Vec>; + +type NotificationHandlerFut = EventHandlerFut; +type NotificationHandlerFn = + Box NotificationHandlerFut + Send + Sync>; + /// An async/await enabled Matrix client. /// /// All of the state is held in an `Arc` so the `Client` can be cloned freely. @@ -178,13 +194,20 @@ pub struct Client { key_claim_lock: Arc>, pub(crate) members_request_locks: Arc>>>, pub(crate) typing_notice_times: Arc>, - /// Any implementor of EventHandler will act as the callbacks for various - /// events. - event_handler: Arc>>, + /// Event handlers. See `register_event_handler`. + pub(crate) event_handlers: Arc>, + /// Notification handlers. See `register_notification_handler`. + notification_handlers: Arc>>, /// Whether the client should operate in application service style mode. /// This is low-level functionality. For an high-level API check the /// `matrix_sdk_appservice` crate. appservice_mode: bool, + /// An event that can be listened on to wait for a successful sync. The + /// event will only be fired if a sync loop is running. Can be used for + /// synchronization, e.g. if we send out a request to create a room, we can + /// wait for the sync to get the data to fetch a room object from the state + /// store. + sync_beat: Arc, } #[cfg(not(tarpaulin_include))] @@ -559,13 +582,15 @@ impl Client { http_client, base_client, #[cfg(feature = "encryption")] - group_session_locks: Arc::new(DashMap::new()), + group_session_locks: Default::default(), #[cfg(feature = "encryption")] - key_claim_lock: Arc::new(Mutex::new(())), - members_request_locks: Arc::new(DashMap::new()), - typing_notice_times: Arc::new(DashMap::new()), - event_handler: Arc::new(RwLock::new(None)), + key_claim_lock: Default::default(), + members_request_locks: Default::default(), + typing_notice_times: Default::default(), + event_handlers: Default::default(), + notification_handlers: Default::default(), appservice_mode: config.appservice_mode, + sync_beat: event_listener::Event::new().into(), }) } @@ -629,7 +654,7 @@ impl Client { Ok(result) } - async fn discover_homeserver(&self) -> Result { + async fn discover_homeserver(&self) -> HttpResult { self.send(discover_homeserver::Request::new(), Some(RequestConfig::new().disable_retry())) .await } @@ -644,7 +669,7 @@ impl Client { *homeserver = homeserver_url; } - async fn get_supported_versions(&self) -> Result { + async fn get_supported_versions(&self) -> HttpResult { self.send( get_supported_versions::Request::new(), Some(RequestConfig::new().disable_retry()), @@ -668,12 +693,7 @@ impl Client { ) -> Result<()> { let txn_id = incoming_transaction.txn_id.clone(); let response = incoming_transaction.try_into_sync_response(txn_id)?; - let base_client = self.base_client.clone(); - let sync_response = base_client.receive_sync_response(response).await?; - - if let Some(handler) = self.event_handler.read().await.as_ref() { - handler.handle_sync(&sync_response).await; - } + self.process_sync(response).await?; Ok(()) } @@ -708,6 +728,16 @@ impl Client { self.base_client.olm_machine().await.map(|o| o.identity_keys().ed25519().to_owned()) } + /// Get all the tracked users we know about + /// + /// Tracked users are users for which we keep the device list of E2EE + /// capable devices up to date. + #[cfg(feature = "encryption")] + #[cfg_attr(feature = "docs", doc(cfg(encryption)))] + pub async fn tracked_users(&self) -> HashSet { + self.base_client.olm_machine().await.map(|o| o.tracked_users()).unwrap_or_default() + } + /// Fetches the display name of the owner of the client. /// /// # Example @@ -869,13 +899,125 @@ impl Client { Ok(()) } - /// Add `EventHandler` to `Client`. + /// Register a handler for a specific event type. /// - /// The methods of `EventHandler` are called when the respective - /// `RoomEvents` occur. - pub async fn set_event_handler(&self, handler: Box) { - let handler = Handler { inner: handler, client: self.clone() }; - *self.event_handler.write().await = Some(handler); + /// The handler is a function or closure with one or more arguments. The + /// first argument is the event itself. All additional arguments are + /// "context" arguments: They have to implement [`EventHandlerContext`]. + /// This trait is named that way because most of the types implementing it + /// give additional context about an event: The room it was in, its raw form + /// and other similar things. As an exception to this, + /// [`Client`] also implements the `EventHandlerContext` trait + /// so you don't have to clone your client into the event handler manually. + /// + /// Some context arguments are not universally applicable. A context + /// argument that isn't available for the given event type will result in + /// the event handler being skipped and an error being logged. The following + /// context argument types are only available for a subset of event types: + /// + /// * [`Room`][room::Room] is only available for room-specific events, i.e. + /// not for events like global account data events or presence events + /// + /// [`EventHandlerContext`]: crate::event_handler::EventHandlerContext + /// + /// # Examples + /// + /// ```no_run + /// # let client: matrix_sdk::Client = unimplemented!(); + /// use matrix_sdk::{ + /// room::Room, + /// ruma::{ + /// events::{ + /// macros::EventContent, + /// push_rules::PushRulesEvent, + /// room::{message::MessageEventContent, topic::TopicEventContent}, + /// SyncMessageEvent, SyncStateEvent, + /// }, + /// Int, MilliSecondsSinceUnixEpoch, + /// }, + /// Client, + /// }; + /// use serde::{Deserialize, Serialize}; + /// + /// # let _ = async { + /// client + /// .register_event_handler( + /// |ev: SyncMessageEvent, room: Room, client: Client| async move { + /// // Common usage: Room event plus room and client. + /// }, + /// ) + /// .await + /// .register_event_handler(|ev: SyncStateEvent| async move { + /// // Also possible: Omit any or all arguments after the first. + /// }) + /// .await; + /// + /// // Custom events work exactly the same way, you just need to declare the content struct and + /// // use the EventContent derive macro on it. + /// #[derive(Clone, Debug, Deserialize, Serialize, EventContent)] + /// #[ruma_event(type = "org.shiny_new_2fa.token", kind = Message)] + /// struct TokenEventContent { + /// token: String, + /// #[serde(rename = "exp")] + /// expires_at: MilliSecondsSinceUnixEpoch, + /// } + /// + /// client.register_event_handler( + /// |ev: SyncMessageEvent, room: Room| async move { + /// todo!("Display the token"); + /// }, + /// ).await; + /// # }; + /// ``` + pub async fn register_event_handler(&self, handler: H) -> &Self + where + Ev: SyncEvent + DeserializeOwned + Send + 'static, + H: EventHandler, + ::Output: EventHandlerResult, + { + let event_type = H::ID.1; + self.event_handlers.write().await.entry(H::ID).or_default().push(Box::new(move |data| { + let maybe_fut = serde_json::from_str(data.raw.get()) + .map(|ev| handler.clone().handle_event(ev, data)); + + async move { + match maybe_fut { + Ok(Some(fut)) => { + fut.await.print_error(event_type); + } + Ok(None) => { + error!("Event handler for {} has an invalid context argument", event_type); + } + Err(e) => { + warn!( + "Failed to deserialize `{}` event, skipping event handler.\n\ + Deserialization error: {}", + event_type, e, + ); + } + } + } + .boxed() + })); + + self + } + + /// Register a handler for a notification. + /// + /// Similar to `.register_event_handler`, but only allows functions or + /// closures with exactly the three arguments `Notification`, `room::Room`, + /// `Client` for now. + pub async fn register_notification_handler(&self, handler: H) -> &Self + where + H: Fn(Notification, room::Room, Client) -> Fut + Send + Sync + 'static, + Fut: Future + Send + 'static, + { + self.notification_handlers.write().await.push(Box::new( + move |notification, room, client| (handler)(notification, room, client).boxed(), + )); + + self } /// Get all the rooms the client knows about. @@ -990,7 +1132,7 @@ impl Client { /// /// This should be the first step when trying to login so you can call the /// appropriate method for the next step. - pub async fn get_login_types(&self) -> Result { + pub async fn get_login_types(&self) -> HttpResult { let request = get_login_types::Request::new(); self.send(request, None).await } @@ -1407,7 +1549,7 @@ impl Client { pub async fn register( &self, registration: impl Into>, - ) -> Result { + ) -> HttpResult { info!("Registering to {}", self.homeserver().await); let config = if self.appservice_mode { @@ -1497,7 +1639,7 @@ impl Client { /// # Arguments /// /// * `room_id` - The `RoomId` of the room to be joined. - pub async fn join_room_by_id(&self, room_id: &RoomId) -> Result { + pub async fn join_room_by_id(&self, room_id: &RoomId) -> HttpResult { let request = join_room_by_id::Request::new(room_id); self.send(request, None).await } @@ -1515,7 +1657,7 @@ impl Client { &self, alias: &RoomIdOrAliasId, server_names: &[Box], - ) -> Result { + ) -> HttpResult { let request = assign!(join_room_by_id_or_alias::Request::new(alias), { server_name: server_names, }); @@ -1558,7 +1700,7 @@ impl Client { limit: Option, since: Option<&str>, server: Option<&ServerName>, - ) -> Result { + ) -> HttpResult { let limit = limit.map(UInt::from); let request = assign!(get_public_rooms::Request::new(), { @@ -1599,7 +1741,7 @@ impl Client { pub async fn create_room( &self, room: impl Into>, - ) -> Result { + ) -> HttpResult { let request = room.into(); self.send(request, None).await } @@ -1639,7 +1781,7 @@ impl Client { pub async fn public_rooms_filtered( &self, room_search: impl Into>, - ) -> Result { + ) -> HttpResult { let request = room_search.into(); self.send(request, None).await } @@ -1773,7 +1915,7 @@ impl Client { let txn_id = txn_id.unwrap_or_else(Uuid::new_v4).to_string(); let request = send_message_event::Request::new(room_id, &txn_id, &content); - self.send(request, None).await + Ok(self.send(request, None).await?) } } @@ -1821,7 +1963,7 @@ impl Client { &self, request: Request, config: Option, - ) -> Result + ) -> HttpResult where Request: OutgoingRequest + Debug, HttpError: From>, @@ -1833,8 +1975,9 @@ impl Client { pub(crate) async fn send_to_device( &self, request: &ToDeviceRequest, - ) -> Result { + ) -> HttpResult { let txn_id_string = request.txn_id_string(); + let request = RumaToDeviceRequest::new_raw( request.event_type.as_str(), &txn_id_string, @@ -1867,7 +2010,7 @@ impl Client { /// } /// # }); /// ``` - pub async fn devices(&self) -> Result { + pub async fn devices(&self) -> HttpResult { let request = get_devices::Request::new(); self.send(request, None).await @@ -1924,7 +2067,7 @@ impl Client { &self, devices: &[DeviceIdBox], auth_data: Option>, - ) -> Result { + ) -> HttpResult { let mut request = delete_devices::Request::new(devices); request.auth = auth_data; @@ -1959,13 +2102,91 @@ impl Client { ); let response = self.send(request, Some(request_config)).await?; - let sync_response = self.base_client.receive_sync_response(response).await?; + self.process_sync(response).await + } - if let Some(handler) = self.event_handler.read().await.as_ref() { - handler.handle_sync(&sync_response).await; + async fn process_sync(&self, response: sync_events::Response) -> Result { + let response = self.base_client.receive_sync_response(response).await?; + let SyncResponse { + next_batch: _, + rooms, + presence, + account_data, + to_device: _, + device_lists: _, + device_one_time_keys_count: _, + ambiguity_changes: _, + notifications, + } = &response; + + self.handle_sync_events(EventKind::GlobalAccountData, &None, &account_data.events).await?; + self.handle_sync_events(EventKind::Presence, &None, &presence.events).await?; + + for (room_id, room_info) in &rooms.join { + let room = self.get_room(room_id); + if room.is_none() { + error!("Can't call event handler, room {} not found", room_id); + continue; + } + + let JoinedRoom { unread_notifications: _, timeline, state, account_data, ephemeral } = + room_info; + + self.handle_sync_events(EventKind::EphemeralRoomData, &room, &ephemeral.events).await?; + self.handle_sync_events(EventKind::RoomAccountData, &room, &account_data.events) + .await?; + self.handle_sync_state_events(&room, &state.events).await?; + self.handle_sync_timeline_events(&room, &timeline.events).await?; } - Ok(sync_response) + for (room_id, room_info) in &rooms.leave { + let room = self.get_room(room_id); + if room.is_none() { + error!("Can't call event handler, room {} not found", room_id); + continue; + } + + let LeftRoom { timeline, state, account_data } = room_info; + + self.handle_sync_events(EventKind::RoomAccountData, &room, &account_data.events) + .await?; + self.handle_sync_state_events(&room, &state.events).await?; + self.handle_sync_timeline_events(&room, &timeline.events).await?; + } + + for (room_id, room_info) in &rooms.invite { + let room = self.get_room(room_id); + if room.is_none() { + error!("Can't call event handler, room {} not found", room_id); + continue; + } + + // FIXME: Destructure room_info + self.handle_sync_events(EventKind::InitialState, &room, &room_info.invite_state.events) + .await?; + } + + for handler in &*self.notification_handlers.read().await { + for (room_id, room_notifications) in notifications { + let room = match self.get_room(room_id) { + Some(room) => room, + None => { + warn!("Can't call notification handler, room {} not found", room_id); + continue; + } + }; + + for notification in room_notifications { + matrix_sdk_common::executor::spawn((handler)( + notification.clone(), + room.clone(), + self.clone(), + )); + } + } + } + + Ok(response) } /// Repeatedly call sync to synchronize the client state with the server. @@ -2153,6 +2374,87 @@ impl Client { sync_settings.token = Some(self.sync_token().await.expect("No sync token found after initial sync")); + + self.sync_beat.notify(usize::MAX); + } + } + + #[cfg(feature = "encryption")] + #[cfg_attr(feature = "docs", doc(cfg(encryption)))] + async fn send_account_data( + &self, + content: ruma::events::AnyGlobalAccountDataEventContent, + ) -> Result { + let own_user = + self.user_id().await.ok_or_else(|| Error::from(HttpError::AuthenticationRequired))?; + let data = serde_json::value::to_raw_value(&content)?; + + let request = ruma::api::client::r0::config::set_global_account_data::Request::new( + &data, + ruma::events::EventContent::event_type(&content), + &own_user, + ); + + Ok(self.send(request, None).await?) + } + + #[cfg(feature = "encryption")] + #[cfg_attr(feature = "docs", doc(cfg(encryption)))] + pub(crate) async fn create_dm_room(&self, user_id: UserId) -> Result> { + use ruma::{ + api::client::r0::room::create_room::RoomPreset, + events::AnyGlobalAccountDataEventContent, + }; + + const SYNC_WAIT_TIME: Duration = Duration::from_secs(3); + + // First we create the DM room, where we invite the user and tell the + // invitee that the room should be a DM. + let invite = &[user_id.clone()]; + + let request = assign!( + ruma::api::client::r0::room::create_room::Request::new(), + { + invite, + is_direct: true, + preset: Some(RoomPreset::TrustedPrivateChat), + } + ); + + let response = self.send(request, None).await?; + + // Now we need to mark the room as a DM for ourselves, we fetch the + // existing `m.direct` event and append the room to the list of DMs we + // have with this user. + let mut content = self + .store() + .get_account_data_event(EventType::Direct) + .await? + .map(|e| e.deserialize()) + .transpose()? + .and_then(|e| { + if let AnyGlobalAccountDataEventContent::Direct(c) = e.content() { + Some(c) + } else { + None + } + }) + .unwrap_or_else(|| ruma::events::direct::DirectEventContent(BTreeMap::new())); + + content.entry(user_id.to_owned()).or_default().push(response.room_id.to_owned()); + + // TODO We should probably save the fact that we need to send this out + // because otherwise we might end up in a state where we have a DM that + // isn't marked as one. + self.send_account_data(AnyGlobalAccountDataEventContent::Direct(content)).await?; + + // If the room is already in our store, fetch it, otherwise wait for a + // sync to be done which should put the room into our store. + if let Some(room) = self.get_joined_room(&response.room_id) { + Ok(Some(room)) + } else { + self.sync_beat.listen().wait_timeout(SYNC_WAIT_TIME); + Ok(self.get_joined_room(&response.room_id)) } } @@ -2296,7 +2598,7 @@ impl Client { /// /// println!("{:?}", device.verified()); /// - /// let verification = device.start_verification().await.unwrap(); + /// let verification = device.request_verification().await.unwrap(); /// # }); /// ``` #[cfg(feature = "encryption")] @@ -2311,6 +2613,89 @@ impl Client { Ok(device.map(|d| Device { inner: d, client: self.clone() })) } + /// Get a E2EE identity of an user. + /// + /// # Arguments + /// + /// * `user_id` - The unique id of the user that the identity belongs to. + /// + /// Returns a `UserIdentity` if one is found and the crypto store + /// didn't throw an error. + /// + /// This will always return None if the client hasn't been logged in. + /// + /// # Example + /// + /// ```no_run + /// # use std::convert::TryFrom; + /// # use matrix_sdk::{Client, ruma::UserId}; + /// # use url::Url; + /// # use futures::executor::block_on; + /// # let alice = UserId::try_from("@alice:example.org").unwrap(); + /// # let homeserver = Url::parse("http://example.com").unwrap(); + /// # let client = Client::new(homeserver).unwrap(); + /// # block_on(async { + /// let user = client.get_user_identity(&alice).await?; + /// + /// if let Some(user) = user { + /// println!("{:?}", user.verified()); + /// + /// let verification = user.request_verification().await?; + /// } + /// # anyhow::Result::<()>::Ok(()) }); + /// ``` + #[cfg(feature = "encryption")] + #[cfg_attr(feature = "docs", doc(cfg(encryption)))] + pub async fn get_user_identity( + &self, + user_id: &UserId, + ) -> StdResult, CryptoStoreError> { + use crate::identities::UserIdentity; + + if let Some(olm) = self.base_client.olm_machine().await { + let identity = olm.get_identity(user_id).await?; + + Ok(identity.map(|i| match i { + matrix_sdk_base::crypto::UserIdentities::Own(i) => { + UserIdentity::new_own(self.clone(), i) + } + matrix_sdk_base::crypto::UserIdentities::Other(i) => { + UserIdentity::new(self.clone(), i, self.get_dm_room(user_id)) + } + })) + } else { + Ok(None) + } + } + + /// Get the status of the private cross signing keys. + /// + /// This can be used to check which private cross signing keys we have + /// stored locally. + #[cfg(feature = "encryption")] + #[cfg_attr(feature = "docs", doc(cfg(encryption)))] + pub async fn cross_signing_status(&self) -> Option { + if let Some(machine) = self.base_client.olm_machine().await { + Some(machine.cross_signing_status().await) + } else { + None + } + } + + #[cfg(feature = "encryption")] + #[cfg_attr(feature = "docs", doc(cfg(encryption)))] + fn get_dm_room(&self, user_id: &UserId) -> Option { + let rooms = self.joined_rooms(); + let room_pairs: Vec<_> = + rooms.iter().map(|r| (r.room_id().to_owned(), r.direct_target())).collect(); + trace!(rooms =? room_pairs, "Finding direct room"); + + let room = rooms.into_iter().find(|r| r.direct_target().as_ref() == Some(user_id)); + + trace!(room =? room, "Found room"); + room + } + /// Create and upload a new cross signing identity. /// /// # Arguments @@ -2741,7 +3126,7 @@ impl Client { } /// Gets information about the owner of a given access token. - pub async fn whoami(&self) -> Result { + pub async fn whoami(&self) -> HttpResult { let request = whoami::Request::new(); self.send(request, None).await } @@ -2769,8 +3154,10 @@ mod test { use std::{ collections::BTreeMap, convert::{TryFrom, TryInto}, + future, io::Cursor, str::FromStr, + sync::Arc, time::Duration, }; @@ -2800,17 +3187,18 @@ mod test { event_id, events::{ room::{ + member::MemberEventContent, message::{ImageMessageEventContent, MessageEventContent}, ImageInfo, }, - AnyMessageEventContent, AnySyncStateEvent, EventType, + AnyMessageEventContent, AnySyncStateEvent, EventType, SyncStateEvent, }, mxc_uri, room_id, thirdparty, uint, user_id, UserId, }; use serde_json::json; use super::{Client, Session, SyncSettings, Url}; - use crate::{ClientConfig, HttpError, RequestConfig, RoomMember}; + use crate::{room, ClientConfig, HttpError, RequestConfig, RoomMember}; async fn logged_in_client() -> Client { let session = Session { @@ -3070,6 +3458,56 @@ mod test { // assert_eq!(1, ignored_users.len()) } + #[tokio::test] + async fn event_handler() { + use std::sync::atomic::{AtomicU8, Ordering::SeqCst}; + + let client = logged_in_client().await; + + let member_count = Arc::new(AtomicU8::new(0)); + let typing_count = Arc::new(AtomicU8::new(0)); + let power_levels_count = Arc::new(AtomicU8::new(0)); + + client + .register_event_handler({ + let member_count = member_count.clone(); + move |_ev: SyncStateEvent, _room: room::Room| { + member_count.fetch_add(1, SeqCst); + future::ready(()) + } + }) + .await + .register_event_handler({ + let typing_count = typing_count.clone(); + move |_ev: SyncStateEvent| { + typing_count.fetch_add(1, SeqCst); + future::ready(()) + } + }) + .await + .register_event_handler({ + let power_levels_count = power_levels_count.clone(); + move |_ev: SyncStateEvent, + _client: Client, + _room: room::Room| { + power_levels_count.fetch_add(1, SeqCst); + future::ready(()) + } + }) + .await; + + let response = EventBuilder::default() + .add_room_event(EventsJson::Member) + .add_ephemeral(EventsJson::Typing) + .add_state_event(EventsJson::PowerLevels) + .build_sync_response(); + client.process_sync(response).await.unwrap(); + + assert_eq!(member_count.load(SeqCst), 1); + assert_eq!(typing_count.load(SeqCst), 1); + assert_eq!(power_levels_count.load(SeqCst), 1); + } + #[tokio::test] async fn room_creation() { let client = logged_in_client().await; @@ -3137,12 +3575,8 @@ mod test { }); if let Err(err) = client.register(user).await { - if let crate::Error::Http(HttpError::UiaaError(FromHttpResponseError::Http( - ServerError::Known(UiaaResponse::MatrixError(client_api::Error { - kind, - message, - status_code, - })), + if let HttpError::UiaaError(FromHttpResponseError::Http(ServerError::Known( + UiaaResponse::MatrixError(client_api::Error { kind, message, status_code }), ))) = err { if let client_api::error::ErrorKind::Forbidden = kind { diff --git a/matrix_sdk/src/error.rs b/matrix_sdk/src/error.rs index 6ebb9716..3a794008 100644 --- a/matrix_sdk/src/error.rs +++ b/matrix_sdk/src/error.rs @@ -40,6 +40,9 @@ use url::ParseError as UrlParseError; /// Result type of the rust-sdk. pub type Result = std::result::Result; +/// Result type of a pure HTTP request. +pub type HttpResult = std::result::Result; + /// An HTTP error, representing either a connection error or an error while /// converting the raw HTTP response into a Matrix response. #[derive(Error, Debug)] @@ -182,6 +185,30 @@ pub enum RoomKeyImportError { Export(#[from] KeyExportError), } +impl HttpError { + /// Try to destructure the error into an universal interactive auth info. + /// + /// Some requests require universal interactive auth, doing such a request + /// will always fail the first time with a 401 status code, the response + /// body will contain info how the client can authenticate. + /// + /// The request will need to be retried, this time containing additional + /// authentication data. + /// + /// This method is an convenience method to get to the info the server + /// returned on the first, failed request. + pub fn uiaa_response(&self) -> Option<&UiaaInfo> { + if let HttpError::UiaaError(FromHttpResponseError::Http(ServerError::Known( + UiaaError::AuthResponse(i), + ))) = self + { + Some(i) + } else { + None + } + } +} + impl Error { /// Try to destructure the error into an universal interactive auth info. /// diff --git a/matrix_sdk/src/event_handler.rs b/matrix_sdk/src/event_handler.rs new file mode 100644 index 00000000..60799eab --- /dev/null +++ b/matrix_sdk/src/event_handler.rs @@ -0,0 +1,439 @@ +// Copyright 2021 Jonas Platte +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Types and traits related for event handlers. For usage, see +//! [`Client::register_event_handler`]. +//! +//! ### How it works +//! +//! The `register_event_handler` method registers event handlers of different +//! signatures by actually storing boxed closures that all have the same +//! signature of `async (EventHandlerData) -> ()` where `EventHandlerData` is a +//! private type that contains all of the data an event handler *might* need. +//! +//! The stored closure takes care of deserializing the event which the +//! `EventHandlerData` contains as a (borrowed) [`serde_json::value::RawValue`], +//! extracing the context arguments from other fields of `EventHandlerData` and +//! calling / `.await`ing the event handler if the previous steps succeeded. +//! It also logs any errors from the above chain of function calls. +//! +//! For more details, see the [`EventHandler`] trait. + +use std::{borrow::Cow, future::Future, ops::Deref}; + +use matrix_sdk_base::deserialized_responses::SyncRoomEvent; +use ruma::{events::AnySyncStateEvent, serde::Raw}; +use serde::Deserialize; +use serde_json::value::RawValue as RawJsonValue; + +use crate::{room, Client}; + +#[doc(hidden)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] +pub enum EventKind { + GlobalAccountData, + RoomAccountData, + EphemeralRoomData, + Message { redacted: bool }, + State { redacted: bool }, + StrippedState { redacted: bool }, + InitialState, + ToDevice, + Presence, +} + +/// A statically-known event kind/type that can be retrieved from an event sync. +pub trait SyncEvent { + #[doc(hidden)] + const ID: (EventKind, &'static str); +} + +/// Interface for event handlers. +/// +/// This trait is an abstraction for a certain kind of functions / closures, +/// specifically: +/// +/// * They must have at least one argument, which is the event itself, a type +/// that implements [`SyncEvent`]. Any additional arguments need to implement +/// the [`EventHandlerContext`] trait. +/// * Their return type has to be one of: `()`, `Result<(), impl +/// std::error::Error>` or `anyhow::Result<()>` (requires the `anyhow` Cargo +/// feature to be enabled) +/// +/// ### How it works +/// +/// This trait is basically a very constrained version of `Fn`: It requires at +/// least one argument, which is represented as its own generic parameter `Ev` +/// with the remaining parameter types being represented by the second generic +/// parameter `Ctx`; they have to be stuffed into one generic parameter as a +/// tuple because Rust doesn't have variadic generics. +/// +/// `Ev` and `Ctx` are generic parameters rather than associated types because +/// the argument list is a generic parameter for the `Fn` traits too, so a +/// single type could implement `Fn` multiple times with different argument +/// listsΒΉ. Luckily, when calling [`Client::register_event_handler`] with a +/// closure argument the trait solver takes into account that only a single one +/// of the implementations applies (even though this could theoretically change +/// through a dependency upgrade) and uses that rather than raising an ambiguity +/// error. This is the same trick used by web frameworks like actix-web and +/// axum. +/// +/// ΒΉ the only thing stopping such types from existing in stable Rust is that +/// all manual implementations of the `Fn` traits require a Nightly feature +pub trait EventHandler: Clone + Send + Sync + 'static { + /// The future returned by `handle_event`. + #[doc(hidden)] + type Future: Future + Send + 'static; + + /// The event type being handled, for example a message event of type + /// `m.room.message`. + #[doc(hidden)] + const ID: (EventKind, &'static str); + + /// Create a future for handling the given event. + /// + /// `data` provides additional data about the event, for example the room it + /// appeared in. + /// + /// Returns `None` if one of the context extractors failed. + #[doc(hidden)] + fn handle_event(&self, ev: Ev, data: EventHandlerData<'_>) -> Option; +} + +#[doc(hidden)] +#[derive(Debug)] +pub struct EventHandlerData<'a> { + pub client: Client, + pub room: Option, + pub raw: &'a RawJsonValue, +} + +/// Context for an event handler. +/// +/// This trait defines the set of types that may be used as additional arguments +/// in event handler functions after the event itself. +pub trait EventHandlerContext: Sized { + #[doc(hidden)] + fn from_data(_: &EventHandlerData<'_>) -> Option; +} + +impl EventHandlerContext for Client { + fn from_data(data: &EventHandlerData<'_>) -> Option { + Some(data.client.clone()) + } +} + +/// This event handler context argument is only applicable to room-specific +/// events. +/// +/// Trying to use it in the event handler for another event, for example a +/// global account data or presence event, will result in the event handler +/// being skipped and an error getting logged. +impl EventHandlerContext for room::Room { + fn from_data(data: &EventHandlerData<'_>) -> Option { + data.room.clone() + } +} + +/// The raw JSON form of an event. +/// +/// Used as a context argument for event handlers (see +/// [`Client::register_event_handler`]). +// FIXME: This could be made to not own the raw JSON value with some changes to +// the traits above, but only with GATs. +#[derive(Clone, Debug)] +pub struct RawEvent(pub Box); + +impl Deref for RawEvent { + type Target = RawJsonValue; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl EventHandlerContext for RawEvent { + fn from_data(data: &EventHandlerData<'_>) -> Option { + Some(Self(data.raw.to_owned())) + } +} + +/// Return types supported for event handlers implement this trait. +/// +/// It is not meant to be implemented outside of matrix-sdk. +pub trait EventHandlerResult: Sized { + #[doc(hidden)] + fn print_error(&self, event_type: &str); +} + +impl EventHandlerResult for () { + fn print_error(&self, _event_type: &str) {} +} + +impl EventHandlerResult for Result<(), E> { + fn print_error(&self, event_type: &str) { + if let Err(e) = self { + tracing::error!("Event handler for `{}` failed: {}", event_type, e); + } + } +} + +#[cfg(feature = "anyhow")] +impl EventHandlerResult for anyhow::Result<()> { + fn print_error(&self, event_type: &str) { + if let Err(e) = self { + tracing::error!("Event handler for `{}` failed: {:?}", event_type, e); + } + } +} + +#[derive(Deserialize)] +struct UnsignedDetails { + redacted_because: Option, +} + +/// Event handling internals. +impl Client { + pub(crate) async fn handle_sync_events( + &self, + kind: EventKind, + room: &Option, + events: &[Raw], + ) -> serde_json::Result<()> { + self.handle_sync_events_wrapped(kind, room, events, |x| x).await + } + + pub(crate) async fn handle_sync_state_events( + &self, + room: &Option, + state_events: &[Raw], + ) -> serde_json::Result<()> { + #[derive(Deserialize)] + struct StateEventDetails<'a> { + #[serde(borrow, rename = "type")] + event_type: Cow<'a, str>, + unsigned: Option, + } + + self.handle_sync_events_wrapped_with(room, state_events, std::convert::identity, |raw| { + let StateEventDetails { event_type, unsigned } = raw.deserialize_as()?; + let redacted = unsigned.and_then(|u| u.redacted_because).is_some(); + Ok((EventKind::State { redacted }, event_type)) + }) + .await + } + + pub(crate) async fn handle_sync_timeline_events( + &self, + room: &Option, + timeline_events: &[SyncRoomEvent], + ) -> serde_json::Result<()> { + // FIXME: add EncryptionInfo to context + #[derive(Deserialize)] + struct TimelineEventDetails<'a> { + #[serde(borrow, rename = "type")] + event_type: Cow<'a, str>, + state_key: Option, + unsigned: Option, + } + + self.handle_sync_events_wrapped_with( + room, + timeline_events, + |e| &e.event, + |raw| { + let TimelineEventDetails { event_type, state_key, unsigned } = + raw.deserialize_as()?; + + let redacted = unsigned.and_then(|u| u.redacted_because).is_some(); + let kind = match state_key { + Some(_) => EventKind::State { redacted }, + None => EventKind::Message { redacted }, + }; + + Ok((kind, event_type)) + }, + ) + .await + } + + async fn handle_sync_events_wrapped<'a, T: 'a, U: 'a>( + &self, + kind: EventKind, + room: &Option, + events: &'a [U], + get_event: impl Fn(&'a U) -> &'a Raw, + ) -> Result<(), serde_json::Error> { + #[derive(Deserialize)] + struct ExtractType<'a> { + #[serde(borrow, rename = "type")] + event_type: Cow<'a, str>, + } + + self.handle_sync_events_wrapped_with(room, events, get_event, |raw| { + Ok((kind, raw.deserialize_as::()?.event_type)) + }) + .await + } + + async fn handle_sync_events_wrapped_with<'a, T: 'a, U: 'a>( + &self, + room: &Option, + list: &'a [U], + get_event: impl Fn(&'a U) -> &'a Raw, + get_id: impl Fn(&Raw) -> serde_json::Result<(EventKind, Cow<'_, str>)>, + ) -> serde_json::Result<()> { + for x in list { + let event = get_event(x); + let (ev_kind, ev_type) = get_id(event)?; + let event_handler_id = (ev_kind, &*ev_type); + + if let Some(handlers) = self.event_handlers.read().await.get(&event_handler_id) { + for handler in &*handlers { + let data = EventHandlerData { + client: self.clone(), + room: room.clone(), + raw: event.json(), + }; + matrix_sdk_common::executor::spawn((handler)(data)); + } + } + } + + Ok(()) + } +} + +macro_rules! impl_event_handler { + ($($ty:ident),* $(,)?) => { + impl EventHandler for Fun + where + Ev: SyncEvent, + Fun: Fn(Ev, $($ty),*) -> Fut + Clone + Send + Sync + 'static, + Fut: Future + Send + 'static, + Fut::Output: EventHandlerResult, + $($ty: EventHandlerContext),* + { + type Future = Fut; + const ID: (EventKind, &'static str) = Ev::ID; + + fn handle_event(&self, ev: Ev, _d: EventHandlerData<'_>) -> Option { + Some((self)(ev, $($ty::from_data(&_d)?),*)) + } + } + }; +} + +impl_event_handler!(); +impl_event_handler!(A); +impl_event_handler!(A, B); +impl_event_handler!(A, B, C); +impl_event_handler!(A, B, C, D); +impl_event_handler!(A, B, C, D, E); +impl_event_handler!(A, B, C, D, E, F); +impl_event_handler!(A, B, C, D, E, F, G); +impl_event_handler!(A, B, C, D, E, F, G, H); + +mod static_events { + use ruma::events::{ + self, + presence::{PresenceEvent, PresenceEventContent}, + StaticEventContent, + }; + + use super::{EventKind, SyncEvent}; + + impl SyncEvent for events::GlobalAccountDataEvent + where + C: StaticEventContent + events::GlobalAccountDataEventContent, + { + const ID: (EventKind, &'static str) = (EventKind::GlobalAccountData, C::TYPE); + } + + impl SyncEvent for events::RoomAccountDataEvent + where + C: StaticEventContent + events::RoomAccountDataEventContent, + { + const ID: (EventKind, &'static str) = (EventKind::RoomAccountData, C::TYPE); + } + + impl SyncEvent for events::SyncEphemeralRoomEvent + where + C: StaticEventContent + events::EphemeralRoomEventContent, + { + const ID: (EventKind, &'static str) = (EventKind::EphemeralRoomData, C::TYPE); + } + + impl SyncEvent for events::SyncMessageEvent + where + C: StaticEventContent + events::MessageEventContent, + { + const ID: (EventKind, &'static str) = (EventKind::Message { redacted: false }, C::TYPE); + } + + impl SyncEvent for events::SyncStateEvent + where + C: StaticEventContent + events::StateEventContent, + { + const ID: (EventKind, &'static str) = (EventKind::State { redacted: false }, C::TYPE); + } + + impl SyncEvent for events::StrippedStateEvent + where + C: StaticEventContent + events::StateEventContent, + { + const ID: (EventKind, &'static str) = + (EventKind::StrippedState { redacted: false }, C::TYPE); + } + + impl SyncEvent for events::InitialStateEvent + where + C: StaticEventContent + events::StateEventContent, + { + const ID: (EventKind, &'static str) = (EventKind::InitialState, C::TYPE); + } + + impl SyncEvent for events::ToDeviceEvent + where + C: StaticEventContent + events::ToDeviceEventContent, + { + const ID: (EventKind, &'static str) = (EventKind::ToDevice, C::TYPE); + } + + impl SyncEvent for PresenceEvent { + const ID: (EventKind, &'static str) = (EventKind::Presence, PresenceEventContent::TYPE); + } + + impl SyncEvent for events::RedactedSyncMessageEvent + where + C: StaticEventContent + events::RedactedMessageEventContent, + { + const ID: (EventKind, &'static str) = (EventKind::Message { redacted: true }, C::TYPE); + } + + impl SyncEvent for events::RedactedSyncStateEvent + where + C: StaticEventContent + events::RedactedStateEventContent, + { + const ID: (EventKind, &'static str) = (EventKind::State { redacted: true }, C::TYPE); + } + + impl SyncEvent for events::RedactedStrippedStateEvent + where + C: StaticEventContent + events::RedactedStateEventContent, + { + const ID: (EventKind, &'static str) = + (EventKind::StrippedState { redacted: true }, C::TYPE); + } +} diff --git a/matrix_sdk/src/event_handler/mod.rs b/matrix_sdk/src/event_handler/mod.rs deleted file mode 100644 index e7cf2105..00000000 --- a/matrix_sdk/src/event_handler/mod.rs +++ /dev/null @@ -1,964 +0,0 @@ -// Copyright 2020 Damir JeliΔ‡ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -use std::ops::Deref; - -use matrix_sdk_base::{hoist_and_deserialize_state_event, hoist_room_event_prev_content}; -use matrix_sdk_common::async_trait; -use ruma::{ - api::client::r0::push::get_notifications::Notification, - events::{ - call::{ - answer::AnswerEventContent, candidates::CandidatesEventContent, - hangup::HangupEventContent, invite::InviteEventContent, - }, - custom::CustomEventContent, - fully_read::FullyReadEventContent, - ignored_user_list::IgnoredUserListEventContent, - presence::PresenceEvent, - push_rules::PushRulesEventContent, - reaction::ReactionEventContent, - receipt::ReceiptEventContent, - room::{ - aliases::AliasesEventContent, - avatar::AvatarEventContent, - canonical_alias::CanonicalAliasEventContent, - join_rules::JoinRulesEventContent, - member::MemberEventContent, - message::{feedback::FeedbackEventContent, MessageEventContent as MsgEventContent}, - name::NameEventContent, - power_levels::PowerLevelsEventContent, - redaction::SyncRedactionEvent, - tombstone::TombstoneEventContent, - }, - typing::TypingEventContent, - AnyGlobalAccountDataEvent, AnyRoomAccountDataEvent, AnyStrippedStateEvent, - AnySyncEphemeralRoomEvent, AnySyncMessageEvent, AnySyncRoomEvent, AnySyncStateEvent, - GlobalAccountDataEvent, RoomAccountDataEvent, StrippedStateEvent, SyncEphemeralRoomEvent, - SyncMessageEvent, SyncStateEvent, - }, - serde::Raw, - RoomId, -}; -use serde_json::value::RawValue as RawJsonValue; - -use crate::{deserialized_responses::SyncResponse, room::Room, Client}; - -pub(crate) struct Handler { - pub(crate) inner: Box, - pub(crate) client: Client, -} - -impl Deref for Handler { - type Target = dyn EventHandler; - - fn deref(&self) -> &Self::Target { - &*self.inner - } -} - -impl Handler { - fn get_room(&self, room_id: &RoomId) -> Option { - self.client.get_room(room_id) - } - - pub(crate) async fn handle_sync(&self, response: &SyncResponse) { - for event in response.account_data.events.iter().filter_map(|e| e.deserialize().ok()) { - self.handle_account_data_event(&event).await; - } - - for (room_id, room_info) in &response.rooms.join { - if let Some(room) = self.get_room(room_id) { - for event in room_info.ephemeral.events.iter().filter_map(|e| e.deserialize().ok()) - { - self.handle_ephemeral_event(room.clone(), &event).await; - } - - for event in - room_info.account_data.events.iter().filter_map(|e| e.deserialize().ok()) - { - self.handle_room_account_data_event(room.clone(), &event).await; - } - - for (raw_event, event) in room_info.state.events.iter().filter_map(|e| { - if let Ok(d) = hoist_and_deserialize_state_event(e) { - Some((e, d)) - } else { - None - } - }) { - self.handle_state_event(room.clone(), &event, raw_event).await; - } - - for (raw_event, event) in room_info.timeline.events.iter().filter_map(|e| { - if let Ok(d) = hoist_room_event_prev_content(&e.event) { - Some((&e.event, d)) - } else { - None - } - }) { - self.handle_timeline_event(room.clone(), &event, raw_event).await; - } - } - } - - for (room_id, room_info) in &response.rooms.leave { - if let Some(room) = self.get_room(room_id) { - for event in - room_info.account_data.events.iter().filter_map(|e| e.deserialize().ok()) - { - self.handle_room_account_data_event(room.clone(), &event).await; - } - - for (raw_event, event) in room_info.state.events.iter().filter_map(|e| { - if let Ok(d) = hoist_and_deserialize_state_event(e) { - Some((e, d)) - } else { - None - } - }) { - self.handle_state_event(room.clone(), &event, raw_event).await; - } - - for (raw_event, event) in room_info.timeline.events.iter().filter_map(|e| { - if let Ok(d) = hoist_room_event_prev_content(&e.event) { - Some((&e.event, d)) - } else { - None - } - }) { - self.handle_timeline_event(room.clone(), &event, raw_event).await; - } - } - } - - for (room_id, room_info) in &response.rooms.invite { - if let Some(room) = self.get_room(room_id) { - for event in - room_info.invite_state.events.iter().filter_map(|e| e.deserialize().ok()) - { - self.handle_stripped_state_event(room.clone(), &event).await; - } - } - } - - for event in response.presence.events.iter().filter_map(|e| e.deserialize().ok()) { - self.on_presence_event(&event).await; - } - - for (room_id, notifications) in &response.notifications { - if let Some(room) = self.get_room(room_id) { - for notification in notifications { - self.on_room_notification(room.clone(), notification.clone()).await; - } - } - } - } - - async fn handle_timeline_event( - &self, - room: Room, - event: &AnySyncRoomEvent, - raw_event: &Raw, - ) { - match event { - AnySyncRoomEvent::State(event) => match event { - AnySyncStateEvent::RoomMember(e) => self.on_room_member(room, e).await, - AnySyncStateEvent::RoomName(e) => self.on_room_name(room, e).await, - AnySyncStateEvent::RoomCanonicalAlias(e) => { - self.on_room_canonical_alias(room, e).await - } - AnySyncStateEvent::RoomAliases(e) => self.on_room_aliases(room, e).await, - AnySyncStateEvent::RoomAvatar(e) => self.on_room_avatar(room, e).await, - AnySyncStateEvent::RoomPowerLevels(e) => self.on_room_power_levels(room, e).await, - AnySyncStateEvent::RoomTombstone(e) => self.on_room_tombstone(room, e).await, - AnySyncStateEvent::RoomJoinRules(e) => self.on_room_join_rules(room, e).await, - AnySyncStateEvent::PolicyRuleRoom(_) - | AnySyncStateEvent::PolicyRuleServer(_) - | AnySyncStateEvent::PolicyRuleUser(_) - | AnySyncStateEvent::RoomCreate(_) - | AnySyncStateEvent::RoomEncryption(_) - | AnySyncStateEvent::RoomGuestAccess(_) - | AnySyncStateEvent::RoomHistoryVisibility(_) - | AnySyncStateEvent::RoomPinnedEvents(_) - | AnySyncStateEvent::RoomServerAcl(_) - | AnySyncStateEvent::RoomThirdPartyInvite(_) - | AnySyncStateEvent::RoomTopic(_) - | AnySyncStateEvent::SpaceChild(_) - | AnySyncStateEvent::SpaceParent(_) => {} - _ => { - if let Ok(e) = raw_event.deserialize_as::>() - { - self.on_custom_event(room, &CustomEvent::State(&e)).await; - } - } - }, - AnySyncRoomEvent::Message(event) => match event { - AnySyncMessageEvent::RoomMessage(e) => self.on_room_message(room, e).await, - AnySyncMessageEvent::RoomMessageFeedback(e) => { - self.on_room_message_feedback(room, e).await - } - AnySyncMessageEvent::RoomRedaction(e) => self.on_room_redaction(room, e).await, - AnySyncMessageEvent::Reaction(e) => self.on_room_reaction(room, e).await, - AnySyncMessageEvent::CallInvite(e) => self.on_room_call_invite(room, e).await, - AnySyncMessageEvent::CallAnswer(e) => self.on_room_call_answer(room, e).await, - AnySyncMessageEvent::CallCandidates(e) => { - self.on_room_call_candidates(room, e).await - } - AnySyncMessageEvent::CallHangup(e) => self.on_room_call_hangup(room, e).await, - AnySyncMessageEvent::KeyVerificationReady(_) - | AnySyncMessageEvent::KeyVerificationStart(_) - | AnySyncMessageEvent::KeyVerificationCancel(_) - | AnySyncMessageEvent::KeyVerificationAccept(_) - | AnySyncMessageEvent::KeyVerificationKey(_) - | AnySyncMessageEvent::KeyVerificationMac(_) - | AnySyncMessageEvent::KeyVerificationDone(_) - | AnySyncMessageEvent::RoomEncrypted(_) - | AnySyncMessageEvent::Sticker(_) => {} - _ => { - if let Ok(e) = - raw_event.deserialize_as::>() - { - self.on_custom_event(room, &CustomEvent::Message(&e)).await; - } - } - }, - AnySyncRoomEvent::RedactedState(_event) => {} - AnySyncRoomEvent::RedactedMessage(_event) => {} - } - } - - async fn handle_state_event( - &self, - room: Room, - event: &AnySyncStateEvent, - raw_event: &Raw, - ) { - match event { - AnySyncStateEvent::RoomMember(member) => self.on_state_member(room, member).await, - AnySyncStateEvent::RoomName(name) => self.on_state_name(room, name).await, - AnySyncStateEvent::RoomCanonicalAlias(canonical) => { - self.on_state_canonical_alias(room, canonical).await - } - AnySyncStateEvent::RoomAliases(aliases) => self.on_state_aliases(room, aliases).await, - AnySyncStateEvent::RoomAvatar(avatar) => self.on_state_avatar(room, avatar).await, - AnySyncStateEvent::RoomPowerLevels(power) => { - self.on_state_power_levels(room, power).await - } - AnySyncStateEvent::RoomJoinRules(rules) => self.on_state_join_rules(room, rules).await, - AnySyncStateEvent::RoomTombstone(tomb) => { - // TODO make `on_state_tombstone` method - self.on_room_tombstone(room, tomb).await - } - AnySyncStateEvent::PolicyRuleRoom(_) - | AnySyncStateEvent::PolicyRuleServer(_) - | AnySyncStateEvent::PolicyRuleUser(_) - | AnySyncStateEvent::RoomCreate(_) - | AnySyncStateEvent::RoomEncryption(_) - | AnySyncStateEvent::RoomGuestAccess(_) - | AnySyncStateEvent::RoomHistoryVisibility(_) - | AnySyncStateEvent::RoomPinnedEvents(_) - | AnySyncStateEvent::RoomServerAcl(_) - | AnySyncStateEvent::RoomThirdPartyInvite(_) - | AnySyncStateEvent::RoomTopic(_) - | AnySyncStateEvent::SpaceChild(_) - | AnySyncStateEvent::SpaceParent(_) => {} - _ => { - if let Ok(e) = raw_event.deserialize_as::>() { - self.on_custom_event(room, &CustomEvent::State(&e)).await; - } - } - } - } - - pub(crate) async fn handle_stripped_state_event( - &self, - // TODO these events are only handled in invited rooms. - room: Room, - event: &AnyStrippedStateEvent, - ) { - match event { - AnyStrippedStateEvent::RoomMember(member) => { - self.on_stripped_state_member(room, member, None).await - } - AnyStrippedStateEvent::RoomName(name) => self.on_stripped_state_name(room, name).await, - AnyStrippedStateEvent::RoomCanonicalAlias(canonical) => { - self.on_stripped_state_canonical_alias(room, canonical).await - } - AnyStrippedStateEvent::RoomAliases(aliases) => { - self.on_stripped_state_aliases(room, aliases).await - } - AnyStrippedStateEvent::RoomAvatar(avatar) => { - self.on_stripped_state_avatar(room, avatar).await - } - AnyStrippedStateEvent::RoomPowerLevels(power) => { - self.on_stripped_state_power_levels(room, power).await - } - AnyStrippedStateEvent::RoomJoinRules(rules) => { - self.on_stripped_state_join_rules(room, rules).await - } - _ => {} - } - } - - pub(crate) async fn handle_room_account_data_event( - &self, - room: Room, - event: &AnyRoomAccountDataEvent, - ) { - if let AnyRoomAccountDataEvent::FullyRead(event) = event { - self.on_non_room_fully_read(room, event).await - } - } - - pub(crate) async fn handle_account_data_event(&self, event: &AnyGlobalAccountDataEvent) { - match event { - AnyGlobalAccountDataEvent::IgnoredUserList(ignored) => { - self.on_non_room_ignored_users(ignored).await - } - AnyGlobalAccountDataEvent::PushRules(rules) => self.on_non_room_push_rules(rules).await, - _ => {} - } - } - - pub(crate) async fn handle_ephemeral_event( - &self, - room: Room, - event: &AnySyncEphemeralRoomEvent, - ) { - match event { - AnySyncEphemeralRoomEvent::Typing(typing) => { - self.on_non_room_typing(room, typing).await - } - AnySyncEphemeralRoomEvent::Receipt(receipt) => { - self.on_non_room_receipt(room, receipt).await - } - _ => {} - } - } -} - -/// This represents the various "unrecognized" events. -#[derive(Clone, Copy, Debug)] -pub enum CustomEvent<'c> { - /// A custom basic event. - Basic(&'c GlobalAccountDataEvent), - /// A custom basic event. - EphemeralRoom(&'c SyncEphemeralRoomEvent), - /// A custom room event. - Message(&'c SyncMessageEvent), - /// A custom state event. - State(&'c SyncStateEvent), - /// A custom stripped state event. - StrippedState(&'c StrippedStateEvent), -} - -/// This trait allows any type implementing `EventHandler` to specify event -/// callbacks for each event. The `Client` calls each method when the -/// corresponding event is received. -/// -/// # Examples -/// ``` -/// # use std::ops::Deref; -/// # use std::sync::Arc; -/// # use std::{env, process::exit}; -/// # use matrix_sdk::{ -/// # async_trait, -/// # EventHandler, -/// # ruma::events::{ -/// # room::message::{MessageEventContent, MessageType, TextMessageEventContent}, -/// # SyncMessageEvent -/// # }, -/// # locks::RwLock, -/// # room::Room, -/// # }; -/// -/// struct EventCallback; -/// -/// #[async_trait] -/// impl EventHandler for EventCallback { -/// async fn on_room_message(&self, room: Room, event: &SyncMessageEvent) { -/// if let Room::Joined(room) = room { -/// if let SyncMessageEvent { -/// content: -/// MessageEventContent { -/// msgtype: MessageType::Text(TextMessageEventContent { body: msg_body, .. }), -/// .. -/// }, -/// sender, -/// .. -/// } = event -/// { -/// let member = room.get_member(&sender).await.unwrap().unwrap(); -/// let name = member -/// .display_name() -/// .unwrap_or_else(|| member.user_id().as_str()); -/// println!("{}: {}", name, msg_body); -/// } -/// } -/// } -/// } -/// ``` -#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] -#[cfg_attr(not(target_arch = "wasm32"), async_trait)] -pub trait EventHandler: Send + Sync { - // ROOM EVENTS from `IncomingTimeline` - /// Fires when `Client` receives a `RoomEvent::RoomMember` event. - async fn on_room_member(&self, _: Room, _: &SyncStateEvent) {} - /// Fires when `Client` receives a `RoomEvent::RoomName` event. - async fn on_room_name(&self, _: Room, _: &SyncStateEvent) {} - /// Fires when `Client` receives a `RoomEvent::RoomCanonicalAlias` event. - async fn on_room_canonical_alias( - &self, - _: Room, - _: &SyncStateEvent, - ) { - } - /// Fires when `Client` receives a `RoomEvent::RoomAliases` event. - async fn on_room_aliases(&self, _: Room, _: &SyncStateEvent) {} - /// Fires when `Client` receives a `RoomEvent::RoomAvatar` event. - async fn on_room_avatar(&self, _: Room, _: &SyncStateEvent) {} - /// Fires when `Client` receives a `RoomEvent::RoomMessage` event. - async fn on_room_message(&self, _: Room, _: &SyncMessageEvent) {} - /// Fires when `Client` receives a `RoomEvent::RoomMessageFeedback` event. - async fn on_room_message_feedback(&self, _: Room, _: &SyncMessageEvent) {} - /// Fires when `Client` receives a `RoomEvent::Reaction` event. - async fn on_room_reaction(&self, _: Room, _: &SyncMessageEvent) {} - /// Fires when `Client` receives a `RoomEvent::CallInvite` event - async fn on_room_call_invite(&self, _: Room, _: &SyncMessageEvent) {} - /// Fires when `Client` receives a `RoomEvent::CallAnswer` event - async fn on_room_call_answer(&self, _: Room, _: &SyncMessageEvent) {} - /// Fires when `Client` receives a `RoomEvent::CallCandidates` event - async fn on_room_call_candidates(&self, _: Room, _: &SyncMessageEvent) { - } - /// Fires when `Client` receives a `RoomEvent::CallHangup` event - async fn on_room_call_hangup(&self, _: Room, _: &SyncMessageEvent) {} - /// Fires when `Client` receives a `RoomEvent::RoomRedaction` event. - async fn on_room_redaction(&self, _: Room, _: &SyncRedactionEvent) {} - /// Fires when `Client` receives a `RoomEvent::RoomPowerLevels` event. - async fn on_room_power_levels(&self, _: Room, _: &SyncStateEvent) {} - /// Fires when `Client` receives a `RoomEvent::RoomJoinRules` event. - async fn on_room_join_rules(&self, _: Room, _: &SyncStateEvent) {} - /// Fires when `Client` receives a `RoomEvent::Tombstone` event. - async fn on_room_tombstone(&self, _: Room, _: &SyncStateEvent) {} - - /// Fires when `Client` receives room events that trigger notifications - /// according to the push rules of the user. - async fn on_room_notification(&self, _: Room, _: Notification) {} - - // `RoomEvent`s from `IncomingState` - /// Fires when `Client` receives a `StateEvent::RoomMember` event. - async fn on_state_member(&self, _: Room, _: &SyncStateEvent) {} - /// Fires when `Client` receives a `StateEvent::RoomName` event. - async fn on_state_name(&self, _: Room, _: &SyncStateEvent) {} - /// Fires when `Client` receives a `StateEvent::RoomCanonicalAlias` event. - async fn on_state_canonical_alias( - &self, - _: Room, - _: &SyncStateEvent, - ) { - } - /// Fires when `Client` receives a `StateEvent::RoomAliases` event. - async fn on_state_aliases(&self, _: Room, _: &SyncStateEvent) {} - /// Fires when `Client` receives a `StateEvent::RoomAvatar` event. - async fn on_state_avatar(&self, _: Room, _: &SyncStateEvent) {} - /// Fires when `Client` receives a `StateEvent::RoomPowerLevels` event. - async fn on_state_power_levels(&self, _: Room, _: &SyncStateEvent) {} - /// Fires when `Client` receives a `StateEvent::RoomJoinRules` event. - async fn on_state_join_rules(&self, _: Room, _: &SyncStateEvent) {} - - // `AnyStrippedStateEvent`s - /// Fires when `Client` receives a - /// `AnyStrippedStateEvent::StrippedRoomMember` event. - async fn on_stripped_state_member( - &self, - _: Room, - _: &StrippedStateEvent, - _: Option, - ) { - } - /// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomName` - /// event. - async fn on_stripped_state_name(&self, _: Room, _: &StrippedStateEvent) {} - /// Fires when `Client` receives a - /// `AnyStrippedStateEvent::StrippedRoomCanonicalAlias` event. - async fn on_stripped_state_canonical_alias( - &self, - _: Room, - _: &StrippedStateEvent, - ) { - } - /// Fires when `Client` receives a - /// `AnyStrippedStateEvent::StrippedRoomAliases` event. - async fn on_stripped_state_aliases( - &self, - _: Room, - _: &StrippedStateEvent, - ) { - } - /// Fires when `Client` receives a - /// `AnyStrippedStateEvent::StrippedRoomAvatar` event. - async fn on_stripped_state_avatar(&self, _: Room, _: &StrippedStateEvent) {} - /// Fires when `Client` receives a - /// `AnyStrippedStateEvent::StrippedRoomPowerLevels` event. - async fn on_stripped_state_power_levels( - &self, - _: Room, - _: &StrippedStateEvent, - ) { - } - /// Fires when `Client` receives a - /// `AnyStrippedStateEvent::StrippedRoomJoinRules` event. - async fn on_stripped_state_join_rules( - &self, - _: Room, - _: &StrippedStateEvent, - ) { - } - - // `NonRoomEvent` (this is a type alias from ruma_events) - /// Fires when `Client` receives a `NonRoomEvent::RoomPresence` event. - async fn on_non_room_presence(&self, _: Room, _: &PresenceEvent) {} - /// Fires when `Client` receives a `NonRoomEvent::RoomName` event. - async fn on_non_room_ignored_users( - &self, - _: &GlobalAccountDataEvent, - ) { - } - /// Fires when `Client` receives a `NonRoomEvent::RoomCanonicalAlias` event. - async fn on_non_room_push_rules(&self, _: &GlobalAccountDataEvent) {} - /// Fires when `Client` receives a `NonRoomEvent::RoomAliases` event. - async fn on_non_room_fully_read( - &self, - _: Room, - _: &RoomAccountDataEvent, - ) { - } - /// Fires when `Client` receives a `NonRoomEvent::Typing` event. - async fn on_non_room_typing(&self, _: Room, _: &SyncEphemeralRoomEvent) {} - /// Fires when `Client` receives a `NonRoomEvent::Receipt` event. - /// - /// This is always a read receipt. - async fn on_non_room_receipt(&self, _: Room, _: &SyncEphemeralRoomEvent) {} - - // `PresenceEvent` is a struct so there is only the one method - /// Fires when `Client` receives a `NonRoomEvent::RoomAliases` event. - async fn on_presence_event(&self, _: &PresenceEvent) {} - - /// Fires when `Client` receives a `Event::Custom` event or if - /// deserialization fails because the event was unknown to ruma. - /// - /// The only guarantee this method can give about the event is that it is - /// valid JSON. - async fn on_unrecognized_event(&self, _: Room, _: &RawJsonValue) {} - - /// Fires when `Client` receives a `Event::Custom` event or if - /// deserialization fails because the event was unknown to ruma. - /// - /// The only guarantee this method can give about the event is that it is in - /// the shape of a valid matrix event. - async fn on_custom_event(&self, _: Room, _: &CustomEvent<'_>) {} -} - -#[cfg(test)] -mod test { - use std::{sync::Arc, time::Duration}; - - use matrix_sdk_common::{async_trait, locks::Mutex}; - use matrix_sdk_test::{async_test, test_json}; - use mockito::{mock, Matcher}; - use ruma::user_id; - #[cfg(target_arch = "wasm32")] - pub use wasm_bindgen_test::*; - - use super::*; - - #[derive(Clone)] - pub struct EvHandlerTest(Arc>>); - - #[cfg_attr(target_arch = "wasm32", async_trait(?Send))] - #[cfg_attr(not(target_arch = "wasm32"), async_trait)] - impl EventHandler for EvHandlerTest { - async fn on_room_member(&self, _: Room, _: &SyncStateEvent) { - self.0.lock().await.push("member".to_string()) - } - async fn on_room_name(&self, _: Room, _: &SyncStateEvent) { - self.0.lock().await.push("name".to_string()) - } - async fn on_room_canonical_alias( - &self, - _: Room, - _: &SyncStateEvent, - ) { - self.0.lock().await.push("canonical".to_string()) - } - async fn on_room_aliases(&self, _: Room, _: &SyncStateEvent) { - self.0.lock().await.push("aliases".to_string()) - } - async fn on_room_avatar(&self, _: Room, _: &SyncStateEvent) { - self.0.lock().await.push("avatar".to_string()) - } - async fn on_room_message(&self, _: Room, _: &SyncMessageEvent) { - self.0.lock().await.push("message".to_string()) - } - async fn on_room_message_feedback( - &self, - _: Room, - _: &SyncMessageEvent, - ) { - self.0.lock().await.push("feedback".to_string()) - } - async fn on_room_call_invite(&self, _: Room, _: &SyncMessageEvent) { - self.0.lock().await.push("call invite".to_string()) - } - async fn on_room_call_answer(&self, _: Room, _: &SyncMessageEvent) { - self.0.lock().await.push("call answer".to_string()) - } - async fn on_room_call_candidates( - &self, - _: Room, - _: &SyncMessageEvent, - ) { - self.0.lock().await.push("call candidates".to_string()) - } - async fn on_room_call_hangup(&self, _: Room, _: &SyncMessageEvent) { - self.0.lock().await.push("call hangup".to_string()) - } - async fn on_room_redaction(&self, _: Room, _: &SyncRedactionEvent) { - self.0.lock().await.push("redaction".to_string()) - } - async fn on_room_power_levels(&self, _: Room, _: &SyncStateEvent) { - self.0.lock().await.push("power".to_string()) - } - async fn on_room_tombstone(&self, _: Room, _: &SyncStateEvent) { - self.0.lock().await.push("tombstone".to_string()) - } - - async fn on_state_member(&self, _: Room, _: &SyncStateEvent) { - self.0.lock().await.push("state member".to_string()) - } - async fn on_state_name(&self, _: Room, _: &SyncStateEvent) { - self.0.lock().await.push("state name".to_string()) - } - async fn on_state_canonical_alias( - &self, - _: Room, - _: &SyncStateEvent, - ) { - self.0.lock().await.push("state canonical".to_string()) - } - async fn on_state_aliases(&self, _: Room, _: &SyncStateEvent) { - self.0.lock().await.push("state aliases".to_string()) - } - async fn on_state_avatar(&self, _: Room, _: &SyncStateEvent) { - self.0.lock().await.push("state avatar".to_string()) - } - async fn on_state_power_levels( - &self, - _: Room, - _: &SyncStateEvent, - ) { - self.0.lock().await.push("state power".to_string()) - } - async fn on_state_join_rules(&self, _: Room, _: &SyncStateEvent) { - self.0.lock().await.push("state rules".to_string()) - } - - // `AnyStrippedStateEvent`s - /// Fires when `Client` receives a - /// `AnyStrippedStateEvent::StrippedRoomMember` event. - async fn on_stripped_state_member( - &self, - _: Room, - _: &StrippedStateEvent, - _: Option, - ) { - self.0.lock().await.push("stripped state member".to_string()) - } - /// Fires when `Client` receives a - /// `AnyStrippedStateEvent::StrippedRoomName` event. - async fn on_stripped_state_name(&self, _: Room, _: &StrippedStateEvent) { - self.0.lock().await.push("stripped state name".to_string()) - } - /// Fires when `Client` receives a - /// `AnyStrippedStateEvent::StrippedRoomCanonicalAlias` event. - async fn on_stripped_state_canonical_alias( - &self, - _: Room, - _: &StrippedStateEvent, - ) { - self.0.lock().await.push("stripped state canonical".to_string()) - } - /// Fires when `Client` receives a - /// `AnyStrippedStateEvent::StrippedRoomAliases` event. - async fn on_stripped_state_aliases( - &self, - _: Room, - _: &StrippedStateEvent, - ) { - self.0.lock().await.push("stripped state aliases".to_string()) - } - /// Fires when `Client` receives a - /// `AnyStrippedStateEvent::StrippedRoomAvatar` event. - async fn on_stripped_state_avatar( - &self, - _: Room, - _: &StrippedStateEvent, - ) { - self.0.lock().await.push("stripped state avatar".to_string()) - } - /// Fires when `Client` receives a - /// `AnyStrippedStateEvent::StrippedRoomPowerLevels` event. - async fn on_stripped_state_power_levels( - &self, - _: Room, - _: &StrippedStateEvent, - ) { - self.0.lock().await.push("stripped state power".to_string()) - } - /// Fires when `Client` receives a - /// `AnyStrippedStateEvent::StrippedRoomJoinRules` event. - async fn on_stripped_state_join_rules( - &self, - _: Room, - _: &StrippedStateEvent, - ) { - self.0.lock().await.push("stripped state rules".to_string()) - } - - async fn on_non_room_presence(&self, _: Room, _: &PresenceEvent) { - self.0.lock().await.push("presence".to_string()) - } - async fn on_non_room_ignored_users( - &self, - _: &GlobalAccountDataEvent, - ) { - self.0.lock().await.push("account ignore".to_string()) - } - async fn on_non_room_push_rules(&self, _: &GlobalAccountDataEvent) { - self.0.lock().await.push("account push rules".to_string()) - } - async fn on_non_room_fully_read( - &self, - _: Room, - _: &RoomAccountDataEvent, - ) { - self.0.lock().await.push("account read".to_string()) - } - async fn on_non_room_typing( - &self, - _: Room, - _: &SyncEphemeralRoomEvent, - ) { - self.0.lock().await.push("typing event".to_string()) - } - async fn on_non_room_receipt( - &self, - _: Room, - _: &SyncEphemeralRoomEvent, - ) { - self.0.lock().await.push("receipt event".to_string()) - } - async fn on_presence_event(&self, _: &PresenceEvent) { - self.0.lock().await.push("presence event".to_string()) - } - async fn on_unrecognized_event(&self, _: Room, _: &RawJsonValue) { - self.0.lock().await.push("unrecognized event".to_string()) - } - async fn on_custom_event(&self, _: Room, _: &CustomEvent<'_>) { - self.0.lock().await.push("custom event".to_string()) - } - async fn on_room_notification(&self, _: Room, _: Notification) { - self.0.lock().await.push("notification".to_string()) - } - } - - use crate::{Client, Session, SyncSettings}; - - async fn get_client() -> Client { - let session = Session { - access_token: "1234".to_owned(), - user_id: user_id!("@example:localhost"), - device_id: "DEVICEID".into(), - }; - let homeserver = url::Url::parse(&mockito::server_url()).unwrap(); - let client = Client::new(homeserver).unwrap(); - client.restore_login(session).await.unwrap(); - client - } - - async fn mock_sync(client: &Client, response: String) { - let _m = mock("GET", Matcher::Regex(r"^/_matrix/client/r0/sync\?.*$".to_string())) - .with_status(200) - .match_header("authorization", "Bearer 1234") - .with_body(response) - .create(); - - let sync_settings = SyncSettings::new().timeout(Duration::from_millis(3000)); - let _response = client.sync_once(sync_settings).await.unwrap(); - } - - #[async_test] - async fn event_handler_joined() { - let vec = Arc::new(Mutex::new(Vec::new())); - let test_vec = Arc::clone(&vec); - let handler = Box::new(EvHandlerTest(vec)); - - let client = get_client().await; - client.set_event_handler(handler).await; - mock_sync(&client, test_json::SYNC.to_string()).await; - - let v = test_vec.lock().await; - assert_eq!( - v.as_slice(), - [ - "account ignore", - "receipt event", - "account read", - "state rules", - "state member", - "state aliases", - "state power", - "state canonical", - "state member", - "state member", - "message", - "presence event", - "notification", - ], - ) - } - - #[async_test] - async fn event_handler_invite() { - let vec = Arc::new(Mutex::new(Vec::new())); - let test_vec = Arc::clone(&vec); - let handler = Box::new(EvHandlerTest(vec)); - - let client = get_client().await; - client.set_event_handler(handler).await; - mock_sync(&client, test_json::INVITE_SYNC.to_string()).await; - - let v = test_vec.lock().await; - assert_eq!(v.as_slice(), ["stripped state name", "stripped state member", "presence event"],) - } - - #[async_test] - async fn event_handler_leave() { - let vec = Arc::new(Mutex::new(Vec::new())); - let test_vec = Arc::clone(&vec); - let handler = Box::new(EvHandlerTest(vec)); - - let client = get_client().await; - client.set_event_handler(handler).await; - mock_sync(&client, test_json::LEAVE_SYNC.to_string()).await; - - let v = test_vec.lock().await; - assert_eq!( - v.as_slice(), - [ - "account ignore", - "state rules", - "state member", - "state aliases", - "state power", - "state canonical", - "state member", - "state member", - "message", - "presence event", - "notification", - ], - ) - } - - #[async_test] - async fn event_handler_more_events() { - let vec = Arc::new(Mutex::new(Vec::new())); - let test_vec = Arc::clone(&vec); - let handler = Box::new(EvHandlerTest(vec)); - - let client = get_client().await; - client.set_event_handler(handler).await; - mock_sync(&client, test_json::MORE_SYNC.to_string()).await; - - let v = test_vec.lock().await; - assert_eq!( - v.as_slice(), - [ - "receipt event", - "typing event", - "message", - "message", // this is a message edit event - "redaction", - "message", // this is a notice event - ], - ) - } - - #[async_test] - async fn event_handler_voip() { - let vec = Arc::new(Mutex::new(Vec::new())); - let test_vec = Arc::clone(&vec); - let handler = Box::new(EvHandlerTest(vec)); - - let client = get_client().await; - client.set_event_handler(handler).await; - mock_sync(&client, test_json::VOIP_SYNC.to_string()).await; - - let v = test_vec.lock().await; - assert_eq!(v.as_slice(), ["call invite", "call answer", "call candidates", "call hangup",],) - } - - #[async_test] - async fn event_handler_two_syncs() { - let vec = Arc::new(Mutex::new(Vec::new())); - let test_vec = Arc::clone(&vec); - let handler = Box::new(EvHandlerTest(vec)); - - let client = get_client().await; - client.set_event_handler(handler).await; - mock_sync(&client, test_json::SYNC.to_string()).await; - mock_sync(&client, test_json::MORE_SYNC.to_string()).await; - - let v = test_vec.lock().await; - assert_eq!( - v.as_slice(), - [ - "account ignore", - "receipt event", - "account read", - "state rules", - "state member", - "state aliases", - "state power", - "state canonical", - "state member", - "state member", - "message", - "presence event", - "notification", - "receipt event", - "typing event", - "message", - "message", // this is a message edit event - "redaction", - "message", // this is a notice event - "notification", - "notification", - "notification", - ], - ) - } -} diff --git a/matrix_sdk/src/device.rs b/matrix_sdk/src/identities/devices.rs similarity index 55% rename from matrix_sdk/src/device.rs rename to matrix_sdk/src/identities/devices.rs index 3d50a460..fb89ee56 100644 --- a/matrix_sdk/src/device.rs +++ b/matrix_sdk/src/identities/devices.rs @@ -20,14 +20,19 @@ use matrix_sdk_base::crypto::{ }; use ruma::{events::key::verification::VerificationMethod, DeviceId, DeviceIdBox}; +use super::ManualVerifyError; use crate::{ error::Result, verification::{SasVerification, VerificationRequest}, Client, }; +/// A device represents a E2EE capable client or device of an user. +/// +/// A device is backed by [device keys] that are uploaded to the server. +/// +/// [device keys]: https://spec.matrix.org/unstable/client-server-api/#device-keys #[derive(Clone, Debug)] -/// A device represents a E2EE capable client of an user. pub struct Device { pub(crate) inner: BaseDevice, pub(crate) client: Client, @@ -42,49 +47,13 @@ impl Deref for Device { } impl Device { - /// Start a interactive verification with this `Device` + /// Request an interactive verification with this `Device`. /// - /// Returns a `Sas` object that represents the interactive verification - /// flow. - /// - /// This method has been deprecated in the spec and the - /// [`request_verification()`] method should be used instead. - /// - /// # Examples - /// - /// ```no_run - /// # use std::convert::TryFrom; - /// # use matrix_sdk::{Client, ruma::UserId}; - /// # use url::Url; - /// # use futures::executor::block_on; - /// # let alice = UserId::try_from("@alice:example.org").unwrap(); - /// # let homeserver = Url::parse("http://example.com").unwrap(); - /// # let client = Client::new(homeserver).unwrap(); - /// # block_on(async { - /// let device = client.get_device(&alice, "DEVICEID".into()) - /// .await - /// .unwrap() - /// .unwrap(); - /// - /// let verification = device.start_verification().await.unwrap(); - /// # }); - /// ``` - /// - /// [`request_verification()`]: #method.request_verification - pub async fn start_verification(&self) -> Result { - let (sas, request) = self.inner.start_verification().await?; - self.client.send_to_device(&request).await?; - - Ok(SasVerification { inner: sas, client: self.client.clone() }) - } - - /// Request an interacitve verification with this `Device` - /// - /// Returns a `VerificationRequest` object and a to-device request that - /// needs to be sent out. + /// Returns a [`VerificationRequest`] object that can be used to control the + /// verification flow. /// /// The default methods that are supported are `m.sas.v1` and - /// `m.qr_code.show.v1`, if this isn't desireable the + /// `m.qr_code.show.v1`, if this isn't desirable the /// [`request_verification_with_methods()`] method can be used to override /// this. /// @@ -99,13 +68,12 @@ impl Device { /// # let homeserver = Url::parse("http://example.com").unwrap(); /// # let client = Client::new(homeserver).unwrap(); /// # block_on(async { - /// let device = client.get_device(&alice, "DEVICEID".into()) - /// .await - /// .unwrap() - /// .unwrap(); + /// let device = client.get_device(&alice, "DEVICEID".into()).await?; /// - /// let verification = device.request_verification().await.unwrap(); - /// # }); + /// if let Some(device) = device { + /// let verification = device.request_verification().await?; + /// } + /// # anyhow::Result::<()>::Ok(()) }); /// ``` /// /// [`request_verification_with_methods()`]: @@ -117,14 +85,19 @@ impl Device { Ok(VerificationRequest { inner: verification, client: self.client.clone() }) } - /// Request an interacitve verification with this `Device` + /// Request an interactive verification with this `Device`. /// - /// Returns a `VerificationRequest` object and a to-device request that - /// needs to be sent out. + /// Returns a [`VerificationRequest`] object that can be used to control the + /// verification flow. /// /// # Arguments /// - /// * `methods` - The verification methods that we want to support. + /// * `methods` - The verification methods that we want to support. Must be + /// non-empty. + /// + /// # Panics + /// + /// This method will panic if `methods` is empty. /// /// # Examples /// @@ -143,30 +116,157 @@ impl Device { /// # let homeserver = Url::parse("http://example.com").unwrap(); /// # let client = Client::new(homeserver).unwrap(); /// # block_on(async { - /// let device = client.get_device(&alice, "DEVICEID".into()) - /// .await - /// .unwrap() - /// .unwrap(); + /// let device = client.get_device(&alice, "DEVICEID".into()).await?; /// /// // We don't want to support showing a QR code, we only support SAS /// // verification /// let methods = vec![VerificationMethod::SasV1]; /// - /// let verification = device.request_verification_with_methods(methods).await.unwrap(); - /// # }); + /// if let Some(device) = device { + /// let verification = device.request_verification_with_methods(methods).await?; + /// } + /// # anyhow::Result::<()>::Ok(()) }); /// ``` pub async fn request_verification_with_methods( &self, methods: Vec, ) -> Result { + if methods.is_empty() { + panic!("The list of verification methods can't be non-empty"); + } + let (verification, request) = self.inner.request_verification_with_methods(methods).await; self.client.send_verification_request(request).await?; Ok(VerificationRequest { inner: verification, client: self.client.clone() }) } - /// Is the device considered to be verified, either by locally trusting it - /// or using cross signing. + /// Start an interactive verification with this [`Device`] + /// + /// Returns a [`SasVerification`] object that represents the interactive + /// verification flow. + /// + /// This method has been deprecated in the spec and the + /// [`request_verification()`] method should be used instead. + /// + /// # Examples + /// + /// ```no_run + /// # use std::convert::TryFrom; + /// # use matrix_sdk::{Client, ruma::UserId}; + /// # use url::Url; + /// # use futures::executor::block_on; + /// # let alice = UserId::try_from("@alice:example.org").unwrap(); + /// # let homeserver = Url::parse("http://example.com").unwrap(); + /// # let client = Client::new(homeserver).unwrap(); + /// # block_on(async { + /// let device = client.get_device(&alice, "DEVICEID".into()).await?; + /// + /// if let Some(device) = device { + /// let verification = device.start_verification().await?; + /// } + /// # anyhow::Result::<()>::Ok(()) }); + /// ``` + /// + /// [`request_verification()`]: #method.request_verification + #[deprecated( + since = "0.4.0", + note = "directly starting a verification is deprecated in the spec. \ + Users should instead use request_verification()" + )] + pub async fn start_verification(&self) -> Result { + let (sas, request) = self.inner.start_verification().await?; + self.client.send_to_device(&request).await?; + + Ok(SasVerification { inner: sas, client: self.client.clone() }) + } + + /// Manually verify this device. + /// + /// This method will attempt to sign the device using our private cross + /// signing key. + /// + /// This method will always fail if the device belongs to someone else, we + /// can only sign our own devices. + /// + /// It can also fail if we don't have the private part of our self-signing + /// key. + /// + /// The state of our private cross signing keys can be inspected using the + /// [`Client::cross_signing_status()`] method. + /// + /// [`Client::cross_signing_status()`]: crate::Client::cross_signing_status + /// + /// # Examples + /// + /// ```no_run + /// # use std::convert::TryFrom; + /// # use matrix_sdk::{ + /// # Client, + /// # ruma::{ + /// # UserId, + /// # events::key::verification::VerificationMethod, + /// # } + /// # }; + /// # use url::Url; + /// # use futures::executor::block_on; + /// # let alice = UserId::try_from("@alice:example.org").unwrap(); + /// # let homeserver = Url::parse("http://example.com").unwrap(); + /// # let client = Client::new(homeserver).unwrap(); + /// # block_on(async { + /// let device = client.get_device(&alice, "DEVICEID".into()).await?; + /// + /// if let Some(device) = device { + /// device.verify().await?; + /// } + /// # anyhow::Result::<()>::Ok(()) }); + /// ``` + pub async fn verify(&self) -> std::result::Result<(), ManualVerifyError> { + let request = self.inner.verify().await?; + self.client.send(request, None).await?; + + Ok(()) + } + + /// Is the device considered to be verified. + /// + /// A device is considered to be verified, either if it's locally trusted, + /// or if it's signed by the appropriate cross signing key. + /// + /// If the device belongs to our own userk, the device needs to be signed by + /// our self-signing key and our own user identity needs to be verified. + /// + /// If the device belongs to some other user, the device needs to be signed + /// by the users signing key and the user identity of the user needs to be + /// verified. + //// # Examples + /// + /// ```no_run + /// # use std::convert::TryFrom; + /// # use matrix_sdk::{ + /// # Client, + /// # ruma::{ + /// # UserId, + /// # events::key::verification::VerificationMethod, + /// # } + /// # }; + /// # use url::Url; + /// # use futures::executor::block_on; + /// # let alice = UserId::try_from("@alice:example.org").unwrap(); + /// # let homeserver = Url::parse("http://example.com").unwrap(); + /// # let client = Client::new(homeserver).unwrap(); + /// # block_on(async { + /// let user = client.get_user_identity(&alice).await?; + /// + /// if let Some(user) = user { + /// if user.verified() { + /// println!("User {} is verified", user.user_id().as_str()); + /// } else { + /// println!("User {} is not verified", user.user_id().as_str()); + /// } + /// } + /// # anyhow::Result::<()>::Ok(()) }); + /// ``` pub fn verified(&self) -> bool { self.inner.verified() } @@ -187,7 +287,7 @@ impl Device { } } -/// A read only view over all devices belonging to a user. +/// The collection of all the [`Device`]s a user has. #[derive(Debug)] pub struct UserDevices { pub(crate) inner: BaseUserDevices, diff --git a/matrix_sdk/src/identities/mod.rs b/matrix_sdk/src/identities/mod.rs new file mode 100644 index 00000000..f627d286 --- /dev/null +++ b/matrix_sdk/src/identities/mod.rs @@ -0,0 +1,120 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Cryptographic identities used in Matrix. +//! +//! There are two types of cryptographic identities in Matrix. +//! +//! 1. Devices, which are backed by [device keys], they represent each +//! individual log in by an E2EE capable Matrix client. We represent devices +//! using the [`Device`] struct. +//! +//! 2. User identities, which are backed by [cross signing keys]. The user +//! identity represent a unique E2EE capable identity of any given user. This +//! identity is generally created and uploaded to the server by the first E2EE +//! capable client the user logs in with. We represent user identities using the +//! [`UserIdentity`] struct. +//! +//! A [`Device`] or an [`UserIdentity`] can be used to inspect the public keys +//! of the device or identity, or it can be used to initiate a interactive +//! verification flow. They can also be manually marked as verified. +//! +//! # Examples +//! +//! Verifying a device is pretty straightforward: +//! +//! ```no_run +//! # use std::convert::TryFrom; +//! # use matrix_sdk::{Client, ruma::UserId}; +//! # use url::Url; +//! # use futures::executor::block_on; +//! # let alice = UserId::try_from("@alice:example.org").unwrap(); +//! # let homeserver = Url::parse("http://example.com").unwrap(); +//! # let client = Client::new(homeserver).unwrap(); +//! # block_on(async { +//! let device = client.get_device(&alice, "DEVICEID".into()).await?; +//! +//! if let Some(device) = device { +//! // Let's request the device to be verified. +//! let verification = device.request_verification().await?; +//! +//! // Actually this is taking too long. +//! verification.cancel().await?; +//! +//! // Let's just mark it as verified. +//! device.verify().await?; +//! } +//! # anyhow::Result::<()>::Ok(()) }); +//! ``` +//! +//! Verifying a user identity works largely the same: +//! +//! ```no_run +//! # use std::convert::TryFrom; +//! # use matrix_sdk::{Client, ruma::UserId}; +//! # use url::Url; +//! # use futures::executor::block_on; +//! # let alice = UserId::try_from("@alice:example.org").unwrap(); +//! # let homeserver = Url::parse("http://example.com").unwrap(); +//! # let client = Client::new(homeserver).unwrap(); +//! # block_on(async { +//! let user = client.get_user_identity(&alice).await?; +//! +//! if let Some(user) = user { +//! // Let's request the user to be verified. +//! let verification = user.request_verification().await?; +//! +//! // Actually this is taking too long. +//! verification.cancel().await?; +//! +//! // Let's just mark it as verified. +//! user.verify().await?; +//! } +//! # anyhow::Result::<()>::Ok(()) }); +//! ``` +//! +//! [cross signing keys]: https://spec.matrix.org/unstable/client-server-api/#cross-signing +//! [device keys]: https://spec.matrix.org/unstable/client-server-api/#device-keys + +mod devices; +mod users; + +pub use devices::{Device, UserDevices}; +pub use matrix_sdk_base::crypto::MasterPubkey; +pub use users::UserIdentity; + +/// Error for the manual verification step, when we manually sign users or +/// devices. +#[derive(thiserror::Error, Debug)] +pub enum ManualVerifyError { + /// Error that happens when we try to upload the user or device signature. + #[error(transparent)] + Http(#[from] crate::HttpError), + /// Error that happens when we try to sign the user or device. + #[error(transparent)] + Signature(#[from] matrix_sdk_base::crypto::SignatureError), +} + +/// Error when requesting a verification. +#[derive(thiserror::Error, Debug)] +pub enum RequestVerificationError { + /// An ordinary error coming from the SDK, i.e. when we fail to send out a + /// HTTP request or if there's an error with the storage layer. + #[error(transparent)] + Sdk(#[from] crate::Error), + /// Verifying other users requires having a DM open with them, this error + /// signals that we didn't have a DM and that we failed to create one. + #[error("Couldn't create a DM with user {0} where the verification should take place")] + RoomCreation(ruma::UserId), +} diff --git a/matrix_sdk/src/identities/users.rs b/matrix_sdk/src/identities/users.rs new file mode 100644 index 00000000..5cf55332 --- /dev/null +++ b/matrix_sdk/src/identities/users.rs @@ -0,0 +1,406 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{result::Result, sync::Arc}; + +use matrix_sdk_base::{ + crypto::{ + MasterPubkey, OwnUserIdentity as InnerOwnUserIdentity, UserIdentity as InnerUserIdentity, + }, + locks::RwLock, +}; +use ruma::{ + events::{ + key::verification::VerificationMethod, + room::message::{MessageEventContent, MessageType}, + AnyMessageEventContent, + }, + UserId, +}; + +use super::{ManualVerifyError, RequestVerificationError}; +use crate::{room::Joined, verification::VerificationRequest, Client}; + +/// A struct representing a E2EE capable identity of a user. +/// +/// The identity is backed by public [cross signing] keys that users upload. +/// +/// [cross signing]: https://spec.matrix.org/unstable/client-server-api/#cross-signing +#[derive(Debug, Clone)] +pub struct UserIdentity { + inner: UserIdentities, +} + +impl UserIdentity { + pub(crate) fn new_own(client: Client, identity: InnerOwnUserIdentity) -> Self { + let identity = OwnUserIdentity { inner: identity, client }; + + Self { inner: identity.into() } + } + + pub(crate) fn new(client: Client, identity: InnerUserIdentity, room: Option) -> Self { + let identity = OtherUserIdentity { + inner: identity, + client, + direct_message_room: RwLock::new(room).into(), + }; + + Self { inner: identity.into() } + } + + /// The ID of the user this E2EE identity belongs to. + pub fn user_id(&self) -> &UserId { + match &self.inner { + UserIdentities::Own(i) => i.inner.user_id(), + UserIdentities::Other(i) => i.inner.user_id(), + } + } + + /// Request an interacitve verification with this `UserIdentity`. + /// + /// Returns a [`VerificationRequest`] object that can be used to control the + /// verification flow. + /// + /// This will send out a `m.key.verification.request` event to all the E2EE + /// capable devices we have if we're requesting verification with our own + /// user identity or will send out the event to a DM we share with the user. + /// + /// If we don't share a DM with this user one will be created before the + /// event gets sent out. + /// + /// The default methods that are supported are `m.sas.v1` and + /// `m.qr_code.show.v1`, if this isn't desirable the + /// [`request_verification_with_methods()`] method can be used to override + /// this. + /// + /// # Examples + /// + /// ```no_run + /// # use std::convert::TryFrom; + /// # use matrix_sdk::{Client, ruma::UserId}; + /// # use url::Url; + /// # let alice = UserId::try_from("@alice:example.org").unwrap(); + /// # let homeserver = Url::parse("http://example.com").unwrap(); + /// # let client = Client::new(homeserver).unwrap(); + /// # futures::executor::block_on(async { + /// let user = client.get_user_identity(&alice).await?; + /// + /// if let Some(user) = user { + /// let verification = user.request_verification().await?; + /// } + /// + /// # anyhow::Result::<()>::Ok(()) }); + /// ``` + /// + /// [`request_verification_with_methods()`]: + /// #method.request_verification_with_methods + pub async fn request_verification( + &self, + ) -> Result { + match &self.inner { + UserIdentities::Own(i) => i.request_verification(None).await, + UserIdentities::Other(i) => i.request_verification(None).await, + } + } + + /// Request an interacitve verification with this `UserIdentity`. + /// + /// Returns a [`VerificationRequest`] object that can be used to control the + /// verification flow. + /// + /// This methods behaves the same way as [`request_verification()`], + /// but the advertised verification methods can be manually selected. + /// + /// # Arguments + /// + /// * `methods` - The verification methods that we want to support. Must be + /// non-empty. + /// + /// # Panics + /// + /// This method will panic if `methods` is empty. + /// + /// # Examples + /// + /// ```no_run + /// # use std::convert::TryFrom; + /// # use matrix_sdk::{ + /// # Client, + /// # ruma::{ + /// # UserId, + /// # events::key::verification::VerificationMethod, + /// # } + /// # }; + /// # use url::Url; + /// # use futures::executor::block_on; + /// # let alice = UserId::try_from("@alice:example.org").unwrap(); + /// # let homeserver = Url::parse("http://example.com").unwrap(); + /// # let client = Client::new(homeserver).unwrap(); + /// # block_on(async { + /// let user = client.get_user_identity(&alice).await?; + /// + /// // We don't want to support showing a QR code, we only support SAS + /// // verification + /// let methods = vec![VerificationMethod::SasV1]; + /// + /// if let Some(user) = user { + /// let verification = user.request_verification_with_methods(methods).await?; + /// } + /// # anyhow::Result::<()>::Ok(()) }); + /// ``` + /// + /// [`request_verification()`]: #method.request_verification + pub async fn request_verification_with_methods( + &self, + methods: Vec, + ) -> Result { + if methods.is_empty() { + panic!("The list of verification methods can't be non-empty"); + } + + match &self.inner { + UserIdentities::Own(i) => i.request_verification(Some(methods)).await, + UserIdentities::Other(i) => i.request_verification(Some(methods)).await, + } + } + + /// Manually verify this [`UserIdentity`]. + /// + /// This method will attempt to sign the user identity using our private + /// cross signing key. Verifying can fail if we don't have the private + /// part of our user-signing key. + /// + /// The state of our private cross signing keys can be inspected using the + /// [`Client::cross_signing_status()`] method. + /// + /// [`Client::cross_signing_status()`]: crate::Client::cross_signing_status + /// + /// # Examples + /// + /// ```no_run + /// # use std::convert::TryFrom; + /// # use matrix_sdk::{ + /// # Client, + /// # ruma::{ + /// # UserId, + /// # events::key::verification::VerificationMethod, + /// # } + /// # }; + /// # use url::Url; + /// # use futures::executor::block_on; + /// # let alice = UserId::try_from("@alice:example.org").unwrap(); + /// # let homeserver = Url::parse("http://example.com").unwrap(); + /// # let client = Client::new(homeserver).unwrap(); + /// # block_on(async { + /// let user = client.get_user_identity(&alice).await?; + /// + /// if let Some(user) = user { + /// user.verify().await?; + /// } + /// # anyhow::Result::<()>::Ok(()) }); + /// ``` + pub async fn verify(&self) -> Result<(), ManualVerifyError> { + match &self.inner { + UserIdentities::Own(i) => i.verify().await, + UserIdentities::Other(i) => i.verify().await, + } + } + + /// Is the user identity considered to be verified. + /// + /// A user identity is considered to be verified if it has been signed by + /// our user-signing key, if the identity belongs to another user, or if we + /// locally marked it as verified, if the user identity belongs to us. + /// + /// If the identity belongs to another user, our own user identity needs to + /// be verified as well for the identity to be considered to be verified. + /// + /// # Examples + /// + /// ```no_run + /// # use std::convert::TryFrom; + /// # use matrix_sdk::{ + /// # Client, + /// # ruma::{ + /// # UserId, + /// # events::key::verification::VerificationMethod, + /// # } + /// # }; + /// # use url::Url; + /// # use futures::executor::block_on; + /// # let alice = UserId::try_from("@alice:example.org").unwrap(); + /// # let homeserver = Url::parse("http://example.com").unwrap(); + /// # let client = Client::new(homeserver).unwrap(); + /// # block_on(async { + /// let user = client.get_user_identity(&alice).await?; + /// + /// if let Some(user) = user { + /// if user.verified() { + /// println!("User {} is verified", user.user_id().as_str()); + /// } else { + /// println!("User {} is not verified", user.user_id().as_str()); + /// } + /// } + /// # anyhow::Result::<()>::Ok(()) }); + /// ``` + pub fn verified(&self) -> bool { + match &self.inner { + UserIdentities::Own(i) => i.inner.is_verified(), + UserIdentities::Other(i) => i.inner.verified(), + } + } + + /// Get the public part of the master key of this user identity. + /// + /// # Examples + /// + /// ```no_run + /// # use std::convert::TryFrom; + /// # use matrix_sdk::{ + /// # Client, + /// # ruma::{ + /// # UserId, + /// # events::key::verification::VerificationMethod, + /// # } + /// # }; + /// # use url::Url; + /// # use futures::executor::block_on; + /// # let alice = UserId::try_from("@alice:example.org").unwrap(); + /// # let homeserver = Url::parse("http://example.com").unwrap(); + /// # let client = Client::new(homeserver).unwrap(); + /// # block_on(async { + /// let user = client.get_user_identity(&alice).await?; + /// + /// if let Some(user) = user { + /// // Let's verify the user after we confirm that the master key + /// // matches what we expect, for this we fetch the first public key we + /// // can find, there's currently only a single key allowed so this is + /// // fine. + /// if user.master_key().get_first_key() == Some("MyMasterKey") { + /// println!( + /// "Master keys match for user {}, marking the user as verified", + /// user.user_id().as_str(), + /// ); + /// user.verify().await?; + /// } else { + /// println!("Master keys don't match for user {}", user.user_id().as_str()); + /// } + /// } + /// # anyhow::Result::<()>::Ok(()) }); + /// ``` + pub fn master_key(&self) -> &MasterPubkey { + match &self.inner { + UserIdentities::Own(i) => i.inner.master_key(), + UserIdentities::Other(i) => i.inner.master_key(), + } + } +} + +#[derive(Debug, Clone)] +enum UserIdentities { + Own(OwnUserIdentity), + Other(OtherUserIdentity), +} + +impl From for UserIdentities { + fn from(i: OwnUserIdentity) -> Self { + Self::Own(i) + } +} + +impl From for UserIdentities { + fn from(i: OtherUserIdentity) -> Self { + Self::Other(i) + } +} + +#[derive(Debug, Clone)] +struct OwnUserIdentity { + pub(crate) inner: InnerOwnUserIdentity, + pub(crate) client: Client, +} + +#[derive(Debug, Clone)] +struct OtherUserIdentity { + pub(crate) inner: InnerUserIdentity, + pub(crate) client: Client, + pub(crate) direct_message_room: Arc>>, +} + +impl OwnUserIdentity { + async fn request_verification( + &self, + methods: Option>, + ) -> Result { + let (verification, request) = if let Some(methods) = methods { + self.inner + .request_verification_with_methods(methods) + .await + .map_err(crate::Error::from)? + } else { + self.inner.request_verification().await.map_err(crate::Error::from)? + }; + + self.client.send_verification_request(request).await?; + + Ok(VerificationRequest { inner: verification, client: self.client.clone() }) + } + + async fn verify(&self) -> Result<(), ManualVerifyError> { + let request = self.inner.verify().await?; + self.client.send(request, None).await?; + + Ok(()) + } +} + +impl OtherUserIdentity { + async fn request_verification( + &self, + methods: Option>, + ) -> Result { + let content = self.inner.verification_request_content(methods.clone()).await; + + let room = if let Some(room) = self.direct_message_room.read().await.as_ref() { + room.clone() + } else if let Some(room) = + self.client.create_dm_room(self.inner.user_id().to_owned()).await? + { + room + } else { + return Err(RequestVerificationError::RoomCreation(self.inner.user_id().to_owned())); + }; + + let response = room + .send( + AnyMessageEventContent::RoomMessage(MessageEventContent::new( + MessageType::VerificationRequest(content), + )), + None, + ) + .await?; + + let verification = + self.inner.request_verification(room.room_id(), &response.event_id, methods).await; + + Ok(VerificationRequest { inner: verification, client: self.client.clone() }) + } + + async fn verify(&self) -> Result<(), ManualVerifyError> { + let request = self.inner.verify().await?; + self.client.send(request, None).await?; + + Ok(()) + } +} diff --git a/matrix_sdk/src/lib.rs b/matrix_sdk/src/lib.rs index be0c099e..e7a5497f 100644 --- a/matrix_sdk/src/lib.rs +++ b/matrix_sdk/src/lib.rs @@ -39,23 +39,24 @@ //! The following crate feature flags are available: //! //! * `encryption`: Enables end-to-end encryption support in the library. -//! * `sled_cryptostore`: Enables a Sled based store for the encryption -//! keys. If this is disabled and `encryption` support is enabled the keys will -//! by default be stored only in memory and thus lost after the client is -//! destroyed. +//! * `sled_cryptostore`: Enables a Sled based store for the encryption keys. If +//! this is disabled and `encryption` support is enabled the keys will by +//! default be stored only in memory and thus lost after the client is +//! destroyed. //! * `markdown`: Support for sending markdown formatted messages. //! * `socks`: Enables SOCKS support in reqwest, the default HTTP client. //! * `sso_login`: Enables SSO login with a local http server. //! * `require_auth_for_profile_requests`: Whether to send the access token in -//! the authentication -//! header when calling endpoints that retrieve profile data. This matches the -//! synapse configuration `require_auth_for_profile_requests`. Enabled by -//! default. +//! the authentication header when calling endpoints that retrieve profile +//! data. This matches the synapse configuration +//! `require_auth_for_profile_requests`. Enabled by default. //! * `appservice`: Enables low-level appservice functionality. For an //! high-level API there's the `matrix-sdk-appservice` crate +//! * `anyhow`: Support for returning `anyhow::Result<()>` from event handlers. #![deny( missing_debug_implementations, + missing_docs, dead_code, missing_docs, trivial_casts, @@ -90,7 +91,7 @@ pub use ruma; mod client; mod error; -mod event_handler; +pub mod event_handler; mod http_client; /// High-level room API pub mod room; @@ -98,16 +99,14 @@ pub mod room; mod room_member; #[cfg(feature = "encryption")] -mod device; +#[cfg_attr(feature = "docs", doc(cfg(encryption)))] +pub mod identities; #[cfg(feature = "encryption")] +#[cfg_attr(feature = "docs", doc(cfg(encryption)))] pub mod verification; pub use client::{Client, ClientConfig, LoopCtrl, RequestConfig, SyncSettings}; -#[cfg(feature = "encryption")] -#[cfg_attr(feature = "docs", doc(cfg(encryption)))] -pub use device::Device; -pub use error::{Error, HttpError, Result}; -pub use event_handler::{CustomEvent, EventHandler}; +pub use error::{Error, HttpError, HttpResult, Result}; pub use http_client::HttpSend; pub use room_member::RoomMember; #[cfg(not(target_arch = "wasm32"))] diff --git a/matrix_sdk/src/room/common.rs b/matrix_sdk/src/room/common.rs index fc7cc8a1..091074e1 100644 --- a/matrix_sdk/src/room/common.rs +++ b/matrix_sdk/src/room/common.rs @@ -14,6 +14,7 @@ use ruma::{ }; use crate::{ + error::HttpResult, media::{MediaFormat, MediaRequest, MediaType}, room::RoomType, BaseRoom, Client, Result, RoomMember, @@ -143,7 +144,7 @@ impl Common { pub async fn messages( &self, request: impl Into>, - ) -> Result { + ) -> HttpResult { let request = request.into(); self.client.send(request, None).await } @@ -376,4 +377,25 @@ impl Common { .await .map_err(Into::into) } + + /// Check if all members of this room are verified and all their devices are + /// verified. + /// + /// Returns true if all devices in the room are verified, otherwise false. + #[cfg(feature = "encryption")] + #[cfg_attr(feature = "docs", doc(cfg(encryption)))] + pub async fn contains_only_verified_devices(&self) -> Result { + let user_ids = self.client.store().get_user_ids(self.room_id()).await?; + + for user_id in user_ids { + let devices = self.client.get_user_devices(&user_id).await?; + let any_unverified = devices.devices().any(|d| !d.verified()); + + if any_unverified { + return Ok(false); + } + } + + Ok(true) + } } diff --git a/matrix_sdk/src/room/joined.rs b/matrix_sdk/src/room/joined.rs index 5de22dc6..048ba7a3 100644 --- a/matrix_sdk/src/room/joined.rs +++ b/matrix_sdk/src/room/joined.rs @@ -46,7 +46,7 @@ use ruma::{ #[cfg(feature = "encryption")] use tracing::instrument; -use crate::{room::Common, BaseRoom, Client, Error, Result, RoomType}; +use crate::{error::HttpResult, room::Common, BaseRoom, Client, HttpError, Result, RoomType}; const TYPING_NOTICE_TIMEOUT: Duration = Duration::from_secs(4); const TYPING_NOTICE_RESEND_TIMEOUT: Duration = Duration::from_secs(3); @@ -562,7 +562,7 @@ impl Joined { &self, content: impl Into, state_key: &str, - ) -> Result { + ) -> HttpResult { let content = content.into(); let request = send_state_event::Request::new(self.inner.room_id(), state_key, &content); @@ -606,7 +606,7 @@ impl Joined { event_id: &EventId, reason: Option<&str>, txn_id: Option, - ) -> Result { + ) -> HttpResult { let txn_id = txn_id.unwrap_or_else(Uuid::new_v4).to_string(); let request = assign!(redact_event::Request::new(self.inner.room_id(), event_id, &txn_id), { @@ -642,8 +642,8 @@ impl Joined { /// room.set_tag("u.work", tag_info ); /// # }) /// ``` - pub async fn set_tag(&self, tag: &str, tag_info: TagInfo) -> Result { - let user_id = self.client.user_id().await.ok_or(Error::AuthenticationRequired)?; + pub async fn set_tag(&self, tag: &str, tag_info: TagInfo) -> HttpResult { + let user_id = self.client.user_id().await.ok_or(HttpError::AuthenticationRequired)?; let request = create_tag::Request::new(&user_id, self.inner.room_id(), tag, tag_info); self.client.send(request, None).await } @@ -654,8 +654,8 @@ impl Joined { /// /// # Arguments /// * `tag` - The tag to remove. - pub async fn remove_tag(&self, tag: &str) -> Result { - let user_id = self.client.user_id().await.ok_or(Error::AuthenticationRequired)?; + pub async fn remove_tag(&self, tag: &str) -> HttpResult { + let user_id = self.client.user_id().await.ok_or(HttpError::AuthenticationRequired)?; let request = delete_tag::Request::new(&user_id, self.inner.room_id(), tag); self.client.send(request, None).await } diff --git a/matrix_sdk/src/room/room.rs b/matrix_sdk/src/room/room.rs deleted file mode 100644 index 8b137891..00000000 --- a/matrix_sdk/src/room/room.rs +++ /dev/null @@ -1 +0,0 @@ - diff --git a/matrix_sdk/src/verification/mod.rs b/matrix_sdk/src/verification/mod.rs index ca7ae6c0..b59eb706 100644 --- a/matrix_sdk/src/verification/mod.rs +++ b/matrix_sdk/src/verification/mod.rs @@ -25,9 +25,10 @@ //! [VerificationRequest::is_ready()] method returns true, the verification can //! transition into one of the supported verification flows: //! -//! * [SasVerification] - Interactive verification using a short authentication +//! * [`SasVerification`] - Interactive verification using a short +//! authentication //! string. -//! * [QrVerification] - Interactive verification using QR codes. +//! * [`QrVerification`] - Interactive verification using QR codes. #[cfg(feature = "qrcode")] mod qrcode; diff --git a/matrix_sdk_appservice/Cargo.toml b/matrix_sdk_appservice/Cargo.toml index 7a3a98e3..8191bc4f 100644 --- a/matrix_sdk_appservice/Cargo.toml +++ b/matrix_sdk_appservice/Cargo.toml @@ -8,7 +8,17 @@ name = "matrix-sdk-appservice" version = "0.1.0" [features] -default = ["warp"] +default = ["warp", "native-tls"] + +encryption = ["matrix-sdk/encryption"] +sled_state_store = ["matrix-sdk/sled_state_store"] +sled_cryptostore = ["matrix-sdk/sled_cryptostore"] +markdown = ["matrix-sdk/markdown"] +native-tls = ["matrix-sdk/native-tls"] +rustls-tls = ["matrix-sdk/rustls-tls"] +socks = ["matrix-sdk/socks"] +sso_login = ["matrix-sdk/sso_login"] +require_auth_for_profile_requests = ["matrix-sdk/require_auth_for_profile_requests"] docs = ["warp"] @@ -26,7 +36,7 @@ tracing = "0.1" url = "2" warp = { git = "https://github.com/seanmonstar/warp.git", rev = "629405", optional = true, default-features = false } -matrix-sdk = { version = "0.3", path = "../matrix_sdk", default-features = false, features = ["appservice", "native-tls"] } +matrix-sdk = { version = "0.3", path = "../matrix_sdk", default-features = false, features = ["appservice"] } [dependencies.ruma] version = "0.3.0" diff --git a/matrix_sdk_appservice/examples/appservice_autojoin.rs b/matrix_sdk_appservice/examples/appservice_autojoin.rs index f792f242..a5b8b9ff 100644 --- a/matrix_sdk_appservice/examples/appservice_autojoin.rs +++ b/matrix_sdk_appservice/examples/appservice_autojoin.rs @@ -2,7 +2,6 @@ use std::{convert::TryFrom, env}; use matrix_sdk_appservice::{ matrix_sdk::{ - async_trait, room::Room, ruma::{ events::{ @@ -11,51 +10,27 @@ use matrix_sdk_appservice::{ }, UserId, }, - EventHandler, }, - AppService, AppServiceRegistration, + AppService, AppServiceRegistration, Result, }; -use tracing::{error, trace}; +use tracing::trace; -struct AppServiceEventHandler { +pub async fn handle_room_member( appservice: AppService, -} + room: Room, + event: SyncStateEvent, +) -> Result<()> { + if !appservice.user_id_is_in_namespace(&event.state_key)? { + trace!("not an appservice user: {}", event.state_key); + } else if let MembershipState::Invite = event.content.membership { + let user_id = UserId::try_from(event.state_key.as_str())?; + appservice.register_virtual_user(user_id.localpart()).await?; -impl AppServiceEventHandler { - pub fn new(appservice: AppService) -> Self { - Self { appservice } + let client = appservice.virtual_user_client(user_id.localpart()).await?; + client.join_room_by_id(room.room_id()).await?; } - pub async fn handle_room_member( - &self, - room: Room, - event: &SyncStateEvent, - ) -> Result<(), Box> { - if !self.appservice.user_id_is_in_namespace(&event.state_key)? { - trace!("not an appservice user: {}", event.state_key); - } else if let MembershipState::Invite = event.content.membership { - let user_id = UserId::try_from(event.state_key.clone())?; - - let appservice = self.appservice.clone(); - appservice.register_virtual_user(user_id.localpart()).await?; - - let client = appservice.virtual_user_client(user_id.localpart()).await?; - - client.join_room_by_id(room.room_id()).await?; - } - - Ok(()) - } -} - -#[async_trait] -impl EventHandler for AppServiceEventHandler { - async fn on_room_member(&self, room: Room, event: &SyncStateEvent) { - match self.handle_room_member(room, event).await { - Ok(_) => (), - Err(error) => error!("{:?}", error), - } - } + Ok(()) } #[tokio::main] @@ -68,7 +43,14 @@ pub async fn main() -> Result<(), Box> { let registration = AppServiceRegistration::try_from_yaml_file("./tests/registration.yaml")?; let mut appservice = AppService::new(homeserver_url, server_name, registration).await?; - appservice.set_event_handler(Box::new(AppServiceEventHandler::new(appservice.clone()))).await?; + appservice + .register_event_handler({ + let appservice = appservice.clone(); + move |event: SyncStateEvent, room: Room| { + handle_room_member(appservice.clone(), room, event) + } + }) + .await?; let (host, port) = appservice.registration().get_host_and_port()?; appservice.run(host, port).await?; diff --git a/matrix_sdk_appservice/src/error.rs b/matrix_sdk_appservice/src/error.rs index 1a65a66b..9c0fda4c 100644 --- a/matrix_sdk_appservice/src/error.rs +++ b/matrix_sdk_appservice/src/error.rs @@ -87,3 +87,9 @@ impl From for Error { Self::WarpRejection(format!("{:?}", rejection)) } } + +impl From for Error { + fn from(e: matrix_sdk::HttpError) -> Self { + matrix_sdk::Error::from(e).into() + } +} diff --git a/matrix_sdk_appservice/src/lib.rs b/matrix_sdk_appservice/src/lib.rs index 71de60ec..46269c14 100644 --- a/matrix_sdk_appservice/src/lib.rs +++ b/matrix_sdk_appservice/src/lib.rs @@ -34,14 +34,10 @@ //! ```no_run //! # async { //! # -//! # use matrix_sdk::{async_trait, EventHandler}; -//! # -//! # struct MyEventHandler; -//! # -//! # #[async_trait] -//! # impl EventHandler for MyEventHandler {} -//! # -//! use matrix_sdk_appservice::{AppService, AppServiceRegistration}; +//! use matrix_sdk_appservice::{ +//! ruma::events::{SyncStateEvent, room::member::MemberEventContent}, +//! AppService, AppServiceRegistration +//! }; //! //! let homeserver_url = "http://127.0.0.1:8008"; //! let server_name = "localhost"; @@ -59,7 +55,9 @@ //! ")?; //! //! let mut appservice = AppService::new(homeserver_url, server_name, registration).await?; -//! appservice.set_event_handler(Box::new(MyEventHandler)).await?; +//! appservice.register_event_handler(|_ev: SyncStateEvent| async { +//! // do stuff +//! }); //! //! let (host, port) = appservice.registration().get_host_and_port()?; //! appservice.run(host, port).await?; @@ -80,6 +78,7 @@ compile_error!("one webserver feature must be enabled. available ones: `warp`"); use std::{ convert::{TryFrom, TryInto}, fs::File, + future::Future, ops::Deref, path::PathBuf, sync::Arc, @@ -92,7 +91,10 @@ pub use matrix_sdk; #[doc(no_inline)] pub use matrix_sdk::ruma; use matrix_sdk::{ - bytes::Bytes, reqwest::Url, Client, ClientConfig, EventHandler, HttpError, Session, + bytes::Bytes, + event_handler::{EventHandler, EventHandlerResult, SyncEvent}, + reqwest::Url, + Client, ClientConfig, Session, }; use regex::Regex; use ruma::{ @@ -106,12 +108,13 @@ use ruma::{ }, assign, identifiers, DeviceId, ServerNameBox, UserId, }; +use serde::de::DeserializeOwned; use tracing::{info, warn}; mod error; mod webserver; -pub type Result = std::result::Result; +pub type Result = std::result::Result; pub type Host = String; pub type Port = u16; @@ -354,8 +357,8 @@ impl AppService { Ok(entry.value().clone()) } - /// Convenience wrapper around [`Client::set_event_handler()`] that attaches - /// the event handler to the [`MainUser`]'s [`Client`] + /// Convenience wrapper around [`Client::register_event_handler()`] that + /// attaches the event handler to the [`MainUser`]'s [`Client`] /// /// Note that the event handler in the [`AppService`] context only triggers /// [`join` room `timeline` events], so no state events or events from the @@ -370,10 +373,14 @@ impl AppService { /// /// [`join` room `timeline` events]: https://spec.matrix.org/unstable/client-server-api/#get_matrixclientr0sync /// [MSC2409]: https://github.com/matrix-org/matrix-doc/pull/2409 - pub async fn set_event_handler(&mut self, handler: Box) -> Result<()> { + pub async fn register_event_handler(&mut self, handler: H) -> Result<()> + where + Ev: SyncEvent + DeserializeOwned + Send + 'static, + H: EventHandler, + ::Output: EventHandlerResult, + { let client = self.get_cached_client(None)?; - - client.set_event_handler(handler).await; + client.register_event_handler(handler).await; Ok(()) } @@ -395,9 +402,9 @@ impl AppService { match client.register(request).await { Ok(_) => (), Err(error) => match error { - matrix_sdk::Error::Http(HttpError::UiaaError(FromHttpResponseError::Http( + matrix_sdk::HttpError::UiaaError(FromHttpResponseError::Http( ServerError::Known(UiaaResponse::MatrixError(ref matrix_error)), - ))) => { + )) => { match matrix_error.kind { ErrorKind::UserInUse => { // TODO: persist the fact that we registered that user diff --git a/matrix_sdk_appservice/tests/tests.rs b/matrix_sdk_appservice/tests/tests.rs index 8174daef..8551c65d 100644 --- a/matrix_sdk_appservice/tests/tests.rs +++ b/matrix_sdk_appservice/tests/tests.rs @@ -1,13 +1,14 @@ -use std::sync::{Arc, Mutex}; +use std::{ + future, + sync::{Arc, Mutex}, +}; use matrix_sdk::{ - async_trait, - room::Room, ruma::{ api::appservice::Registration, events::{room::member::MemberEventContent, SyncStateEvent}, }, - ClientConfig, EventHandler, RequestConfig, + ClientConfig, RequestConfig, }; use matrix_sdk_appservice::*; use matrix_sdk_test::{appservice::TransactionBuilder, async_test, EventsJson}; @@ -203,28 +204,17 @@ async fn test_no_access_token() -> Result<()> { async fn test_event_handler() -> Result<()> { let mut appservice = appservice(None).await?; - #[derive(Clone)] - struct Example { - pub on_state_member: Arc>, - } - - impl Example { - pub fn new() -> Self { - #[allow(clippy::mutex_atomic)] - Self { on_state_member: Arc::new(Mutex::new(false)) } - } - } - - #[async_trait] - impl EventHandler for Example { - async fn on_room_member(&self, _: Room, _: &SyncStateEvent) { - let on_state_member = self.on_state_member.clone(); - *on_state_member.lock().unwrap() = true; - } - } - - let example = Example::new(); - appservice.set_event_handler(Box::new(example.clone())).await?; + #[allow(clippy::mutex_atomic)] + let on_state_member = Arc::new(Mutex::new(false)); + appservice + .register_event_handler({ + let on_state_member = on_state_member.clone(); + move |_ev: SyncStateEvent| { + *on_state_member.lock().unwrap() = true; + future::ready(()) + } + }) + .await?; let uri = "/_matrix/app/v1/transactions/1?access_token=hs_token"; @@ -241,7 +231,7 @@ async fn test_event_handler() -> Result<()> { .await .unwrap(); - let on_room_member_called = *example.on_state_member.lock().unwrap(); + let on_room_member_called = *on_state_member.lock().unwrap(); assert!(on_room_member_called); Ok(()) diff --git a/matrix_sdk_base/src/client.rs b/matrix_sdk_base/src/client.rs index 399b5348..6d824657 100644 --- a/matrix_sdk_base/src/client.rs +++ b/matrix_sdk_base/src/client.rs @@ -50,16 +50,15 @@ use ruma::{ use ruma::{ api::client::r0::{self as api, push::get_notifications::Notification}, events::{ - room::member::{MemberEventContent, MembershipState}, - AnyGlobalAccountDataEvent, AnyRoomAccountDataEvent, AnyStrippedStateEvent, - AnySyncEphemeralRoomEvent, AnySyncRoomEvent, AnySyncStateEvent, EventContent, EventType, - StateEvent, + room::member::MembershipState, AnyGlobalAccountDataEvent, AnyRoomAccountDataEvent, + AnyStrippedStateEvent, AnySyncEphemeralRoomEvent, AnySyncRoomEvent, AnySyncStateEvent, + EventContent, EventType, }, push::{Action, PushConditionRoomCtx, Ruleset}, serde::Raw, MilliSecondsSinceUnixEpoch, RoomId, UInt, UserId, }; -use tracing::{info, warn}; +use tracing::{info, trace, warn}; use zeroize::Zeroizing; use crate::{ @@ -71,97 +70,6 @@ use crate::{ pub type Token = String; -/// A deserialization wrapper for extracting the prev_content field when -/// found in an `unsigned` field. -/// -/// Represents the outer `unsigned` field -#[derive(serde::Deserialize)] -pub struct AdditionalEventData { - unsigned: AdditionalUnsignedData, -} - -/// A deserialization wrapper for extracting the prev_content field when -/// found in an `unsigned` field. -/// -/// Represents the inner `prev_content` field -#[derive(serde::Deserialize)] -pub struct AdditionalUnsignedData { - pub prev_content: Option>, -} - -/// Transform an `AnySyncStateEvent` by hoisting `prev_content` field from -/// `unsigned` to the top level. -/// -/// Due to a [bug in synapse][synapse-bug], `prev_content` often ends up in -/// `unsigned` contrary to the C2S spec. Some more discussion can be found -/// [here][discussion]. Until this is fixed in synapse or handled in Ruma, we -/// use this to hoist up `prev_content` to the top level. -/// -/// [synapse-bug]: -/// [discussion]: -pub fn hoist_and_deserialize_state_event( - event: &Raw, -) -> StdResult { - let prev_content = event.deserialize_as::()?.unsigned.prev_content; - - let mut ev = event.deserialize()?; - - if let AnySyncStateEvent::RoomMember(ref mut member) = ev { - if member.prev_content.is_none() { - member.prev_content = prev_content.and_then(|e| e.deserialize().ok()); - } - } - - Ok(ev) -} - -fn hoist_member_event( - event: &Raw>, -) -> StdResult, serde_json::Error> { - let prev_content = event.deserialize_as::()?.unsigned.prev_content; - - let mut e = event.deserialize()?; - - if e.prev_content.is_none() { - e.prev_content = prev_content.and_then(|e| e.deserialize().ok()); - } - - Ok(e) -} - -/// Transform an `AnySyncRoomEvent` by hoisting `prev_content` field from -/// `unsigned` to the top level. -/// -/// Due to a [bug in synapse][synapse-bug], `prev_content` often ends up in -/// `unsigned` contrary to the C2S spec. Some more discussion can be found -/// [here][discussion]. Until this is fixed in synapse or handled in Ruma, we -/// use this to hoist up `prev_content` to the top level. -/// -/// [synapse-bug]: -/// [discussion]: -pub fn hoist_room_event_prev_content( - event: &Raw, -) -> StdResult { - let prev_content = event - .deserialize_as::() - .map(|more_unsigned| more_unsigned.unsigned) - .map(|additional| additional.prev_content)? - .and_then(|p| p.deserialize().ok()); - - let mut ev = event.deserialize()?; - - match &mut ev { - AnySyncRoomEvent::State(AnySyncStateEvent::RoomMember(ref mut member)) - if member.prev_content.is_none() => - { - member.prev_content = prev_content; - } - _ => (), - } - - Ok(ev) -} - /// A no IO Client implementation. /// /// This Client is a state machine that receives responses and events and @@ -445,7 +353,7 @@ impl BaseClient { #[allow(unused_mut)] let mut event: SyncRoomEvent = event.into(); - match hoist_room_event_prev_content(&event.event) { + match event.event.deserialize() { Ok(e) => { #[allow(clippy::single_match)] match &e { @@ -611,7 +519,7 @@ impl BaseClient { let room_id = room_info.room_id.clone(); for raw_event in events { - let event = match hoist_and_deserialize_state_event(raw_event) { + let event = match raw_event.deserialize() { Ok(e) => e, Err(e) => { warn!( @@ -687,15 +595,23 @@ impl BaseClient { let mut account_data = BTreeMap::new(); for raw_event in events { - let event = if let Ok(e) = raw_event.deserialize() { - e - } else { - continue; + let event = match raw_event.deserialize() { + Ok(e) => e, + Err(e) => { + warn!(error =? e, "Failed to deserialize a global account data event"); + continue; + } }; if let AnyGlobalAccountDataEvent::Direct(e) = &event { for (user_id, rooms) in e.content.iter() { for room_id in rooms { + trace!( + room_id = room_id.as_str(), + target = user_id.as_str(), + "Marking room as direct room" + ); + if let Some(room) = changes.room_infos.get_mut(room_id) { room.base_info.dm_target = Some(user_id.clone()); } else if let Some(room) = self.store.get_room(room_id) { @@ -916,6 +832,12 @@ impl BaseClient { new_rooms.invite.insert(room_id, new_info); } + // TODO remove this, we're processing account data events here again + // because we want to have the push rules in place before we process + // rooms and their events, but we want to create the rooms before we + // process the `m.direct` account data event. + self.handle_account_data(&account_data.events, &mut changes).await; + changes.presence = presence .events .iter() @@ -976,7 +898,7 @@ impl BaseClient { let members: Vec = response .chunk .iter() - .filter_map(|e| hoist_member_event(e).ok().and_then(|e| MemberEvent::try_from(e).ok())) + .filter_map(|e| e.deserialize().ok().and_then(|e| MemberEvent::try_from(e).ok())) .collect(); let mut ambiguity_cache = AmbiguityCache::new(self.store.clone()); diff --git a/matrix_sdk_base/src/lib.rs b/matrix_sdk_base/src/lib.rs index 051f2181..0b3c9991 100644 --- a/matrix_sdk_base/src/lib.rs +++ b/matrix_sdk_base/src/lib.rs @@ -15,7 +15,7 @@ //! This crate implements a [Matrix](https://matrix.org/) client library. //! -//! ## Crate Feature Flags +//! ## Crate Feature Flags //! //! The following crate feature flags are available: //! @@ -50,9 +50,7 @@ mod rooms; mod session; mod store; -pub use client::{ - hoist_and_deserialize_state_event, hoist_room_event_prev_content, BaseClient, BaseClientConfig, -}; +pub use client::{BaseClient, BaseClientConfig}; #[cfg(feature = "encryption")] #[cfg_attr(feature = "docs", doc(cfg(encryption)))] pub use matrix_sdk_crypto as crypto; diff --git a/matrix_sdk_base/src/store/memory_store.rs b/matrix_sdk_base/src/store/memory_store.rs index 311e6373..eb415898 100644 --- a/matrix_sdk_base/src/store/memory_store.rs +++ b/matrix_sdk_base/src/store/memory_store.rs @@ -66,30 +66,32 @@ pub struct MemoryStore { room_event_receipts: Arc>>>>, media: Arc>>>, + custom: Arc, Vec>>, } impl MemoryStore { #[allow(dead_code)] pub fn new() -> Self { Self { - sync_token: Arc::new(RwLock::new(None)), - filters: DashMap::new().into(), - account_data: DashMap::new().into(), - members: DashMap::new().into(), - profiles: DashMap::new().into(), - display_names: DashMap::new().into(), - joined_user_ids: DashMap::new().into(), - invited_user_ids: DashMap::new().into(), - room_info: DashMap::new().into(), - room_state: DashMap::new().into(), - room_account_data: DashMap::new().into(), - stripped_room_info: DashMap::new().into(), - stripped_room_state: DashMap::new().into(), - stripped_members: DashMap::new().into(), - presence: DashMap::new().into(), - room_user_receipts: DashMap::new().into(), - room_event_receipts: DashMap::new().into(), + sync_token: Default::default(), + filters: Default::default(), + account_data: Default::default(), + members: Default::default(), + profiles: Default::default(), + display_names: Default::default(), + joined_user_ids: Default::default(), + invited_user_ids: Default::default(), + room_info: Default::default(), + room_state: Default::default(), + room_account_data: Default::default(), + stripped_room_info: Default::default(), + stripped_room_state: Default::default(), + stripped_members: Default::default(), + presence: Default::default(), + room_user_receipts: Default::default(), + room_event_receipts: Default::default(), media: Arc::new(Mutex::new(LruCache::new(100))), + custom: DashMap::new().into(), } } @@ -407,6 +409,14 @@ impl MemoryStore { .unwrap_or_else(Vec::new)) } + async fn get_custom_value(&self, key: &[u8]) -> Result>> { + Ok(self.custom.get(key).map(|e| e.value().clone())) + } + + async fn set_custom_value(&self, key: &[u8], value: Vec) -> Result>> { + Ok(self.custom.insert(key.to_vec(), value)) + } + async fn add_media_content(&self, request: &MediaRequest, data: Vec) -> Result<()> { self.media.lock().await.put(request.unique_key(), data); @@ -563,6 +573,14 @@ impl StateStore for MemoryStore { self.get_event_room_receipt_events(room_id, receipt_type, event_id).await } + async fn get_custom_value(&self, key: &[u8]) -> Result>> { + self.get_custom_value(key).await + } + + async fn set_custom_value(&self, key: &[u8], value: Vec) -> Result>> { + self.set_custom_value(key, value).await + } + async fn add_media_content(&self, request: &MediaRequest, data: Vec) -> Result<()> { self.add_media_content(request, data).await } diff --git a/matrix_sdk_base/src/store/mod.rs b/matrix_sdk_base/src/store/mod.rs index 17c77c1a..0b4dc74a 100644 --- a/matrix_sdk_base/src/store/mod.rs +++ b/matrix_sdk_base/src/store/mod.rs @@ -263,6 +263,22 @@ pub trait StateStore: AsyncTraitDeps { event_id: &EventId, ) -> Result>; + /// Get arbitrary data from the custom store + /// + /// # Arguments + /// + /// * `key` - The key to fetch data for + async fn get_custom_value(&self, key: &[u8]) -> Result>>; + + /// Put arbitrary data into the custom store + /// + /// # Arguments + /// + /// * `key` - The key to insert data into + /// + /// * `value` - The value to insert + async fn set_custom_value(&self, key: &[u8], value: Vec) -> Result>>; + /// Add a media file's content in the media store. /// /// # Arguments @@ -310,15 +326,12 @@ pub struct Store { impl Store { fn new(inner: Box) -> Self { - let session = Arc::new(RwLock::new(None)); - let sync_token = Arc::new(RwLock::new(None)); - Self { inner: inner.into(), - session, - sync_token, - rooms: DashMap::new().into(), - stripped_rooms: DashMap::new().into(), + session: Default::default(), + sync_token: Default::default(), + rooms: Default::default(), + stripped_rooms: Default::default(), } } diff --git a/matrix_sdk_base/src/store/sled_store/mod.rs b/matrix_sdk_base/src/store/sled_store/mod.rs index ea7db5c2..242744da 100644 --- a/matrix_sdk_base/src/store/sled_store/mod.rs +++ b/matrix_sdk_base/src/store/sled_store/mod.rs @@ -189,6 +189,7 @@ pub struct SledStore { room_user_receipts: Tree, room_event_receipts: Tree, media: Tree, + custom: Tree, } impl std::fmt::Debug for SledStore { @@ -226,6 +227,8 @@ impl SledStore { let media = db.open_tree("media")?; + let custom = db.open_tree("custom")?; + Ok(Self { path, inner: db, @@ -247,6 +250,7 @@ impl SledStore { room_user_receipts, room_event_receipts, media, + custom, }) } @@ -762,6 +766,17 @@ impl SledStore { .map(|m| m.to_vec())) } + async fn get_custom_value(&self, key: &[u8]) -> Result>> { + Ok(self.custom.get(key)?.map(|v| v.to_vec())) + } + + async fn set_custom_value(&self, key: &[u8], value: Vec) -> Result>> { + let ret = self.custom.insert(key, value)?.map(|v| v.to_vec()); + self.inner.flush_async().await?; + + Ok(ret) + } + async fn remove_media_content(&self, request: &MediaRequest) -> Result<()> { self.media.remove( (request.media_type.unique_key().as_str(), request.format.unique_key().as_str()) @@ -899,6 +914,14 @@ impl StateStore for SledStore { self.get_event_room_receipt_events(room_id, receipt_type, event_id).await } + async fn get_custom_value(&self, key: &[u8]) -> Result>> { + self.get_custom_value(key).await + } + + async fn set_custom_value(&self, key: &[u8], value: Vec) -> Result>> { + self.set_custom_value(key, value).await + } + async fn add_media_content(&self, request: &MediaRequest, data: Vec) -> Result<()> { self.add_media_content(request, data).await } @@ -939,7 +962,7 @@ mod test { }; use serde_json::json; - use super::{SledStore, StateChanges}; + use super::{Result, SledStore, StateChanges}; use crate::{ deserialized_responses::MemberEvent, media::{MediaFormat, MediaRequest, MediaThumbnailSize, MediaType}, @@ -1155,4 +1178,19 @@ mod test { assert!(store.get_media_content(&request_file).await.unwrap().is_none()); assert!(store.get_media_content(&request_thumbnail).await.unwrap().is_none()); } + + #[async_test] + async fn test_custom_storage() -> Result<()> { + let key = "my_key"; + let value = &[0, 1, 2, 3]; + let store = SledStore::open()?; + + store.set_custom_value(key.as_bytes(), value.to_vec()).await?; + + let read = store.get_custom_value(key.as_bytes()).await?; + + assert_eq!(Some(value.as_ref()), read.as_deref()); + + Ok(()) + } } diff --git a/matrix_sdk_crypto/Cargo.toml b/matrix_sdk_crypto/Cargo.toml index d24e7af1..1002bfd2 100644 --- a/matrix_sdk_crypto/Cargo.toml +++ b/matrix_sdk_crypto/Cargo.toml @@ -58,7 +58,7 @@ indoc = "1.0.3" criterion = { version = "0.3.4", features = ["async", "async_tokio", "html_reports"] } [target.'cfg(target_os = "linux")'.dev-dependencies] -pprof = { version = "0.4.3", features = ["flamegraph"] } +pprof = { version = "0.5.0", features = ["flamegraph", "criterion"] } [[bench]] name = "crypto_bench" diff --git a/matrix_sdk_crypto/benches/README.md b/matrix_sdk_crypto/benches/README.md new file mode 100644 index 00000000..816d9029 --- /dev/null +++ b/matrix_sdk_crypto/benches/README.md @@ -0,0 +1,80 @@ +# Benchmarks for the rust-sdk crypto layer + +This directory contains various benchmarks that test critical functionality in +the crypto layer in the rust-sdk. + +We're using [Criterion] for the benchmarks, the full documentation for Criterion +can be found [here](https://bheisler.github.io/criterion.rs/book/criterion_rs.html). + +## Running the benchmarks + +The benchmark can be simply run by using the `bench` command of `cargo`: + +```bash +$ cargo bench +``` + +This will work from the workspace directory of the rust-sdk. + +If you want to pass options to the benchmark [you'll need to specify the name of +the benchmark](https://bheisler.github.io/criterion.rs/book/faq.html#cargo-bench-gives-unrecognized-option-errors-for-valid-command-line-options): + +```bash +$ cargo bench --bench crypto_bench -- # Your options go here +``` + +If you want to run only a specific benchmark, simply pass the name of the +benchmark as an argument: + +```bash +$ cargo bench --bench crypto_bench "Room key sharing/" +``` + +After the benchmarks are done, a HTML report can be found in `target/criterion/report/index.html`. + +### Using a baseline for the benchmark + +The benchmarks will by default compare the results to the previous run of the +benchmark. If you are improving the performance of a specific feature and run +the benchmark many times, it may be useful to store a baseline to compare +against instead. + +The `--save-baseline` switch can be used to create a baseline for the benchmark. + +```bash +$ cargo bench --bench crypto_bench -- --save-baseline libolm +``` + +After you make your changes you can use the baseline to compare the results like +so: + +```bash +$ cargo bench --bench crypto_bench -- --baseline libolm +``` + +### Generating Flame Graphs for the benchmarks + +The benchmarks support profiling and generating [Flame Graphs] while they run in +profiling mode using [pprof]. + +Profiling usually requieres root permissions, to avoid the need for root +permissions you can adjust the value of `perf_event_paranoid`, e.g. the most +permisive value is `-1`: + +```bash +$ echo -1 | sudo tee /proc/sys/kernel/perf_event_paranoid +``` + +To generate flame graphs feature simply enable the profiling mode using the +`--profile-time` command line flag: + +```bash +$ cargo bench --bench crypto_bench -- --profile-time=5 +``` + +After the benchmarks are done, a flame graph for each individual benchmark can be +found in `target/criterion//profile/flamegraph.svg`. + +[pprof]: https://docs.rs/pprof/0.5.0/pprof/index.html# +[Criterion]: https://docs.rs/criterion/0.3.5/criterion/ +[Flame Graphs]: https://www.brendangregg.com/flamegraphs.html diff --git a/matrix_sdk_crypto/benches/crypto_bench.rs b/matrix_sdk_crypto/benches/crypto_bench.rs index 280d01af..d919b9db 100644 --- a/matrix_sdk_crypto/benches/crypto_bench.rs +++ b/matrix_sdk_crypto/benches/crypto_bench.rs @@ -1,6 +1,3 @@ -#[cfg(target_os = "linux")] -mod perf; - use std::sync::Arc; use criterion::*; @@ -262,7 +259,10 @@ pub fn devices_missing_sessions_collecting(c: &mut Criterion) { fn criterion() -> Criterion { #[cfg(target_os = "linux")] - let criterion = Criterion::default().with_profiler(perf::FlamegraphProfiler::new(100)); + let criterion = Criterion::default().with_profiler(pprof::criterion::PProfProfiler::new( + 100, + pprof::criterion::Output::Flamegraph(None), + )); #[cfg(not(target_os = "linux"))] let criterion = Criterion::default(); diff --git a/matrix_sdk_crypto/benches/perf.rs b/matrix_sdk_crypto/benches/perf.rs deleted file mode 100644 index b2dfb878..00000000 --- a/matrix_sdk_crypto/benches/perf.rs +++ /dev/null @@ -1,76 +0,0 @@ -//! This is a simple Criterion Profiler implementation using pprof. -//! -//! It's mostly a direct copy from here: https://www.jibbow.com/posts/criterion-flamegraphs/ -use std::{fs::File, os::raw::c_int, path::Path}; - -use criterion::profiler::Profiler; -use pprof::ProfilerGuard; - -/// Small custom profiler that can be used with Criterion to create a flamegraph -/// for benchmarks. Also see [the Criterion documentation on -/// this][custom-profiler]. -/// -/// ## Example on how to enable the custom profiler: -/// -/// ``` -/// mod perf; -/// use perf::FlamegraphProfiler; -/// -/// fn fibonacci_profiled(criterion: &mut Criterion) { -/// // Use the criterion struct as normal here. -/// } -/// -/// fn custom() -> Criterion { -/// Criterion::default().with_profiler(FlamegraphProfiler::new()) -/// } -/// -/// criterion_group! { -/// name = benches; -/// config = custom(); -/// targets = fibonacci_profiled -/// } -/// ``` -/// -/// The neat thing about this is that it will sample _only_ the benchmark, and -/// not other stuff like the setup process. -/// -/// Further, it will only kick in if `--profile-time