Merge branch 'master' into verification-improvements

master
Damir Jelić 2021-06-04 18:32:20 +02:00
commit 7cca358399
32 changed files with 1790 additions and 188 deletions

View File

@ -26,7 +26,7 @@ rustls-tls = ["reqwest/rustls-tls"]
socks = ["reqwest/socks"]
sso_login = ["warp", "rand", "tokio-stream"]
require_auth_for_profile_requests = []
appservice = ["matrix-sdk-common/appservice", "serde_yaml"]
appservice = ["matrix-sdk-common/appservice"]
docs = ["encryption", "sled_cryptostore", "sled_state_store", "sso_login"]
@ -41,7 +41,6 @@ url = "2.2.0"
zeroize = "1.2.0"
mime = "0.3.16"
rand = { version = "0.8.2", optional = true }
serde_yaml = { version = "0.8", optional = true }
bytes = "1.0.1"
matrix-sdk-common = { version = "0.2.0", path = "../matrix_sdk_common" }

View File

@ -14,7 +14,11 @@
// limitations under the License.
#[cfg(feature = "encryption")]
use std::{collections::BTreeMap, io::Write, path::PathBuf};
use std::{
collections::BTreeMap,
io::{Cursor, Write},
path::PathBuf,
};
#[cfg(feature = "sso_login")]
use std::{
collections::HashMap,
@ -38,10 +42,13 @@ use http::Response;
#[cfg(feature = "encryption")]
use matrix_sdk_base::crypto::{
decrypt_key_export, encrypt_key_export, olm::InboundGroupSession, store::CryptoStoreError,
OutgoingRequests, RoomMessageRequest, ToDeviceRequest,
AttachmentDecryptor, OutgoingRequests, RoomMessageRequest, ToDeviceRequest,
};
use matrix_sdk_base::{
deserialized_responses::SyncResponse, events::AnyMessageEventContent, identifiers::MxcUri,
deserialized_responses::SyncResponse,
events::AnyMessageEventContent,
identifiers::MxcUri,
media::{MediaEventContent, MediaFormat, MediaRequest, MediaThumbnailSize, MediaType},
BaseClient, BaseClientConfig, SendAccessToken, Session, Store,
};
use mime::{self, Mime};
@ -83,7 +90,8 @@ use matrix_sdk_common::api::r0::{
},
};
use matrix_sdk_common::{
api::r0::{
api::{
r0::{
account::register,
device::{delete_devices, get_devices},
directory::{get_public_rooms, get_public_rooms_filtered},
@ -97,6 +105,8 @@ use matrix_sdk_common::{
sync::sync_events,
uiaa::AuthData,
},
unversioned::{discover_homeserver, get_supported_versions},
},
assign,
identifiers::{DeviceIdBox, RoomId, RoomIdOrAliasId, ServerName, UserId},
instant::{Duration, Instant},
@ -139,7 +149,7 @@ const SSO_SERVER_BIND_TRIES: u8 = 10;
#[derive(Clone)]
pub struct Client {
/// The URL of the homeserver to connect to.
homeserver: Arc<Url>,
homeserver: Arc<RwLock<Url>>,
/// The underlying HTTP client.
http_client: HttpClient,
/// User session data.
@ -161,7 +171,7 @@ pub struct Client {
#[cfg(not(tarpaulin_include))]
impl Debug for Client {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> StdResult<(), fmt::Error> {
write!(fmt, "Client {{ homeserver: {} }}", self.homeserver)
write!(fmt, "Client")
}
}
@ -291,6 +301,11 @@ impl ClientConfig {
self
}
/// Get the [`RequestConfig`]
pub fn get_request_config(&self) -> &RequestConfig {
&self.request_config
}
/// Specify a client to handle sending requests and receiving responses.
///
/// Any type that implements the `HttpSend` trait can be used to
@ -499,7 +514,7 @@ impl Client {
///
/// * `config` - Configuration for the client.
pub fn new_with_config(homeserver_url: Url, config: ClientConfig) -> Result<Self> {
let homeserver = Arc::new(homeserver_url);
let homeserver = Arc::new(RwLock::new(homeserver_url));
let client = if let Some(client) = config.client {
client
@ -510,12 +525,8 @@ impl Client {
let base_client = BaseClient::new_with_config(config.base_config)?;
let session = base_client.session().clone();
let http_client = HttpClient {
homeserver: homeserver.clone(),
inner: client,
session,
request_config: config.request_config,
};
let http_client =
HttpClient::new(client, homeserver.clone(), session, config.request_config);
Ok(Self {
homeserver,
@ -531,6 +542,89 @@ impl Client {
})
}
/// Creates a new client for making HTTP requests to the homeserver of the
/// given user. Follows homeserver discovery directions described
/// [here](https://spec.matrix.org/unstable/client-server-api/#well-known-uri).
///
/// # Arguments
///
/// * `user_id` - The id of the user whose homeserver the client should
/// connect to.
///
/// # Example
/// ```no_run
/// # use std::convert::TryFrom;
/// # use matrix_sdk::{Client, identifiers::UserId};
/// # use futures::executor::block_on;
/// let alice = UserId::try_from("@alice:example.org").unwrap();
/// # block_on(async {
/// let client = Client::new_from_user_id(alice.clone()).await.unwrap();
/// client.login(alice.localpart(), "password", None, None).await.unwrap();
/// # });
/// ```
pub async fn new_from_user_id(user_id: UserId) -> Result<Self> {
let config = ClientConfig::new();
Client::new_from_user_id_with_config(user_id, config).await
}
/// Creates a new client for making HTTP requests to the homeserver of the
/// given user and configuration. Follows homeserver discovery directions
/// described [here](https://spec.matrix.org/unstable/client-server-api/#well-known-uri).
///
/// # Arguments
///
/// * `user_id` - The id of the user whose homeserver the client should
/// connect to.
///
/// * `config` - Configuration for the client.
pub async fn new_from_user_id_with_config(
user_id: UserId,
config: ClientConfig,
) -> Result<Self> {
let homeserver = Client::homeserver_from_user_id(user_id)?;
let mut client = Client::new_with_config(homeserver, config)?;
let well_known = client.discover_homeserver().await?;
let well_known = Url::parse(well_known.homeserver.base_url.as_ref())?;
client.set_homeserver(well_known).await;
client.get_supported_versions().await?;
Ok(client)
}
fn homeserver_from_user_id(user_id: UserId) -> Result<Url> {
let homeserver = format!("https://{}", user_id.server_name());
#[allow(unused_mut)]
let mut result = Url::parse(homeserver.as_str())?;
// Mockito only knows how to test http endpoints:
// https://github.com/lipanski/mockito/issues/127
#[cfg(test)]
let _ = result.set_scheme("http");
Ok(result)
}
async fn discover_homeserver(&self) -> Result<discover_homeserver::Response> {
self.send(discover_homeserver::Request::new(), Some(RequestConfig::new().disable_retry()))
.await
}
/// Change the homeserver URL used by this client.
///
/// # Arguments
///
/// * `homeserver_url` - The new URL to use.
pub async fn set_homeserver(&mut self, homeserver_url: Url) {
let mut homeserver = self.homeserver.write().await;
*homeserver = homeserver_url;
}
async fn get_supported_versions(&self) -> Result<get_supported_versions::Response> {
self.send(
get_supported_versions::Request::new(),
Some(RequestConfig::new().disable_retry()),
)
.await
}
/// Process a [transaction] received from the homeserver
///
/// # Arguments
@ -563,8 +657,8 @@ impl Client {
}
/// The Homeserver of the client.
pub fn homeserver(&self) -> &Url {
&self.homeserver
pub async fn homeserver(&self) -> Url {
self.homeserver.read().await.clone()
}
/// Get the user id of the current owner of the client.
@ -863,8 +957,8 @@ impl Client {
/// successful SSO login.
///
/// [`login_with_token`]: #method.login_with_token
pub fn get_sso_login_url(&self, redirect_url: &str) -> Result<String> {
let homeserver = self.homeserver();
pub async fn get_sso_login_url(&self, redirect_url: &str) -> Result<String> {
let homeserver = self.homeserver().await;
let request = sso_login::Request::new(redirect_url)
.try_into_http_request::<Vec<u8>>(homeserver.as_str(), SendAccessToken::None);
match request {
@ -925,7 +1019,7 @@ impl Client {
device_id: Option<&str>,
initial_device_display_name: Option<&str>,
) -> Result<login::Response> {
info!("Logging in to {} as {:?}", self.homeserver, user);
info!("Logging in to {} as {:?}", self.homeserver().await, user);
let request = assign!(
login::Request::new(
@ -1034,7 +1128,7 @@ impl Client {
where
C: Future<Output = Result<()>>,
{
info!("Logging in to {}", self.homeserver);
info!("Logging in to {}", self.homeserver().await);
let (signal_tx, signal_rx) = oneshot::channel();
let (data_tx, data_rx) = oneshot::channel();
let data_tx_mutex = Arc::new(std::sync::Mutex::new(Some(data_tx)));
@ -1106,7 +1200,7 @@ impl Client {
tokio::spawn(server);
let sso_url = self.get_sso_login_url(redirect_url.as_str()).unwrap();
let sso_url = self.get_sso_login_url(redirect_url.as_str()).await.unwrap();
match use_sso_login_url(sso_url).await {
Ok(t) => t,
@ -1190,7 +1284,7 @@ impl Client {
device_id: Option<&str>,
initial_device_display_name: Option<&str>,
) -> Result<login::Response> {
info!("Logging in to {}", self.homeserver);
info!("Logging in to {}", self.homeserver().await);
let request = assign!(
login::Request::new(
@ -1261,7 +1355,7 @@ impl Client {
&self,
registration: impl Into<register::Request<'_>>,
) -> Result<register::Response> {
info!("Registering to {}", self.homeserver);
info!("Registering to {}", self.homeserver().await);
let request = registration.into();
self.send(request, None).await
@ -1915,7 +2009,7 @@ impl Client {
#[cfg(feature = "encryption")]
{
// This is needed because sometimes we need to automatically
// claim some one-time keys to unwedge an exisitng Olm session.
// claim some one-time keys to unwedge an existing Olm session.
if let Err(e) = self.claim_one_time_keys([].iter()).await {
warn!("Error while claiming one-time keys {:?}", e);
}
@ -2382,13 +2476,227 @@ impl Client {
Ok(olm.import_keys(import, |_, _| {}).await?)
}
/// Get a media file's content.
///
/// If the content is encrypted and encryption is enabled, the content will
/// be decrypted.
///
/// # Arguments
///
/// * `request` - The `MediaRequest` of the content.
///
/// * `use_cache` - If we should use the media cache for this request.
pub async fn get_media_content(
&self,
request: &MediaRequest,
use_cache: bool,
) -> Result<Vec<u8>> {
let content = if use_cache {
self.base_client.store().get_media_content(request).await?
} else {
None
};
if let Some(content) = content {
Ok(content)
} else {
let content: Vec<u8> = match &request.media_type {
MediaType::Encrypted(file) => {
let content: Vec<u8> =
self.send(get_content::Request::from_url(&file.url)?, None).await?.file;
#[cfg(feature = "encryption")]
let content = {
let mut cursor = Cursor::new(content);
let mut reader =
AttachmentDecryptor::new(&mut cursor, file.as_ref().clone().into())?;
let mut decrypted = Vec::new();
reader.read_to_end(&mut decrypted)?;
decrypted
};
content
}
MediaType::Uri(uri) => {
if let MediaFormat::Thumbnail(size) = &request.format {
self.send(
get_content_thumbnail::Request::from_url(
&uri,
size.width,
size.height,
)?,
None,
)
.await?
.file
} else {
self.send(get_content::Request::from_url(&uri)?, None).await?.file
}
}
};
if use_cache {
self.base_client.store().add_media_content(request, content.clone()).await?;
}
Ok(content)
}
}
/// Remove a media file's content from the store.
///
/// # Arguments
///
/// * `request` - The `MediaRequest` of the content.
pub async fn remove_media_content(&self, request: &MediaRequest) -> Result<()> {
Ok(self.base_client.store().remove_media_content(request).await?)
}
/// Delete all the media content corresponding to the given
/// uri from the store.
///
/// # Arguments
///
/// * `uri` - The `MxcUri` of the files.
pub async fn remove_media_content_for_uri(&self, uri: &MxcUri) -> Result<()> {
Ok(self.base_client.store().remove_media_content_for_uri(&uri).await?)
}
/// Get the file of the given media event content.
///
/// If the content is encrypted and encryption is enabled, the content will
/// be decrypted.
///
/// Returns `Ok(None)` if the event content has no file.
///
/// This is a convenience method that calls the
/// [`get_media_content`](#method.get_media_content) method.
///
/// # Arguments
///
/// * `event_content` - The media event content.
///
/// * `use_cache` - If we should use the media cache for this file.
pub async fn get_file(
&self,
event_content: impl MediaEventContent,
use_cache: bool,
) -> Result<Option<Vec<u8>>> {
if let Some(media_type) = event_content.file() {
Ok(Some(
self.get_media_content(
&MediaRequest { media_type, format: MediaFormat::File },
use_cache,
)
.await?,
))
} else {
Ok(None)
}
}
/// Remove the file of the given media event content from the cache.
///
/// This is a convenience method that calls the
/// [`remove_media_content`](#method.remove_media_content) method.
///
/// # Arguments
///
/// * `event_content` - The media event content.
pub async fn remove_file(&self, event_content: impl MediaEventContent) -> Result<()> {
if let Some(media_type) = event_content.file() {
self.remove_media_content(&MediaRequest { media_type, format: MediaFormat::File })
.await?
}
Ok(())
}
/// Get a thumbnail of the given media event content.
///
/// If the content is encrypted and encryption is enabled, the content will
/// be decrypted.
///
/// Returns `Ok(None)` if the event content has no thumbnail.
///
/// This is a convenience method that calls the
/// [`get_media_content`](#method.get_media_content) method.
///
/// # Arguments
///
/// * `event_content` - The media event content.
///
/// * `size` - The _desired_ size of the thumbnail. The actual thumbnail may
/// not match the size specified.
///
/// * `use_cache` - If we should use the media cache for this thumbnail.
pub async fn get_thumbnail(
&self,
event_content: impl MediaEventContent,
size: MediaThumbnailSize,
use_cache: bool,
) -> Result<Option<Vec<u8>>> {
if let Some(media_type) = event_content.thumbnail() {
Ok(Some(
self.get_media_content(
&MediaRequest { media_type, format: MediaFormat::Thumbnail(size) },
use_cache,
)
.await?,
))
} else {
Ok(None)
}
}
/// Remove the thumbnail of the given media event content from the cache.
///
/// This is a convenience method that calls the
/// [`remove_media_content`](#method.remove_media_content) method.
///
/// # Arguments
///
/// * `event_content` - The media event content.
///
/// * `size` - The _desired_ size of the thumbnail. Must match the size
/// requested with [`get_thumbnail`](#method.get_thumbnail).
pub async fn remove_thumbnail(
&self,
event_content: impl MediaEventContent,
size: MediaThumbnailSize,
) -> Result<()> {
if let Some(media_type) = event_content.file() {
self.remove_media_content(&MediaRequest {
media_type,
format: MediaFormat::Thumbnail(size),
})
.await?
}
Ok(())
}
}
#[cfg(test)]
mod test {
use std::{collections::BTreeMap, convert::TryInto, io::Cursor, str::FromStr, time::Duration};
use std::{
collections::BTreeMap,
convert::{TryFrom, TryInto},
io::Cursor,
str::FromStr,
time::Duration,
};
use matrix_sdk_base::identifiers::mxc_uri;
use matrix_sdk_base::{
api::r0::media::get_content_thumbnail::Method,
events::room::{message::ImageMessageEventContent, ImageInfo},
identifiers::mxc_uri,
media::{MediaFormat, MediaRequest, MediaThumbnailSize, MediaType},
uint,
};
use matrix_sdk_common::{
api::r0::{
account::register::Request as RegistrationRequest,
@ -2398,7 +2706,7 @@ mod test {
assign,
directory::Filter,
events::{room::message::MessageEventContent, AnyMessageEventContent},
identifiers::{event_id, room_id, user_id},
identifiers::{event_id, room_id, user_id, UserId},
thirdparty,
};
use matrix_sdk_test::{test_json, EventBuilder, EventsJson};
@ -2424,6 +2732,62 @@ mod test {
client
}
#[tokio::test]
async fn set_homeserver() {
let homeserver = Url::from_str("http://example.com/").unwrap();
let mut client = Client::new(homeserver).unwrap();
let homeserver = Url::from_str(&mockito::server_url()).unwrap();
client.set_homeserver(homeserver.clone()).await;
assert_eq!(client.homeserver().await, homeserver);
}
#[tokio::test]
async fn successful_discovery() {
let server_url = mockito::server_url();
let domain = server_url.strip_prefix("http://").unwrap();
let alice = UserId::try_from("@alice:".to_string() + domain).unwrap();
let _m_well_known = mock("GET", "/.well-known/matrix/client")
.with_status(200)
.with_body(
test_json::WELL_KNOWN.to_string().replace("HOMESERVER_URL", server_url.as_ref()),
)
.create();
let _m_versions = mock("GET", "/_matrix/client/versions")
.with_status(200)
.with_body(test_json::VERSIONS.to_string())
.create();
let client = Client::new_from_user_id(alice).await.unwrap();
assert_eq!(client.homeserver().await, Url::parse(server_url.as_ref()).unwrap());
}
#[tokio::test]
async fn discovery_broken_server() {
let server_url = mockito::server_url();
let domain = server_url.strip_prefix("http://").unwrap();
let alice = UserId::try_from("@alice:".to_string() + domain).unwrap();
let _m = mock("GET", "/.well-known/matrix/client")
.with_status(200)
.with_body(
test_json::WELL_KNOWN.to_string().replace("HOMESERVER_URL", server_url.as_ref()),
)
.create();
if Client::new_from_user_id(alice).await.is_ok() {
panic!(
"Creating a client from a user ID should fail when the \
.well-known server returns no version information."
);
}
}
#[tokio::test]
async fn login() {
let homeserver = Url::from_str(&mockito::server_url()).unwrap();
@ -2513,7 +2877,7 @@ mod test {
.any(|flow| matches!(flow, LoginType::Sso(_)));
assert!(can_sso);
let sso_url = client.get_sso_login_url("http://127.0.0.1:3030");
let sso_url = client.get_sso_login_url("http://127.0.0.1:3030").await;
assert!(sso_url.is_ok());
let _m = mock("POST", "/_matrix/client/r0/login")
@ -2625,7 +2989,7 @@ mod test {
client.base_client.receive_sync_response(response).await.unwrap();
let room_id = room_id!("!SVkFJHzfwvuaIEawgC:localhost");
assert_eq!(client.homeserver(), &Url::parse(&mockito::server_url()).unwrap());
assert_eq!(client.homeserver().await, Url::parse(&mockito::server_url()).unwrap());
let room = client.get_joined_room(&room_id);
assert!(room.is_some());
@ -3279,6 +3643,7 @@ mod test {
.with_status(200)
.match_header("authorization", "Bearer 1234")
.with_body(test_json::SYNC.to_string())
.expect_at_least(1)
.create();
let sync_settings = SyncSettings::new().timeout(Duration::from_millis(3000));
@ -3288,6 +3653,19 @@ mod test {
let room = client.get_joined_room(&room_id!("!SVkFJHzfwvuaIEawgC:localhost")).unwrap();
assert_eq!("tutorial".to_string(), room.display_name().await.unwrap());
let _m = mock("GET", Matcher::Regex(r"^/_matrix/client/r0/sync\?.*$".to_string()))
.with_status(200)
.match_header("authorization", "Bearer 1234")
.with_body(test_json::INVITE_SYNC.to_string())
.expect_at_least(1)
.create();
let _response = client.sync_once(SyncSettings::new()).await.unwrap();
let invited_room = client.get_invited_room(&room_id!("!696r7674:example.com")).unwrap();
assert_eq!("My Room Name".to_string(), invited_room.display_name().await.unwrap());
}
#[tokio::test]
@ -3395,4 +3773,74 @@ mod test {
panic!("this request should return an `Err` variant")
}
}
#[tokio::test]
async fn get_media_content() {
let client = logged_in_client().await;
let request = MediaRequest {
media_type: MediaType::Uri(mxc_uri!("mxc://localhost/textfile")),
format: MediaFormat::File,
};
let m = mock(
"GET",
Matcher::Regex(r"^/_matrix/media/r0/download/localhost/textfile\?.*$".to_string()),
)
.with_status(200)
.with_body("Some very interesting text.")
.expect(2)
.create();
assert!(client.get_media_content(&request, true).await.is_ok());
assert!(client.get_media_content(&request, true).await.is_ok());
assert!(client.get_media_content(&request, false).await.is_ok());
m.assert();
}
#[tokio::test]
async fn get_media_file() {
let client = logged_in_client().await;
let event_content = ImageMessageEventContent::plain(
"filename.jpg".into(),
mxc_uri!("mxc://example.org/image"),
Some(Box::new(assign!(ImageInfo::new(), {
height: Some(uint!(398)),
width: Some(uint!(394)),
mimetype: Some("image/jpeg".into()),
size: Some(uint!(31037)),
}))),
);
let m = mock(
"GET",
Matcher::Regex(r"^/_matrix/media/r0/download/example%2Eorg/image\?.*$".to_string()),
)
.with_status(200)
.with_body("binaryjpegdata")
.create();
assert!(client.get_file(event_content.clone(), true).await.is_ok());
assert!(client.get_file(event_content.clone(), true).await.is_ok());
m.assert();
let m = mock(
"GET",
Matcher::Regex(r"^/_matrix/media/r0/thumbnail/example%2Eorg/image\?.*$".to_string()),
)
.with_status(200)
.with_body("smallerbinaryjpegdata")
.create();
assert!(client
.get_thumbnail(
event_content,
MediaThumbnailSize { method: Method::Scale, width: uint!(100), height: uint!(100) },
true
)
.await
.is_ok());
m.assert();
}
}

View File

@ -18,7 +18,7 @@ use std::io::Error as IoError;
use http::StatusCode;
#[cfg(feature = "encryption")]
use matrix_sdk_base::crypto::store::CryptoStoreError;
use matrix_sdk_base::crypto::{store::CryptoStoreError, DecryptorError};
use matrix_sdk_base::{Error as MatrixError, StoreError};
use matrix_sdk_common::{
api::{
@ -31,6 +31,7 @@ use matrix_sdk_common::{
use reqwest::Error as ReqwestError;
use serde_json::Error as JsonError;
use thiserror::Error;
use url::ParseError as UrlParseError;
/// Result type of the rust-sdk.
pub type Result<T> = std::result::Result<T, Error>;
@ -121,13 +122,22 @@ pub enum Error {
#[error(transparent)]
CryptoStoreError(#[from] CryptoStoreError),
/// An error occured in the state store.
/// An error occurred during decryption.
#[cfg(feature = "encryption")]
#[error(transparent)]
DecryptorError(#[from] DecryptorError),
/// An error occurred in the state store.
#[error(transparent)]
StateStore(#[from] StoreError),
/// An error encountered when trying to parse an identifier.
#[error(transparent)]
Identifier(#[from] IdentifierError),
/// An error encountered when trying to parse a url.
#[error(transparent)]
Url(#[from] UrlParseError),
}
impl Error {

View File

@ -379,7 +379,7 @@ pub trait EventHandler: Send + Sync {
async fn on_room_redaction(&self, _: Room, _: &SyncRedactionEvent) {}
/// Fires when `Client` receives a `RoomEvent::RoomPowerLevels` event.
async fn on_room_power_levels(&self, _: Room, _: &SyncStateEvent<PowerLevelsEventContent>) {}
/// Fires when `Client` receives a `RoomEvent::Tombstone` event.
/// Fires when `Client` receives a `RoomEvent::RoomJoinRules` event.
async fn on_room_join_rules(&self, _: Room, _: &SyncStateEvent<JoinRulesEventContent>) {}
/// Fires when `Client` receives a `RoomEvent::Tombstone` event.
async fn on_room_tombstone(&self, _: Room, _: &SyncStateEvent<TombstoneEventContent>) {}

View File

@ -97,7 +97,7 @@ pub trait HttpSend: AsyncTraitDeps {
#[derive(Clone, Debug)]
pub(crate) struct HttpClient {
pub(crate) inner: Arc<dyn HttpSend>,
pub(crate) homeserver: Arc<Url>,
pub(crate) homeserver: Arc<RwLock<Url>>,
pub(crate) session: Arc<RwLock<Option<Session>>>,
pub(crate) request_config: RequestConfig,
}
@ -106,6 +106,15 @@ pub(crate) struct HttpClient {
use crate::OutgoingRequestAppserviceExt;
impl HttpClient {
pub(crate) fn new(
inner: Arc<dyn HttpSend>,
homeserver: Arc<RwLock<Url>>,
session: Arc<RwLock<Option<Session>>>,
request_config: RequestConfig,
) -> Self {
HttpClient { inner, homeserver, session, request_config }
}
async fn send_request<Request: OutgoingRequest>(
&self,
request: Request,
@ -124,7 +133,7 @@ impl HttpClient {
let request = if !self.request_config.assert_identity {
self.try_into_http_request(request, session, config).await?
} else {
self.try_into_http_request_with_identy_assertion(request, session, config).await?
self.try_into_http_request_with_identity_assertion(request, session, config).await?
};
self.inner.send_request(request, config).await
@ -161,14 +170,17 @@ impl HttpClient {
};
let http_request = request
.try_into_http_request::<BytesMut>(&self.homeserver.to_string(), access_token)?
.try_into_http_request::<BytesMut>(
&self.homeserver.read().await.to_string(),
access_token,
)?
.map(|body| body.freeze());
Ok(http_request)
}
#[cfg(feature = "appservice")]
async fn try_into_http_request_with_identy_assertion<Request: OutgoingRequest>(
async fn try_into_http_request_with_identity_assertion<Request: OutgoingRequest>(
&self,
request: Request,
session: Arc<RwLock<Option<Session>>>,
@ -189,7 +201,7 @@ impl HttpClient {
let http_request = request
.try_into_http_request_with_user_id::<BytesMut>(
&self.homeserver.to_string(),
&self.homeserver.read().await.to_string(),
access_token,
user_id,
)?

View File

@ -52,8 +52,7 @@
//! synapse configuration `require_auth_for_profile_requests`. Enabled by
//! default.
//! * `appservice`: Enables low-level appservice functionality. For an
//! high-level API there's the
//! `matrix-sdk-appservice` crate
//! high-level API there's the `matrix-sdk-appservice` crate
#![deny(
missing_debug_implementations,
@ -81,7 +80,7 @@ pub use bytes::{Bytes, BytesMut};
#[cfg_attr(feature = "docs", doc(cfg(encryption)))]
pub use matrix_sdk_base::crypto::{EncryptionInfo, LocalTrust};
pub use matrix_sdk_base::{
Error as BaseError, Room as BaseRoom, RoomInfo, RoomMember as BaseRoomMember, RoomType,
media, Error as BaseError, Room as BaseRoom, RoomInfo, RoomMember as BaseRoomMember, RoomType,
Session, StateChanges, StoreError,
};
pub use matrix_sdk_common::*;

View File

@ -34,7 +34,7 @@ impl Common {
/// # Arguments
/// * `client` - The client used to make requests.
///
/// * `room` - The underlaying room.
/// * `room` - The underlying room.
pub fn new(client: Client, room: BaseRoom) -> Self {
// TODO: Make this private
Self { inner: room, client }

View File

@ -5,7 +5,7 @@ use crate::{room::Common, BaseRoom, Client, Result, RoomType};
/// A room in the invited state.
///
/// This struct contains all methodes specific to a `Room` with type
/// `RoomType::Invited`. Operations may fail once the underlaying `Room` changes
/// `RoomType::Invited`. Operations may fail once the underlying `Room` changes
/// `RoomType`.
#[derive(Debug, Clone)]
pub struct Invited {
@ -13,13 +13,13 @@ pub struct Invited {
}
impl Invited {
/// Create a new `room::Invited` if the underlaying `Room` has type
/// Create a new `room::Invited` if the underlying `Room` has type
/// `RoomType::Invited`.
///
/// # Arguments
/// * `client` - The client used to make requests.
///
/// * `room` - The underlaying room.
/// * `room` - The underlying room.
pub fn new(client: Client, room: BaseRoom) -> Option<Self> {
// TODO: Make this private
if room.room_type() == RoomType::Invited {

View File

@ -48,7 +48,7 @@ const TYPING_NOTICE_RESEND_TIMEOUT: Duration = Duration::from_secs(3);
/// A room in the joined state.
///
/// The `JoinedRoom` contains all methodes specific to a `Room` with type
/// `RoomType::Joined`. Operations may fail once the underlaying `Room` changes
/// `RoomType::Joined`. Operations may fail once the underlying `Room` changes
/// `RoomType`.
#[derive(Debug, Clone)]
pub struct Joined {
@ -64,13 +64,13 @@ impl Deref for Joined {
}
impl Joined {
/// Create a new `room::Joined` if the underlaying `BaseRoom` has type
/// Create a new `room::Joined` if the underlying `BaseRoom` has type
/// `RoomType::Joined`.
///
/// # Arguments
/// * `client` - The client used to make requests.
///
/// * `room` - The underlaying room.
/// * `room` - The underlying room.
pub fn new(client: Client, room: BaseRoom) -> Option<Self> {
// TODO: Make this private
if room.room_type() == RoomType::Joined {

View File

@ -7,7 +7,7 @@ use crate::{room::Common, BaseRoom, Client, Result, RoomType};
/// A room in the left state.
///
/// This struct contains all methodes specific to a `Room` with type
/// `RoomType::Left`. Operations may fail once the underlaying `Room` changes
/// `RoomType::Left`. Operations may fail once the underlying `Room` changes
/// `RoomType`.
#[derive(Debug, Clone)]
pub struct Left {
@ -15,13 +15,13 @@ pub struct Left {
}
impl Left {
/// Create a new `room::Left` if the underlaying `Room` has type
/// Create a new `room::Left` if the underlying `Room` has type
/// `RoomType::Left`.
///
/// # Arguments
/// * `client` - The client used to make requests.
///
/// * `room` - The underlaying room.
/// * `room` - The underlying room.
pub fn new(client: Client, room: BaseRoom) -> Option<Self> {
// TODO: Make this private
if room.room_type() == RoomType::Left {

View File

@ -18,7 +18,7 @@ use matrix_sdk_base::crypto::{
use crate::{error::Result, Client};
/// An object controling the interactive verification flow.
/// An object controlling the interactive verification flow.
#[derive(Debug, Clone)]
pub struct Sas {
pub(crate) inner: BaseSas,

View File

@ -18,7 +18,7 @@ use matrix_sdk_base::crypto::{
use crate::{Client, Result};
/// An object controling the interactive verification flow.
/// An object controlling the interactive verification flow.
#[derive(Debug, Clone)]
pub struct VerificationRequest {
pub(crate) inner: BaseVerificationRequest,

View File

@ -16,6 +16,7 @@ docs = []
[dependencies]
actix-rt = { version = "2", optional = true }
actix-web = { version = "4.0.0-beta.6", optional = true }
dashmap = "4"
futures = "0.3"
futures-util = "0.3"
http = "0.2"

View File

@ -34,9 +34,10 @@ impl EventHandler for AppserviceEventHandler {
if let MembershipState::Invite = event.content.membership {
let user_id = UserId::try_from(event.state_key.clone()).unwrap();
self.appservice.register(user_id.localpart()).await.unwrap();
let mut appservice = self.appservice.clone();
appservice.register(user_id.localpart()).await.unwrap();
let client = self.appservice.client(Some(user_id.localpart())).await.unwrap();
let client = appservice.virtual_user(user_id.localpart()).await.unwrap();
client.join_room_by_id(room.room_id()).await.unwrap();
}
@ -53,7 +54,7 @@ pub async fn main() -> std::io::Result<()> {
let registration =
AppserviceRegistration::try_from_yaml_file("./tests/registration.yaml").unwrap();
let appservice = Appservice::new(homeserver_url, server_name, registration).await.unwrap();
let mut appservice = Appservice::new(homeserver_url, server_name, registration).await.unwrap();
let event_handler = AppserviceEventHandler::new(appservice.clone());

View File

@ -65,7 +65,7 @@ async fn push_transactions(
return Ok(HttpResponse::Unauthorized().finish());
}
appservice.client(None).await?.receive_transaction(request.incoming).await?;
appservice.get_cached_client(None)?.receive_transaction(request.incoming).await?;
Ok(HttpResponse::Ok().json("{}"))
}

View File

@ -16,9 +16,6 @@ use thiserror::Error;
#[derive(Error, Debug)]
pub enum Error {
#[error("tried to run without webserver configured")]
RunWithoutServer,
#[error("missing access token")]
MissingAccessToken,
@ -31,6 +28,9 @@ pub enum Error {
#[error("no port found")]
MissingRegistrationPort,
#[error("no client for localpart found")]
NoClientForLocalpart,
#[error(transparent)]
HttpRequest(#[from] matrix_sdk::FromHttpRequestError),

View File

@ -14,8 +14,9 @@
//! Matrix [Application Service] library
//!
//! The appservice crate aims to provide a batteries-included experience. That
//! means that we
//! The appservice crate aims to provide a batteries-included experience by
//! being a thin wrapper around the [`matrix_sdk`]. That means that we
//!
//! * ship with functionality to configure your webserver crate or simply run
//! the webserver for you
//! * receive and validate requests from the homeserver correctly
@ -57,7 +58,7 @@
//! regex: '@_appservice_.*'
//! ")?;
//!
//! let appservice = Appservice::new(homeserver_url, server_name, registration).await?;
//! let mut appservice = Appservice::new(homeserver_url, server_name, registration).await?;
//! appservice.set_event_handler(Box::new(AppserviceEventHandler)).await?;
//!
//! let (host, port) = appservice.registration().get_host_and_port()?;
@ -81,8 +82,10 @@ use std::{
fs::File,
ops::Deref,
path::PathBuf,
sync::Arc,
};
use dashmap::DashMap;
use http::Uri;
#[doc(inline)]
pub use matrix_sdk::api_appservice as api;
@ -98,8 +101,7 @@ use matrix_sdk::{
assign,
identifiers::{self, DeviceId, ServerNameBox, UserId},
reqwest::Url,
Client, ClientConfig, EventHandler, FromHttpResponseError, HttpError, RequestConfig,
ServerError, Session,
Client, ClientConfig, EventHandler, FromHttpResponseError, HttpError, ServerError, Session,
};
use regex::Regex;
use tracing::warn;
@ -173,34 +175,39 @@ impl Deref for AppserviceRegistration {
}
}
async fn client_session_with_login_restore(
client: &Client,
registration: &AppserviceRegistration,
localpart: impl AsRef<str> + Into<Box<str>>,
server_name: &ServerNameBox,
) -> Result<()> {
let session = Session {
access_token: registration.as_token.clone(),
user_id: UserId::parse_with_server_name(localpart, server_name)?,
device_id: DeviceId::new(),
};
client.restore_login(session).await?;
type Localpart = String;
Ok(())
}
/// The `localpart` of the user associated with the application service via
/// `sender_localpart` in [`AppserviceRegistration`].
///
/// Dummy type for shared documentation
#[allow(dead_code)]
pub type MainUser = ();
/// The application service may specify the virtual user to act as through use
/// of a user_id query string parameter on the request. The user specified in
/// the query string must be covered by one of the [`AppserviceRegistration`]'s
/// `users` namespaces.
///
/// Dummy type for shared documentation
pub type VirtualUser = ();
/// Appservice
#[derive(Debug, Clone)]
pub struct Appservice {
homeserver_url: Url,
server_name: ServerNameBox,
registration: AppserviceRegistration,
client_sender_localpart: Client,
registration: Arc<AppserviceRegistration>,
clients: Arc<DashMap<Localpart, Client>>,
}
impl Appservice {
/// Create new Appservice
///
/// Also creates and caches a [`Client`] for the [`MainUser`].
/// The default [`ClientConfig`] is used, if you want to customize it
/// use [`Self::new_with_config()`] instead.
///
/// # Arguments
///
/// * `homeserver_url` - The homeserver that the client should connect to.
@ -215,28 +222,49 @@ impl Appservice {
server_name: impl TryInto<ServerNameBox, Error = identifiers::Error>,
registration: AppserviceRegistration,
) -> Result<Self> {
let homeserver_url = homeserver_url.try_into()?;
let server_name = server_name.try_into()?;
let client_sender_localpart = Client::new(homeserver_url.clone())?;
client_session_with_login_restore(
&client_sender_localpart,
&registration,
registration.sender_localpart.as_ref(),
&server_name,
let appservice = Self::new_with_config(
homeserver_url,
server_name,
registration,
ClientConfig::default(),
)
.await?;
Ok(Appservice { homeserver_url, server_name, registration, client_sender_localpart })
Ok(appservice)
}
/// Get a [`Client`]
/// Same as [`Self::new()`] but lets you provide a [`ClientConfig`] for the
/// [`Client`]
pub async fn new_with_config(
homeserver_url: impl TryInto<Url, Error = url::ParseError>,
server_name: impl TryInto<ServerNameBox, Error = identifiers::Error>,
registration: AppserviceRegistration,
client_config: ClientConfig,
) -> Result<Self> {
let homeserver_url = homeserver_url.try_into()?;
let server_name = server_name.try_into()?;
let registration = Arc::new(registration);
let clients = Arc::new(DashMap::new());
let sender_localpart = registration.sender_localpart.clone();
let appservice = Appservice { homeserver_url, server_name, registration, clients };
// we cache the [`MainUser`] by default
appservice.virtual_user_with_config(sender_localpart, client_config).await?;
Ok(appservice)
}
/// Create a [`Client`] for the given [`VirtualUser`]'s `localpart`
///
/// Will return a `Client` that's configured to [assert the identity] on all
/// outgoing homeserver requests if `localpart` is given. If not given
/// the `Client` will use the main user associated with this appservice,
/// that is the `sender_localpart` in the [`AppserviceRegistration`]
/// Will create and return a [`Client`] that's configured to [assert the
/// identity] on all outgoing homeserver requests if `localpart` is
/// given.
///
/// This method is a singleton that saves the client internally for re-use
/// based on the `localpart`. The cached [`Client`] can be retrieved either
/// by calling this method again or by calling [`Self::get_cached_client()`]
/// which is non-async convenience wrapper.
///
/// # Arguments
///
@ -244,26 +272,48 @@ impl Appservice {
///
/// [registration]: https://matrix.org/docs/spec/application_service/r0.1.2#registration
/// [assert the identity]: https://matrix.org/docs/spec/application_service/r0.1.2#identity-assertion
pub async fn client(&self, localpart: Option<&str>) -> Result<Client> {
let localpart = localpart.unwrap_or_else(|| self.registration.sender_localpart.as_ref());
pub async fn virtual_user(&self, localpart: impl AsRef<str>) -> Result<Client> {
let client = self.virtual_user_with_config(localpart, ClientConfig::default()).await?;
// The `as_token` in the `Session` maps to the main appservice user
Ok(client)
}
/// Same as [`Self::virtual_user()`] but with the ability to pass in a
/// [`ClientConfig`]
///
/// Since this method is a singleton follow-up calls with different
/// [`ClientConfig`]s will be ignored.
pub async fn virtual_user_with_config(
&self,
localpart: impl AsRef<str>,
config: ClientConfig,
) -> Result<Client> {
// TODO: check if localpart is covered by namespace?
let localpart = localpart.as_ref();
let client = if let Some(client) = self.clients.get(localpart) {
client.clone()
} else {
let user_id = UserId::parse_with_server_name(localpart, &self.server_name)?;
// The `as_token` in the `Session` maps to the [`MainUser`]
// (`sender_localpart`) by default, so we don't need to assert identity
// in that case
let client = if localpart == self.registration.sender_localpart {
self.client_sender_localpart.clone()
} else {
let request_config = RequestConfig::default().assert_identity();
let config = ClientConfig::default().request_config(request_config);
if localpart != self.registration.sender_localpart {
config.get_request_config().assert_identity();
}
let client = Client::new_with_config(self.homeserver_url.clone(), config)?;
client_session_with_login_restore(
&client,
&self.registration,
localpart,
&self.server_name,
)
.await?;
let session = Session {
access_token: self.registration.as_token.clone(),
user_id: user_id.clone(),
// TODO: expose & proper E2EE
device_id: DeviceId::new(),
};
client.restore_login(session).await?;
self.clients.insert(localpart.to_owned(), client.clone());
client
};
@ -271,9 +321,28 @@ impl Appservice {
Ok(client)
}
/// Get cached [`Client`]
///
/// Will return the client for the given `localpart` if previously
/// constructed with [`Self::virtual_user()`] or
/// [`Self::virtual_user_with_config()`].
///
/// If no `localpart` is given it assumes the [`MainUser`]'s `localpart`. If
/// no client for `localpart` is found it will return an Error.
pub fn get_cached_client(&self, localpart: Option<&str>) -> Result<Client> {
let localpart = localpart.unwrap_or_else(|| self.registration.sender_localpart.as_ref());
let entry = self.clients.get(localpart).ok_or(Error::NoClientForLocalpart)?;
Ok(entry.value().clone())
}
/// Convenience wrapper around [`Client::set_event_handler()`]
pub async fn set_event_handler(&self, handler: Box<dyn EventHandler>) -> Result<()> {
let client = self.client(None).await?;
///
/// Attaches the event handler to the [`MainUser`]'s [`Client`]
pub async fn set_event_handler(&mut self, handler: Box<dyn EventHandler>) -> Result<()> {
let client = self.get_cached_client(None)?;
client.set_event_handler(handler).await;
Ok(())
@ -286,13 +355,13 @@ impl Appservice {
///
/// * `localpart` - The localpart of the user to register. Must be covered
/// by the namespaces in the [`Registration`] in order to succeed.
pub async fn register(&self, localpart: impl AsRef<str>) -> Result<()> {
pub async fn register(&mut self, localpart: impl AsRef<str>) -> Result<()> {
let request = assign!(RegistrationRequest::new(), {
username: Some(localpart.as_ref()),
login_type: Some(&LoginType::ApplicationService),
});
let client = self.client(None).await?;
let client = self.get_cached_client(None)?;
match client.register(request).await {
Ok(_) => (),
Err(error) => match error {
@ -328,7 +397,8 @@ impl Appservice {
self.registration.hs_token == hs_token.as_ref()
}
/// Check if given `user_id` is in any of the registration user namespaces
/// Check if given `user_id` is in any of the [`AppserviceRegistration`]'s
/// `users` namespaces
pub fn user_id_is_in_namespace(&self, user_id: impl AsRef<str>) -> Result<bool> {
for user in &self.registration.namespaces.users {
// TODO: precompile on Appservice construction

View File

@ -12,7 +12,7 @@ mod actix {
Appservice::new(
mockito::server_url().as_ref(),
"test.local",
AppserviceRegistration::try_from_yaml_file("./tests/registration.yaml").unwrap(),
AppserviceRegistration::try_from_yaml_str(include_str!("./registration.yaml")).unwrap(),
)
.await
.unwrap()

View File

@ -59,7 +59,7 @@ fn member_json() -> serde_json::Value {
#[async_test]
async fn test_event_handler() -> Result<()> {
let appservice = appservice(None).await?;
let mut appservice = appservice(None).await?;
struct Example {}
@ -87,7 +87,7 @@ async fn test_event_handler() -> Result<()> {
events,
);
appservice.client(None).await?.receive_transaction(incoming).await?;
appservice.get_cached_client(None)?.receive_transaction(incoming).await?;
Ok(())
}
@ -105,7 +105,7 @@ async fn test_transaction() -> Result<()> {
events,
);
appservice.client(None).await?.receive_transaction(incoming).await?;
appservice.get_cached_client(None)?.receive_transaction(incoming).await?;
Ok(())
}

View File

@ -25,6 +25,7 @@ docs = ["encryption", "sled_cryptostore"]
[dependencies]
dashmap = "4.0.2"
lru = "0.6.5"
serde = { version = "1.0.122", features = ["rc"] }
serde_json = "1.0.61"
tracing = "0.1.22"

View File

@ -42,7 +42,8 @@ use matrix_sdk_common::{
events::{
room::member::{MemberEventContent, MembershipState},
AnyGlobalAccountDataEvent, AnyRoomAccountDataEvent, AnyStrippedStateEvent,
AnySyncRoomEvent, AnySyncStateEvent, EventContent, EventType, StateEvent,
AnySyncEphemeralRoomEvent, AnySyncRoomEvent, AnySyncStateEvent, EventContent, EventType,
StateEvent,
},
identifiers::{RoomId, UserId},
instant::Instant,
@ -685,7 +686,7 @@ impl BaseClient {
for room_id in rooms {
if let Some(room) = changes.room_infos.get_mut(room_id) {
room.base_info.dm_target = Some(user_id.clone());
} else if let Some(room) = self.store.get_bare_room(room_id) {
} else if let Some(room) = self.store.get_room(room_id) {
let mut info = room.clone_info();
info.base_info.dm_target = Some(user_id.clone());
changes.add_room(info);
@ -784,6 +785,15 @@ impl BaseClient {
)
.await?;
if let Some(event) =
new_info.ephemeral.events.iter().find_map(|e| match e.deserialize() {
Ok(AnySyncEphemeralRoomEvent::Receipt(event)) => Some(event.content),
_ => None,
})
{
changes.add_receipts(&room_id, event);
}
if new_info.timeline.limited {
room_info.mark_members_missing();
}
@ -931,7 +941,7 @@ impl BaseClient {
async fn apply_changes(&self, changes: &StateChanges) {
for (room_id, room_info) in &changes.room_infos {
if let Some(room) = self.store.get_bare_room(&room_id) {
if let Some(room) = self.store.get_room(&room_id) {
room.update_summary(room_info.clone())
}
}
@ -958,7 +968,7 @@ impl BaseClient {
.collect();
let mut ambiguity_cache = AmbiguityCache::new(self.store.clone());
if let Some(room) = self.store.get_bare_room(room_id) {
if let Some(room) = self.store.get_room(room_id) {
let mut room_info = room.clone_info();
room_info.mark_members_synced();

View File

@ -45,6 +45,7 @@ pub use crate::{
mod client;
mod error;
pub mod media;
mod rooms;
mod session;
mod store;

View File

@ -0,0 +1,216 @@
//! Common types for [media content](https://matrix.org/docs/spec/client_server/r0.6.1#id66).
use matrix_sdk_common::{
api::r0::media::get_content_thumbnail::Method,
events::{
room::{
message::{
AudioMessageEventContent, FileMessageEventContent, ImageMessageEventContent,
LocationMessageEventContent, VideoMessageEventContent,
},
EncryptedFile,
},
sticker::StickerEventContent,
},
identifiers::MxcUri,
UInt,
};
const UNIQUE_SEPARATOR: &str = "_";
/// A trait to uniquely identify values of the same type.
pub trait UniqueKey {
/// A string that uniquely identifies `Self` compared to other values of
/// the same type.
fn unique_key(&self) -> String;
}
/// The requested format of a media file.
#[derive(Clone, Debug)]
pub enum MediaFormat {
/// The file that was uploaded.
File,
/// A thumbnail of the file that was uploaded.
Thumbnail(MediaThumbnailSize),
}
impl UniqueKey for MediaFormat {
fn unique_key(&self) -> String {
match self {
Self::File => "file".into(),
Self::Thumbnail(size) => size.unique_key(),
}
}
}
/// The requested size of a media thumbnail.
#[derive(Clone, Debug)]
pub struct MediaThumbnailSize {
/// The desired resizing method.
pub method: Method,
/// The desired width of the thumbnail. The actual thumbnail may not match
/// the size specified.
pub width: UInt,
/// The desired height of the thumbnail. The actual thumbnail may not match
/// the size specified.
pub height: UInt,
}
impl UniqueKey for MediaThumbnailSize {
fn unique_key(&self) -> String {
format!("{}{}{}x{}", self.method, UNIQUE_SEPARATOR, self.width, self.height)
}
}
/// A request for media data.
#[derive(Clone, Debug)]
pub enum MediaType {
/// A media content URI.
Uri(MxcUri),
/// An encrypted media content.
Encrypted(Box<EncryptedFile>),
}
impl UniqueKey for MediaType {
fn unique_key(&self) -> String {
match self {
Self::Uri(uri) => uri.to_string(),
Self::Encrypted(file) => file.url.to_string(),
}
}
}
/// A request for media data.
#[derive(Clone, Debug)]
pub struct MediaRequest {
/// The type of the media file.
pub media_type: MediaType,
/// The requested format of the media data.
pub format: MediaFormat,
}
impl UniqueKey for MediaRequest {
fn unique_key(&self) -> String {
format!("{}{}{}", self.media_type.unique_key(), UNIQUE_SEPARATOR, self.format.unique_key())
}
}
/// Trait for media event content.
pub trait MediaEventContent {
/// Get the type of the file for `Self`.
///
/// Returns `None` if `Self` has no file.
fn file(&self) -> Option<MediaType>;
/// Get the type of the thumbnail for `Self`.
///
/// Returns `None` if `Self` has no thumbnail.
fn thumbnail(&self) -> Option<MediaType>;
}
impl MediaEventContent for StickerEventContent {
fn file(&self) -> Option<MediaType> {
Some(MediaType::Uri(self.url.clone()))
}
fn thumbnail(&self) -> Option<MediaType> {
None
}
}
impl MediaEventContent for AudioMessageEventContent {
fn file(&self) -> Option<MediaType> {
self.url
.as_ref()
.map(|uri| MediaType::Uri(uri.clone()))
.or_else(|| self.file.as_ref().map(|e| MediaType::Encrypted(e.clone())))
}
fn thumbnail(&self) -> Option<MediaType> {
None
}
}
impl MediaEventContent for FileMessageEventContent {
fn file(&self) -> Option<MediaType> {
self.url
.as_ref()
.map(|uri| MediaType::Uri(uri.clone()))
.or_else(|| self.file.as_ref().map(|e| MediaType::Encrypted(e.clone())))
}
fn thumbnail(&self) -> Option<MediaType> {
self.info.as_ref().and_then(|info| {
if let Some(uri) = info.thumbnail_url.as_ref() {
Some(MediaType::Uri(uri.clone()))
} else {
info.thumbnail_file.as_ref().map(|file| MediaType::Encrypted(file.clone()))
}
})
}
}
impl MediaEventContent for ImageMessageEventContent {
fn file(&self) -> Option<MediaType> {
self.url
.as_ref()
.map(|uri| MediaType::Uri(uri.clone()))
.or_else(|| self.file.as_ref().map(|e| MediaType::Encrypted(e.clone())))
}
fn thumbnail(&self) -> Option<MediaType> {
self.info
.as_ref()
.and_then(|info| {
if let Some(uri) = info.thumbnail_url.as_ref() {
Some(MediaType::Uri(uri.clone()))
} else {
info.thumbnail_file.as_ref().map(|file| MediaType::Encrypted(file.clone()))
}
})
.or_else(|| self.url.as_ref().map(|uri| MediaType::Uri(uri.clone())))
}
}
impl MediaEventContent for VideoMessageEventContent {
fn file(&self) -> Option<MediaType> {
self.url
.as_ref()
.map(|uri| MediaType::Uri(uri.clone()))
.or_else(|| self.file.as_ref().map(|e| MediaType::Encrypted(e.clone())))
}
fn thumbnail(&self) -> Option<MediaType> {
self.info
.as_ref()
.and_then(|info| {
if let Some(uri) = info.thumbnail_url.as_ref() {
Some(MediaType::Uri(uri.clone()))
} else {
info.thumbnail_file.as_ref().map(|file| MediaType::Encrypted(file.clone()))
}
})
.or_else(|| self.url.as_ref().map(|uri| MediaType::Uri(uri.clone())))
}
}
impl MediaEventContent for LocationMessageEventContent {
fn file(&self) -> Option<MediaType> {
None
}
fn thumbnail(&self) -> Option<MediaType> {
self.info.as_ref().and_then(|info| {
if let Some(uri) = info.thumbnail_url.as_ref() {
Some(MediaType::Uri(uri.clone()))
} else {
info.thumbnail_file.as_ref().map(|file| MediaType::Encrypted(file.clone()))
}
})
}
}

View File

@ -62,28 +62,11 @@ impl BaseRoomInfo {
invited_member_count: u64,
heroes: Vec<RoomMember>,
) -> String {
let heroes_count = heroes.len() as u64;
let invited_joined = (invited_member_count + joined_member_count).saturating_sub(1);
if heroes_count >= invited_joined {
let mut names = heroes.iter().take(3).map(|mem| mem.name()).collect::<Vec<&str>>();
// stabilize ordering
names.sort_unstable();
names.join(", ")
} else if heroes_count < invited_joined && invited_joined > 1 {
let mut names = heroes.iter().take(3).map(|mem| mem.name()).collect::<Vec<&str>>();
names.sort_unstable();
// TODO: What length does the spec want us to use here and in
// the `else`?
format!(
"{}, and {} others",
names.join(", "),
(joined_member_count + invited_member_count)
calculate_room_name(
joined_member_count,
invited_member_count,
heroes.iter().take(3).map(|mem| mem.name()).collect::<Vec<&str>>(),
)
} else {
"Empty room".to_string()
}
}
/// Handle a state event for this room and update our info accordingly.
@ -164,3 +147,81 @@ impl Default for BaseRoomInfo {
}
}
}
/// Calculate room name according to step 3 of the [naming algorithm.][spec]
///
/// [spec]: <https://matrix.org/docs/spec/client_server/latest#calculating-the-display-name-for-a-room>
fn calculate_room_name(
joined_member_count: u64,
invited_member_count: u64,
heroes: Vec<&str>,
) -> String {
let heroes_count = heroes.len() as u64;
let invited_joined = invited_member_count + joined_member_count;
let invited_joined_minus_one = invited_joined.saturating_sub(1);
let names = if heroes_count >= invited_joined_minus_one {
let mut names = heroes;
// stabilize ordering
names.sort_unstable();
names.join(", ")
} else if heroes_count < invited_joined_minus_one && invited_joined > 1 {
let mut names = heroes;
names.sort_unstable();
// TODO: What length does the spec want us to use here and in
// the `else`?
format!("{}, and {} others", names.join(", "), (invited_joined - heroes_count))
} else {
"".to_string()
};
// User is alone.
if invited_joined <= 1 {
if names.is_empty() {
"Empty room".to_string()
} else {
format!("Empty room (was {})", names)
}
} else {
names
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_calculate_room_name() {
let mut actual = calculate_room_name(2, 0, vec!["a"]);
assert_eq!("a", actual);
actual = calculate_room_name(3, 0, vec!["a", "b"]);
assert_eq!("a, b", actual);
actual = calculate_room_name(4, 0, vec!["a", "b", "c"]);
assert_eq!("a, b, c", actual);
actual = calculate_room_name(5, 0, vec!["a", "b", "c"]);
assert_eq!("a, b, c, and 2 others", actual);
actual = calculate_room_name(0, 0, vec![]);
assert_eq!("Empty room", actual);
actual = calculate_room_name(1, 0, vec![]);
assert_eq!("Empty room", actual);
actual = calculate_room_name(0, 1, vec![]);
assert_eq!("Empty room", actual);
actual = calculate_room_name(1, 0, vec!["a"]);
assert_eq!("Empty room (was a)", actual);
actual = calculate_room_name(1, 0, vec!["a", "b"]);
assert_eq!("Empty room (was a, b)", actual);
actual = calculate_room_name(1, 0, vec!["a", "b", "c"]);
assert_eq!("Empty room (was a, b, c)", actual);
}
}

View File

@ -24,6 +24,7 @@ use futures::{
use matrix_sdk_common::{
api::r0::sync::sync_events::RoomSummary as RumaSummary,
events::{
receipt::Receipt,
room::{
create::CreateEventContent, encryption::EncryptionEventContent,
guest_access::GuestAccess, history_visibility::HistoryVisibility, join_rules::JoinRule,
@ -32,7 +33,8 @@ use matrix_sdk_common::{
tag::Tags,
AnyRoomAccountDataEvent, AnyStateEventContent, AnySyncStateEvent, EventType,
},
identifiers::{MxcUri, RoomAliasId, RoomId, UserId},
identifiers::{EventId, MxcUri, RoomAliasId, RoomId, UserId},
receipt::ReceiptType,
};
use serde::{Deserialize, Serialize};
use tracing::info;
@ -449,6 +451,24 @@ impl Room {
Ok(None)
}
}
/// Get the read receipt as a `EventId` and `Receipt` tuple for the given
/// `user_id` in this room.
pub async fn user_read_receipt(
&self,
user_id: &UserId,
) -> StoreResult<Option<(EventId, Receipt)>> {
self.store.get_user_room_receipt_event(self.room_id(), ReceiptType::Read, user_id).await
}
/// Get the read receipts as a list of `UserId` and `Receipt` tuples for the
/// given `event_id` in this room.
pub async fn event_read_receipts(
&self,
event_id: &EventId,
) -> StoreResult<Vec<(UserId, Receipt)>> {
self.store.get_event_room_receipt_events(self.room_id(), ReceiptType::Read, event_id).await
}
}
/// The underlying pure data structure for joined and left rooms.
@ -466,7 +486,7 @@ pub struct RoomInfo {
pub summary: RoomSummary,
/// Flag remembering if the room members are synced.
pub members_synced: bool,
/// The prev batch of this room we received durring the last sync.
/// The prev batch of this room we received during the last sync.
pub last_prev_batch: Option<String>,
/// Base room info which holds some basic event contents important for the
/// room state.

View File

@ -18,22 +18,29 @@ use std::{
};
use dashmap::{DashMap, DashSet};
use lru::LruCache;
use matrix_sdk_common::{
async_trait,
events::{
presence::PresenceEvent,
receipt::Receipt,
room::member::{MemberEventContent, MembershipState},
AnyGlobalAccountDataEvent, AnyRoomAccountDataEvent, AnyStrippedStateEvent,
AnySyncStateEvent, EventType,
},
identifiers::{RoomId, UserId},
identifiers::{EventId, MxcUri, RoomId, UserId},
instant::Instant,
locks::Mutex,
receipt::ReceiptType,
Raw,
};
use tracing::info;
use super::{Result, RoomInfo, StateChanges, StateStore};
use crate::deserialized_responses::{MemberEvent, StrippedMemberEvent};
use crate::{
deserialized_responses::{MemberEvent, StrippedMemberEvent},
media::{MediaRequest, UniqueKey},
};
#[derive(Debug, Clone)]
pub struct MemoryStore {
@ -55,6 +62,12 @@ pub struct MemoryStore {
Arc<DashMap<RoomId, DashMap<String, DashMap<String, Raw<AnyStrippedStateEvent>>>>>,
stripped_members: Arc<DashMap<RoomId, DashMap<UserId, StrippedMemberEvent>>>,
presence: Arc<DashMap<UserId, Raw<PresenceEvent>>>,
#[allow(clippy::type_complexity)]
room_user_receipts: Arc<DashMap<RoomId, DashMap<String, DashMap<UserId, (EventId, Receipt)>>>>,
#[allow(clippy::type_complexity)]
room_event_receipts:
Arc<DashMap<RoomId, DashMap<String, DashMap<EventId, DashMap<UserId, Receipt>>>>>,
media: Arc<Mutex<LruCache<String, Vec<u8>>>>,
}
impl MemoryStore {
@ -76,6 +89,9 @@ impl MemoryStore {
stripped_room_state: DashMap::new().into(),
stripped_members: DashMap::new().into(),
presence: DashMap::new().into(),
room_user_receipts: DashMap::new().into(),
room_event_receipts: DashMap::new().into(),
media: Arc::new(Mutex::new(LruCache::new(100))),
}
}
@ -220,6 +236,43 @@ impl MemoryStore {
}
}
for (room, content) in &changes.receipts {
for (event_id, receipts) in &content.0 {
for (receipt_type, receipts) in receipts {
for (user_id, receipt) in receipts {
// Add the receipt to the room user receipts
if let Some((old_event, _)) = self
.room_user_receipts
.entry(room.clone())
.or_insert_with(DashMap::new)
.entry(receipt_type.to_string())
.or_insert_with(DashMap::new)
.insert(user_id.clone(), (event_id.clone(), receipt.clone()))
{
// Remove the old receipt from the room event receipts
if let Some(receipt_map) = self.room_event_receipts.get(room) {
if let Some(event_map) = receipt_map.get(receipt_type.as_ref()) {
if let Some(user_map) = event_map.get_mut(&old_event) {
user_map.remove(user_id);
}
}
}
}
// Add the receipt to the room event receipts
self.room_event_receipts
.entry(room.clone())
.or_insert_with(DashMap::new)
.entry(receipt_type.to_string())
.or_insert_with(DashMap::new)
.entry(event_id.clone())
.or_insert_with(DashMap::new)
.insert(user_id.clone(), receipt.clone());
}
}
}
}
info!("Saved changes in {:?}", now.elapsed());
Ok(())
@ -311,6 +364,68 @@ impl MemoryStore {
.get(room_id)
.and_then(|m| m.get(event_type.as_ref()).map(|e| e.clone())))
}
async fn get_user_room_receipt_event(
&self,
room_id: &RoomId,
receipt_type: ReceiptType,
user_id: &UserId,
) -> Result<Option<(EventId, Receipt)>> {
Ok(self.room_user_receipts.get(room_id).and_then(|m| {
m.get(receipt_type.as_ref()).and_then(|m| m.get(user_id).map(|r| r.clone()))
}))
}
async fn get_event_room_receipt_events(
&self,
room_id: &RoomId,
receipt_type: ReceiptType,
event_id: &EventId,
) -> Result<Vec<(UserId, Receipt)>> {
Ok(self
.room_event_receipts
.get(room_id)
.and_then(|m| {
m.get(receipt_type.as_ref()).and_then(|m| {
m.get(event_id)
.map(|m| m.iter().map(|r| (r.key().clone(), r.value().clone())).collect())
})
})
.unwrap_or_else(Vec::new))
}
async fn add_media_content(&self, request: &MediaRequest, data: Vec<u8>) -> Result<()> {
self.media.lock().await.put(request.unique_key(), data);
Ok(())
}
async fn get_media_content(&self, request: &MediaRequest) -> Result<Option<Vec<u8>>> {
Ok(self.media.lock().await.get(&request.unique_key()).cloned())
}
async fn remove_media_content(&self, request: &MediaRequest) -> Result<()> {
self.media.lock().await.pop(&request.unique_key());
Ok(())
}
async fn remove_media_content_for_uri(&self, uri: &MxcUri) -> Result<()> {
let mut media_store = self.media.lock().await;
let keys: Vec<String> = media_store
.iter()
.filter_map(
|(key, _)| if key.starts_with(&uri.to_string()) { Some(key.clone()) } else { None },
)
.collect();
for key in keys {
media_store.pop(&key);
}
Ok(())
}
}
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
@ -408,4 +523,191 @@ impl StateStore for MemoryStore {
) -> Result<Option<Raw<AnyRoomAccountDataEvent>>> {
self.get_room_account_data_event(room_id, event_type).await
}
async fn get_user_room_receipt_event(
&self,
room_id: &RoomId,
receipt_type: ReceiptType,
user_id: &UserId,
) -> Result<Option<(EventId, Receipt)>> {
self.get_user_room_receipt_event(room_id, receipt_type, user_id).await
}
async fn get_event_room_receipt_events(
&self,
room_id: &RoomId,
receipt_type: ReceiptType,
event_id: &EventId,
) -> Result<Vec<(UserId, Receipt)>> {
self.get_event_room_receipt_events(room_id, receipt_type, event_id).await
}
async fn add_media_content(&self, request: &MediaRequest, data: Vec<u8>) -> Result<()> {
self.add_media_content(request, data).await
}
async fn get_media_content(&self, request: &MediaRequest) -> Result<Option<Vec<u8>>> {
self.get_media_content(request).await
}
async fn remove_media_content(&self, request: &MediaRequest) -> Result<()> {
self.remove_media_content(request).await
}
async fn remove_media_content_for_uri(&self, uri: &MxcUri) -> Result<()> {
self.remove_media_content_for_uri(uri).await
}
}
#[cfg(test)]
#[cfg(not(feature = "sled_state_store"))]
mod test {
use matrix_sdk_common::{
api::r0::media::get_content_thumbnail::Method,
identifiers::{event_id, mxc_uri, room_id, user_id, UserId},
receipt::ReceiptType,
uint,
};
use matrix_sdk_test::async_test;
use serde_json::json;
use super::{MemoryStore, StateChanges};
use crate::media::{MediaFormat, MediaRequest, MediaThumbnailSize, MediaType};
fn user_id() -> UserId {
user_id!("@example:localhost")
}
#[async_test]
async fn test_receipts_saving() {
let store = MemoryStore::new();
let room_id = room_id!("!test:localhost");
let first_event_id = event_id!("$1435641916114394fHBLK:matrix.org");
let second_event_id = event_id!("$fHBLK1435641916114394:matrix.org");
let first_receipt_event = serde_json::from_value(json!({
first_event_id.clone(): {
"m.read": {
user_id(): {
"ts": 1436451550453u64
}
}
}
}))
.unwrap();
let second_receipt_event = serde_json::from_value(json!({
second_event_id.clone(): {
"m.read": {
user_id(): {
"ts": 1436451551453u64
}
}
}
}))
.unwrap();
assert!(store
.get_user_room_receipt_event(&room_id, ReceiptType::Read, &user_id())
.await
.unwrap()
.is_none());
assert!(store
.get_event_room_receipt_events(&room_id, ReceiptType::Read, &first_event_id)
.await
.unwrap()
.is_empty());
assert!(store
.get_event_room_receipt_events(&room_id, ReceiptType::Read, &second_event_id)
.await
.unwrap()
.is_empty());
let mut changes = StateChanges::default();
changes.add_receipts(&room_id, first_receipt_event);
store.save_changes(&changes).await.unwrap();
assert!(store
.get_user_room_receipt_event(&room_id, ReceiptType::Read, &user_id())
.await
.unwrap()
.is_some(),);
assert_eq!(
store
.get_event_room_receipt_events(&room_id, ReceiptType::Read, &first_event_id)
.await
.unwrap()
.len(),
1
);
assert!(store
.get_event_room_receipt_events(&room_id, ReceiptType::Read, &second_event_id)
.await
.unwrap()
.is_empty());
let mut changes = StateChanges::default();
changes.add_receipts(&room_id, second_receipt_event);
store.save_changes(&changes).await.unwrap();
assert!(store
.get_user_room_receipt_event(&room_id, ReceiptType::Read, &user_id())
.await
.unwrap()
.is_some());
assert!(store
.get_event_room_receipt_events(&room_id, ReceiptType::Read, &first_event_id)
.await
.unwrap()
.is_empty());
assert_eq!(
store
.get_event_room_receipt_events(&room_id, ReceiptType::Read, &second_event_id)
.await
.unwrap()
.len(),
1
);
}
#[async_test]
async fn test_media_content() {
let store = MemoryStore::new();
let uri = mxc_uri!("mxc://localhost/media");
let content: Vec<u8> = "somebinarydata".into();
let request_file =
MediaRequest { media_type: MediaType::Uri(uri.clone()), format: MediaFormat::File };
let request_thumbnail = MediaRequest {
media_type: MediaType::Uri(uri.clone()),
format: MediaFormat::Thumbnail(MediaThumbnailSize {
method: Method::Crop,
width: uint!(100),
height: uint!(100),
}),
};
assert!(store.get_media_content(&request_file).await.unwrap().is_none());
assert!(store.get_media_content(&request_thumbnail).await.unwrap().is_none());
store.add_media_content(&request_file, content.clone()).await.unwrap();
assert!(store.get_media_content(&request_file).await.unwrap().is_some());
store.remove_media_content(&request_file).await.unwrap();
assert!(store.get_media_content(&request_file).await.unwrap().is_none());
store.add_media_content(&request_file, content.clone()).await.unwrap();
assert!(store.get_media_content(&request_file).await.unwrap().is_some());
store.add_media_content(&request_thumbnail, content.clone()).await.unwrap();
assert!(store.get_media_content(&request_thumbnail).await.unwrap().is_some());
store.remove_media_content_for_uri(&uri).await.unwrap();
assert!(store.get_media_content(&request_file).await.unwrap().is_none());
assert!(store.get_media_content(&request_thumbnail).await.unwrap().is_none());
}
}

View File

@ -25,11 +25,15 @@ use matrix_sdk_common::{
api::r0::push::get_notifications::Notification,
async_trait,
events::{
presence::PresenceEvent, room::member::MemberEventContent, AnyGlobalAccountDataEvent,
AnyRoomAccountDataEvent, AnyStrippedStateEvent, AnySyncStateEvent, EventContent, EventType,
presence::PresenceEvent,
receipt::{Receipt, ReceiptEventContent},
room::member::MemberEventContent,
AnyGlobalAccountDataEvent, AnyRoomAccountDataEvent, AnyStrippedStateEvent,
AnySyncStateEvent, EventContent, EventType,
},
identifiers::{RoomId, UserId},
identifiers::{EventId, MxcUri, RoomId, UserId},
locks::RwLock,
receipt::ReceiptType,
AsyncTraitDeps, Raw,
};
#[cfg(feature = "sled_state_store")]
@ -37,6 +41,7 @@ use sled::Db;
use crate::{
deserialized_responses::{MemberEvent, StrippedMemberEvent},
media::MediaRequest,
rooms::{RoomInfo, RoomType},
Room, Session,
};
@ -210,6 +215,72 @@ pub trait StateStore: AsyncTraitDeps {
room_id: &RoomId,
event_type: EventType,
) -> Result<Option<Raw<AnyRoomAccountDataEvent>>>;
/// Get an event out of the user room receipt store.
///
/// # Arguments
///
/// * `room_id` - The id of the room for which the receipt should be
/// fetched.
///
/// * `receipt_type` - The type of the receipt.
///
/// * `user_id` - The id of the user for who the receipt should be fetched.
async fn get_user_room_receipt_event(
&self,
room_id: &RoomId,
receipt_type: ReceiptType,
user_id: &UserId,
) -> Result<Option<(EventId, Receipt)>>;
/// Get events out of the event room receipt store.
///
/// # Arguments
///
/// * `room_id` - The id of the room for which the receipts should be
/// fetched.
///
/// * `receipt_type` - The type of the receipts.
///
/// * `event_id` - The id of the event for which the receipts should be
/// fetched.
async fn get_event_room_receipt_events(
&self,
room_id: &RoomId,
receipt_type: ReceiptType,
event_id: &EventId,
) -> Result<Vec<(UserId, Receipt)>>;
/// Add a media file's content in the media store.
///
/// # Arguments
///
/// * `request` - The `MediaRequest` of the file.
///
/// * `content` - The content of the file.
async fn add_media_content(&self, request: &MediaRequest, content: Vec<u8>) -> Result<()>;
/// Get a media file's content out of the media store.
///
/// # Arguments
///
/// * `request` - The `MediaRequest` of the file.
async fn get_media_content(&self, request: &MediaRequest) -> Result<Option<Vec<u8>>>;
/// Removes a media file's content from the media store.
///
/// # Arguments
///
/// * `request` - The `MediaRequest` of the file.
async fn remove_media_content(&self, request: &MediaRequest) -> Result<()>;
/// Removes all the media files' content associated to an `MxcUri` from the
/// media store.
///
/// # Arguments
///
/// * `uri` - The `MxcUri` of the media files.
async fn remove_media_content_for_uri(&self, uri: &MxcUri) -> Result<()>;
}
/// A state store wrapper for the SDK.
@ -291,11 +362,6 @@ impl Store {
Ok((Self::new(Box::new(inner.clone())), inner.inner))
}
pub(crate) fn get_bare_room(&self, room_id: &RoomId) -> Option<Room> {
#[allow(clippy::map_clone)]
self.rooms.get(room_id).map(|r| r.clone())
}
/// Get all the rooms this store knows about.
pub fn get_rooms(&self) -> Vec<Room> {
self.rooms.iter().filter_map(|r| self.get_room(r.key())).collect()
@ -303,15 +369,17 @@ impl Store {
/// Get the room with the given room id.
pub fn get_room(&self, room_id: &RoomId) -> Option<Room> {
self.get_bare_room(room_id).and_then(|r| match r.room_type() {
RoomType::Joined => Some(r),
RoomType::Left => Some(r),
self.rooms
.get(room_id)
.and_then(|r| match r.room_type() {
RoomType::Joined => Some(r.clone()),
RoomType::Left => Some(r.clone()),
RoomType::Invited => self.get_stripped_room(room_id),
})
.or_else(|| self.get_stripped_room(room_id))
}
fn get_stripped_room(&self, room_id: &RoomId) -> Option<Room> {
#[allow(clippy::map_clone)]
self.stripped_rooms.get(room_id).map(|r| r.clone())
}
@ -369,6 +437,8 @@ pub struct StateChanges {
pub room_account_data: BTreeMap<RoomId, BTreeMap<String, Raw<AnyRoomAccountDataEvent>>>,
/// A map of `RoomId` to `RoomInfo`.
pub room_infos: BTreeMap<RoomId, RoomInfo>,
/// A map of `RoomId` to `ReceiptEventContent`.
pub receipts: BTreeMap<RoomId, ReceiptEventContent>,
/// A mapping of `RoomId` to a map of event type to a map of state key to
/// `AnyStrippedStateEvent`.
@ -404,7 +474,7 @@ impl StateChanges {
/// Update the `StateChanges` struct with the given `RoomInfo`.
pub fn add_stripped_room(&mut self, room: RoomInfo) {
self.invited_room_info.insert(room.room_id.as_ref().to_owned(), room);
self.room_infos.insert(room.room_id.as_ref().to_owned(), room);
}
/// Update the `StateChanges` struct with the given `AnyBasicEvent`.
@ -462,4 +532,10 @@ impl StateChanges {
pub fn add_notification(&mut self, room_id: &RoomId, notification: Notification) {
self.notifications.entry(room_id.to_owned()).or_insert_with(Vec::new).push(notification);
}
/// Update the `StateChanges` struct with the given room with a new
/// `Receipts`.
pub fn add_receipts(&mut self, room_id: &RoomId, event: ReceiptEventContent) {
self.receipts.insert(room_id.to_owned(), event);
}
}

View File

@ -16,7 +16,7 @@ mod store_key;
use std::{
collections::BTreeSet,
convert::TryFrom,
convert::{TryFrom, TryInto},
path::{Path, PathBuf},
sync::Arc,
time::Instant,
@ -30,10 +30,12 @@ use matrix_sdk_common::{
async_trait,
events::{
presence::PresenceEvent,
receipt::Receipt,
room::member::{MemberEventContent, MembershipState},
AnyGlobalAccountDataEvent, AnyRoomAccountDataEvent, AnySyncStateEvent, EventType,
},
identifiers::{RoomId, UserId},
identifiers::{EventId, MxcUri, RoomId, UserId},
receipt::ReceiptType,
Raw,
};
use serde::{Deserialize, Serialize};
@ -45,7 +47,10 @@ use tracing::info;
use self::store_key::{EncryptedEvent, StoreKey};
use super::{Result, RoomInfo, StateChanges, StateStore, StoreError};
use crate::deserialized_responses::MemberEvent;
use crate::{
deserialized_responses::MemberEvent,
media::{MediaRequest, UniqueKey},
};
#[derive(Debug, Serialize, Deserialize)]
pub enum DatabaseType {
@ -127,12 +132,41 @@ impl EncodeKey for (&str, &str, &str) {
}
}
impl EncodeKey for (&str, &str, &str, &str) {
fn encode(&self) -> Vec<u8> {
[
self.0.as_bytes(),
&[ENCODE_SEPARATOR],
self.1.as_bytes(),
&[ENCODE_SEPARATOR],
self.2.as_bytes(),
&[ENCODE_SEPARATOR],
self.3.as_bytes(),
&[ENCODE_SEPARATOR],
]
.concat()
}
}
impl EncodeKey for EventType {
fn encode(&self) -> Vec<u8> {
self.as_str().encode()
}
}
/// Get the value at `position` in encoded `key`.
///
/// The key must have been encoded with the `EncodeKey` trait. `position`
/// corresponds to the position in the tuple before the key was encoded. If it
/// wasn't encoded in a tuple, use `0`.
///
/// Returns `None` if there is no key at `position`.
pub fn decode_key_value(key: &[u8], position: usize) -> Option<String> {
let values: Vec<&[u8]> = key.split(|v| *v == ENCODE_SEPARATOR).collect();
values.get(position).map(|s| String::from_utf8_lossy(s).to_string())
}
#[derive(Clone)]
pub struct SledStore {
path: Option<PathBuf>,
@ -152,6 +186,9 @@ pub struct SledStore {
stripped_room_state: Tree,
stripped_members: Tree,
presence: Tree,
room_user_receipts: Tree,
room_event_receipts: Tree,
media: Tree,
}
impl std::fmt::Debug for SledStore {
@ -184,6 +221,11 @@ impl SledStore {
let stripped_members = db.open_tree("stripped_members")?;
let stripped_room_state = db.open_tree("stripped_room_state")?;
let room_user_receipts = db.open_tree("room_user_receipts")?;
let room_event_receipts = db.open_tree("room_event_receipts")?;
let media = db.open_tree("media")?;
Ok(Self {
path,
inner: db,
@ -202,6 +244,9 @@ impl SledStore {
stripped_room_info,
stripped_members,
stripped_room_state,
room_user_receipts,
room_event_receipts,
media,
})
}
@ -459,6 +504,58 @@ impl SledStore {
ret?;
let ret: Result<(), TransactionError<SerializationError>> =
(&self.room_user_receipts, &self.room_event_receipts).transaction(
|(room_user_receipts, room_event_receipts)| {
for (room, content) in &changes.receipts {
for (event_id, receipts) in &content.0 {
for (receipt_type, receipts) in receipts {
for (user_id, receipt) in receipts {
// Add the receipt to the room user receipts
if let Some(old) = room_user_receipts.insert(
(room.as_str(), receipt_type.as_ref(), user_id.as_str())
.encode(),
self.serialize_event(&(event_id, receipt))
.map_err(ConflictableTransactionError::Abort)?,
)? {
// Remove the old receipt from the room event receipts
let (old_event, _): (EventId, Receipt) = self
.deserialize_event(&old)
.map_err(ConflictableTransactionError::Abort)?;
room_event_receipts.remove(
(
room.as_str(),
receipt_type.as_ref(),
old_event.as_str(),
user_id.as_str(),
)
.encode(),
)?;
}
// Add the receipt to the room event receipts
room_event_receipts.insert(
(
room.as_str(),
receipt_type.as_ref(),
event_id.as_str(),
user_id.as_str(),
)
.encode(),
self.serialize_event(receipt)
.map_err(ConflictableTransactionError::Abort)?,
)?;
}
}
}
}
Ok(())
},
);
ret?;
self.inner.flush_async().await?;
info!("Saved changes in {:?}", now.elapsed());
@ -598,6 +695,79 @@ impl SledStore {
.map(|m| self.deserialize_event(&m))
.transpose()?)
}
async fn get_user_room_receipt_event(
&self,
room_id: &RoomId,
receipt_type: ReceiptType,
user_id: &UserId,
) -> Result<Option<(EventId, Receipt)>> {
Ok(self
.room_user_receipts
.get((room_id.as_str(), receipt_type.as_ref(), user_id.as_str()).encode())?
.map(|m| self.deserialize_event(&m))
.transpose()?)
}
async fn get_event_room_receipt_events(
&self,
room_id: &RoomId,
receipt_type: ReceiptType,
event_id: &EventId,
) -> Result<Vec<(UserId, Receipt)>> {
self.room_event_receipts
.scan_prefix((room_id.as_str(), receipt_type.as_ref(), event_id.as_str()).encode())
.map(|u| {
u.map_err(StoreError::Sled).and_then(|(key, value)| {
self.deserialize_event(&value)
.map(|receipt| {
(decode_key_value(&key, 3).unwrap().try_into().unwrap(), receipt)
})
.map_err(Into::into)
})
})
.collect()
}
async fn add_media_content(&self, request: &MediaRequest, data: Vec<u8>) -> Result<()> {
self.media.insert(
(request.media_type.unique_key().as_str(), request.format.unique_key().as_str())
.encode(),
data,
)?;
Ok(())
}
async fn get_media_content(&self, request: &MediaRequest) -> Result<Option<Vec<u8>>> {
Ok(self
.media
.get(
(request.media_type.unique_key().as_str(), request.format.unique_key().as_str())
.encode(),
)?
.map(|m| m.to_vec()))
}
async fn remove_media_content(&self, request: &MediaRequest) -> Result<()> {
self.media.remove(
(request.media_type.unique_key().as_str(), request.format.unique_key().as_str())
.encode(),
)?;
Ok(())
}
async fn remove_media_content_for_uri(&self, uri: &MxcUri) -> Result<()> {
let keys = self.media.scan_prefix(uri.as_str().encode()).keys();
let mut batch = sled::Batch::default();
for key in keys {
batch.remove(key?);
}
Ok(self.media.apply_batch(batch)?)
}
}
#[async_trait]
@ -689,6 +859,40 @@ impl StateStore for SledStore {
) -> Result<Option<Raw<AnyRoomAccountDataEvent>>> {
self.get_room_account_data_event(room_id, event_type).await
}
async fn get_user_room_receipt_event(
&self,
room_id: &RoomId,
receipt_type: ReceiptType,
user_id: &UserId,
) -> Result<Option<(EventId, Receipt)>> {
self.get_user_room_receipt_event(room_id, receipt_type, user_id).await
}
async fn get_event_room_receipt_events(
&self,
room_id: &RoomId,
receipt_type: ReceiptType,
event_id: &EventId,
) -> Result<Vec<(UserId, Receipt)>> {
self.get_event_room_receipt_events(room_id, receipt_type, event_id).await
}
async fn add_media_content(&self, request: &MediaRequest, data: Vec<u8>) -> Result<()> {
self.add_media_content(request, data).await
}
async fn get_media_content(&self, request: &MediaRequest) -> Result<Option<Vec<u8>>> {
self.get_media_content(request).await
}
async fn remove_media_content(&self, request: &MediaRequest) -> Result<()> {
self.remove_media_content(request).await
}
async fn remove_media_content_for_uri(&self, uri: &MxcUri) -> Result<()> {
self.remove_media_content_for_uri(uri).await
}
}
#[cfg(test)]
@ -696,6 +900,7 @@ mod test {
use std::convert::TryFrom;
use matrix_sdk_common::{
api::r0::media::get_content_thumbnail::Method,
events::{
room::{
member::{MemberEventContent, MembershipState},
@ -703,14 +908,19 @@ mod test {
},
AnySyncStateEvent, EventType, Unsigned,
},
identifiers::{room_id, user_id, EventId, UserId},
MilliSecondsSinceUnixEpoch, Raw,
identifiers::{event_id, mxc_uri, room_id, user_id, EventId, UserId},
receipt::ReceiptType,
uint, MilliSecondsSinceUnixEpoch, Raw,
};
use matrix_sdk_test::async_test;
use serde_json::json;
use super::{SledStore, StateChanges};
use crate::{deserialized_responses::MemberEvent, StateStore};
use crate::{
deserialized_responses::MemberEvent,
media::{MediaFormat, MediaRequest, MediaThumbnailSize, MediaType},
StateStore,
};
fn user_id() -> UserId {
user_id!("@example:localhost")
@ -788,4 +998,137 @@ mod test {
.unwrap()
.is_some());
}
#[async_test]
async fn test_receipts_saving() {
let store = SledStore::open().unwrap();
let room_id = room_id!("!test:localhost");
let first_event_id = event_id!("$1435641916114394fHBLK:matrix.org");
let second_event_id = event_id!("$fHBLK1435641916114394:matrix.org");
let first_receipt_event = serde_json::from_value(json!({
first_event_id.clone(): {
"m.read": {
user_id(): {
"ts": 1436451550453u64
}
}
}
}))
.unwrap();
let second_receipt_event = serde_json::from_value(json!({
second_event_id.clone(): {
"m.read": {
user_id(): {
"ts": 1436451551453u64
}
}
}
}))
.unwrap();
assert!(store
.get_user_room_receipt_event(&room_id, ReceiptType::Read, &user_id())
.await
.unwrap()
.is_none());
assert!(store
.get_event_room_receipt_events(&room_id, ReceiptType::Read, &first_event_id)
.await
.unwrap()
.is_empty());
assert!(store
.get_event_room_receipt_events(&room_id, ReceiptType::Read, &second_event_id)
.await
.unwrap()
.is_empty());
let mut changes = StateChanges::default();
changes.add_receipts(&room_id, first_receipt_event);
store.save_changes(&changes).await.unwrap();
assert!(store
.get_user_room_receipt_event(&room_id, ReceiptType::Read, &user_id())
.await
.unwrap()
.is_some(),);
assert_eq!(
store
.get_event_room_receipt_events(&room_id, ReceiptType::Read, &first_event_id)
.await
.unwrap()
.len(),
1
);
assert!(store
.get_event_room_receipt_events(&room_id, ReceiptType::Read, &second_event_id)
.await
.unwrap()
.is_empty());
let mut changes = StateChanges::default();
changes.add_receipts(&room_id, second_receipt_event);
store.save_changes(&changes).await.unwrap();
assert!(store
.get_user_room_receipt_event(&room_id, ReceiptType::Read, &user_id())
.await
.unwrap()
.is_some());
assert!(store
.get_event_room_receipt_events(&room_id, ReceiptType::Read, &first_event_id)
.await
.unwrap()
.is_empty());
assert_eq!(
store
.get_event_room_receipt_events(&room_id, ReceiptType::Read, &second_event_id)
.await
.unwrap()
.len(),
1
);
}
#[async_test]
async fn test_media_content() {
let store = SledStore::open().unwrap();
let uri = mxc_uri!("mxc://localhost/media");
let content: Vec<u8> = "somebinarydata".into();
let request_file =
MediaRequest { media_type: MediaType::Uri(uri.clone()), format: MediaFormat::File };
let request_thumbnail = MediaRequest {
media_type: MediaType::Uri(uri.clone()),
format: MediaFormat::Thumbnail(MediaThumbnailSize {
method: Method::Crop,
width: uint!(100),
height: uint!(100),
}),
};
assert!(store.get_media_content(&request_file).await.unwrap().is_none());
assert!(store.get_media_content(&request_thumbnail).await.unwrap().is_none());
store.add_media_content(&request_file, content.clone()).await.unwrap();
assert!(store.get_media_content(&request_file).await.unwrap().is_some());
store.remove_media_content(&request_file).await.unwrap();
assert!(store.get_media_content(&request_file).await.unwrap().is_none());
store.add_media_content(&request_file, content.clone()).await.unwrap();
assert!(store.get_media_content(&request_file).await.unwrap().is_some());
store.add_media_content(&request_thumbnail, content.clone()).await.unwrap();
assert!(store.get_media_content(&request_thumbnail).await.unwrap().is_some());
store.remove_media_content_for_uri(&uri).await.unwrap();
assert!(store.get_media_content(&request_file).await.unwrap().is_none());
assert!(store.get_media_content(&request_thumbnail).await.unwrap().is_none());
}
}

View File

@ -23,7 +23,7 @@ use aes_ctr::{
};
use base64::DecodeError;
use getrandom::getrandom;
use matrix_sdk_common::events::room::{JsonWebKey, JsonWebKeyInit};
use matrix_sdk_common::events::room::{EncryptedFile, JsonWebKey, JsonWebKeyInit};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use thiserror::Error;
@ -252,6 +252,12 @@ pub struct EncryptionInfo {
pub hashes: BTreeMap<String, String>,
}
impl From<EncryptedFile> for EncryptionInfo {
fn from(file: EncryptedFile) -> Self {
Self { version: file.v, web_key: file.key, iv: file.iv, hashes: file.hashes }
}
}
#[cfg(test)]
mod test {
use std::io::{Cursor, Read};

View File

@ -364,7 +364,7 @@ impl EventBuilder {
}
}
/// Embedded sync reponse files
/// Embedded sync response files
pub enum SyncResponseFile {
All,
Default,

View File

@ -42,3 +42,29 @@ lazy_static! {
]
});
}
lazy_static! {
pub static ref WELL_KNOWN: JsonValue = json!({
"m.homeserver": {
"base_url": "HOMESERVER_URL"
}
});
}
lazy_static! {
pub static ref VERSIONS: JsonValue = json!({
"versions": [
"r0.0.1",
"r0.1.0",
"r0.2.0",
"r0.3.0",
"r0.4.0",
"r0.5.0",
"r0.6.0"
],
"unstable_features": {
"org.matrix.label_based_filtering":true,
"org.matrix.e2e_cross_signing":true
}
});
}

View File

@ -720,7 +720,7 @@ lazy_static! {
lazy_static! {
pub static ref INVITE_SYNC: JsonValue = json!({
"device_one_time_keys_count": {},
"next_batch": "s526_47314_0_7_1_1_1_11444_1",
"next_batch": "s526_47314_0_7_1_1_1_11444_2",
"device_lists": {
"changed": [
"@example:example.org"