diff --git a/matrix_sdk/src/client.rs b/matrix_sdk/src/client.rs index 6e16314c..5a393d84 100644 --- a/matrix_sdk/src/client.rs +++ b/matrix_sdk/src/client.rs @@ -14,7 +14,11 @@ // limitations under the License. #[cfg(feature = "encryption")] -use std::{collections::BTreeMap, io::Write, path::PathBuf}; +use std::{ + collections::BTreeMap, + io::{Cursor, Write}, + path::PathBuf, +}; #[cfg(feature = "sso_login")] use std::{ collections::HashMap, @@ -38,10 +42,13 @@ use http::Response; #[cfg(feature = "encryption")] use matrix_sdk_base::crypto::{ decrypt_key_export, encrypt_key_export, olm::InboundGroupSession, store::CryptoStoreError, - OutgoingRequests, RoomMessageRequest, ToDeviceRequest, + AttachmentDecryptor, OutgoingRequests, RoomMessageRequest, ToDeviceRequest, }; use matrix_sdk_base::{ - deserialized_responses::SyncResponse, events::AnyMessageEventContent, identifiers::MxcUri, + deserialized_responses::SyncResponse, + events::AnyMessageEventContent, + identifiers::MxcUri, + media::{MediaEventContent, MediaFormat, MediaRequest, MediaThumbnailSize, MediaType}, BaseClient, BaseClientConfig, SendAccessToken, Session, Store, }; use mime::{self, Mime}; @@ -2465,6 +2472,208 @@ impl Client { Ok(olm.import_keys(import, |_, _| {}).await?) } + + /// Get a media file's content. + /// + /// If the content is encrypted and encryption is enabled, the content will + /// be decrypted. + /// + /// # Arguments + /// + /// * `request` - The `MediaRequest` of the content. + /// + /// * `use_cache` - If we should use the media cache for this request. + pub async fn get_media_content( + &self, + request: &MediaRequest, + use_cache: bool, + ) -> Result> { + let content = if use_cache { + self.base_client.store().get_media_content(request).await? + } else { + None + }; + + if let Some(content) = content { + Ok(content) + } else { + let content: Vec = match &request.media_type { + MediaType::Encrypted(file) => { + let content: Vec = + self.send(get_content::Request::from_url(&file.url)?, None).await?.file; + + #[cfg(feature = "encryption")] + let content = { + let mut cursor = Cursor::new(content); + let mut reader = + AttachmentDecryptor::new(&mut cursor, file.as_ref().clone().into())?; + + let mut decrypted = Vec::new(); + reader.read_to_end(&mut decrypted)?; + + decrypted + }; + + content + } + MediaType::Uri(uri) => { + if let MediaFormat::Thumbnail(size) = &request.format { + self.send( + get_content_thumbnail::Request::from_url( + &uri, + size.width, + size.height, + )?, + None, + ) + .await? + .file + } else { + self.send(get_content::Request::from_url(&uri)?, None).await?.file + } + } + }; + + if use_cache { + self.base_client.store().add_media_content(request, content.clone()).await?; + } + + Ok(content) + } + } + + /// Remove a media file's content from the store. + /// + /// # Arguments + /// + /// * `request` - The `MediaRequest` of the content. + pub async fn remove_media_content(&self, request: &MediaRequest) -> Result<()> { + Ok(self.base_client.store().remove_media_content(request).await?) + } + + /// Delete all the media content corresponding to the given + /// uri from the store. + /// + /// # Arguments + /// + /// * `uri` - The `MxcUri` of the files. + pub async fn remove_media_content_for_uri(&self, uri: &MxcUri) -> Result<()> { + Ok(self.base_client.store().remove_media_content_for_uri(&uri).await?) + } + + /// Get the file of the given media event content. + /// + /// If the content is encrypted and encryption is enabled, the content will + /// be decrypted. + /// + /// Returns `Ok(None)` if the event content has no file. + /// + /// This is a convenience method that calls the + /// [`get_media_content`](#method.get_media_content) method. + /// + /// # Arguments + /// + /// * `event_content` - The media event content. + /// + /// * `use_cache` - If we should use the media cache for this file. + pub async fn get_file( + &self, + event_content: impl MediaEventContent, + use_cache: bool, + ) -> Result>> { + if let Some(media_type) = event_content.file() { + Ok(Some( + self.get_media_content( + &MediaRequest { media_type, format: MediaFormat::File }, + use_cache, + ) + .await?, + )) + } else { + Ok(None) + } + } + + /// Remove the file of the given media event content from the cache. + /// + /// This is a convenience method that calls the + /// [`remove_media_content`](#method.remove_media_content) method. + /// + /// # Arguments + /// + /// * `event_content` - The media event content. + pub async fn remove_file(&self, event_content: impl MediaEventContent) -> Result<()> { + if let Some(media_type) = event_content.file() { + self.remove_media_content(&MediaRequest { media_type, format: MediaFormat::File }) + .await? + } + + Ok(()) + } + + /// Get a thumbnail of the given media event content. + /// + /// If the content is encrypted and encryption is enabled, the content will + /// be decrypted. + /// + /// Returns `Ok(None)` if the event content has no thumbnail. + /// + /// This is a convenience method that calls the + /// [`get_media_content`](#method.get_media_content) method. + /// + /// # Arguments + /// + /// * `event_content` - The media event content. + /// + /// * `size` - The _desired_ size of the thumbnail. The actual thumbnail may + /// not match the size specified. + /// + /// * `use_cache` - If we should use the media cache for this thumbnail. + pub async fn get_thumbnail( + &self, + event_content: impl MediaEventContent, + size: MediaThumbnailSize, + use_cache: bool, + ) -> Result>> { + if let Some(media_type) = event_content.thumbnail() { + Ok(Some( + self.get_media_content( + &MediaRequest { media_type, format: MediaFormat::Thumbnail(size) }, + use_cache, + ) + .await?, + )) + } else { + Ok(None) + } + } + + /// Remove the thumbnail of the given media event content from the cache. + /// + /// This is a convenience method that calls the + /// [`remove_media_content`](#method.remove_media_content) method. + /// + /// # Arguments + /// + /// * `event_content` - The media event content. + /// + /// * `size` - The _desired_ size of the thumbnail. Must match the size + /// requested with [`get_thumbnail`](#method.get_thumbnail). + pub async fn remove_thumbnail( + &self, + event_content: impl MediaEventContent, + size: MediaThumbnailSize, + ) -> Result<()> { + if let Some(media_type) = event_content.file() { + self.remove_media_content(&MediaRequest { + media_type, + format: MediaFormat::Thumbnail(size), + }) + .await? + } + + Ok(()) + } } #[cfg(test)] @@ -2477,7 +2686,13 @@ mod test { time::Duration, }; - use matrix_sdk_base::identifiers::mxc_uri; + use matrix_sdk_base::{ + api::r0::media::get_content_thumbnail::Method, + events::room::{message::ImageMessageEventContent, ImageInfo}, + identifiers::mxc_uri, + media::{MediaFormat, MediaRequest, MediaThumbnailSize, MediaType}, + uint, + }; use matrix_sdk_common::{ api::r0::{ account::register::Request as RegistrationRequest, @@ -3554,4 +3769,74 @@ mod test { panic!("this request should return an `Err` variant") } } + + #[tokio::test] + async fn get_media_content() { + let client = logged_in_client().await; + + let request = MediaRequest { + media_type: MediaType::Uri(mxc_uri!("mxc://localhost/textfile")), + format: MediaFormat::File, + }; + + let m = mock( + "GET", + Matcher::Regex(r"^/_matrix/media/r0/download/localhost/textfile\?.*$".to_string()), + ) + .with_status(200) + .with_body("Some very interesting text.") + .expect(2) + .create(); + + assert!(client.get_media_content(&request, true).await.is_ok()); + assert!(client.get_media_content(&request, true).await.is_ok()); + assert!(client.get_media_content(&request, false).await.is_ok()); + m.assert(); + } + + #[tokio::test] + async fn get_media_file() { + let client = logged_in_client().await; + + let event_content = ImageMessageEventContent::plain( + "filename.jpg".into(), + mxc_uri!("mxc://example.org/image"), + Some(Box::new(assign!(ImageInfo::new(), { + height: Some(uint!(398)), + width: Some(uint!(394)), + mimetype: Some("image/jpeg".into()), + size: Some(uint!(31037)), + }))), + ); + + let m = mock( + "GET", + Matcher::Regex(r"^/_matrix/media/r0/download/example%2Eorg/image\?.*$".to_string()), + ) + .with_status(200) + .with_body("binaryjpegdata") + .create(); + + assert!(client.get_file(event_content.clone(), true).await.is_ok()); + assert!(client.get_file(event_content.clone(), true).await.is_ok()); + m.assert(); + + let m = mock( + "GET", + Matcher::Regex(r"^/_matrix/media/r0/thumbnail/example%2Eorg/image\?.*$".to_string()), + ) + .with_status(200) + .with_body("smallerbinaryjpegdata") + .create(); + + assert!(client + .get_thumbnail( + event_content, + MediaThumbnailSize { method: Method::Scale, width: uint!(100), height: uint!(100) }, + true + ) + .await + .is_ok()); + m.assert(); + } } diff --git a/matrix_sdk/src/error.rs b/matrix_sdk/src/error.rs index b5221376..4fa4ff0f 100644 --- a/matrix_sdk/src/error.rs +++ b/matrix_sdk/src/error.rs @@ -18,7 +18,7 @@ use std::io::Error as IoError; use http::StatusCode; #[cfg(feature = "encryption")] -use matrix_sdk_base::crypto::store::CryptoStoreError; +use matrix_sdk_base::crypto::{store::CryptoStoreError, DecryptorError}; use matrix_sdk_base::{Error as MatrixError, StoreError}; use matrix_sdk_common::{ api::{ @@ -122,6 +122,11 @@ pub enum Error { #[error(transparent)] CryptoStoreError(#[from] CryptoStoreError), + /// An error occurred during decryption. + #[cfg(feature = "encryption")] + #[error(transparent)] + DecryptorError(#[from] DecryptorError), + /// An error occurred in the state store. #[error(transparent)] StateStore(#[from] StoreError), diff --git a/matrix_sdk/src/lib.rs b/matrix_sdk/src/lib.rs index 8b5da4eb..9a105305 100644 --- a/matrix_sdk/src/lib.rs +++ b/matrix_sdk/src/lib.rs @@ -81,7 +81,7 @@ pub use bytes::{Bytes, BytesMut}; #[cfg_attr(feature = "docs", doc(cfg(encryption)))] pub use matrix_sdk_base::crypto::{EncryptionInfo, LocalTrust}; pub use matrix_sdk_base::{ - Error as BaseError, Room as BaseRoom, RoomInfo, RoomMember as BaseRoomMember, RoomType, + media, Error as BaseError, Room as BaseRoom, RoomInfo, RoomMember as BaseRoomMember, RoomType, Session, StateChanges, StoreError, }; pub use matrix_sdk_common::*; diff --git a/matrix_sdk_base/Cargo.toml b/matrix_sdk_base/Cargo.toml index baf1b072..1c25de53 100644 --- a/matrix_sdk_base/Cargo.toml +++ b/matrix_sdk_base/Cargo.toml @@ -25,6 +25,7 @@ docs = ["encryption", "sled_cryptostore"] [dependencies] dashmap = "4.0.2" +lru = "0.6.5" serde = { version = "1.0.122", features = ["rc"] } serde_json = "1.0.61" tracing = "0.1.22" diff --git a/matrix_sdk_base/src/lib.rs b/matrix_sdk_base/src/lib.rs index 358f69ea..326700ce 100644 --- a/matrix_sdk_base/src/lib.rs +++ b/matrix_sdk_base/src/lib.rs @@ -45,6 +45,7 @@ pub use crate::{ mod client; mod error; +pub mod media; mod rooms; mod session; mod store; diff --git a/matrix_sdk_base/src/media.rs b/matrix_sdk_base/src/media.rs new file mode 100644 index 00000000..9d075e63 --- /dev/null +++ b/matrix_sdk_base/src/media.rs @@ -0,0 +1,216 @@ +//! Common types for [media content](https://matrix.org/docs/spec/client_server/r0.6.1#id66). + +use matrix_sdk_common::{ + api::r0::media::get_content_thumbnail::Method, + events::{ + room::{ + message::{ + AudioMessageEventContent, FileMessageEventContent, ImageMessageEventContent, + LocationMessageEventContent, VideoMessageEventContent, + }, + EncryptedFile, + }, + sticker::StickerEventContent, + }, + identifiers::MxcUri, + UInt, +}; + +const UNIQUE_SEPARATOR: &str = "_"; + +/// A trait to uniquely identify values of the same type. +pub trait UniqueKey { + /// A string that uniquely identifies `Self` compared to other values of + /// the same type. + fn unique_key(&self) -> String; +} + +/// The requested format of a media file. +#[derive(Clone, Debug)] +pub enum MediaFormat { + /// The file that was uploaded. + File, + + /// A thumbnail of the file that was uploaded. + Thumbnail(MediaThumbnailSize), +} + +impl UniqueKey for MediaFormat { + fn unique_key(&self) -> String { + match self { + Self::File => "file".into(), + Self::Thumbnail(size) => size.unique_key(), + } + } +} + +/// The requested size of a media thumbnail. +#[derive(Clone, Debug)] +pub struct MediaThumbnailSize { + /// The desired resizing method. + pub method: Method, + + /// The desired width of the thumbnail. The actual thumbnail may not match + /// the size specified. + pub width: UInt, + + /// The desired height of the thumbnail. The actual thumbnail may not match + /// the size specified. + pub height: UInt, +} + +impl UniqueKey for MediaThumbnailSize { + fn unique_key(&self) -> String { + format!("{}{}{}x{}", self.method, UNIQUE_SEPARATOR, self.width, self.height) + } +} + +/// A request for media data. +#[derive(Clone, Debug)] +pub enum MediaType { + /// A media content URI. + Uri(MxcUri), + + /// An encrypted media content. + Encrypted(Box), +} + +impl UniqueKey for MediaType { + fn unique_key(&self) -> String { + match self { + Self::Uri(uri) => uri.to_string(), + Self::Encrypted(file) => file.url.to_string(), + } + } +} + +/// A request for media data. +#[derive(Clone, Debug)] +pub struct MediaRequest { + /// The type of the media file. + pub media_type: MediaType, + + /// The requested format of the media data. + pub format: MediaFormat, +} + +impl UniqueKey for MediaRequest { + fn unique_key(&self) -> String { + format!("{}{}{}", self.media_type.unique_key(), UNIQUE_SEPARATOR, self.format.unique_key()) + } +} + +/// Trait for media event content. +pub trait MediaEventContent { + /// Get the type of the file for `Self`. + /// + /// Returns `None` if `Self` has no file. + fn file(&self) -> Option; + + /// Get the type of the thumbnail for `Self`. + /// + /// Returns `None` if `Self` has no thumbnail. + fn thumbnail(&self) -> Option; +} + +impl MediaEventContent for StickerEventContent { + fn file(&self) -> Option { + Some(MediaType::Uri(self.url.clone())) + } + + fn thumbnail(&self) -> Option { + None + } +} + +impl MediaEventContent for AudioMessageEventContent { + fn file(&self) -> Option { + self.url + .as_ref() + .map(|uri| MediaType::Uri(uri.clone())) + .or_else(|| self.file.as_ref().map(|e| MediaType::Encrypted(e.clone()))) + } + + fn thumbnail(&self) -> Option { + None + } +} + +impl MediaEventContent for FileMessageEventContent { + fn file(&self) -> Option { + self.url + .as_ref() + .map(|uri| MediaType::Uri(uri.clone())) + .or_else(|| self.file.as_ref().map(|e| MediaType::Encrypted(e.clone()))) + } + + fn thumbnail(&self) -> Option { + self.info.as_ref().and_then(|info| { + if let Some(uri) = info.thumbnail_url.as_ref() { + Some(MediaType::Uri(uri.clone())) + } else { + info.thumbnail_file.as_ref().map(|file| MediaType::Encrypted(file.clone())) + } + }) + } +} + +impl MediaEventContent for ImageMessageEventContent { + fn file(&self) -> Option { + self.url + .as_ref() + .map(|uri| MediaType::Uri(uri.clone())) + .or_else(|| self.file.as_ref().map(|e| MediaType::Encrypted(e.clone()))) + } + + fn thumbnail(&self) -> Option { + self.info + .as_ref() + .and_then(|info| { + if let Some(uri) = info.thumbnail_url.as_ref() { + Some(MediaType::Uri(uri.clone())) + } else { + info.thumbnail_file.as_ref().map(|file| MediaType::Encrypted(file.clone())) + } + }) + .or_else(|| self.url.as_ref().map(|uri| MediaType::Uri(uri.clone()))) + } +} + +impl MediaEventContent for VideoMessageEventContent { + fn file(&self) -> Option { + self.url + .as_ref() + .map(|uri| MediaType::Uri(uri.clone())) + .or_else(|| self.file.as_ref().map(|e| MediaType::Encrypted(e.clone()))) + } + + fn thumbnail(&self) -> Option { + self.info + .as_ref() + .and_then(|info| { + if let Some(uri) = info.thumbnail_url.as_ref() { + Some(MediaType::Uri(uri.clone())) + } else { + info.thumbnail_file.as_ref().map(|file| MediaType::Encrypted(file.clone())) + } + }) + .or_else(|| self.url.as_ref().map(|uri| MediaType::Uri(uri.clone()))) + } +} + +impl MediaEventContent for LocationMessageEventContent { + fn file(&self) -> Option { + None + } + + fn thumbnail(&self) -> Option { + self.info.as_ref().and_then(|info| { + if let Some(uri) = info.thumbnail_url.as_ref() { + Some(MediaType::Uri(uri.clone())) + } else { + info.thumbnail_file.as_ref().map(|file| MediaType::Encrypted(file.clone())) + } + }) + } +} diff --git a/matrix_sdk_base/src/store/memory_store.rs b/matrix_sdk_base/src/store/memory_store.rs index 7306dfcb..7d15254a 100644 --- a/matrix_sdk_base/src/store/memory_store.rs +++ b/matrix_sdk_base/src/store/memory_store.rs @@ -18,6 +18,7 @@ use std::{ }; use dashmap::{DashMap, DashSet}; +use lru::LruCache; use matrix_sdk_common::{ async_trait, events::{ @@ -27,15 +28,19 @@ use matrix_sdk_common::{ AnyGlobalAccountDataEvent, AnyRoomAccountDataEvent, AnyStrippedStateEvent, AnySyncStateEvent, EventType, }, - identifiers::{EventId, RoomId, UserId}, + identifiers::{EventId, MxcUri, RoomId, UserId}, instant::Instant, + locks::Mutex, receipt::ReceiptType, Raw, }; use tracing::info; use super::{Result, RoomInfo, StateChanges, StateStore}; -use crate::deserialized_responses::{MemberEvent, StrippedMemberEvent}; +use crate::{ + deserialized_responses::{MemberEvent, StrippedMemberEvent}, + media::{MediaRequest, UniqueKey}, +}; #[derive(Debug, Clone)] pub struct MemoryStore { @@ -62,6 +67,7 @@ pub struct MemoryStore { #[allow(clippy::type_complexity)] room_event_receipts: Arc>>>>, + media: Arc>>>, } impl MemoryStore { @@ -85,6 +91,7 @@ impl MemoryStore { presence: DashMap::new().into(), room_user_receipts: DashMap::new().into(), room_event_receipts: DashMap::new().into(), + media: Arc::new(Mutex::new(LruCache::new(100))), } } @@ -386,6 +393,39 @@ impl MemoryStore { }) .unwrap_or_else(Vec::new)) } + + async fn add_media_content(&self, request: &MediaRequest, data: Vec) -> Result<()> { + self.media.lock().await.put(request.unique_key(), data); + + Ok(()) + } + + async fn get_media_content(&self, request: &MediaRequest) -> Result>> { + Ok(self.media.lock().await.get(&request.unique_key()).cloned()) + } + + async fn remove_media_content(&self, request: &MediaRequest) -> Result<()> { + self.media.lock().await.pop(&request.unique_key()); + + Ok(()) + } + + async fn remove_media_content_for_uri(&self, uri: &MxcUri) -> Result<()> { + let mut media_store = self.media.lock().await; + + let keys: Vec = media_store + .iter() + .filter_map( + |(key, _)| if key.starts_with(&uri.to_string()) { Some(key.clone()) } else { None }, + ) + .collect(); + + for key in keys { + media_store.pop(&key); + } + + Ok(()) + } } #[cfg_attr(target_arch = "wasm32", async_trait(?Send))] @@ -501,19 +541,38 @@ impl StateStore for MemoryStore { ) -> Result> { self.get_event_room_receipt_events(room_id, receipt_type, event_id).await } + + async fn add_media_content(&self, request: &MediaRequest, data: Vec) -> Result<()> { + self.add_media_content(request, data).await + } + + async fn get_media_content(&self, request: &MediaRequest) -> Result>> { + self.get_media_content(request).await + } + + async fn remove_media_content(&self, request: &MediaRequest) -> Result<()> { + self.remove_media_content(request).await + } + + async fn remove_media_content_for_uri(&self, uri: &MxcUri) -> Result<()> { + self.remove_media_content_for_uri(uri).await + } } #[cfg(test)] #[cfg(not(feature = "sled_state_store"))] mod test { use matrix_sdk_common::{ - identifiers::{event_id, room_id, user_id}, + api::r0::media::get_content_thumbnail::Method, + identifiers::{event_id, mxc_uri, room_id, user_id, UserId}, receipt::ReceiptType, + uint, }; use matrix_sdk_test::async_test; use serde_json::json; use super::{MemoryStore, StateChanges}; + use crate::media::{MediaFormat, MediaRequest, MediaThumbnailSize, MediaType}; fn user_id() -> UserId { user_id!("@example:localhost") @@ -612,4 +671,43 @@ mod test { 1 ); } + + #[async_test] + async fn test_media_content() { + let store = MemoryStore::new(); + + let uri = mxc_uri!("mxc://localhost/media"); + let content: Vec = "somebinarydata".into(); + + let request_file = + MediaRequest { media_type: MediaType::Uri(uri.clone()), format: MediaFormat::File }; + + let request_thumbnail = MediaRequest { + media_type: MediaType::Uri(uri.clone()), + format: MediaFormat::Thumbnail(MediaThumbnailSize { + method: Method::Crop, + width: uint!(100), + height: uint!(100), + }), + }; + + assert!(store.get_media_content(&request_file).await.unwrap().is_none()); + assert!(store.get_media_content(&request_thumbnail).await.unwrap().is_none()); + + store.add_media_content(&request_file, content.clone()).await.unwrap(); + assert!(store.get_media_content(&request_file).await.unwrap().is_some()); + + store.remove_media_content(&request_file).await.unwrap(); + assert!(store.get_media_content(&request_file).await.unwrap().is_none()); + + store.add_media_content(&request_file, content.clone()).await.unwrap(); + assert!(store.get_media_content(&request_file).await.unwrap().is_some()); + + store.add_media_content(&request_thumbnail, content.clone()).await.unwrap(); + assert!(store.get_media_content(&request_thumbnail).await.unwrap().is_some()); + + store.remove_media_content_for_uri(&uri).await.unwrap(); + assert!(store.get_media_content(&request_file).await.unwrap().is_none()); + assert!(store.get_media_content(&request_thumbnail).await.unwrap().is_none()); + } } diff --git a/matrix_sdk_base/src/store/mod.rs b/matrix_sdk_base/src/store/mod.rs index 0c8380f2..bbc915ed 100644 --- a/matrix_sdk_base/src/store/mod.rs +++ b/matrix_sdk_base/src/store/mod.rs @@ -31,7 +31,7 @@ use matrix_sdk_common::{ AnyGlobalAccountDataEvent, AnyRoomAccountDataEvent, AnyStrippedStateEvent, AnySyncStateEvent, EventContent, EventType, }, - identifiers::{EventId, RoomId, UserId}, + identifiers::{EventId, MxcUri, RoomId, UserId}, locks::RwLock, receipt::ReceiptType, AsyncTraitDeps, Raw, @@ -41,6 +41,7 @@ use sled::Db; use crate::{ deserialized_responses::{MemberEvent, StrippedMemberEvent}, + media::MediaRequest, rooms::{RoomInfo, RoomType}, Room, Session, }; @@ -249,6 +250,37 @@ pub trait StateStore: AsyncTraitDeps { receipt_type: ReceiptType, event_id: &EventId, ) -> Result>; + + /// Add a media file's content in the media store. + /// + /// # Arguments + /// + /// * `request` - The `MediaRequest` of the file. + /// + /// * `content` - The content of the file. + async fn add_media_content(&self, request: &MediaRequest, content: Vec) -> Result<()>; + + /// Get a media file's content out of the media store. + /// + /// # Arguments + /// + /// * `request` - The `MediaRequest` of the file. + async fn get_media_content(&self, request: &MediaRequest) -> Result>>; + + /// Removes a media file's content from the media store. + /// + /// # Arguments + /// + /// * `request` - The `MediaRequest` of the file. + async fn remove_media_content(&self, request: &MediaRequest) -> Result<()>; + + /// Removes all the media files' content associated to an `MxcUri` from the + /// media store. + /// + /// # Arguments + /// + /// * `uri` - The `MxcUri` of the media files. + async fn remove_media_content_for_uri(&self, uri: &MxcUri) -> Result<()>; } /// A state store wrapper for the SDK. diff --git a/matrix_sdk_base/src/store/sled_store/mod.rs b/matrix_sdk_base/src/store/sled_store/mod.rs index bcf46007..13d35184 100644 --- a/matrix_sdk_base/src/store/sled_store/mod.rs +++ b/matrix_sdk_base/src/store/sled_store/mod.rs @@ -34,7 +34,7 @@ use matrix_sdk_common::{ room::member::{MemberEventContent, MembershipState}, AnyGlobalAccountDataEvent, AnyRoomAccountDataEvent, AnySyncStateEvent, EventType, }, - identifiers::{EventId, RoomId, UserId}, + identifiers::{EventId, MxcUri, RoomId, UserId}, receipt::ReceiptType, Raw, }; @@ -47,7 +47,10 @@ use tracing::info; use self::store_key::{EncryptedEvent, StoreKey}; use super::{Result, RoomInfo, StateChanges, StateStore, StoreError}; -use crate::deserialized_responses::MemberEvent; +use crate::{ + deserialized_responses::MemberEvent, + media::{MediaRequest, UniqueKey}, +}; #[derive(Debug, Serialize, Deserialize)] pub enum DatabaseType { @@ -185,6 +188,7 @@ pub struct SledStore { presence: Tree, room_user_receipts: Tree, room_event_receipts: Tree, + media: Tree, } impl std::fmt::Debug for SledStore { @@ -220,6 +224,8 @@ impl SledStore { let room_user_receipts = db.open_tree("room_user_receipts")?; let room_event_receipts = db.open_tree("room_event_receipts")?; + let media = db.open_tree("media")?; + Ok(Self { path, inner: db, @@ -240,6 +246,7 @@ impl SledStore { stripped_room_state, room_user_receipts, room_event_receipts, + media, }) } @@ -721,6 +728,46 @@ impl SledStore { }) .collect() } + + async fn add_media_content(&self, request: &MediaRequest, data: Vec) -> Result<()> { + self.media.insert( + (request.media_type.unique_key().as_str(), request.format.unique_key().as_str()) + .encode(), + data, + )?; + + Ok(()) + } + + async fn get_media_content(&self, request: &MediaRequest) -> Result>> { + Ok(self + .media + .get( + (request.media_type.unique_key().as_str(), request.format.unique_key().as_str()) + .encode(), + )? + .map(|m| m.to_vec())) + } + + async fn remove_media_content(&self, request: &MediaRequest) -> Result<()> { + self.media.remove( + (request.media_type.unique_key().as_str(), request.format.unique_key().as_str()) + .encode(), + )?; + + Ok(()) + } + + async fn remove_media_content_for_uri(&self, uri: &MxcUri) -> Result<()> { + let keys = self.media.scan_prefix(uri.as_str().encode()).keys(); + + let mut batch = sled::Batch::default(); + for key in keys { + batch.remove(key?); + } + + Ok(self.media.apply_batch(batch)?) + } } #[async_trait] @@ -830,6 +877,22 @@ impl StateStore for SledStore { ) -> Result> { self.get_event_room_receipt_events(room_id, receipt_type, event_id).await } + + async fn add_media_content(&self, request: &MediaRequest, data: Vec) -> Result<()> { + self.add_media_content(request, data).await + } + + async fn get_media_content(&self, request: &MediaRequest) -> Result>> { + self.get_media_content(request).await + } + + async fn remove_media_content(&self, request: &MediaRequest) -> Result<()> { + self.remove_media_content(request).await + } + + async fn remove_media_content_for_uri(&self, uri: &MxcUri) -> Result<()> { + self.remove_media_content_for_uri(uri).await + } } #[cfg(test)] @@ -837,6 +900,7 @@ mod test { use std::convert::TryFrom; use matrix_sdk_common::{ + api::r0::media::get_content_thumbnail::Method, events::{ room::{ member::{MemberEventContent, MembershipState}, @@ -844,15 +908,19 @@ mod test { }, AnySyncStateEvent, EventType, Unsigned, }, - identifiers::{event_id, room_id, user_id, EventId, UserId}, + identifiers::{event_id, mxc_uri, room_id, user_id, EventId, UserId}, receipt::ReceiptType, - MilliSecondsSinceUnixEpoch, Raw, + uint, MilliSecondsSinceUnixEpoch, Raw, }; use matrix_sdk_test::async_test; use serde_json::json; use super::{SledStore, StateChanges}; - use crate::{deserialized_responses::MemberEvent, StateStore}; + use crate::{ + deserialized_responses::MemberEvent, + media::{MediaFormat, MediaRequest, MediaThumbnailSize, MediaType}, + StateStore, + }; fn user_id() -> UserId { user_id!("@example:localhost") @@ -1024,4 +1092,43 @@ mod test { 1 ); } + + #[async_test] + async fn test_media_content() { + let store = SledStore::open().unwrap(); + + let uri = mxc_uri!("mxc://localhost/media"); + let content: Vec = "somebinarydata".into(); + + let request_file = + MediaRequest { media_type: MediaType::Uri(uri.clone()), format: MediaFormat::File }; + + let request_thumbnail = MediaRequest { + media_type: MediaType::Uri(uri.clone()), + format: MediaFormat::Thumbnail(MediaThumbnailSize { + method: Method::Crop, + width: uint!(100), + height: uint!(100), + }), + }; + + assert!(store.get_media_content(&request_file).await.unwrap().is_none()); + assert!(store.get_media_content(&request_thumbnail).await.unwrap().is_none()); + + store.add_media_content(&request_file, content.clone()).await.unwrap(); + assert!(store.get_media_content(&request_file).await.unwrap().is_some()); + + store.remove_media_content(&request_file).await.unwrap(); + assert!(store.get_media_content(&request_file).await.unwrap().is_none()); + + store.add_media_content(&request_file, content.clone()).await.unwrap(); + assert!(store.get_media_content(&request_file).await.unwrap().is_some()); + + store.add_media_content(&request_thumbnail, content.clone()).await.unwrap(); + assert!(store.get_media_content(&request_thumbnail).await.unwrap().is_some()); + + store.remove_media_content_for_uri(&uri).await.unwrap(); + assert!(store.get_media_content(&request_file).await.unwrap().is_none()); + assert!(store.get_media_content(&request_thumbnail).await.unwrap().is_none()); + } } diff --git a/matrix_sdk_crypto/src/file_encryption/attachments.rs b/matrix_sdk_crypto/src/file_encryption/attachments.rs index 81c3cf23..70354f1f 100644 --- a/matrix_sdk_crypto/src/file_encryption/attachments.rs +++ b/matrix_sdk_crypto/src/file_encryption/attachments.rs @@ -23,7 +23,7 @@ use aes_ctr::{ }; use base64::DecodeError; use getrandom::getrandom; -use matrix_sdk_common::events::room::{JsonWebKey, JsonWebKeyInit}; +use matrix_sdk_common::events::room::{EncryptedFile, JsonWebKey, JsonWebKeyInit}; use serde::{Deserialize, Serialize}; use sha2::{Digest, Sha256}; use thiserror::Error; @@ -252,6 +252,12 @@ pub struct EncryptionInfo { pub hashes: BTreeMap, } +impl From for EncryptionInfo { + fn from(file: EncryptedFile) -> Self { + Self { version: file.v, web_key: file.key, iv: file.iv, hashes: file.hashes } + } +} + #[cfg(test)] mod test { use std::io::{Cursor, Read};