diff --git a/matrix_sdk/Cargo.toml b/matrix_sdk/Cargo.toml index abee49cc..a581c7b0 100644 --- a/matrix_sdk/Cargo.toml +++ b/matrix_sdk/Cargo.toml @@ -26,7 +26,7 @@ rustls-tls = ["reqwest/rustls-tls"] socks = ["reqwest/socks"] sso_login = ["warp", "rand", "tokio-stream"] require_auth_for_profile_requests = [] -appservice = ["matrix-sdk-common/appservice", "serde_yaml"] +appservice = ["matrix-sdk-common/appservice"] docs = ["encryption", "sled_cryptostore", "sled_state_store", "sso_login"] @@ -41,7 +41,6 @@ url = "2.2.0" zeroize = "1.2.0" mime = "0.3.16" rand = { version = "0.8.2", optional = true } -serde_yaml = { version = "0.8", optional = true } bytes = "1.0.1" matrix-sdk-common = { version = "0.2.0", path = "../matrix_sdk_common" } diff --git a/matrix_sdk/src/client.rs b/matrix_sdk/src/client.rs index b61aed6b..646637cd 100644 --- a/matrix_sdk/src/client.rs +++ b/matrix_sdk/src/client.rs @@ -14,7 +14,11 @@ // limitations under the License. #[cfg(feature = "encryption")] -use std::{collections::BTreeMap, io::Write, path::PathBuf}; +use std::{ + collections::BTreeMap, + io::{Cursor, Write}, + path::PathBuf, +}; #[cfg(feature = "sso_login")] use std::{ collections::HashMap, @@ -38,10 +42,13 @@ use http::Response; #[cfg(feature = "encryption")] use matrix_sdk_base::crypto::{ decrypt_key_export, encrypt_key_export, olm::InboundGroupSession, store::CryptoStoreError, - OutgoingRequests, RoomMessageRequest, ToDeviceRequest, + AttachmentDecryptor, OutgoingRequests, RoomMessageRequest, ToDeviceRequest, }; use matrix_sdk_base::{ - deserialized_responses::SyncResponse, events::AnyMessageEventContent, identifiers::MxcUri, + deserialized_responses::SyncResponse, + events::AnyMessageEventContent, + identifiers::MxcUri, + media::{MediaEventContent, MediaFormat, MediaRequest, MediaThumbnailSize, MediaType}, BaseClient, BaseClientConfig, SendAccessToken, Session, Store, }; use mime::{self, Mime}; @@ -83,19 +90,22 @@ use matrix_sdk_common::api::r0::{ }, }; use matrix_sdk_common::{ - api::r0::{ - account::register, - device::{delete_devices, get_devices}, - directory::{get_public_rooms, get_public_rooms_filtered}, - filter::{create_filter::Request as FilterUploadRequest, FilterDefinition}, - media::{create_content, get_content, get_content_thumbnail}, - membership::{join_room_by_id, join_room_by_id_or_alias}, - message::send_message_event, - profile::{get_avatar_url, get_display_name, set_avatar_url, set_display_name}, - room::create_room, - session::{get_login_types, login, sso_login}, - sync::sync_events, - uiaa::AuthData, + api::{ + r0::{ + account::register, + device::{delete_devices, get_devices}, + directory::{get_public_rooms, get_public_rooms_filtered}, + filter::{create_filter::Request as FilterUploadRequest, FilterDefinition}, + media::{create_content, get_content, get_content_thumbnail}, + membership::{join_room_by_id, join_room_by_id_or_alias}, + message::send_message_event, + profile::{get_avatar_url, get_display_name, set_avatar_url, set_display_name}, + room::create_room, + session::{get_login_types, login, sso_login}, + sync::sync_events, + uiaa::AuthData, + }, + unversioned::{discover_homeserver, get_supported_versions}, }, assign, identifiers::{DeviceIdBox, RoomId, RoomIdOrAliasId, ServerName, UserId}, @@ -139,7 +149,7 @@ const SSO_SERVER_BIND_TRIES: u8 = 10; #[derive(Clone)] pub struct Client { /// The URL of the homeserver to connect to. - homeserver: Arc, + homeserver: Arc>, /// The underlying HTTP client. http_client: HttpClient, /// User session data. @@ -161,7 +171,7 @@ pub struct Client { #[cfg(not(tarpaulin_include))] impl Debug for Client { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> StdResult<(), fmt::Error> { - write!(fmt, "Client {{ homeserver: {} }}", self.homeserver) + write!(fmt, "Client") } } @@ -291,6 +301,11 @@ impl ClientConfig { self } + /// Get the [`RequestConfig`] + pub fn get_request_config(&self) -> &RequestConfig { + &self.request_config + } + /// Specify a client to handle sending requests and receiving responses. /// /// Any type that implements the `HttpSend` trait can be used to @@ -499,7 +514,7 @@ impl Client { /// /// * `config` - Configuration for the client. pub fn new_with_config(homeserver_url: Url, config: ClientConfig) -> Result { - let homeserver = Arc::new(homeserver_url); + let homeserver = Arc::new(RwLock::new(homeserver_url)); let client = if let Some(client) = config.client { client @@ -510,12 +525,8 @@ impl Client { let base_client = BaseClient::new_with_config(config.base_config)?; let session = base_client.session().clone(); - let http_client = HttpClient { - homeserver: homeserver.clone(), - inner: client, - session, - request_config: config.request_config, - }; + let http_client = + HttpClient::new(client, homeserver.clone(), session, config.request_config); Ok(Self { homeserver, @@ -531,6 +542,89 @@ impl Client { }) } + /// Creates a new client for making HTTP requests to the homeserver of the + /// given user. Follows homeserver discovery directions described + /// [here](https://spec.matrix.org/unstable/client-server-api/#well-known-uri). + /// + /// # Arguments + /// + /// * `user_id` - The id of the user whose homeserver the client should + /// connect to. + /// + /// # Example + /// ```no_run + /// # use std::convert::TryFrom; + /// # use matrix_sdk::{Client, identifiers::UserId}; + /// # use futures::executor::block_on; + /// let alice = UserId::try_from("@alice:example.org").unwrap(); + /// # block_on(async { + /// let client = Client::new_from_user_id(alice.clone()).await.unwrap(); + /// client.login(alice.localpart(), "password", None, None).await.unwrap(); + /// # }); + /// ``` + pub async fn new_from_user_id(user_id: UserId) -> Result { + let config = ClientConfig::new(); + Client::new_from_user_id_with_config(user_id, config).await + } + + /// Creates a new client for making HTTP requests to the homeserver of the + /// given user and configuration. Follows homeserver discovery directions + /// described [here](https://spec.matrix.org/unstable/client-server-api/#well-known-uri). + /// + /// # Arguments + /// + /// * `user_id` - The id of the user whose homeserver the client should + /// connect to. + /// + /// * `config` - Configuration for the client. + pub async fn new_from_user_id_with_config( + user_id: UserId, + config: ClientConfig, + ) -> Result { + let homeserver = Client::homeserver_from_user_id(user_id)?; + let mut client = Client::new_with_config(homeserver, config)?; + + let well_known = client.discover_homeserver().await?; + let well_known = Url::parse(well_known.homeserver.base_url.as_ref())?; + client.set_homeserver(well_known).await; + client.get_supported_versions().await?; + Ok(client) + } + + fn homeserver_from_user_id(user_id: UserId) -> Result { + let homeserver = format!("https://{}", user_id.server_name()); + #[allow(unused_mut)] + let mut result = Url::parse(homeserver.as_str())?; + // Mockito only knows how to test http endpoints: + // https://github.com/lipanski/mockito/issues/127 + #[cfg(test)] + let _ = result.set_scheme("http"); + Ok(result) + } + + async fn discover_homeserver(&self) -> Result { + self.send(discover_homeserver::Request::new(), Some(RequestConfig::new().disable_retry())) + .await + } + + /// Change the homeserver URL used by this client. + /// + /// # Arguments + /// + /// * `homeserver_url` - The new URL to use. + pub async fn set_homeserver(&mut self, homeserver_url: Url) { + let mut homeserver = self.homeserver.write().await; + *homeserver = homeserver_url; + } + + async fn get_supported_versions(&self) -> Result { + self.send( + get_supported_versions::Request::new(), + Some(RequestConfig::new().disable_retry()), + ) + .await + } + /// Process a [transaction] received from the homeserver /// /// # Arguments @@ -563,8 +657,8 @@ impl Client { } /// The Homeserver of the client. - pub fn homeserver(&self) -> &Url { - &self.homeserver + pub async fn homeserver(&self) -> Url { + self.homeserver.read().await.clone() } /// Get the user id of the current owner of the client. @@ -863,8 +957,8 @@ impl Client { /// successful SSO login. /// /// [`login_with_token`]: #method.login_with_token - pub fn get_sso_login_url(&self, redirect_url: &str) -> Result { - let homeserver = self.homeserver(); + pub async fn get_sso_login_url(&self, redirect_url: &str) -> Result { + let homeserver = self.homeserver().await; let request = sso_login::Request::new(redirect_url) .try_into_http_request::>(homeserver.as_str(), SendAccessToken::None); match request { @@ -925,7 +1019,7 @@ impl Client { device_id: Option<&str>, initial_device_display_name: Option<&str>, ) -> Result { - info!("Logging in to {} as {:?}", self.homeserver, user); + info!("Logging in to {} as {:?}", self.homeserver().await, user); let request = assign!( login::Request::new( @@ -1034,7 +1128,7 @@ impl Client { where C: Future>, { - info!("Logging in to {}", self.homeserver); + info!("Logging in to {}", self.homeserver().await); let (signal_tx, signal_rx) = oneshot::channel(); let (data_tx, data_rx) = oneshot::channel(); let data_tx_mutex = Arc::new(std::sync::Mutex::new(Some(data_tx))); @@ -1106,7 +1200,7 @@ impl Client { tokio::spawn(server); - let sso_url = self.get_sso_login_url(redirect_url.as_str()).unwrap(); + let sso_url = self.get_sso_login_url(redirect_url.as_str()).await.unwrap(); match use_sso_login_url(sso_url).await { Ok(t) => t, @@ -1190,7 +1284,7 @@ impl Client { device_id: Option<&str>, initial_device_display_name: Option<&str>, ) -> Result { - info!("Logging in to {}", self.homeserver); + info!("Logging in to {}", self.homeserver().await); let request = assign!( login::Request::new( @@ -1261,7 +1355,7 @@ impl Client { &self, registration: impl Into>, ) -> Result { - info!("Registering to {}", self.homeserver); + info!("Registering to {}", self.homeserver().await); let request = registration.into(); self.send(request, None).await @@ -1915,7 +2009,7 @@ impl Client { #[cfg(feature = "encryption")] { // This is needed because sometimes we need to automatically - // claim some one-time keys to unwedge an exisitng Olm session. + // claim some one-time keys to unwedge an existing Olm session. if let Err(e) = self.claim_one_time_keys([].iter()).await { warn!("Error while claiming one-time keys {:?}", e); } @@ -2382,13 +2476,227 @@ impl Client { Ok(olm.import_keys(import, |_, _| {}).await?) } + + /// Get a media file's content. + /// + /// If the content is encrypted and encryption is enabled, the content will + /// be decrypted. + /// + /// # Arguments + /// + /// * `request` - The `MediaRequest` of the content. + /// + /// * `use_cache` - If we should use the media cache for this request. + pub async fn get_media_content( + &self, + request: &MediaRequest, + use_cache: bool, + ) -> Result> { + let content = if use_cache { + self.base_client.store().get_media_content(request).await? + } else { + None + }; + + if let Some(content) = content { + Ok(content) + } else { + let content: Vec = match &request.media_type { + MediaType::Encrypted(file) => { + let content: Vec = + self.send(get_content::Request::from_url(&file.url)?, None).await?.file; + + #[cfg(feature = "encryption")] + let content = { + let mut cursor = Cursor::new(content); + let mut reader = + AttachmentDecryptor::new(&mut cursor, file.as_ref().clone().into())?; + + let mut decrypted = Vec::new(); + reader.read_to_end(&mut decrypted)?; + + decrypted + }; + + content + } + MediaType::Uri(uri) => { + if let MediaFormat::Thumbnail(size) = &request.format { + self.send( + get_content_thumbnail::Request::from_url( + &uri, + size.width, + size.height, + )?, + None, + ) + .await? + .file + } else { + self.send(get_content::Request::from_url(&uri)?, None).await?.file + } + } + }; + + if use_cache { + self.base_client.store().add_media_content(request, content.clone()).await?; + } + + Ok(content) + } + } + + /// Remove a media file's content from the store. + /// + /// # Arguments + /// + /// * `request` - The `MediaRequest` of the content. + pub async fn remove_media_content(&self, request: &MediaRequest) -> Result<()> { + Ok(self.base_client.store().remove_media_content(request).await?) + } + + /// Delete all the media content corresponding to the given + /// uri from the store. + /// + /// # Arguments + /// + /// * `uri` - The `MxcUri` of the files. + pub async fn remove_media_content_for_uri(&self, uri: &MxcUri) -> Result<()> { + Ok(self.base_client.store().remove_media_content_for_uri(&uri).await?) + } + + /// Get the file of the given media event content. + /// + /// If the content is encrypted and encryption is enabled, the content will + /// be decrypted. + /// + /// Returns `Ok(None)` if the event content has no file. + /// + /// This is a convenience method that calls the + /// [`get_media_content`](#method.get_media_content) method. + /// + /// # Arguments + /// + /// * `event_content` - The media event content. + /// + /// * `use_cache` - If we should use the media cache for this file. + pub async fn get_file( + &self, + event_content: impl MediaEventContent, + use_cache: bool, + ) -> Result>> { + if let Some(media_type) = event_content.file() { + Ok(Some( + self.get_media_content( + &MediaRequest { media_type, format: MediaFormat::File }, + use_cache, + ) + .await?, + )) + } else { + Ok(None) + } + } + + /// Remove the file of the given media event content from the cache. + /// + /// This is a convenience method that calls the + /// [`remove_media_content`](#method.remove_media_content) method. + /// + /// # Arguments + /// + /// * `event_content` - The media event content. + pub async fn remove_file(&self, event_content: impl MediaEventContent) -> Result<()> { + if let Some(media_type) = event_content.file() { + self.remove_media_content(&MediaRequest { media_type, format: MediaFormat::File }) + .await? + } + + Ok(()) + } + + /// Get a thumbnail of the given media event content. + /// + /// If the content is encrypted and encryption is enabled, the content will + /// be decrypted. + /// + /// Returns `Ok(None)` if the event content has no thumbnail. + /// + /// This is a convenience method that calls the + /// [`get_media_content`](#method.get_media_content) method. + /// + /// # Arguments + /// + /// * `event_content` - The media event content. + /// + /// * `size` - The _desired_ size of the thumbnail. The actual thumbnail may + /// not match the size specified. + /// + /// * `use_cache` - If we should use the media cache for this thumbnail. + pub async fn get_thumbnail( + &self, + event_content: impl MediaEventContent, + size: MediaThumbnailSize, + use_cache: bool, + ) -> Result>> { + if let Some(media_type) = event_content.thumbnail() { + Ok(Some( + self.get_media_content( + &MediaRequest { media_type, format: MediaFormat::Thumbnail(size) }, + use_cache, + ) + .await?, + )) + } else { + Ok(None) + } + } + + /// Remove the thumbnail of the given media event content from the cache. + /// + /// This is a convenience method that calls the + /// [`remove_media_content`](#method.remove_media_content) method. + /// + /// # Arguments + /// + /// * `event_content` - The media event content. + /// + /// * `size` - The _desired_ size of the thumbnail. Must match the size + /// requested with [`get_thumbnail`](#method.get_thumbnail). + pub async fn remove_thumbnail( + &self, + event_content: impl MediaEventContent, + size: MediaThumbnailSize, + ) -> Result<()> { + if let Some(media_type) = event_content.file() { + self.remove_media_content(&MediaRequest { + media_type, + format: MediaFormat::Thumbnail(size), + }) + .await? + } + + Ok(()) + } } #[cfg(test)] mod test { - use std::{collections::BTreeMap, convert::TryInto, io::Cursor, str::FromStr, time::Duration}; + use std::{ + collections::BTreeMap, + convert::{TryFrom, TryInto}, + io::Cursor, + str::FromStr, + time::Duration, + }; - use matrix_sdk_base::identifiers::mxc_uri; + use matrix_sdk_base::{ + api::r0::media::get_content_thumbnail::Method, + events::room::{message::ImageMessageEventContent, ImageInfo}, + identifiers::mxc_uri, + media::{MediaFormat, MediaRequest, MediaThumbnailSize, MediaType}, + uint, + }; use matrix_sdk_common::{ api::r0::{ account::register::Request as RegistrationRequest, @@ -2398,7 +2706,7 @@ mod test { assign, directory::Filter, events::{room::message::MessageEventContent, AnyMessageEventContent}, - identifiers::{event_id, room_id, user_id}, + identifiers::{event_id, room_id, user_id, UserId}, thirdparty, }; use matrix_sdk_test::{test_json, EventBuilder, EventsJson}; @@ -2424,6 +2732,62 @@ mod test { client } + #[tokio::test] + async fn set_homeserver() { + let homeserver = Url::from_str("http://example.com/").unwrap(); + + let mut client = Client::new(homeserver).unwrap(); + + let homeserver = Url::from_str(&mockito::server_url()).unwrap(); + + client.set_homeserver(homeserver.clone()).await; + + assert_eq!(client.homeserver().await, homeserver); + } + + #[tokio::test] + async fn successful_discovery() { + let server_url = mockito::server_url(); + let domain = server_url.strip_prefix("http://").unwrap(); + let alice = UserId::try_from("@alice:".to_string() + domain).unwrap(); + + let _m_well_known = mock("GET", "/.well-known/matrix/client") + .with_status(200) + .with_body( + test_json::WELL_KNOWN.to_string().replace("HOMESERVER_URL", server_url.as_ref()), + ) + .create(); + + let _m_versions = mock("GET", "/_matrix/client/versions") + .with_status(200) + .with_body(test_json::VERSIONS.to_string()) + .create(); + let client = Client::new_from_user_id(alice).await.unwrap(); + + assert_eq!(client.homeserver().await, Url::parse(server_url.as_ref()).unwrap()); + } + + #[tokio::test] + async fn discovery_broken_server() { + let server_url = mockito::server_url(); + let domain = server_url.strip_prefix("http://").unwrap(); + let alice = UserId::try_from("@alice:".to_string() + domain).unwrap(); + + let _m = mock("GET", "/.well-known/matrix/client") + .with_status(200) + .with_body( + test_json::WELL_KNOWN.to_string().replace("HOMESERVER_URL", server_url.as_ref()), + ) + .create(); + + if Client::new_from_user_id(alice).await.is_ok() { + panic!( + "Creating a client from a user ID should fail when the \ + .well-known server returns no version information." + ); + } + } + #[tokio::test] async fn login() { let homeserver = Url::from_str(&mockito::server_url()).unwrap(); @@ -2513,7 +2877,7 @@ mod test { .any(|flow| matches!(flow, LoginType::Sso(_))); assert!(can_sso); - let sso_url = client.get_sso_login_url("http://127.0.0.1:3030"); + let sso_url = client.get_sso_login_url("http://127.0.0.1:3030").await; assert!(sso_url.is_ok()); let _m = mock("POST", "/_matrix/client/r0/login") @@ -2625,7 +2989,7 @@ mod test { client.base_client.receive_sync_response(response).await.unwrap(); let room_id = room_id!("!SVkFJHzfwvuaIEawgC:localhost"); - assert_eq!(client.homeserver(), &Url::parse(&mockito::server_url()).unwrap()); + assert_eq!(client.homeserver().await, Url::parse(&mockito::server_url()).unwrap()); let room = client.get_joined_room(&room_id); assert!(room.is_some()); @@ -3279,6 +3643,7 @@ mod test { .with_status(200) .match_header("authorization", "Bearer 1234") .with_body(test_json::SYNC.to_string()) + .expect_at_least(1) .create(); let sync_settings = SyncSettings::new().timeout(Duration::from_millis(3000)); @@ -3288,6 +3653,19 @@ mod test { let room = client.get_joined_room(&room_id!("!SVkFJHzfwvuaIEawgC:localhost")).unwrap(); assert_eq!("tutorial".to_string(), room.display_name().await.unwrap()); + + let _m = mock("GET", Matcher::Regex(r"^/_matrix/client/r0/sync\?.*$".to_string())) + .with_status(200) + .match_header("authorization", "Bearer 1234") + .with_body(test_json::INVITE_SYNC.to_string()) + .expect_at_least(1) + .create(); + + let _response = client.sync_once(SyncSettings::new()).await.unwrap(); + + let invited_room = client.get_invited_room(&room_id!("!696r7674:example.com")).unwrap(); + + assert_eq!("My Room Name".to_string(), invited_room.display_name().await.unwrap()); } #[tokio::test] @@ -3395,4 +3773,74 @@ mod test { panic!("this request should return an `Err` variant") } } + + #[tokio::test] + async fn get_media_content() { + let client = logged_in_client().await; + + let request = MediaRequest { + media_type: MediaType::Uri(mxc_uri!("mxc://localhost/textfile")), + format: MediaFormat::File, + }; + + let m = mock( + "GET", + Matcher::Regex(r"^/_matrix/media/r0/download/localhost/textfile\?.*$".to_string()), + ) + .with_status(200) + .with_body("Some very interesting text.") + .expect(2) + .create(); + + assert!(client.get_media_content(&request, true).await.is_ok()); + assert!(client.get_media_content(&request, true).await.is_ok()); + assert!(client.get_media_content(&request, false).await.is_ok()); + m.assert(); + } + + #[tokio::test] + async fn get_media_file() { + let client = logged_in_client().await; + + let event_content = ImageMessageEventContent::plain( + "filename.jpg".into(), + mxc_uri!("mxc://example.org/image"), + Some(Box::new(assign!(ImageInfo::new(), { + height: Some(uint!(398)), + width: Some(uint!(394)), + mimetype: Some("image/jpeg".into()), + size: Some(uint!(31037)), + }))), + ); + + let m = mock( + "GET", + Matcher::Regex(r"^/_matrix/media/r0/download/example%2Eorg/image\?.*$".to_string()), + ) + .with_status(200) + .with_body("binaryjpegdata") + .create(); + + assert!(client.get_file(event_content.clone(), true).await.is_ok()); + assert!(client.get_file(event_content.clone(), true).await.is_ok()); + m.assert(); + + let m = mock( + "GET", + Matcher::Regex(r"^/_matrix/media/r0/thumbnail/example%2Eorg/image\?.*$".to_string()), + ) + .with_status(200) + .with_body("smallerbinaryjpegdata") + .create(); + + assert!(client + .get_thumbnail( + event_content, + MediaThumbnailSize { method: Method::Scale, width: uint!(100), height: uint!(100) }, + true + ) + .await + .is_ok()); + m.assert(); + } } diff --git a/matrix_sdk/src/error.rs b/matrix_sdk/src/error.rs index fc12b91a..4fa4ff0f 100644 --- a/matrix_sdk/src/error.rs +++ b/matrix_sdk/src/error.rs @@ -18,7 +18,7 @@ use std::io::Error as IoError; use http::StatusCode; #[cfg(feature = "encryption")] -use matrix_sdk_base::crypto::store::CryptoStoreError; +use matrix_sdk_base::crypto::{store::CryptoStoreError, DecryptorError}; use matrix_sdk_base::{Error as MatrixError, StoreError}; use matrix_sdk_common::{ api::{ @@ -31,6 +31,7 @@ use matrix_sdk_common::{ use reqwest::Error as ReqwestError; use serde_json::Error as JsonError; use thiserror::Error; +use url::ParseError as UrlParseError; /// Result type of the rust-sdk. pub type Result = std::result::Result; @@ -121,13 +122,22 @@ pub enum Error { #[error(transparent)] CryptoStoreError(#[from] CryptoStoreError), - /// An error occured in the state store. + /// An error occurred during decryption. + #[cfg(feature = "encryption")] + #[error(transparent)] + DecryptorError(#[from] DecryptorError), + + /// An error occurred in the state store. #[error(transparent)] StateStore(#[from] StoreError), /// An error encountered when trying to parse an identifier. #[error(transparent)] Identifier(#[from] IdentifierError), + + /// An error encountered when trying to parse a url. + #[error(transparent)] + Url(#[from] UrlParseError), } impl Error { diff --git a/matrix_sdk/src/event_handler/mod.rs b/matrix_sdk/src/event_handler/mod.rs index 3290ef68..53c46784 100644 --- a/matrix_sdk/src/event_handler/mod.rs +++ b/matrix_sdk/src/event_handler/mod.rs @@ -379,7 +379,7 @@ pub trait EventHandler: Send + Sync { async fn on_room_redaction(&self, _: Room, _: &SyncRedactionEvent) {} /// Fires when `Client` receives a `RoomEvent::RoomPowerLevels` event. async fn on_room_power_levels(&self, _: Room, _: &SyncStateEvent) {} - /// Fires when `Client` receives a `RoomEvent::Tombstone` event. + /// Fires when `Client` receives a `RoomEvent::RoomJoinRules` event. async fn on_room_join_rules(&self, _: Room, _: &SyncStateEvent) {} /// Fires when `Client` receives a `RoomEvent::Tombstone` event. async fn on_room_tombstone(&self, _: Room, _: &SyncStateEvent) {} diff --git a/matrix_sdk/src/http_client.rs b/matrix_sdk/src/http_client.rs index aced4169..63c83e60 100644 --- a/matrix_sdk/src/http_client.rs +++ b/matrix_sdk/src/http_client.rs @@ -97,7 +97,7 @@ pub trait HttpSend: AsyncTraitDeps { #[derive(Clone, Debug)] pub(crate) struct HttpClient { pub(crate) inner: Arc, - pub(crate) homeserver: Arc, + pub(crate) homeserver: Arc>, pub(crate) session: Arc>>, pub(crate) request_config: RequestConfig, } @@ -106,6 +106,15 @@ pub(crate) struct HttpClient { use crate::OutgoingRequestAppserviceExt; impl HttpClient { + pub(crate) fn new( + inner: Arc, + homeserver: Arc>, + session: Arc>>, + request_config: RequestConfig, + ) -> Self { + HttpClient { inner, homeserver, session, request_config } + } + async fn send_request( &self, request: Request, @@ -124,7 +133,7 @@ impl HttpClient { let request = if !self.request_config.assert_identity { self.try_into_http_request(request, session, config).await? } else { - self.try_into_http_request_with_identy_assertion(request, session, config).await? + self.try_into_http_request_with_identity_assertion(request, session, config).await? }; self.inner.send_request(request, config).await @@ -161,14 +170,17 @@ impl HttpClient { }; let http_request = request - .try_into_http_request::(&self.homeserver.to_string(), access_token)? + .try_into_http_request::( + &self.homeserver.read().await.to_string(), + access_token, + )? .map(|body| body.freeze()); Ok(http_request) } #[cfg(feature = "appservice")] - async fn try_into_http_request_with_identy_assertion( + async fn try_into_http_request_with_identity_assertion( &self, request: Request, session: Arc>>, @@ -189,7 +201,7 @@ impl HttpClient { let http_request = request .try_into_http_request_with_user_id::( - &self.homeserver.to_string(), + &self.homeserver.read().await.to_string(), access_token, user_id, )? diff --git a/matrix_sdk/src/lib.rs b/matrix_sdk/src/lib.rs index 1aa8a377..cb3892e5 100644 --- a/matrix_sdk/src/lib.rs +++ b/matrix_sdk/src/lib.rs @@ -52,8 +52,7 @@ //! synapse configuration `require_auth_for_profile_requests`. Enabled by //! default. //! * `appservice`: Enables low-level appservice functionality. For an -//! high-level API there's the -//! `matrix-sdk-appservice` crate +//! high-level API there's the `matrix-sdk-appservice` crate #![deny( missing_debug_implementations, @@ -81,7 +80,7 @@ pub use bytes::{Bytes, BytesMut}; #[cfg_attr(feature = "docs", doc(cfg(encryption)))] pub use matrix_sdk_base::crypto::{EncryptionInfo, LocalTrust}; pub use matrix_sdk_base::{ - Error as BaseError, Room as BaseRoom, RoomInfo, RoomMember as BaseRoomMember, RoomType, + media, Error as BaseError, Room as BaseRoom, RoomInfo, RoomMember as BaseRoomMember, RoomType, Session, StateChanges, StoreError, }; pub use matrix_sdk_common::*; diff --git a/matrix_sdk/src/room/common.rs b/matrix_sdk/src/room/common.rs index 82647506..cff99001 100644 --- a/matrix_sdk/src/room/common.rs +++ b/matrix_sdk/src/room/common.rs @@ -34,7 +34,7 @@ impl Common { /// # Arguments /// * `client` - The client used to make requests. /// - /// * `room` - The underlaying room. + /// * `room` - The underlying room. pub fn new(client: Client, room: BaseRoom) -> Self { // TODO: Make this private Self { inner: room, client } diff --git a/matrix_sdk/src/room/invited.rs b/matrix_sdk/src/room/invited.rs index fc7c1196..3ff681e4 100644 --- a/matrix_sdk/src/room/invited.rs +++ b/matrix_sdk/src/room/invited.rs @@ -5,7 +5,7 @@ use crate::{room::Common, BaseRoom, Client, Result, RoomType}; /// A room in the invited state. /// /// This struct contains all methodes specific to a `Room` with type -/// `RoomType::Invited`. Operations may fail once the underlaying `Room` changes +/// `RoomType::Invited`. Operations may fail once the underlying `Room` changes /// `RoomType`. #[derive(Debug, Clone)] pub struct Invited { @@ -13,13 +13,13 @@ pub struct Invited { } impl Invited { - /// Create a new `room::Invited` if the underlaying `Room` has type + /// Create a new `room::Invited` if the underlying `Room` has type /// `RoomType::Invited`. /// /// # Arguments /// * `client` - The client used to make requests. /// - /// * `room` - The underlaying room. + /// * `room` - The underlying room. pub fn new(client: Client, room: BaseRoom) -> Option { // TODO: Make this private if room.room_type() == RoomType::Invited { diff --git a/matrix_sdk/src/room/joined.rs b/matrix_sdk/src/room/joined.rs index db10e91c..2741bf69 100644 --- a/matrix_sdk/src/room/joined.rs +++ b/matrix_sdk/src/room/joined.rs @@ -48,7 +48,7 @@ const TYPING_NOTICE_RESEND_TIMEOUT: Duration = Duration::from_secs(3); /// A room in the joined state. /// /// The `JoinedRoom` contains all methodes specific to a `Room` with type -/// `RoomType::Joined`. Operations may fail once the underlaying `Room` changes +/// `RoomType::Joined`. Operations may fail once the underlying `Room` changes /// `RoomType`. #[derive(Debug, Clone)] pub struct Joined { @@ -64,13 +64,13 @@ impl Deref for Joined { } impl Joined { - /// Create a new `room::Joined` if the underlaying `BaseRoom` has type + /// Create a new `room::Joined` if the underlying `BaseRoom` has type /// `RoomType::Joined`. /// /// # Arguments /// * `client` - The client used to make requests. /// - /// * `room` - The underlaying room. + /// * `room` - The underlying room. pub fn new(client: Client, room: BaseRoom) -> Option { // TODO: Make this private if room.room_type() == RoomType::Joined { diff --git a/matrix_sdk/src/room/left.rs b/matrix_sdk/src/room/left.rs index 3169f356..0714f6c0 100644 --- a/matrix_sdk/src/room/left.rs +++ b/matrix_sdk/src/room/left.rs @@ -7,7 +7,7 @@ use crate::{room::Common, BaseRoom, Client, Result, RoomType}; /// A room in the left state. /// /// This struct contains all methodes specific to a `Room` with type -/// `RoomType::Left`. Operations may fail once the underlaying `Room` changes +/// `RoomType::Left`. Operations may fail once the underlying `Room` changes /// `RoomType`. #[derive(Debug, Clone)] pub struct Left { @@ -15,13 +15,13 @@ pub struct Left { } impl Left { - /// Create a new `room::Left` if the underlaying `Room` has type + /// Create a new `room::Left` if the underlying `Room` has type /// `RoomType::Left`. /// /// # Arguments /// * `client` - The client used to make requests. /// - /// * `room` - The underlaying room. + /// * `room` - The underlying room. pub fn new(client: Client, room: BaseRoom) -> Option { // TODO: Make this private if room.room_type() == RoomType::Left { diff --git a/matrix_sdk/src/sas.rs b/matrix_sdk/src/sas.rs index b2eed875..31de78ed 100644 --- a/matrix_sdk/src/sas.rs +++ b/matrix_sdk/src/sas.rs @@ -18,7 +18,7 @@ use matrix_sdk_base::crypto::{ use crate::{error::Result, Client}; -/// An object controling the interactive verification flow. +/// An object controlling the interactive verification flow. #[derive(Debug, Clone)] pub struct Sas { pub(crate) inner: BaseSas, diff --git a/matrix_sdk/src/verification_request.rs b/matrix_sdk/src/verification_request.rs index 88688eeb..4d2773a2 100644 --- a/matrix_sdk/src/verification_request.rs +++ b/matrix_sdk/src/verification_request.rs @@ -18,7 +18,7 @@ use matrix_sdk_base::crypto::{ use crate::{Client, Result}; -/// An object controling the interactive verification flow. +/// An object controlling the interactive verification flow. #[derive(Debug, Clone)] pub struct VerificationRequest { pub(crate) inner: BaseVerificationRequest, diff --git a/matrix_sdk_appservice/Cargo.toml b/matrix_sdk_appservice/Cargo.toml index 766fefac..5230ff1e 100644 --- a/matrix_sdk_appservice/Cargo.toml +++ b/matrix_sdk_appservice/Cargo.toml @@ -16,6 +16,7 @@ docs = [] [dependencies] actix-rt = { version = "2", optional = true } actix-web = { version = "4.0.0-beta.6", optional = true } +dashmap = "4" futures = "0.3" futures-util = "0.3" http = "0.2" diff --git a/matrix_sdk_appservice/examples/actix_autojoin.rs b/matrix_sdk_appservice/examples/actix_autojoin.rs index ad385180..4c91dd3d 100644 --- a/matrix_sdk_appservice/examples/actix_autojoin.rs +++ b/matrix_sdk_appservice/examples/actix_autojoin.rs @@ -34,9 +34,10 @@ impl EventHandler for AppserviceEventHandler { if let MembershipState::Invite = event.content.membership { let user_id = UserId::try_from(event.state_key.clone()).unwrap(); - self.appservice.register(user_id.localpart()).await.unwrap(); + let mut appservice = self.appservice.clone(); + appservice.register(user_id.localpart()).await.unwrap(); - let client = self.appservice.client(Some(user_id.localpart())).await.unwrap(); + let client = appservice.virtual_user(user_id.localpart()).await.unwrap(); client.join_room_by_id(room.room_id()).await.unwrap(); } @@ -53,7 +54,7 @@ pub async fn main() -> std::io::Result<()> { let registration = AppserviceRegistration::try_from_yaml_file("./tests/registration.yaml").unwrap(); - let appservice = Appservice::new(homeserver_url, server_name, registration).await.unwrap(); + let mut appservice = Appservice::new(homeserver_url, server_name, registration).await.unwrap(); let event_handler = AppserviceEventHandler::new(appservice.clone()); diff --git a/matrix_sdk_appservice/src/actix.rs b/matrix_sdk_appservice/src/actix.rs index bf5d2800..672f417a 100644 --- a/matrix_sdk_appservice/src/actix.rs +++ b/matrix_sdk_appservice/src/actix.rs @@ -65,7 +65,7 @@ async fn push_transactions( return Ok(HttpResponse::Unauthorized().finish()); } - appservice.client(None).await?.receive_transaction(request.incoming).await?; + appservice.get_cached_client(None)?.receive_transaction(request.incoming).await?; Ok(HttpResponse::Ok().json("{}")) } diff --git a/matrix_sdk_appservice/src/error.rs b/matrix_sdk_appservice/src/error.rs index fcbdf997..c42b3370 100644 --- a/matrix_sdk_appservice/src/error.rs +++ b/matrix_sdk_appservice/src/error.rs @@ -16,9 +16,6 @@ use thiserror::Error; #[derive(Error, Debug)] pub enum Error { - #[error("tried to run without webserver configured")] - RunWithoutServer, - #[error("missing access token")] MissingAccessToken, @@ -31,6 +28,9 @@ pub enum Error { #[error("no port found")] MissingRegistrationPort, + #[error("no client for localpart found")] + NoClientForLocalpart, + #[error(transparent)] HttpRequest(#[from] matrix_sdk::FromHttpRequestError), diff --git a/matrix_sdk_appservice/src/lib.rs b/matrix_sdk_appservice/src/lib.rs index 11c9710d..8a9d0b79 100644 --- a/matrix_sdk_appservice/src/lib.rs +++ b/matrix_sdk_appservice/src/lib.rs @@ -14,8 +14,9 @@ //! Matrix [Application Service] library //! -//! The appservice crate aims to provide a batteries-included experience. That -//! means that we +//! The appservice crate aims to provide a batteries-included experience by +//! being a thin wrapper around the [`matrix_sdk`]. That means that we +//! //! * ship with functionality to configure your webserver crate or simply run //! the webserver for you //! * receive and validate requests from the homeserver correctly @@ -57,7 +58,7 @@ //! regex: '@_appservice_.*' //! ")?; //! -//! let appservice = Appservice::new(homeserver_url, server_name, registration).await?; +//! let mut appservice = Appservice::new(homeserver_url, server_name, registration).await?; //! appservice.set_event_handler(Box::new(AppserviceEventHandler)).await?; //! //! let (host, port) = appservice.registration().get_host_and_port()?; @@ -81,8 +82,10 @@ use std::{ fs::File, ops::Deref, path::PathBuf, + sync::Arc, }; +use dashmap::DashMap; use http::Uri; #[doc(inline)] pub use matrix_sdk::api_appservice as api; @@ -98,8 +101,7 @@ use matrix_sdk::{ assign, identifiers::{self, DeviceId, ServerNameBox, UserId}, reqwest::Url, - Client, ClientConfig, EventHandler, FromHttpResponseError, HttpError, RequestConfig, - ServerError, Session, + Client, ClientConfig, EventHandler, FromHttpResponseError, HttpError, ServerError, Session, }; use regex::Regex; use tracing::warn; @@ -173,34 +175,39 @@ impl Deref for AppserviceRegistration { } } -async fn client_session_with_login_restore( - client: &Client, - registration: &AppserviceRegistration, - localpart: impl AsRef + Into>, - server_name: &ServerNameBox, -) -> Result<()> { - let session = Session { - access_token: registration.as_token.clone(), - user_id: UserId::parse_with_server_name(localpart, server_name)?, - device_id: DeviceId::new(), - }; - client.restore_login(session).await?; +type Localpart = String; - Ok(()) -} +/// The `localpart` of the user associated with the application service via +/// `sender_localpart` in [`AppserviceRegistration`]. +/// +/// Dummy type for shared documentation +#[allow(dead_code)] +pub type MainUser = (); + +/// The application service may specify the virtual user to act as through use +/// of a user_id query string parameter on the request. The user specified in +/// the query string must be covered by one of the [`AppserviceRegistration`]'s +/// `users` namespaces. +/// +/// Dummy type for shared documentation +pub type VirtualUser = (); /// Appservice #[derive(Debug, Clone)] pub struct Appservice { homeserver_url: Url, server_name: ServerNameBox, - registration: AppserviceRegistration, - client_sender_localpart: Client, + registration: Arc, + clients: Arc>, } impl Appservice { /// Create new Appservice /// + /// Also creates and caches a [`Client`] for the [`MainUser`]. + /// The default [`ClientConfig`] is used, if you want to customize it + /// use [`Self::new_with_config()`] instead. + /// /// # Arguments /// /// * `homeserver_url` - The homeserver that the client should connect to. @@ -215,28 +222,49 @@ impl Appservice { server_name: impl TryInto, registration: AppserviceRegistration, ) -> Result { - let homeserver_url = homeserver_url.try_into()?; - let server_name = server_name.try_into()?; - - let client_sender_localpart = Client::new(homeserver_url.clone())?; - - client_session_with_login_restore( - &client_sender_localpart, - ®istration, - registration.sender_localpart.as_ref(), - &server_name, + let appservice = Self::new_with_config( + homeserver_url, + server_name, + registration, + ClientConfig::default(), ) .await?; - Ok(Appservice { homeserver_url, server_name, registration, client_sender_localpart }) + Ok(appservice) } - /// Get a [`Client`] + /// Same as [`Self::new()`] but lets you provide a [`ClientConfig`] for the + /// [`Client`] + pub async fn new_with_config( + homeserver_url: impl TryInto, + server_name: impl TryInto, + registration: AppserviceRegistration, + client_config: ClientConfig, + ) -> Result { + let homeserver_url = homeserver_url.try_into()?; + let server_name = server_name.try_into()?; + let registration = Arc::new(registration); + let clients = Arc::new(DashMap::new()); + let sender_localpart = registration.sender_localpart.clone(); + + let appservice = Appservice { homeserver_url, server_name, registration, clients }; + + // we cache the [`MainUser`] by default + appservice.virtual_user_with_config(sender_localpart, client_config).await?; + + Ok(appservice) + } + + /// Create a [`Client`] for the given [`VirtualUser`]'s `localpart` /// - /// Will return a `Client` that's configured to [assert the identity] on all - /// outgoing homeserver requests if `localpart` is given. If not given - /// the `Client` will use the main user associated with this appservice, - /// that is the `sender_localpart` in the [`AppserviceRegistration`] + /// Will create and return a [`Client`] that's configured to [assert the + /// identity] on all outgoing homeserver requests if `localpart` is + /// given. + /// + /// This method is a singleton that saves the client internally for re-use + /// based on the `localpart`. The cached [`Client`] can be retrieved either + /// by calling this method again or by calling [`Self::get_cached_client()`] + /// which is non-async convenience wrapper. /// /// # Arguments /// @@ -244,26 +272,48 @@ impl Appservice { /// /// [registration]: https://matrix.org/docs/spec/application_service/r0.1.2#registration /// [assert the identity]: https://matrix.org/docs/spec/application_service/r0.1.2#identity-assertion - pub async fn client(&self, localpart: Option<&str>) -> Result { - let localpart = localpart.unwrap_or_else(|| self.registration.sender_localpart.as_ref()); + pub async fn virtual_user(&self, localpart: impl AsRef) -> Result { + let client = self.virtual_user_with_config(localpart, ClientConfig::default()).await?; - // The `as_token` in the `Session` maps to the main appservice user - // (`sender_localpart`) by default, so we don't need to assert identity - // in that case - let client = if localpart == self.registration.sender_localpart { - self.client_sender_localpart.clone() + Ok(client) + } + + /// Same as [`Self::virtual_user()`] but with the ability to pass in a + /// [`ClientConfig`] + /// + /// Since this method is a singleton follow-up calls with different + /// [`ClientConfig`]s will be ignored. + pub async fn virtual_user_with_config( + &self, + localpart: impl AsRef, + config: ClientConfig, + ) -> Result { + // TODO: check if localpart is covered by namespace? + let localpart = localpart.as_ref(); + + let client = if let Some(client) = self.clients.get(localpart) { + client.clone() } else { - let request_config = RequestConfig::default().assert_identity(); - let config = ClientConfig::default().request_config(request_config); + let user_id = UserId::parse_with_server_name(localpart, &self.server_name)?; + + // The `as_token` in the `Session` maps to the [`MainUser`] + // (`sender_localpart`) by default, so we don't need to assert identity + // in that case + if localpart != self.registration.sender_localpart { + config.get_request_config().assert_identity(); + } + let client = Client::new_with_config(self.homeserver_url.clone(), config)?; - client_session_with_login_restore( - &client, - &self.registration, - localpart, - &self.server_name, - ) - .await?; + let session = Session { + access_token: self.registration.as_token.clone(), + user_id: user_id.clone(), + // TODO: expose & proper E2EE + device_id: DeviceId::new(), + }; + + client.restore_login(session).await?; + self.clients.insert(localpart.to_owned(), client.clone()); client }; @@ -271,9 +321,28 @@ impl Appservice { Ok(client) } + /// Get cached [`Client`] + /// + /// Will return the client for the given `localpart` if previously + /// constructed with [`Self::virtual_user()`] or + /// [`Self::virtual_user_with_config()`]. + /// + /// If no `localpart` is given it assumes the [`MainUser`]'s `localpart`. If + /// no client for `localpart` is found it will return an Error. + pub fn get_cached_client(&self, localpart: Option<&str>) -> Result { + let localpart = localpart.unwrap_or_else(|| self.registration.sender_localpart.as_ref()); + + let entry = self.clients.get(localpart).ok_or(Error::NoClientForLocalpart)?; + + Ok(entry.value().clone()) + } + /// Convenience wrapper around [`Client::set_event_handler()`] - pub async fn set_event_handler(&self, handler: Box) -> Result<()> { - let client = self.client(None).await?; + /// + /// Attaches the event handler to the [`MainUser`]'s [`Client`] + pub async fn set_event_handler(&mut self, handler: Box) -> Result<()> { + let client = self.get_cached_client(None)?; + client.set_event_handler(handler).await; Ok(()) @@ -286,13 +355,13 @@ impl Appservice { /// /// * `localpart` - The localpart of the user to register. Must be covered /// by the namespaces in the [`Registration`] in order to succeed. - pub async fn register(&self, localpart: impl AsRef) -> Result<()> { + pub async fn register(&mut self, localpart: impl AsRef) -> Result<()> { let request = assign!(RegistrationRequest::new(), { username: Some(localpart.as_ref()), login_type: Some(&LoginType::ApplicationService), }); - let client = self.client(None).await?; + let client = self.get_cached_client(None)?; match client.register(request).await { Ok(_) => (), Err(error) => match error { @@ -328,7 +397,8 @@ impl Appservice { self.registration.hs_token == hs_token.as_ref() } - /// Check if given `user_id` is in any of the registration user namespaces + /// Check if given `user_id` is in any of the [`AppserviceRegistration`]'s + /// `users` namespaces pub fn user_id_is_in_namespace(&self, user_id: impl AsRef) -> Result { for user in &self.registration.namespaces.users { // TODO: precompile on Appservice construction diff --git a/matrix_sdk_appservice/tests/actix.rs b/matrix_sdk_appservice/tests/actix.rs index 4bb4b8cd..f7186698 100644 --- a/matrix_sdk_appservice/tests/actix.rs +++ b/matrix_sdk_appservice/tests/actix.rs @@ -12,7 +12,7 @@ mod actix { Appservice::new( mockito::server_url().as_ref(), "test.local", - AppserviceRegistration::try_from_yaml_file("./tests/registration.yaml").unwrap(), + AppserviceRegistration::try_from_yaml_str(include_str!("./registration.yaml")).unwrap(), ) .await .unwrap() diff --git a/matrix_sdk_appservice/tests/tests.rs b/matrix_sdk_appservice/tests/tests.rs index 61fe0ee4..dd7f0dcb 100644 --- a/matrix_sdk_appservice/tests/tests.rs +++ b/matrix_sdk_appservice/tests/tests.rs @@ -59,7 +59,7 @@ fn member_json() -> serde_json::Value { #[async_test] async fn test_event_handler() -> Result<()> { - let appservice = appservice(None).await?; + let mut appservice = appservice(None).await?; struct Example {} @@ -87,7 +87,7 @@ async fn test_event_handler() -> Result<()> { events, ); - appservice.client(None).await?.receive_transaction(incoming).await?; + appservice.get_cached_client(None)?.receive_transaction(incoming).await?; Ok(()) } @@ -105,7 +105,7 @@ async fn test_transaction() -> Result<()> { events, ); - appservice.client(None).await?.receive_transaction(incoming).await?; + appservice.get_cached_client(None)?.receive_transaction(incoming).await?; Ok(()) } diff --git a/matrix_sdk_base/Cargo.toml b/matrix_sdk_base/Cargo.toml index baf1b072..1c25de53 100644 --- a/matrix_sdk_base/Cargo.toml +++ b/matrix_sdk_base/Cargo.toml @@ -25,6 +25,7 @@ docs = ["encryption", "sled_cryptostore"] [dependencies] dashmap = "4.0.2" +lru = "0.6.5" serde = { version = "1.0.122", features = ["rc"] } serde_json = "1.0.61" tracing = "0.1.22" diff --git a/matrix_sdk_base/src/client.rs b/matrix_sdk_base/src/client.rs index d24a2f1c..4bfedba4 100644 --- a/matrix_sdk_base/src/client.rs +++ b/matrix_sdk_base/src/client.rs @@ -42,7 +42,8 @@ use matrix_sdk_common::{ events::{ room::member::{MemberEventContent, MembershipState}, AnyGlobalAccountDataEvent, AnyRoomAccountDataEvent, AnyStrippedStateEvent, - AnySyncRoomEvent, AnySyncStateEvent, EventContent, EventType, StateEvent, + AnySyncEphemeralRoomEvent, AnySyncRoomEvent, AnySyncStateEvent, EventContent, EventType, + StateEvent, }, identifiers::{RoomId, UserId}, instant::Instant, @@ -685,7 +686,7 @@ impl BaseClient { for room_id in rooms { if let Some(room) = changes.room_infos.get_mut(room_id) { room.base_info.dm_target = Some(user_id.clone()); - } else if let Some(room) = self.store.get_bare_room(room_id) { + } else if let Some(room) = self.store.get_room(room_id) { let mut info = room.clone_info(); info.base_info.dm_target = Some(user_id.clone()); changes.add_room(info); @@ -784,6 +785,15 @@ impl BaseClient { ) .await?; + if let Some(event) = + new_info.ephemeral.events.iter().find_map(|e| match e.deserialize() { + Ok(AnySyncEphemeralRoomEvent::Receipt(event)) => Some(event.content), + _ => None, + }) + { + changes.add_receipts(&room_id, event); + } + if new_info.timeline.limited { room_info.mark_members_missing(); } @@ -931,7 +941,7 @@ impl BaseClient { async fn apply_changes(&self, changes: &StateChanges) { for (room_id, room_info) in &changes.room_infos { - if let Some(room) = self.store.get_bare_room(&room_id) { + if let Some(room) = self.store.get_room(&room_id) { room.update_summary(room_info.clone()) } } @@ -958,7 +968,7 @@ impl BaseClient { .collect(); let mut ambiguity_cache = AmbiguityCache::new(self.store.clone()); - if let Some(room) = self.store.get_bare_room(room_id) { + if let Some(room) = self.store.get_room(room_id) { let mut room_info = room.clone_info(); room_info.mark_members_synced(); diff --git a/matrix_sdk_base/src/lib.rs b/matrix_sdk_base/src/lib.rs index 358f69ea..326700ce 100644 --- a/matrix_sdk_base/src/lib.rs +++ b/matrix_sdk_base/src/lib.rs @@ -45,6 +45,7 @@ pub use crate::{ mod client; mod error; +pub mod media; mod rooms; mod session; mod store; diff --git a/matrix_sdk_base/src/media.rs b/matrix_sdk_base/src/media.rs new file mode 100644 index 00000000..9d075e63 --- /dev/null +++ b/matrix_sdk_base/src/media.rs @@ -0,0 +1,216 @@ +//! Common types for [media content](https://matrix.org/docs/spec/client_server/r0.6.1#id66). + +use matrix_sdk_common::{ + api::r0::media::get_content_thumbnail::Method, + events::{ + room::{ + message::{ + AudioMessageEventContent, FileMessageEventContent, ImageMessageEventContent, + LocationMessageEventContent, VideoMessageEventContent, + }, + EncryptedFile, + }, + sticker::StickerEventContent, + }, + identifiers::MxcUri, + UInt, +}; + +const UNIQUE_SEPARATOR: &str = "_"; + +/// A trait to uniquely identify values of the same type. +pub trait UniqueKey { + /// A string that uniquely identifies `Self` compared to other values of + /// the same type. + fn unique_key(&self) -> String; +} + +/// The requested format of a media file. +#[derive(Clone, Debug)] +pub enum MediaFormat { + /// The file that was uploaded. + File, + + /// A thumbnail of the file that was uploaded. + Thumbnail(MediaThumbnailSize), +} + +impl UniqueKey for MediaFormat { + fn unique_key(&self) -> String { + match self { + Self::File => "file".into(), + Self::Thumbnail(size) => size.unique_key(), + } + } +} + +/// The requested size of a media thumbnail. +#[derive(Clone, Debug)] +pub struct MediaThumbnailSize { + /// The desired resizing method. + pub method: Method, + + /// The desired width of the thumbnail. The actual thumbnail may not match + /// the size specified. + pub width: UInt, + + /// The desired height of the thumbnail. The actual thumbnail may not match + /// the size specified. + pub height: UInt, +} + +impl UniqueKey for MediaThumbnailSize { + fn unique_key(&self) -> String { + format!("{}{}{}x{}", self.method, UNIQUE_SEPARATOR, self.width, self.height) + } +} + +/// A request for media data. +#[derive(Clone, Debug)] +pub enum MediaType { + /// A media content URI. + Uri(MxcUri), + + /// An encrypted media content. + Encrypted(Box), +} + +impl UniqueKey for MediaType { + fn unique_key(&self) -> String { + match self { + Self::Uri(uri) => uri.to_string(), + Self::Encrypted(file) => file.url.to_string(), + } + } +} + +/// A request for media data. +#[derive(Clone, Debug)] +pub struct MediaRequest { + /// The type of the media file. + pub media_type: MediaType, + + /// The requested format of the media data. + pub format: MediaFormat, +} + +impl UniqueKey for MediaRequest { + fn unique_key(&self) -> String { + format!("{}{}{}", self.media_type.unique_key(), UNIQUE_SEPARATOR, self.format.unique_key()) + } +} + +/// Trait for media event content. +pub trait MediaEventContent { + /// Get the type of the file for `Self`. + /// + /// Returns `None` if `Self` has no file. + fn file(&self) -> Option; + + /// Get the type of the thumbnail for `Self`. + /// + /// Returns `None` if `Self` has no thumbnail. + fn thumbnail(&self) -> Option; +} + +impl MediaEventContent for StickerEventContent { + fn file(&self) -> Option { + Some(MediaType::Uri(self.url.clone())) + } + + fn thumbnail(&self) -> Option { + None + } +} + +impl MediaEventContent for AudioMessageEventContent { + fn file(&self) -> Option { + self.url + .as_ref() + .map(|uri| MediaType::Uri(uri.clone())) + .or_else(|| self.file.as_ref().map(|e| MediaType::Encrypted(e.clone()))) + } + + fn thumbnail(&self) -> Option { + None + } +} + +impl MediaEventContent for FileMessageEventContent { + fn file(&self) -> Option { + self.url + .as_ref() + .map(|uri| MediaType::Uri(uri.clone())) + .or_else(|| self.file.as_ref().map(|e| MediaType::Encrypted(e.clone()))) + } + + fn thumbnail(&self) -> Option { + self.info.as_ref().and_then(|info| { + if let Some(uri) = info.thumbnail_url.as_ref() { + Some(MediaType::Uri(uri.clone())) + } else { + info.thumbnail_file.as_ref().map(|file| MediaType::Encrypted(file.clone())) + } + }) + } +} + +impl MediaEventContent for ImageMessageEventContent { + fn file(&self) -> Option { + self.url + .as_ref() + .map(|uri| MediaType::Uri(uri.clone())) + .or_else(|| self.file.as_ref().map(|e| MediaType::Encrypted(e.clone()))) + } + + fn thumbnail(&self) -> Option { + self.info + .as_ref() + .and_then(|info| { + if let Some(uri) = info.thumbnail_url.as_ref() { + Some(MediaType::Uri(uri.clone())) + } else { + info.thumbnail_file.as_ref().map(|file| MediaType::Encrypted(file.clone())) + } + }) + .or_else(|| self.url.as_ref().map(|uri| MediaType::Uri(uri.clone()))) + } +} + +impl MediaEventContent for VideoMessageEventContent { + fn file(&self) -> Option { + self.url + .as_ref() + .map(|uri| MediaType::Uri(uri.clone())) + .or_else(|| self.file.as_ref().map(|e| MediaType::Encrypted(e.clone()))) + } + + fn thumbnail(&self) -> Option { + self.info + .as_ref() + .and_then(|info| { + if let Some(uri) = info.thumbnail_url.as_ref() { + Some(MediaType::Uri(uri.clone())) + } else { + info.thumbnail_file.as_ref().map(|file| MediaType::Encrypted(file.clone())) + } + }) + .or_else(|| self.url.as_ref().map(|uri| MediaType::Uri(uri.clone()))) + } +} + +impl MediaEventContent for LocationMessageEventContent { + fn file(&self) -> Option { + None + } + + fn thumbnail(&self) -> Option { + self.info.as_ref().and_then(|info| { + if let Some(uri) = info.thumbnail_url.as_ref() { + Some(MediaType::Uri(uri.clone())) + } else { + info.thumbnail_file.as_ref().map(|file| MediaType::Encrypted(file.clone())) + } + }) + } +} diff --git a/matrix_sdk_base/src/rooms/mod.rs b/matrix_sdk_base/src/rooms/mod.rs index 39121d80..6dc9ba77 100644 --- a/matrix_sdk_base/src/rooms/mod.rs +++ b/matrix_sdk_base/src/rooms/mod.rs @@ -62,28 +62,11 @@ impl BaseRoomInfo { invited_member_count: u64, heroes: Vec, ) -> String { - let heroes_count = heroes.len() as u64; - let invited_joined = (invited_member_count + joined_member_count).saturating_sub(1); - - if heroes_count >= invited_joined { - let mut names = heroes.iter().take(3).map(|mem| mem.name()).collect::>(); - // stabilize ordering - names.sort_unstable(); - names.join(", ") - } else if heroes_count < invited_joined && invited_joined > 1 { - let mut names = heroes.iter().take(3).map(|mem| mem.name()).collect::>(); - names.sort_unstable(); - - // TODO: What length does the spec want us to use here and in - // the `else`? - format!( - "{}, and {} others", - names.join(", "), - (joined_member_count + invited_member_count) - ) - } else { - "Empty room".to_string() - } + calculate_room_name( + joined_member_count, + invited_member_count, + heroes.iter().take(3).map(|mem| mem.name()).collect::>(), + ) } /// Handle a state event for this room and update our info accordingly. @@ -164,3 +147,81 @@ impl Default for BaseRoomInfo { } } } + +/// Calculate room name according to step 3 of the [naming algorithm.][spec] +/// +/// [spec]: +fn calculate_room_name( + joined_member_count: u64, + invited_member_count: u64, + heroes: Vec<&str>, +) -> String { + let heroes_count = heroes.len() as u64; + let invited_joined = invited_member_count + joined_member_count; + let invited_joined_minus_one = invited_joined.saturating_sub(1); + + let names = if heroes_count >= invited_joined_minus_one { + let mut names = heroes; + // stabilize ordering + names.sort_unstable(); + names.join(", ") + } else if heroes_count < invited_joined_minus_one && invited_joined > 1 { + let mut names = heroes; + names.sort_unstable(); + + // TODO: What length does the spec want us to use here and in + // the `else`? + format!("{}, and {} others", names.join(", "), (invited_joined - heroes_count)) + } else { + "".to_string() + }; + + // User is alone. + if invited_joined <= 1 { + if names.is_empty() { + "Empty room".to_string() + } else { + format!("Empty room (was {})", names) + } + } else { + names + } +} + +#[cfg(test)] +mod tests { + use super::*; + #[test] + + fn test_calculate_room_name() { + let mut actual = calculate_room_name(2, 0, vec!["a"]); + assert_eq!("a", actual); + + actual = calculate_room_name(3, 0, vec!["a", "b"]); + assert_eq!("a, b", actual); + + actual = calculate_room_name(4, 0, vec!["a", "b", "c"]); + assert_eq!("a, b, c", actual); + + actual = calculate_room_name(5, 0, vec!["a", "b", "c"]); + assert_eq!("a, b, c, and 2 others", actual); + + actual = calculate_room_name(0, 0, vec![]); + assert_eq!("Empty room", actual); + + actual = calculate_room_name(1, 0, vec![]); + assert_eq!("Empty room", actual); + + actual = calculate_room_name(0, 1, vec![]); + assert_eq!("Empty room", actual); + + actual = calculate_room_name(1, 0, vec!["a"]); + assert_eq!("Empty room (was a)", actual); + + actual = calculate_room_name(1, 0, vec!["a", "b"]); + assert_eq!("Empty room (was a, b)", actual); + + actual = calculate_room_name(1, 0, vec!["a", "b", "c"]); + assert_eq!("Empty room (was a, b, c)", actual); + } +} diff --git a/matrix_sdk_base/src/rooms/normal.rs b/matrix_sdk_base/src/rooms/normal.rs index e86f51a6..6e2fc03c 100644 --- a/matrix_sdk_base/src/rooms/normal.rs +++ b/matrix_sdk_base/src/rooms/normal.rs @@ -24,6 +24,7 @@ use futures::{ use matrix_sdk_common::{ api::r0::sync::sync_events::RoomSummary as RumaSummary, events::{ + receipt::Receipt, room::{ create::CreateEventContent, encryption::EncryptionEventContent, guest_access::GuestAccess, history_visibility::HistoryVisibility, join_rules::JoinRule, @@ -32,7 +33,8 @@ use matrix_sdk_common::{ tag::Tags, AnyRoomAccountDataEvent, AnyStateEventContent, AnySyncStateEvent, EventType, }, - identifiers::{MxcUri, RoomAliasId, RoomId, UserId}, + identifiers::{EventId, MxcUri, RoomAliasId, RoomId, UserId}, + receipt::ReceiptType, }; use serde::{Deserialize, Serialize}; use tracing::info; @@ -449,6 +451,24 @@ impl Room { Ok(None) } } + + /// Get the read receipt as a `EventId` and `Receipt` tuple for the given + /// `user_id` in this room. + pub async fn user_read_receipt( + &self, + user_id: &UserId, + ) -> StoreResult> { + self.store.get_user_room_receipt_event(self.room_id(), ReceiptType::Read, user_id).await + } + + /// Get the read receipts as a list of `UserId` and `Receipt` tuples for the + /// given `event_id` in this room. + pub async fn event_read_receipts( + &self, + event_id: &EventId, + ) -> StoreResult> { + self.store.get_event_room_receipt_events(self.room_id(), ReceiptType::Read, event_id).await + } } /// The underlying pure data structure for joined and left rooms. @@ -466,7 +486,7 @@ pub struct RoomInfo { pub summary: RoomSummary, /// Flag remembering if the room members are synced. pub members_synced: bool, - /// The prev batch of this room we received durring the last sync. + /// The prev batch of this room we received during the last sync. pub last_prev_batch: Option, /// Base room info which holds some basic event contents important for the /// room state. diff --git a/matrix_sdk_base/src/store/memory_store.rs b/matrix_sdk_base/src/store/memory_store.rs index 51c3ea33..7d15254a 100644 --- a/matrix_sdk_base/src/store/memory_store.rs +++ b/matrix_sdk_base/src/store/memory_store.rs @@ -18,22 +18,29 @@ use std::{ }; use dashmap::{DashMap, DashSet}; +use lru::LruCache; use matrix_sdk_common::{ async_trait, events::{ presence::PresenceEvent, + receipt::Receipt, room::member::{MemberEventContent, MembershipState}, AnyGlobalAccountDataEvent, AnyRoomAccountDataEvent, AnyStrippedStateEvent, AnySyncStateEvent, EventType, }, - identifiers::{RoomId, UserId}, + identifiers::{EventId, MxcUri, RoomId, UserId}, instant::Instant, + locks::Mutex, + receipt::ReceiptType, Raw, }; use tracing::info; use super::{Result, RoomInfo, StateChanges, StateStore}; -use crate::deserialized_responses::{MemberEvent, StrippedMemberEvent}; +use crate::{ + deserialized_responses::{MemberEvent, StrippedMemberEvent}, + media::{MediaRequest, UniqueKey}, +}; #[derive(Debug, Clone)] pub struct MemoryStore { @@ -55,6 +62,12 @@ pub struct MemoryStore { Arc>>>>, stripped_members: Arc>>, presence: Arc>>, + #[allow(clippy::type_complexity)] + room_user_receipts: Arc>>>, + #[allow(clippy::type_complexity)] + room_event_receipts: + Arc>>>>, + media: Arc>>>, } impl MemoryStore { @@ -76,6 +89,9 @@ impl MemoryStore { stripped_room_state: DashMap::new().into(), stripped_members: DashMap::new().into(), presence: DashMap::new().into(), + room_user_receipts: DashMap::new().into(), + room_event_receipts: DashMap::new().into(), + media: Arc::new(Mutex::new(LruCache::new(100))), } } @@ -220,6 +236,43 @@ impl MemoryStore { } } + for (room, content) in &changes.receipts { + for (event_id, receipts) in &content.0 { + for (receipt_type, receipts) in receipts { + for (user_id, receipt) in receipts { + // Add the receipt to the room user receipts + if let Some((old_event, _)) = self + .room_user_receipts + .entry(room.clone()) + .or_insert_with(DashMap::new) + .entry(receipt_type.to_string()) + .or_insert_with(DashMap::new) + .insert(user_id.clone(), (event_id.clone(), receipt.clone())) + { + // Remove the old receipt from the room event receipts + if let Some(receipt_map) = self.room_event_receipts.get(room) { + if let Some(event_map) = receipt_map.get(receipt_type.as_ref()) { + if let Some(user_map) = event_map.get_mut(&old_event) { + user_map.remove(user_id); + } + } + } + } + + // Add the receipt to the room event receipts + self.room_event_receipts + .entry(room.clone()) + .or_insert_with(DashMap::new) + .entry(receipt_type.to_string()) + .or_insert_with(DashMap::new) + .entry(event_id.clone()) + .or_insert_with(DashMap::new) + .insert(user_id.clone(), receipt.clone()); + } + } + } + } + info!("Saved changes in {:?}", now.elapsed()); Ok(()) @@ -311,6 +364,68 @@ impl MemoryStore { .get(room_id) .and_then(|m| m.get(event_type.as_ref()).map(|e| e.clone()))) } + + async fn get_user_room_receipt_event( + &self, + room_id: &RoomId, + receipt_type: ReceiptType, + user_id: &UserId, + ) -> Result> { + Ok(self.room_user_receipts.get(room_id).and_then(|m| { + m.get(receipt_type.as_ref()).and_then(|m| m.get(user_id).map(|r| r.clone())) + })) + } + + async fn get_event_room_receipt_events( + &self, + room_id: &RoomId, + receipt_type: ReceiptType, + event_id: &EventId, + ) -> Result> { + Ok(self + .room_event_receipts + .get(room_id) + .and_then(|m| { + m.get(receipt_type.as_ref()).and_then(|m| { + m.get(event_id) + .map(|m| m.iter().map(|r| (r.key().clone(), r.value().clone())).collect()) + }) + }) + .unwrap_or_else(Vec::new)) + } + + async fn add_media_content(&self, request: &MediaRequest, data: Vec) -> Result<()> { + self.media.lock().await.put(request.unique_key(), data); + + Ok(()) + } + + async fn get_media_content(&self, request: &MediaRequest) -> Result>> { + Ok(self.media.lock().await.get(&request.unique_key()).cloned()) + } + + async fn remove_media_content(&self, request: &MediaRequest) -> Result<()> { + self.media.lock().await.pop(&request.unique_key()); + + Ok(()) + } + + async fn remove_media_content_for_uri(&self, uri: &MxcUri) -> Result<()> { + let mut media_store = self.media.lock().await; + + let keys: Vec = media_store + .iter() + .filter_map( + |(key, _)| if key.starts_with(&uri.to_string()) { Some(key.clone()) } else { None }, + ) + .collect(); + + for key in keys { + media_store.pop(&key); + } + + Ok(()) + } } #[cfg_attr(target_arch = "wasm32", async_trait(?Send))] @@ -408,4 +523,191 @@ impl StateStore for MemoryStore { ) -> Result>> { self.get_room_account_data_event(room_id, event_type).await } + + async fn get_user_room_receipt_event( + &self, + room_id: &RoomId, + receipt_type: ReceiptType, + user_id: &UserId, + ) -> Result> { + self.get_user_room_receipt_event(room_id, receipt_type, user_id).await + } + + async fn get_event_room_receipt_events( + &self, + room_id: &RoomId, + receipt_type: ReceiptType, + event_id: &EventId, + ) -> Result> { + self.get_event_room_receipt_events(room_id, receipt_type, event_id).await + } + + async fn add_media_content(&self, request: &MediaRequest, data: Vec) -> Result<()> { + self.add_media_content(request, data).await + } + + async fn get_media_content(&self, request: &MediaRequest) -> Result>> { + self.get_media_content(request).await + } + + async fn remove_media_content(&self, request: &MediaRequest) -> Result<()> { + self.remove_media_content(request).await + } + + async fn remove_media_content_for_uri(&self, uri: &MxcUri) -> Result<()> { + self.remove_media_content_for_uri(uri).await + } +} + +#[cfg(test)] +#[cfg(not(feature = "sled_state_store"))] +mod test { + use matrix_sdk_common::{ + api::r0::media::get_content_thumbnail::Method, + identifiers::{event_id, mxc_uri, room_id, user_id, UserId}, + receipt::ReceiptType, + uint, + }; + use matrix_sdk_test::async_test; + use serde_json::json; + + use super::{MemoryStore, StateChanges}; + use crate::media::{MediaFormat, MediaRequest, MediaThumbnailSize, MediaType}; + + fn user_id() -> UserId { + user_id!("@example:localhost") + } + + #[async_test] + async fn test_receipts_saving() { + let store = MemoryStore::new(); + + let room_id = room_id!("!test:localhost"); + + let first_event_id = event_id!("$1435641916114394fHBLK:matrix.org"); + let second_event_id = event_id!("$fHBLK1435641916114394:matrix.org"); + + let first_receipt_event = serde_json::from_value(json!({ + first_event_id.clone(): { + "m.read": { + user_id(): { + "ts": 1436451550453u64 + } + } + } + })) + .unwrap(); + + let second_receipt_event = serde_json::from_value(json!({ + second_event_id.clone(): { + "m.read": { + user_id(): { + "ts": 1436451551453u64 + } + } + } + })) + .unwrap(); + + assert!(store + .get_user_room_receipt_event(&room_id, ReceiptType::Read, &user_id()) + .await + .unwrap() + .is_none()); + assert!(store + .get_event_room_receipt_events(&room_id, ReceiptType::Read, &first_event_id) + .await + .unwrap() + .is_empty()); + assert!(store + .get_event_room_receipt_events(&room_id, ReceiptType::Read, &second_event_id) + .await + .unwrap() + .is_empty()); + + let mut changes = StateChanges::default(); + changes.add_receipts(&room_id, first_receipt_event); + + store.save_changes(&changes).await.unwrap(); + assert!(store + .get_user_room_receipt_event(&room_id, ReceiptType::Read, &user_id()) + .await + .unwrap() + .is_some(),); + assert_eq!( + store + .get_event_room_receipt_events(&room_id, ReceiptType::Read, &first_event_id) + .await + .unwrap() + .len(), + 1 + ); + assert!(store + .get_event_room_receipt_events(&room_id, ReceiptType::Read, &second_event_id) + .await + .unwrap() + .is_empty()); + + let mut changes = StateChanges::default(); + changes.add_receipts(&room_id, second_receipt_event); + + store.save_changes(&changes).await.unwrap(); + assert!(store + .get_user_room_receipt_event(&room_id, ReceiptType::Read, &user_id()) + .await + .unwrap() + .is_some()); + assert!(store + .get_event_room_receipt_events(&room_id, ReceiptType::Read, &first_event_id) + .await + .unwrap() + .is_empty()); + assert_eq!( + store + .get_event_room_receipt_events(&room_id, ReceiptType::Read, &second_event_id) + .await + .unwrap() + .len(), + 1 + ); + } + + #[async_test] + async fn test_media_content() { + let store = MemoryStore::new(); + + let uri = mxc_uri!("mxc://localhost/media"); + let content: Vec = "somebinarydata".into(); + + let request_file = + MediaRequest { media_type: MediaType::Uri(uri.clone()), format: MediaFormat::File }; + + let request_thumbnail = MediaRequest { + media_type: MediaType::Uri(uri.clone()), + format: MediaFormat::Thumbnail(MediaThumbnailSize { + method: Method::Crop, + width: uint!(100), + height: uint!(100), + }), + }; + + assert!(store.get_media_content(&request_file).await.unwrap().is_none()); + assert!(store.get_media_content(&request_thumbnail).await.unwrap().is_none()); + + store.add_media_content(&request_file, content.clone()).await.unwrap(); + assert!(store.get_media_content(&request_file).await.unwrap().is_some()); + + store.remove_media_content(&request_file).await.unwrap(); + assert!(store.get_media_content(&request_file).await.unwrap().is_none()); + + store.add_media_content(&request_file, content.clone()).await.unwrap(); + assert!(store.get_media_content(&request_file).await.unwrap().is_some()); + + store.add_media_content(&request_thumbnail, content.clone()).await.unwrap(); + assert!(store.get_media_content(&request_thumbnail).await.unwrap().is_some()); + + store.remove_media_content_for_uri(&uri).await.unwrap(); + assert!(store.get_media_content(&request_file).await.unwrap().is_none()); + assert!(store.get_media_content(&request_thumbnail).await.unwrap().is_none()); + } } diff --git a/matrix_sdk_base/src/store/mod.rs b/matrix_sdk_base/src/store/mod.rs index a9b20dd0..bbc915ed 100644 --- a/matrix_sdk_base/src/store/mod.rs +++ b/matrix_sdk_base/src/store/mod.rs @@ -25,11 +25,15 @@ use matrix_sdk_common::{ api::r0::push::get_notifications::Notification, async_trait, events::{ - presence::PresenceEvent, room::member::MemberEventContent, AnyGlobalAccountDataEvent, - AnyRoomAccountDataEvent, AnyStrippedStateEvent, AnySyncStateEvent, EventContent, EventType, + presence::PresenceEvent, + receipt::{Receipt, ReceiptEventContent}, + room::member::MemberEventContent, + AnyGlobalAccountDataEvent, AnyRoomAccountDataEvent, AnyStrippedStateEvent, + AnySyncStateEvent, EventContent, EventType, }, - identifiers::{RoomId, UserId}, + identifiers::{EventId, MxcUri, RoomId, UserId}, locks::RwLock, + receipt::ReceiptType, AsyncTraitDeps, Raw, }; #[cfg(feature = "sled_state_store")] @@ -37,6 +41,7 @@ use sled::Db; use crate::{ deserialized_responses::{MemberEvent, StrippedMemberEvent}, + media::MediaRequest, rooms::{RoomInfo, RoomType}, Room, Session, }; @@ -210,6 +215,72 @@ pub trait StateStore: AsyncTraitDeps { room_id: &RoomId, event_type: EventType, ) -> Result>>; + + /// Get an event out of the user room receipt store. + /// + /// # Arguments + /// + /// * `room_id` - The id of the room for which the receipt should be + /// fetched. + /// + /// * `receipt_type` - The type of the receipt. + /// + /// * `user_id` - The id of the user for who the receipt should be fetched. + async fn get_user_room_receipt_event( + &self, + room_id: &RoomId, + receipt_type: ReceiptType, + user_id: &UserId, + ) -> Result>; + + /// Get events out of the event room receipt store. + /// + /// # Arguments + /// + /// * `room_id` - The id of the room for which the receipts should be + /// fetched. + /// + /// * `receipt_type` - The type of the receipts. + /// + /// * `event_id` - The id of the event for which the receipts should be + /// fetched. + async fn get_event_room_receipt_events( + &self, + room_id: &RoomId, + receipt_type: ReceiptType, + event_id: &EventId, + ) -> Result>; + + /// Add a media file's content in the media store. + /// + /// # Arguments + /// + /// * `request` - The `MediaRequest` of the file. + /// + /// * `content` - The content of the file. + async fn add_media_content(&self, request: &MediaRequest, content: Vec) -> Result<()>; + + /// Get a media file's content out of the media store. + /// + /// # Arguments + /// + /// * `request` - The `MediaRequest` of the file. + async fn get_media_content(&self, request: &MediaRequest) -> Result>>; + + /// Removes a media file's content from the media store. + /// + /// # Arguments + /// + /// * `request` - The `MediaRequest` of the file. + async fn remove_media_content(&self, request: &MediaRequest) -> Result<()>; + + /// Removes all the media files' content associated to an `MxcUri` from the + /// media store. + /// + /// # Arguments + /// + /// * `uri` - The `MxcUri` of the media files. + async fn remove_media_content_for_uri(&self, uri: &MxcUri) -> Result<()>; } /// A state store wrapper for the SDK. @@ -291,11 +362,6 @@ impl Store { Ok((Self::new(Box::new(inner.clone())), inner.inner)) } - pub(crate) fn get_bare_room(&self, room_id: &RoomId) -> Option { - #[allow(clippy::map_clone)] - self.rooms.get(room_id).map(|r| r.clone()) - } - /// Get all the rooms this store knows about. pub fn get_rooms(&self) -> Vec { self.rooms.iter().filter_map(|r| self.get_room(r.key())).collect() @@ -303,15 +369,17 @@ impl Store { /// Get the room with the given room id. pub fn get_room(&self, room_id: &RoomId) -> Option { - self.get_bare_room(room_id).and_then(|r| match r.room_type() { - RoomType::Joined => Some(r), - RoomType::Left => Some(r), - RoomType::Invited => self.get_stripped_room(room_id), - }) + self.rooms + .get(room_id) + .and_then(|r| match r.room_type() { + RoomType::Joined => Some(r.clone()), + RoomType::Left => Some(r.clone()), + RoomType::Invited => self.get_stripped_room(room_id), + }) + .or_else(|| self.get_stripped_room(room_id)) } fn get_stripped_room(&self, room_id: &RoomId) -> Option { - #[allow(clippy::map_clone)] self.stripped_rooms.get(room_id).map(|r| r.clone()) } @@ -369,6 +437,8 @@ pub struct StateChanges { pub room_account_data: BTreeMap>>, /// A map of `RoomId` to `RoomInfo`. pub room_infos: BTreeMap, + /// A map of `RoomId` to `ReceiptEventContent`. + pub receipts: BTreeMap, /// A mapping of `RoomId` to a map of event type to a map of state key to /// `AnyStrippedStateEvent`. @@ -404,7 +474,7 @@ impl StateChanges { /// Update the `StateChanges` struct with the given `RoomInfo`. pub fn add_stripped_room(&mut self, room: RoomInfo) { - self.invited_room_info.insert(room.room_id.as_ref().to_owned(), room); + self.room_infos.insert(room.room_id.as_ref().to_owned(), room); } /// Update the `StateChanges` struct with the given `AnyBasicEvent`. @@ -462,4 +532,10 @@ impl StateChanges { pub fn add_notification(&mut self, room_id: &RoomId, notification: Notification) { self.notifications.entry(room_id.to_owned()).or_insert_with(Vec::new).push(notification); } + + /// Update the `StateChanges` struct with the given room with a new + /// `Receipts`. + pub fn add_receipts(&mut self, room_id: &RoomId, event: ReceiptEventContent) { + self.receipts.insert(room_id.to_owned(), event); + } } diff --git a/matrix_sdk_base/src/store/sled_store/mod.rs b/matrix_sdk_base/src/store/sled_store/mod.rs index c60534bc..13d35184 100644 --- a/matrix_sdk_base/src/store/sled_store/mod.rs +++ b/matrix_sdk_base/src/store/sled_store/mod.rs @@ -16,7 +16,7 @@ mod store_key; use std::{ collections::BTreeSet, - convert::TryFrom, + convert::{TryFrom, TryInto}, path::{Path, PathBuf}, sync::Arc, time::Instant, @@ -30,10 +30,12 @@ use matrix_sdk_common::{ async_trait, events::{ presence::PresenceEvent, + receipt::Receipt, room::member::{MemberEventContent, MembershipState}, AnyGlobalAccountDataEvent, AnyRoomAccountDataEvent, AnySyncStateEvent, EventType, }, - identifiers::{RoomId, UserId}, + identifiers::{EventId, MxcUri, RoomId, UserId}, + receipt::ReceiptType, Raw, }; use serde::{Deserialize, Serialize}; @@ -45,7 +47,10 @@ use tracing::info; use self::store_key::{EncryptedEvent, StoreKey}; use super::{Result, RoomInfo, StateChanges, StateStore, StoreError}; -use crate::deserialized_responses::MemberEvent; +use crate::{ + deserialized_responses::MemberEvent, + media::{MediaRequest, UniqueKey}, +}; #[derive(Debug, Serialize, Deserialize)] pub enum DatabaseType { @@ -127,12 +132,41 @@ impl EncodeKey for (&str, &str, &str) { } } +impl EncodeKey for (&str, &str, &str, &str) { + fn encode(&self) -> Vec { + [ + self.0.as_bytes(), + &[ENCODE_SEPARATOR], + self.1.as_bytes(), + &[ENCODE_SEPARATOR], + self.2.as_bytes(), + &[ENCODE_SEPARATOR], + self.3.as_bytes(), + &[ENCODE_SEPARATOR], + ] + .concat() + } +} + impl EncodeKey for EventType { fn encode(&self) -> Vec { self.as_str().encode() } } +/// Get the value at `position` in encoded `key`. +/// +/// The key must have been encoded with the `EncodeKey` trait. `position` +/// corresponds to the position in the tuple before the key was encoded. If it +/// wasn't encoded in a tuple, use `0`. +/// +/// Returns `None` if there is no key at `position`. +pub fn decode_key_value(key: &[u8], position: usize) -> Option { + let values: Vec<&[u8]> = key.split(|v| *v == ENCODE_SEPARATOR).collect(); + + values.get(position).map(|s| String::from_utf8_lossy(s).to_string()) +} + #[derive(Clone)] pub struct SledStore { path: Option, @@ -152,6 +186,9 @@ pub struct SledStore { stripped_room_state: Tree, stripped_members: Tree, presence: Tree, + room_user_receipts: Tree, + room_event_receipts: Tree, + media: Tree, } impl std::fmt::Debug for SledStore { @@ -184,6 +221,11 @@ impl SledStore { let stripped_members = db.open_tree("stripped_members")?; let stripped_room_state = db.open_tree("stripped_room_state")?; + let room_user_receipts = db.open_tree("room_user_receipts")?; + let room_event_receipts = db.open_tree("room_event_receipts")?; + + let media = db.open_tree("media")?; + Ok(Self { path, inner: db, @@ -202,6 +244,9 @@ impl SledStore { stripped_room_info, stripped_members, stripped_room_state, + room_user_receipts, + room_event_receipts, + media, }) } @@ -459,6 +504,58 @@ impl SledStore { ret?; + let ret: Result<(), TransactionError> = + (&self.room_user_receipts, &self.room_event_receipts).transaction( + |(room_user_receipts, room_event_receipts)| { + for (room, content) in &changes.receipts { + for (event_id, receipts) in &content.0 { + for (receipt_type, receipts) in receipts { + for (user_id, receipt) in receipts { + // Add the receipt to the room user receipts + if let Some(old) = room_user_receipts.insert( + (room.as_str(), receipt_type.as_ref(), user_id.as_str()) + .encode(), + self.serialize_event(&(event_id, receipt)) + .map_err(ConflictableTransactionError::Abort)?, + )? { + // Remove the old receipt from the room event receipts + let (old_event, _): (EventId, Receipt) = self + .deserialize_event(&old) + .map_err(ConflictableTransactionError::Abort)?; + room_event_receipts.remove( + ( + room.as_str(), + receipt_type.as_ref(), + old_event.as_str(), + user_id.as_str(), + ) + .encode(), + )?; + } + + // Add the receipt to the room event receipts + room_event_receipts.insert( + ( + room.as_str(), + receipt_type.as_ref(), + event_id.as_str(), + user_id.as_str(), + ) + .encode(), + self.serialize_event(receipt) + .map_err(ConflictableTransactionError::Abort)?, + )?; + } + } + } + } + + Ok(()) + }, + ); + + ret?; + self.inner.flush_async().await?; info!("Saved changes in {:?}", now.elapsed()); @@ -598,6 +695,79 @@ impl SledStore { .map(|m| self.deserialize_event(&m)) .transpose()?) } + + async fn get_user_room_receipt_event( + &self, + room_id: &RoomId, + receipt_type: ReceiptType, + user_id: &UserId, + ) -> Result> { + Ok(self + .room_user_receipts + .get((room_id.as_str(), receipt_type.as_ref(), user_id.as_str()).encode())? + .map(|m| self.deserialize_event(&m)) + .transpose()?) + } + + async fn get_event_room_receipt_events( + &self, + room_id: &RoomId, + receipt_type: ReceiptType, + event_id: &EventId, + ) -> Result> { + self.room_event_receipts + .scan_prefix((room_id.as_str(), receipt_type.as_ref(), event_id.as_str()).encode()) + .map(|u| { + u.map_err(StoreError::Sled).and_then(|(key, value)| { + self.deserialize_event(&value) + .map(|receipt| { + (decode_key_value(&key, 3).unwrap().try_into().unwrap(), receipt) + }) + .map_err(Into::into) + }) + }) + .collect() + } + + async fn add_media_content(&self, request: &MediaRequest, data: Vec) -> Result<()> { + self.media.insert( + (request.media_type.unique_key().as_str(), request.format.unique_key().as_str()) + .encode(), + data, + )?; + + Ok(()) + } + + async fn get_media_content(&self, request: &MediaRequest) -> Result>> { + Ok(self + .media + .get( + (request.media_type.unique_key().as_str(), request.format.unique_key().as_str()) + .encode(), + )? + .map(|m| m.to_vec())) + } + + async fn remove_media_content(&self, request: &MediaRequest) -> Result<()> { + self.media.remove( + (request.media_type.unique_key().as_str(), request.format.unique_key().as_str()) + .encode(), + )?; + + Ok(()) + } + + async fn remove_media_content_for_uri(&self, uri: &MxcUri) -> Result<()> { + let keys = self.media.scan_prefix(uri.as_str().encode()).keys(); + + let mut batch = sled::Batch::default(); + for key in keys { + batch.remove(key?); + } + + Ok(self.media.apply_batch(batch)?) + } } #[async_trait] @@ -689,6 +859,40 @@ impl StateStore for SledStore { ) -> Result>> { self.get_room_account_data_event(room_id, event_type).await } + + async fn get_user_room_receipt_event( + &self, + room_id: &RoomId, + receipt_type: ReceiptType, + user_id: &UserId, + ) -> Result> { + self.get_user_room_receipt_event(room_id, receipt_type, user_id).await + } + + async fn get_event_room_receipt_events( + &self, + room_id: &RoomId, + receipt_type: ReceiptType, + event_id: &EventId, + ) -> Result> { + self.get_event_room_receipt_events(room_id, receipt_type, event_id).await + } + + async fn add_media_content(&self, request: &MediaRequest, data: Vec) -> Result<()> { + self.add_media_content(request, data).await + } + + async fn get_media_content(&self, request: &MediaRequest) -> Result>> { + self.get_media_content(request).await + } + + async fn remove_media_content(&self, request: &MediaRequest) -> Result<()> { + self.remove_media_content(request).await + } + + async fn remove_media_content_for_uri(&self, uri: &MxcUri) -> Result<()> { + self.remove_media_content_for_uri(uri).await + } } #[cfg(test)] @@ -696,6 +900,7 @@ mod test { use std::convert::TryFrom; use matrix_sdk_common::{ + api::r0::media::get_content_thumbnail::Method, events::{ room::{ member::{MemberEventContent, MembershipState}, @@ -703,14 +908,19 @@ mod test { }, AnySyncStateEvent, EventType, Unsigned, }, - identifiers::{room_id, user_id, EventId, UserId}, - MilliSecondsSinceUnixEpoch, Raw, + identifiers::{event_id, mxc_uri, room_id, user_id, EventId, UserId}, + receipt::ReceiptType, + uint, MilliSecondsSinceUnixEpoch, Raw, }; use matrix_sdk_test::async_test; use serde_json::json; use super::{SledStore, StateChanges}; - use crate::{deserialized_responses::MemberEvent, StateStore}; + use crate::{ + deserialized_responses::MemberEvent, + media::{MediaFormat, MediaRequest, MediaThumbnailSize, MediaType}, + StateStore, + }; fn user_id() -> UserId { user_id!("@example:localhost") @@ -788,4 +998,137 @@ mod test { .unwrap() .is_some()); } + + #[async_test] + async fn test_receipts_saving() { + let store = SledStore::open().unwrap(); + + let room_id = room_id!("!test:localhost"); + + let first_event_id = event_id!("$1435641916114394fHBLK:matrix.org"); + let second_event_id = event_id!("$fHBLK1435641916114394:matrix.org"); + + let first_receipt_event = serde_json::from_value(json!({ + first_event_id.clone(): { + "m.read": { + user_id(): { + "ts": 1436451550453u64 + } + } + } + })) + .unwrap(); + + let second_receipt_event = serde_json::from_value(json!({ + second_event_id.clone(): { + "m.read": { + user_id(): { + "ts": 1436451551453u64 + } + } + } + })) + .unwrap(); + + assert!(store + .get_user_room_receipt_event(&room_id, ReceiptType::Read, &user_id()) + .await + .unwrap() + .is_none()); + assert!(store + .get_event_room_receipt_events(&room_id, ReceiptType::Read, &first_event_id) + .await + .unwrap() + .is_empty()); + assert!(store + .get_event_room_receipt_events(&room_id, ReceiptType::Read, &second_event_id) + .await + .unwrap() + .is_empty()); + + let mut changes = StateChanges::default(); + changes.add_receipts(&room_id, first_receipt_event); + + store.save_changes(&changes).await.unwrap(); + assert!(store + .get_user_room_receipt_event(&room_id, ReceiptType::Read, &user_id()) + .await + .unwrap() + .is_some(),); + assert_eq!( + store + .get_event_room_receipt_events(&room_id, ReceiptType::Read, &first_event_id) + .await + .unwrap() + .len(), + 1 + ); + assert!(store + .get_event_room_receipt_events(&room_id, ReceiptType::Read, &second_event_id) + .await + .unwrap() + .is_empty()); + + let mut changes = StateChanges::default(); + changes.add_receipts(&room_id, second_receipt_event); + + store.save_changes(&changes).await.unwrap(); + assert!(store + .get_user_room_receipt_event(&room_id, ReceiptType::Read, &user_id()) + .await + .unwrap() + .is_some()); + assert!(store + .get_event_room_receipt_events(&room_id, ReceiptType::Read, &first_event_id) + .await + .unwrap() + .is_empty()); + assert_eq!( + store + .get_event_room_receipt_events(&room_id, ReceiptType::Read, &second_event_id) + .await + .unwrap() + .len(), + 1 + ); + } + + #[async_test] + async fn test_media_content() { + let store = SledStore::open().unwrap(); + + let uri = mxc_uri!("mxc://localhost/media"); + let content: Vec = "somebinarydata".into(); + + let request_file = + MediaRequest { media_type: MediaType::Uri(uri.clone()), format: MediaFormat::File }; + + let request_thumbnail = MediaRequest { + media_type: MediaType::Uri(uri.clone()), + format: MediaFormat::Thumbnail(MediaThumbnailSize { + method: Method::Crop, + width: uint!(100), + height: uint!(100), + }), + }; + + assert!(store.get_media_content(&request_file).await.unwrap().is_none()); + assert!(store.get_media_content(&request_thumbnail).await.unwrap().is_none()); + + store.add_media_content(&request_file, content.clone()).await.unwrap(); + assert!(store.get_media_content(&request_file).await.unwrap().is_some()); + + store.remove_media_content(&request_file).await.unwrap(); + assert!(store.get_media_content(&request_file).await.unwrap().is_none()); + + store.add_media_content(&request_file, content.clone()).await.unwrap(); + assert!(store.get_media_content(&request_file).await.unwrap().is_some()); + + store.add_media_content(&request_thumbnail, content.clone()).await.unwrap(); + assert!(store.get_media_content(&request_thumbnail).await.unwrap().is_some()); + + store.remove_media_content_for_uri(&uri).await.unwrap(); + assert!(store.get_media_content(&request_file).await.unwrap().is_none()); + assert!(store.get_media_content(&request_thumbnail).await.unwrap().is_none()); + } } diff --git a/matrix_sdk_crypto/src/file_encryption/attachments.rs b/matrix_sdk_crypto/src/file_encryption/attachments.rs index 81c3cf23..70354f1f 100644 --- a/matrix_sdk_crypto/src/file_encryption/attachments.rs +++ b/matrix_sdk_crypto/src/file_encryption/attachments.rs @@ -23,7 +23,7 @@ use aes_ctr::{ }; use base64::DecodeError; use getrandom::getrandom; -use matrix_sdk_common::events::room::{JsonWebKey, JsonWebKeyInit}; +use matrix_sdk_common::events::room::{EncryptedFile, JsonWebKey, JsonWebKeyInit}; use serde::{Deserialize, Serialize}; use sha2::{Digest, Sha256}; use thiserror::Error; @@ -252,6 +252,12 @@ pub struct EncryptionInfo { pub hashes: BTreeMap, } +impl From for EncryptionInfo { + fn from(file: EncryptedFile) -> Self { + Self { version: file.v, web_key: file.key, iv: file.iv, hashes: file.hashes } + } +} + #[cfg(test)] mod test { use std::io::{Cursor, Read}; diff --git a/matrix_sdk_test/src/lib.rs b/matrix_sdk_test/src/lib.rs index b86b760d..d5f8fb29 100644 --- a/matrix_sdk_test/src/lib.rs +++ b/matrix_sdk_test/src/lib.rs @@ -364,7 +364,7 @@ impl EventBuilder { } } -/// Embedded sync reponse files +/// Embedded sync response files pub enum SyncResponseFile { All, Default, diff --git a/matrix_sdk_test/src/test_json/mod.rs b/matrix_sdk_test/src/test_json/mod.rs index c743e395..64b628f6 100644 --- a/matrix_sdk_test/src/test_json/mod.rs +++ b/matrix_sdk_test/src/test_json/mod.rs @@ -42,3 +42,29 @@ lazy_static! { ] }); } + +lazy_static! { + pub static ref WELL_KNOWN: JsonValue = json!({ + "m.homeserver": { + "base_url": "HOMESERVER_URL" + } + }); +} + +lazy_static! { + pub static ref VERSIONS: JsonValue = json!({ + "versions": [ + "r0.0.1", + "r0.1.0", + "r0.2.0", + "r0.3.0", + "r0.4.0", + "r0.5.0", + "r0.6.0" + ], + "unstable_features": { + "org.matrix.label_based_filtering":true, + "org.matrix.e2e_cross_signing":true + } + }); +} diff --git a/matrix_sdk_test/src/test_json/sync.rs b/matrix_sdk_test/src/test_json/sync.rs index 862b126a..cb1dc57c 100644 --- a/matrix_sdk_test/src/test_json/sync.rs +++ b/matrix_sdk_test/src/test_json/sync.rs @@ -720,7 +720,7 @@ lazy_static! { lazy_static! { pub static ref INVITE_SYNC: JsonValue = json!({ "device_one_time_keys_count": {}, - "next_batch": "s526_47314_0_7_1_1_1_11444_1", + "next_batch": "s526_47314_0_7_1_1_1_11444_2", "device_lists": { "changed": [ "@example:example.org"