Merge branch 'media-store'

master
Damir Jelić 2021-05-31 10:36:20 +02:00
commit d58a190712
10 changed files with 767 additions and 16 deletions

View File

@ -14,7 +14,11 @@
// limitations under the License. // limitations under the License.
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
use std::{collections::BTreeMap, io::Write, path::PathBuf}; use std::{
collections::BTreeMap,
io::{Cursor, Write},
path::PathBuf,
};
#[cfg(feature = "sso_login")] #[cfg(feature = "sso_login")]
use std::{ use std::{
collections::HashMap, collections::HashMap,
@ -38,10 +42,13 @@ use http::Response;
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
use matrix_sdk_base::crypto::{ use matrix_sdk_base::crypto::{
decrypt_key_export, encrypt_key_export, olm::InboundGroupSession, store::CryptoStoreError, decrypt_key_export, encrypt_key_export, olm::InboundGroupSession, store::CryptoStoreError,
OutgoingRequests, RoomMessageRequest, ToDeviceRequest, AttachmentDecryptor, OutgoingRequests, RoomMessageRequest, ToDeviceRequest,
}; };
use matrix_sdk_base::{ use matrix_sdk_base::{
deserialized_responses::SyncResponse, events::AnyMessageEventContent, identifiers::MxcUri, deserialized_responses::SyncResponse,
events::AnyMessageEventContent,
identifiers::MxcUri,
media::{MediaEventContent, MediaFormat, MediaRequest, MediaThumbnailSize, MediaType},
BaseClient, BaseClientConfig, SendAccessToken, Session, Store, BaseClient, BaseClientConfig, SendAccessToken, Session, Store,
}; };
use mime::{self, Mime}; use mime::{self, Mime};
@ -2465,6 +2472,208 @@ impl Client {
Ok(olm.import_keys(import, |_, _| {}).await?) Ok(olm.import_keys(import, |_, _| {}).await?)
} }
/// Get a media file's content.
///
/// If the content is encrypted and encryption is enabled, the content will
/// be decrypted.
///
/// # Arguments
///
/// * `request` - The `MediaRequest` of the content.
///
/// * `use_cache` - If we should use the media cache for this request.
pub async fn get_media_content(
&self,
request: &MediaRequest,
use_cache: bool,
) -> Result<Vec<u8>> {
let content = if use_cache {
self.base_client.store().get_media_content(request).await?
} else {
None
};
if let Some(content) = content {
Ok(content)
} else {
let content: Vec<u8> = match &request.media_type {
MediaType::Encrypted(file) => {
let content: Vec<u8> =
self.send(get_content::Request::from_url(&file.url)?, None).await?.file;
#[cfg(feature = "encryption")]
let content = {
let mut cursor = Cursor::new(content);
let mut reader =
AttachmentDecryptor::new(&mut cursor, file.as_ref().clone().into())?;
let mut decrypted = Vec::new();
reader.read_to_end(&mut decrypted)?;
decrypted
};
content
}
MediaType::Uri(uri) => {
if let MediaFormat::Thumbnail(size) = &request.format {
self.send(
get_content_thumbnail::Request::from_url(
&uri,
size.width,
size.height,
)?,
None,
)
.await?
.file
} else {
self.send(get_content::Request::from_url(&uri)?, None).await?.file
}
}
};
if use_cache {
self.base_client.store().add_media_content(request, content.clone()).await?;
}
Ok(content)
}
}
/// Remove a media file's content from the store.
///
/// # Arguments
///
/// * `request` - The `MediaRequest` of the content.
pub async fn remove_media_content(&self, request: &MediaRequest) -> Result<()> {
Ok(self.base_client.store().remove_media_content(request).await?)
}
/// Delete all the media content corresponding to the given
/// uri from the store.
///
/// # Arguments
///
/// * `uri` - The `MxcUri` of the files.
pub async fn remove_media_content_for_uri(&self, uri: &MxcUri) -> Result<()> {
Ok(self.base_client.store().remove_media_content_for_uri(&uri).await?)
}
/// Get the file of the given media event content.
///
/// If the content is encrypted and encryption is enabled, the content will
/// be decrypted.
///
/// Returns `Ok(None)` if the event content has no file.
///
/// This is a convenience method that calls the
/// [`get_media_content`](#method.get_media_content) method.
///
/// # Arguments
///
/// * `event_content` - The media event content.
///
/// * `use_cache` - If we should use the media cache for this file.
pub async fn get_file(
&self,
event_content: impl MediaEventContent,
use_cache: bool,
) -> Result<Option<Vec<u8>>> {
if let Some(media_type) = event_content.file() {
Ok(Some(
self.get_media_content(
&MediaRequest { media_type, format: MediaFormat::File },
use_cache,
)
.await?,
))
} else {
Ok(None)
}
}
/// Remove the file of the given media event content from the cache.
///
/// This is a convenience method that calls the
/// [`remove_media_content`](#method.remove_media_content) method.
///
/// # Arguments
///
/// * `event_content` - The media event content.
pub async fn remove_file(&self, event_content: impl MediaEventContent) -> Result<()> {
if let Some(media_type) = event_content.file() {
self.remove_media_content(&MediaRequest { media_type, format: MediaFormat::File })
.await?
}
Ok(())
}
/// Get a thumbnail of the given media event content.
///
/// If the content is encrypted and encryption is enabled, the content will
/// be decrypted.
///
/// Returns `Ok(None)` if the event content has no thumbnail.
///
/// This is a convenience method that calls the
/// [`get_media_content`](#method.get_media_content) method.
///
/// # Arguments
///
/// * `event_content` - The media event content.
///
/// * `size` - The _desired_ size of the thumbnail. The actual thumbnail may
/// not match the size specified.
///
/// * `use_cache` - If we should use the media cache for this thumbnail.
pub async fn get_thumbnail(
&self,
event_content: impl MediaEventContent,
size: MediaThumbnailSize,
use_cache: bool,
) -> Result<Option<Vec<u8>>> {
if let Some(media_type) = event_content.thumbnail() {
Ok(Some(
self.get_media_content(
&MediaRequest { media_type, format: MediaFormat::Thumbnail(size) },
use_cache,
)
.await?,
))
} else {
Ok(None)
}
}
/// Remove the thumbnail of the given media event content from the cache.
///
/// This is a convenience method that calls the
/// [`remove_media_content`](#method.remove_media_content) method.
///
/// # Arguments
///
/// * `event_content` - The media event content.
///
/// * `size` - The _desired_ size of the thumbnail. Must match the size
/// requested with [`get_thumbnail`](#method.get_thumbnail).
pub async fn remove_thumbnail(
&self,
event_content: impl MediaEventContent,
size: MediaThumbnailSize,
) -> Result<()> {
if let Some(media_type) = event_content.file() {
self.remove_media_content(&MediaRequest {
media_type,
format: MediaFormat::Thumbnail(size),
})
.await?
}
Ok(())
}
} }
#[cfg(test)] #[cfg(test)]
@ -2477,7 +2686,13 @@ mod test {
time::Duration, time::Duration,
}; };
use matrix_sdk_base::identifiers::mxc_uri; use matrix_sdk_base::{
api::r0::media::get_content_thumbnail::Method,
events::room::{message::ImageMessageEventContent, ImageInfo},
identifiers::mxc_uri,
media::{MediaFormat, MediaRequest, MediaThumbnailSize, MediaType},
uint,
};
use matrix_sdk_common::{ use matrix_sdk_common::{
api::r0::{ api::r0::{
account::register::Request as RegistrationRequest, account::register::Request as RegistrationRequest,
@ -3554,4 +3769,74 @@ mod test {
panic!("this request should return an `Err` variant") panic!("this request should return an `Err` variant")
} }
} }
#[tokio::test]
async fn get_media_content() {
let client = logged_in_client().await;
let request = MediaRequest {
media_type: MediaType::Uri(mxc_uri!("mxc://localhost/textfile")),
format: MediaFormat::File,
};
let m = mock(
"GET",
Matcher::Regex(r"^/_matrix/media/r0/download/localhost/textfile\?.*$".to_string()),
)
.with_status(200)
.with_body("Some very interesting text.")
.expect(2)
.create();
assert!(client.get_media_content(&request, true).await.is_ok());
assert!(client.get_media_content(&request, true).await.is_ok());
assert!(client.get_media_content(&request, false).await.is_ok());
m.assert();
}
#[tokio::test]
async fn get_media_file() {
let client = logged_in_client().await;
let event_content = ImageMessageEventContent::plain(
"filename.jpg".into(),
mxc_uri!("mxc://example.org/image"),
Some(Box::new(assign!(ImageInfo::new(), {
height: Some(uint!(398)),
width: Some(uint!(394)),
mimetype: Some("image/jpeg".into()),
size: Some(uint!(31037)),
}))),
);
let m = mock(
"GET",
Matcher::Regex(r"^/_matrix/media/r0/download/example%2Eorg/image\?.*$".to_string()),
)
.with_status(200)
.with_body("binaryjpegdata")
.create();
assert!(client.get_file(event_content.clone(), true).await.is_ok());
assert!(client.get_file(event_content.clone(), true).await.is_ok());
m.assert();
let m = mock(
"GET",
Matcher::Regex(r"^/_matrix/media/r0/thumbnail/example%2Eorg/image\?.*$".to_string()),
)
.with_status(200)
.with_body("smallerbinaryjpegdata")
.create();
assert!(client
.get_thumbnail(
event_content,
MediaThumbnailSize { method: Method::Scale, width: uint!(100), height: uint!(100) },
true
)
.await
.is_ok());
m.assert();
}
} }

View File

@ -18,7 +18,7 @@ use std::io::Error as IoError;
use http::StatusCode; use http::StatusCode;
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
use matrix_sdk_base::crypto::store::CryptoStoreError; use matrix_sdk_base::crypto::{store::CryptoStoreError, DecryptorError};
use matrix_sdk_base::{Error as MatrixError, StoreError}; use matrix_sdk_base::{Error as MatrixError, StoreError};
use matrix_sdk_common::{ use matrix_sdk_common::{
api::{ api::{
@ -122,6 +122,11 @@ pub enum Error {
#[error(transparent)] #[error(transparent)]
CryptoStoreError(#[from] CryptoStoreError), CryptoStoreError(#[from] CryptoStoreError),
/// An error occurred during decryption.
#[cfg(feature = "encryption")]
#[error(transparent)]
DecryptorError(#[from] DecryptorError),
/// An error occurred in the state store. /// An error occurred in the state store.
#[error(transparent)] #[error(transparent)]
StateStore(#[from] StoreError), StateStore(#[from] StoreError),

View File

@ -81,7 +81,7 @@ pub use bytes::{Bytes, BytesMut};
#[cfg_attr(feature = "docs", doc(cfg(encryption)))] #[cfg_attr(feature = "docs", doc(cfg(encryption)))]
pub use matrix_sdk_base::crypto::{EncryptionInfo, LocalTrust}; pub use matrix_sdk_base::crypto::{EncryptionInfo, LocalTrust};
pub use matrix_sdk_base::{ pub use matrix_sdk_base::{
Error as BaseError, Room as BaseRoom, RoomInfo, RoomMember as BaseRoomMember, RoomType, media, Error as BaseError, Room as BaseRoom, RoomInfo, RoomMember as BaseRoomMember, RoomType,
Session, StateChanges, StoreError, Session, StateChanges, StoreError,
}; };
pub use matrix_sdk_common::*; pub use matrix_sdk_common::*;

View File

@ -25,6 +25,7 @@ docs = ["encryption", "sled_cryptostore"]
[dependencies] [dependencies]
dashmap = "4.0.2" dashmap = "4.0.2"
lru = "0.6.5"
serde = { version = "1.0.122", features = ["rc"] } serde = { version = "1.0.122", features = ["rc"] }
serde_json = "1.0.61" serde_json = "1.0.61"
tracing = "0.1.22" tracing = "0.1.22"

View File

@ -45,6 +45,7 @@ pub use crate::{
mod client; mod client;
mod error; mod error;
pub mod media;
mod rooms; mod rooms;
mod session; mod session;
mod store; mod store;

View File

@ -0,0 +1,216 @@
//! Common types for [media content](https://matrix.org/docs/spec/client_server/r0.6.1#id66).
use matrix_sdk_common::{
api::r0::media::get_content_thumbnail::Method,
events::{
room::{
message::{
AudioMessageEventContent, FileMessageEventContent, ImageMessageEventContent,
LocationMessageEventContent, VideoMessageEventContent,
},
EncryptedFile,
},
sticker::StickerEventContent,
},
identifiers::MxcUri,
UInt,
};
const UNIQUE_SEPARATOR: &str = "_";
/// A trait to uniquely identify values of the same type.
pub trait UniqueKey {
/// A string that uniquely identifies `Self` compared to other values of
/// the same type.
fn unique_key(&self) -> String;
}
/// The requested format of a media file.
#[derive(Clone, Debug)]
pub enum MediaFormat {
/// The file that was uploaded.
File,
/// A thumbnail of the file that was uploaded.
Thumbnail(MediaThumbnailSize),
}
impl UniqueKey for MediaFormat {
fn unique_key(&self) -> String {
match self {
Self::File => "file".into(),
Self::Thumbnail(size) => size.unique_key(),
}
}
}
/// The requested size of a media thumbnail.
#[derive(Clone, Debug)]
pub struct MediaThumbnailSize {
/// The desired resizing method.
pub method: Method,
/// The desired width of the thumbnail. The actual thumbnail may not match
/// the size specified.
pub width: UInt,
/// The desired height of the thumbnail. The actual thumbnail may not match
/// the size specified.
pub height: UInt,
}
impl UniqueKey for MediaThumbnailSize {
fn unique_key(&self) -> String {
format!("{}{}{}x{}", self.method, UNIQUE_SEPARATOR, self.width, self.height)
}
}
/// A request for media data.
#[derive(Clone, Debug)]
pub enum MediaType {
/// A media content URI.
Uri(MxcUri),
/// An encrypted media content.
Encrypted(Box<EncryptedFile>),
}
impl UniqueKey for MediaType {
fn unique_key(&self) -> String {
match self {
Self::Uri(uri) => uri.to_string(),
Self::Encrypted(file) => file.url.to_string(),
}
}
}
/// A request for media data.
#[derive(Clone, Debug)]
pub struct MediaRequest {
/// The type of the media file.
pub media_type: MediaType,
/// The requested format of the media data.
pub format: MediaFormat,
}
impl UniqueKey for MediaRequest {
fn unique_key(&self) -> String {
format!("{}{}{}", self.media_type.unique_key(), UNIQUE_SEPARATOR, self.format.unique_key())
}
}
/// Trait for media event content.
pub trait MediaEventContent {
/// Get the type of the file for `Self`.
///
/// Returns `None` if `Self` has no file.
fn file(&self) -> Option<MediaType>;
/// Get the type of the thumbnail for `Self`.
///
/// Returns `None` if `Self` has no thumbnail.
fn thumbnail(&self) -> Option<MediaType>;
}
impl MediaEventContent for StickerEventContent {
fn file(&self) -> Option<MediaType> {
Some(MediaType::Uri(self.url.clone()))
}
fn thumbnail(&self) -> Option<MediaType> {
None
}
}
impl MediaEventContent for AudioMessageEventContent {
fn file(&self) -> Option<MediaType> {
self.url
.as_ref()
.map(|uri| MediaType::Uri(uri.clone()))
.or_else(|| self.file.as_ref().map(|e| MediaType::Encrypted(e.clone())))
}
fn thumbnail(&self) -> Option<MediaType> {
None
}
}
impl MediaEventContent for FileMessageEventContent {
fn file(&self) -> Option<MediaType> {
self.url
.as_ref()
.map(|uri| MediaType::Uri(uri.clone()))
.or_else(|| self.file.as_ref().map(|e| MediaType::Encrypted(e.clone())))
}
fn thumbnail(&self) -> Option<MediaType> {
self.info.as_ref().and_then(|info| {
if let Some(uri) = info.thumbnail_url.as_ref() {
Some(MediaType::Uri(uri.clone()))
} else {
info.thumbnail_file.as_ref().map(|file| MediaType::Encrypted(file.clone()))
}
})
}
}
impl MediaEventContent for ImageMessageEventContent {
fn file(&self) -> Option<MediaType> {
self.url
.as_ref()
.map(|uri| MediaType::Uri(uri.clone()))
.or_else(|| self.file.as_ref().map(|e| MediaType::Encrypted(e.clone())))
}
fn thumbnail(&self) -> Option<MediaType> {
self.info
.as_ref()
.and_then(|info| {
if let Some(uri) = info.thumbnail_url.as_ref() {
Some(MediaType::Uri(uri.clone()))
} else {
info.thumbnail_file.as_ref().map(|file| MediaType::Encrypted(file.clone()))
}
})
.or_else(|| self.url.as_ref().map(|uri| MediaType::Uri(uri.clone())))
}
}
impl MediaEventContent for VideoMessageEventContent {
fn file(&self) -> Option<MediaType> {
self.url
.as_ref()
.map(|uri| MediaType::Uri(uri.clone()))
.or_else(|| self.file.as_ref().map(|e| MediaType::Encrypted(e.clone())))
}
fn thumbnail(&self) -> Option<MediaType> {
self.info
.as_ref()
.and_then(|info| {
if let Some(uri) = info.thumbnail_url.as_ref() {
Some(MediaType::Uri(uri.clone()))
} else {
info.thumbnail_file.as_ref().map(|file| MediaType::Encrypted(file.clone()))
}
})
.or_else(|| self.url.as_ref().map(|uri| MediaType::Uri(uri.clone())))
}
}
impl MediaEventContent for LocationMessageEventContent {
fn file(&self) -> Option<MediaType> {
None
}
fn thumbnail(&self) -> Option<MediaType> {
self.info.as_ref().and_then(|info| {
if let Some(uri) = info.thumbnail_url.as_ref() {
Some(MediaType::Uri(uri.clone()))
} else {
info.thumbnail_file.as_ref().map(|file| MediaType::Encrypted(file.clone()))
}
})
}
}

View File

@ -18,6 +18,7 @@ use std::{
}; };
use dashmap::{DashMap, DashSet}; use dashmap::{DashMap, DashSet};
use lru::LruCache;
use matrix_sdk_common::{ use matrix_sdk_common::{
async_trait, async_trait,
events::{ events::{
@ -27,15 +28,19 @@ use matrix_sdk_common::{
AnyGlobalAccountDataEvent, AnyRoomAccountDataEvent, AnyStrippedStateEvent, AnyGlobalAccountDataEvent, AnyRoomAccountDataEvent, AnyStrippedStateEvent,
AnySyncStateEvent, EventType, AnySyncStateEvent, EventType,
}, },
identifiers::{EventId, RoomId, UserId}, identifiers::{EventId, MxcUri, RoomId, UserId},
instant::Instant, instant::Instant,
locks::Mutex,
receipt::ReceiptType, receipt::ReceiptType,
Raw, Raw,
}; };
use tracing::info; use tracing::info;
use super::{Result, RoomInfo, StateChanges, StateStore}; use super::{Result, RoomInfo, StateChanges, StateStore};
use crate::deserialized_responses::{MemberEvent, StrippedMemberEvent}; use crate::{
deserialized_responses::{MemberEvent, StrippedMemberEvent},
media::{MediaRequest, UniqueKey},
};
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct MemoryStore { pub struct MemoryStore {
@ -62,6 +67,7 @@ pub struct MemoryStore {
#[allow(clippy::type_complexity)] #[allow(clippy::type_complexity)]
room_event_receipts: room_event_receipts:
Arc<DashMap<RoomId, DashMap<String, DashMap<EventId, DashMap<UserId, Receipt>>>>>, Arc<DashMap<RoomId, DashMap<String, DashMap<EventId, DashMap<UserId, Receipt>>>>>,
media: Arc<Mutex<LruCache<String, Vec<u8>>>>,
} }
impl MemoryStore { impl MemoryStore {
@ -85,6 +91,7 @@ impl MemoryStore {
presence: DashMap::new().into(), presence: DashMap::new().into(),
room_user_receipts: DashMap::new().into(), room_user_receipts: DashMap::new().into(),
room_event_receipts: DashMap::new().into(), room_event_receipts: DashMap::new().into(),
media: Arc::new(Mutex::new(LruCache::new(100))),
} }
} }
@ -386,6 +393,39 @@ impl MemoryStore {
}) })
.unwrap_or_else(Vec::new)) .unwrap_or_else(Vec::new))
} }
async fn add_media_content(&self, request: &MediaRequest, data: Vec<u8>) -> Result<()> {
self.media.lock().await.put(request.unique_key(), data);
Ok(())
}
async fn get_media_content(&self, request: &MediaRequest) -> Result<Option<Vec<u8>>> {
Ok(self.media.lock().await.get(&request.unique_key()).cloned())
}
async fn remove_media_content(&self, request: &MediaRequest) -> Result<()> {
self.media.lock().await.pop(&request.unique_key());
Ok(())
}
async fn remove_media_content_for_uri(&self, uri: &MxcUri) -> Result<()> {
let mut media_store = self.media.lock().await;
let keys: Vec<String> = media_store
.iter()
.filter_map(
|(key, _)| if key.starts_with(&uri.to_string()) { Some(key.clone()) } else { None },
)
.collect();
for key in keys {
media_store.pop(&key);
}
Ok(())
}
} }
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] #[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
@ -501,19 +541,38 @@ impl StateStore for MemoryStore {
) -> Result<Vec<(UserId, Receipt)>> { ) -> Result<Vec<(UserId, Receipt)>> {
self.get_event_room_receipt_events(room_id, receipt_type, event_id).await self.get_event_room_receipt_events(room_id, receipt_type, event_id).await
} }
async fn add_media_content(&self, request: &MediaRequest, data: Vec<u8>) -> Result<()> {
self.add_media_content(request, data).await
}
async fn get_media_content(&self, request: &MediaRequest) -> Result<Option<Vec<u8>>> {
self.get_media_content(request).await
}
async fn remove_media_content(&self, request: &MediaRequest) -> Result<()> {
self.remove_media_content(request).await
}
async fn remove_media_content_for_uri(&self, uri: &MxcUri) -> Result<()> {
self.remove_media_content_for_uri(uri).await
}
} }
#[cfg(test)] #[cfg(test)]
#[cfg(not(feature = "sled_state_store"))] #[cfg(not(feature = "sled_state_store"))]
mod test { mod test {
use matrix_sdk_common::{ use matrix_sdk_common::{
identifiers::{event_id, room_id, user_id}, api::r0::media::get_content_thumbnail::Method,
identifiers::{event_id, mxc_uri, room_id, user_id, UserId},
receipt::ReceiptType, receipt::ReceiptType,
uint,
}; };
use matrix_sdk_test::async_test; use matrix_sdk_test::async_test;
use serde_json::json; use serde_json::json;
use super::{MemoryStore, StateChanges}; use super::{MemoryStore, StateChanges};
use crate::media::{MediaFormat, MediaRequest, MediaThumbnailSize, MediaType};
fn user_id() -> UserId { fn user_id() -> UserId {
user_id!("@example:localhost") user_id!("@example:localhost")
@ -612,4 +671,43 @@ mod test {
1 1
); );
} }
#[async_test]
async fn test_media_content() {
let store = MemoryStore::new();
let uri = mxc_uri!("mxc://localhost/media");
let content: Vec<u8> = "somebinarydata".into();
let request_file =
MediaRequest { media_type: MediaType::Uri(uri.clone()), format: MediaFormat::File };
let request_thumbnail = MediaRequest {
media_type: MediaType::Uri(uri.clone()),
format: MediaFormat::Thumbnail(MediaThumbnailSize {
method: Method::Crop,
width: uint!(100),
height: uint!(100),
}),
};
assert!(store.get_media_content(&request_file).await.unwrap().is_none());
assert!(store.get_media_content(&request_thumbnail).await.unwrap().is_none());
store.add_media_content(&request_file, content.clone()).await.unwrap();
assert!(store.get_media_content(&request_file).await.unwrap().is_some());
store.remove_media_content(&request_file).await.unwrap();
assert!(store.get_media_content(&request_file).await.unwrap().is_none());
store.add_media_content(&request_file, content.clone()).await.unwrap();
assert!(store.get_media_content(&request_file).await.unwrap().is_some());
store.add_media_content(&request_thumbnail, content.clone()).await.unwrap();
assert!(store.get_media_content(&request_thumbnail).await.unwrap().is_some());
store.remove_media_content_for_uri(&uri).await.unwrap();
assert!(store.get_media_content(&request_file).await.unwrap().is_none());
assert!(store.get_media_content(&request_thumbnail).await.unwrap().is_none());
}
} }

View File

@ -31,7 +31,7 @@ use matrix_sdk_common::{
AnyGlobalAccountDataEvent, AnyRoomAccountDataEvent, AnyStrippedStateEvent, AnyGlobalAccountDataEvent, AnyRoomAccountDataEvent, AnyStrippedStateEvent,
AnySyncStateEvent, EventContent, EventType, AnySyncStateEvent, EventContent, EventType,
}, },
identifiers::{EventId, RoomId, UserId}, identifiers::{EventId, MxcUri, RoomId, UserId},
locks::RwLock, locks::RwLock,
receipt::ReceiptType, receipt::ReceiptType,
AsyncTraitDeps, Raw, AsyncTraitDeps, Raw,
@ -41,6 +41,7 @@ use sled::Db;
use crate::{ use crate::{
deserialized_responses::{MemberEvent, StrippedMemberEvent}, deserialized_responses::{MemberEvent, StrippedMemberEvent},
media::MediaRequest,
rooms::{RoomInfo, RoomType}, rooms::{RoomInfo, RoomType},
Room, Session, Room, Session,
}; };
@ -249,6 +250,37 @@ pub trait StateStore: AsyncTraitDeps {
receipt_type: ReceiptType, receipt_type: ReceiptType,
event_id: &EventId, event_id: &EventId,
) -> Result<Vec<(UserId, Receipt)>>; ) -> Result<Vec<(UserId, Receipt)>>;
/// Add a media file's content in the media store.
///
/// # Arguments
///
/// * `request` - The `MediaRequest` of the file.
///
/// * `content` - The content of the file.
async fn add_media_content(&self, request: &MediaRequest, content: Vec<u8>) -> Result<()>;
/// Get a media file's content out of the media store.
///
/// # Arguments
///
/// * `request` - The `MediaRequest` of the file.
async fn get_media_content(&self, request: &MediaRequest) -> Result<Option<Vec<u8>>>;
/// Removes a media file's content from the media store.
///
/// # Arguments
///
/// * `request` - The `MediaRequest` of the file.
async fn remove_media_content(&self, request: &MediaRequest) -> Result<()>;
/// Removes all the media files' content associated to an `MxcUri` from the
/// media store.
///
/// # Arguments
///
/// * `uri` - The `MxcUri` of the media files.
async fn remove_media_content_for_uri(&self, uri: &MxcUri) -> Result<()>;
} }
/// A state store wrapper for the SDK. /// A state store wrapper for the SDK.

View File

@ -34,7 +34,7 @@ use matrix_sdk_common::{
room::member::{MemberEventContent, MembershipState}, room::member::{MemberEventContent, MembershipState},
AnyGlobalAccountDataEvent, AnyRoomAccountDataEvent, AnySyncStateEvent, EventType, AnyGlobalAccountDataEvent, AnyRoomAccountDataEvent, AnySyncStateEvent, EventType,
}, },
identifiers::{EventId, RoomId, UserId}, identifiers::{EventId, MxcUri, RoomId, UserId},
receipt::ReceiptType, receipt::ReceiptType,
Raw, Raw,
}; };
@ -47,7 +47,10 @@ use tracing::info;
use self::store_key::{EncryptedEvent, StoreKey}; use self::store_key::{EncryptedEvent, StoreKey};
use super::{Result, RoomInfo, StateChanges, StateStore, StoreError}; use super::{Result, RoomInfo, StateChanges, StateStore, StoreError};
use crate::deserialized_responses::MemberEvent; use crate::{
deserialized_responses::MemberEvent,
media::{MediaRequest, UniqueKey},
};
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub enum DatabaseType { pub enum DatabaseType {
@ -185,6 +188,7 @@ pub struct SledStore {
presence: Tree, presence: Tree,
room_user_receipts: Tree, room_user_receipts: Tree,
room_event_receipts: Tree, room_event_receipts: Tree,
media: Tree,
} }
impl std::fmt::Debug for SledStore { impl std::fmt::Debug for SledStore {
@ -220,6 +224,8 @@ impl SledStore {
let room_user_receipts = db.open_tree("room_user_receipts")?; let room_user_receipts = db.open_tree("room_user_receipts")?;
let room_event_receipts = db.open_tree("room_event_receipts")?; let room_event_receipts = db.open_tree("room_event_receipts")?;
let media = db.open_tree("media")?;
Ok(Self { Ok(Self {
path, path,
inner: db, inner: db,
@ -240,6 +246,7 @@ impl SledStore {
stripped_room_state, stripped_room_state,
room_user_receipts, room_user_receipts,
room_event_receipts, room_event_receipts,
media,
}) })
} }
@ -721,6 +728,46 @@ impl SledStore {
}) })
.collect() .collect()
} }
async fn add_media_content(&self, request: &MediaRequest, data: Vec<u8>) -> Result<()> {
self.media.insert(
(request.media_type.unique_key().as_str(), request.format.unique_key().as_str())
.encode(),
data,
)?;
Ok(())
}
async fn get_media_content(&self, request: &MediaRequest) -> Result<Option<Vec<u8>>> {
Ok(self
.media
.get(
(request.media_type.unique_key().as_str(), request.format.unique_key().as_str())
.encode(),
)?
.map(|m| m.to_vec()))
}
async fn remove_media_content(&self, request: &MediaRequest) -> Result<()> {
self.media.remove(
(request.media_type.unique_key().as_str(), request.format.unique_key().as_str())
.encode(),
)?;
Ok(())
}
async fn remove_media_content_for_uri(&self, uri: &MxcUri) -> Result<()> {
let keys = self.media.scan_prefix(uri.as_str().encode()).keys();
let mut batch = sled::Batch::default();
for key in keys {
batch.remove(key?);
}
Ok(self.media.apply_batch(batch)?)
}
} }
#[async_trait] #[async_trait]
@ -830,6 +877,22 @@ impl StateStore for SledStore {
) -> Result<Vec<(UserId, Receipt)>> { ) -> Result<Vec<(UserId, Receipt)>> {
self.get_event_room_receipt_events(room_id, receipt_type, event_id).await self.get_event_room_receipt_events(room_id, receipt_type, event_id).await
} }
async fn add_media_content(&self, request: &MediaRequest, data: Vec<u8>) -> Result<()> {
self.add_media_content(request, data).await
}
async fn get_media_content(&self, request: &MediaRequest) -> Result<Option<Vec<u8>>> {
self.get_media_content(request).await
}
async fn remove_media_content(&self, request: &MediaRequest) -> Result<()> {
self.remove_media_content(request).await
}
async fn remove_media_content_for_uri(&self, uri: &MxcUri) -> Result<()> {
self.remove_media_content_for_uri(uri).await
}
} }
#[cfg(test)] #[cfg(test)]
@ -837,6 +900,7 @@ mod test {
use std::convert::TryFrom; use std::convert::TryFrom;
use matrix_sdk_common::{ use matrix_sdk_common::{
api::r0::media::get_content_thumbnail::Method,
events::{ events::{
room::{ room::{
member::{MemberEventContent, MembershipState}, member::{MemberEventContent, MembershipState},
@ -844,15 +908,19 @@ mod test {
}, },
AnySyncStateEvent, EventType, Unsigned, AnySyncStateEvent, EventType, Unsigned,
}, },
identifiers::{event_id, room_id, user_id, EventId, UserId}, identifiers::{event_id, mxc_uri, room_id, user_id, EventId, UserId},
receipt::ReceiptType, receipt::ReceiptType,
MilliSecondsSinceUnixEpoch, Raw, uint, MilliSecondsSinceUnixEpoch, Raw,
}; };
use matrix_sdk_test::async_test; use matrix_sdk_test::async_test;
use serde_json::json; use serde_json::json;
use super::{SledStore, StateChanges}; use super::{SledStore, StateChanges};
use crate::{deserialized_responses::MemberEvent, StateStore}; use crate::{
deserialized_responses::MemberEvent,
media::{MediaFormat, MediaRequest, MediaThumbnailSize, MediaType},
StateStore,
};
fn user_id() -> UserId { fn user_id() -> UserId {
user_id!("@example:localhost") user_id!("@example:localhost")
@ -1024,4 +1092,43 @@ mod test {
1 1
); );
} }
#[async_test]
async fn test_media_content() {
let store = SledStore::open().unwrap();
let uri = mxc_uri!("mxc://localhost/media");
let content: Vec<u8> = "somebinarydata".into();
let request_file =
MediaRequest { media_type: MediaType::Uri(uri.clone()), format: MediaFormat::File };
let request_thumbnail = MediaRequest {
media_type: MediaType::Uri(uri.clone()),
format: MediaFormat::Thumbnail(MediaThumbnailSize {
method: Method::Crop,
width: uint!(100),
height: uint!(100),
}),
};
assert!(store.get_media_content(&request_file).await.unwrap().is_none());
assert!(store.get_media_content(&request_thumbnail).await.unwrap().is_none());
store.add_media_content(&request_file, content.clone()).await.unwrap();
assert!(store.get_media_content(&request_file).await.unwrap().is_some());
store.remove_media_content(&request_file).await.unwrap();
assert!(store.get_media_content(&request_file).await.unwrap().is_none());
store.add_media_content(&request_file, content.clone()).await.unwrap();
assert!(store.get_media_content(&request_file).await.unwrap().is_some());
store.add_media_content(&request_thumbnail, content.clone()).await.unwrap();
assert!(store.get_media_content(&request_thumbnail).await.unwrap().is_some());
store.remove_media_content_for_uri(&uri).await.unwrap();
assert!(store.get_media_content(&request_file).await.unwrap().is_none());
assert!(store.get_media_content(&request_thumbnail).await.unwrap().is_none());
}
} }

View File

@ -23,7 +23,7 @@ use aes_ctr::{
}; };
use base64::DecodeError; use base64::DecodeError;
use getrandom::getrandom; use getrandom::getrandom;
use matrix_sdk_common::events::room::{JsonWebKey, JsonWebKeyInit}; use matrix_sdk_common::events::room::{EncryptedFile, JsonWebKey, JsonWebKeyInit};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256}; use sha2::{Digest, Sha256};
use thiserror::Error; use thiserror::Error;
@ -252,6 +252,12 @@ pub struct EncryptionInfo {
pub hashes: BTreeMap<String, String>, pub hashes: BTreeMap<String, String>,
} }
impl From<EncryptedFile> for EncryptionInfo {
fn from(file: EncryptedFile) -> Self {
Self { version: file.v, web_key: file.key, iv: file.iv, hashes: file.hashes }
}
}
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use std::io::{Cursor, Read}; use std::io::{Cursor, Read};