Merge branch 'master' into sas-verification

master
Damir Jelić 2020-07-22 11:18:26 +02:00
commit e612326714
40 changed files with 3092 additions and 2182 deletions

View File

@ -7,26 +7,46 @@ addons:
jobs: jobs:
allow_failures: allow_failures:
- os: windows - os: osx
- os: linux name: macOS 10.15
name: wasm32-unknown-unknown
include: include:
- stage: Lint - stage: Format
os: linux os: linux
before_script: before_script:
- rustup component add rustfmt - rustup component add rustfmt
script: script:
- cargo fmt --all -- --check - cargo fmt --all -- --check
- stage: Clippy
os: linux
before_script:
- rustup component add clippy
script:
- cargo clippy --all-targets --all-features -- -D warnings
- stage: Test - stage: Test
os: linux os: linux
dist: bionic
- os: windows - os: windows
script:
- cd matrix_sdk
- cargo test --no-default-features --features "messages"
- cd ../matrix_sdk_base
- cargo test --no-default-features --features "messages"
- os: osx - os: osx
- os: linux
name: Minimal build
script:
- cd matrix_sdk
- cargo build --no-default-features
- os: osx
name: macOS 10.15
osx_image: xcode12
- os: linux - os: linux
name: Coverage name: Coverage
before_script: before_script:

View File

@ -21,7 +21,7 @@ http = "0.2.1"
reqwest = "0.10.6" reqwest = "0.10.6"
serde_json = "1.0.56" serde_json = "1.0.56"
thiserror = "1.0.20" thiserror = "1.0.20"
tracing = "0.1.15" tracing = "0.1.16"
url = "2.1.1" url = "2.1.1"
matrix-sdk-common-macros = { version = "0.1.0", path = "../matrix_sdk_common_macros" } matrix-sdk-common-macros = { version = "0.1.0", path = "../matrix_sdk_common_macros" }
@ -37,8 +37,8 @@ version = "0.2.4"
default-features = false default-features = false
features = ["std", "std-future"] features = ["std", "std-future"]
[target.'cfg(not(target_arch = "wasm32"))'.dependencies.futures-timer] [target.'cfg(not(target_arch = "wasm32"))'.dependencies]
version = "3.0.2" futures-timer = "3.0.2"
[target.'cfg(target_arch = "wasm32")'.dependencies.futures-timer] [target.'cfg(target_arch = "wasm32")'.dependencies.futures-timer]
version = "3.0.2" version = "3.0.2"

View File

@ -2,7 +2,7 @@ use std::{env, process::exit};
use matrix_sdk::{ use matrix_sdk::{
self, self,
events::{room::member::MemberEventContent, StrippedStateEventStub}, events::{room::member::MemberEventContent, StrippedStateEvent},
Client, ClientConfig, EventEmitter, SyncRoom, SyncSettings, Client, ClientConfig, EventEmitter, SyncRoom, SyncSettings,
}; };
use matrix_sdk_common_macros::async_trait; use matrix_sdk_common_macros::async_trait;
@ -23,7 +23,7 @@ impl EventEmitter for AutoJoinBot {
async fn on_stripped_state_member( async fn on_stripped_state_member(
&self, &self,
room: SyncRoom, room: SyncRoom,
room_member: &StrippedStateEventStub<MemberEventContent>, room_member: &StrippedStateEvent<MemberEventContent>,
_: Option<MemberEventContent>, _: Option<MemberEventContent>,
) { ) {
if room_member.state_key != self.client.user_id().await.unwrap() { if room_member.state_key != self.client.user_id().await.unwrap() {

View File

@ -4,7 +4,7 @@ use matrix_sdk::{
self, self,
events::{ events::{
room::message::{MessageEventContent, TextMessageEventContent}, room::message::{MessageEventContent, TextMessageEventContent},
MessageEventStub, SyncMessageEvent,
}, },
Client, ClientConfig, EventEmitter, JsonStore, SyncRoom, SyncSettings, Client, ClientConfig, EventEmitter, JsonStore, SyncRoom, SyncSettings,
}; };
@ -25,9 +25,9 @@ impl CommandBot {
#[async_trait] #[async_trait]
impl EventEmitter for CommandBot { impl EventEmitter for CommandBot {
async fn on_room_message(&self, room: SyncRoom, event: &MessageEventStub<MessageEventContent>) { async fn on_room_message(&self, room: SyncRoom, event: &SyncMessageEvent<MessageEventContent>) {
if let SyncRoom::Joined(room) = room { if let SyncRoom::Joined(room) = room {
let msg_body = if let MessageEventStub { let msg_body = if let SyncMessageEvent {
content: MessageEventContent::Text(TextMessageEventContent { body: msg_body, .. }), content: MessageEventContent::Text(TextMessageEventContent { body: msg_body, .. }),
.. ..
} = event } = event

View File

@ -5,7 +5,7 @@ use matrix_sdk::{
self, self,
events::{ events::{
room::message::{MessageEventContent, TextMessageEventContent}, room::message::{MessageEventContent, TextMessageEventContent},
MessageEventStub, SyncMessageEvent,
}, },
Client, ClientConfig, EventEmitter, SyncRoom, SyncSettings, Client, ClientConfig, EventEmitter, SyncRoom, SyncSettings,
}; };
@ -15,9 +15,9 @@ struct EventCallback;
#[async_trait] #[async_trait]
impl EventEmitter for EventCallback { impl EventEmitter for EventCallback {
async fn on_room_message(&self, room: SyncRoom, event: &MessageEventStub<MessageEventContent>) { async fn on_room_message(&self, room: SyncRoom, event: &SyncMessageEvent<MessageEventContent>) {
if let SyncRoom::Joined(room) = room { if let SyncRoom::Joined(room) = room {
if let MessageEventStub { if let SyncMessageEvent {
content: MessageEventContent::Text(TextMessageEventContent { body: msg_body, .. }), content: MessageEventContent::Text(TextMessageEventContent { body: msg_body, .. }),
sender, sender,
.. ..

View File

@ -2,7 +2,7 @@ use matrix_sdk::{
api::r0::sync::sync_events::Response as SyncResponse, api::r0::sync::sync_events::Response as SyncResponse,
events::{ events::{
room::message::{MessageEventContent, TextMessageEventContent}, room::message::{MessageEventContent, TextMessageEventContent},
AnyMessageEventStub, AnyRoomEventStub, MessageEventStub, AnySyncMessageEvent, AnySyncRoomEvent, SyncMessageEvent,
}, },
identifiers::RoomId, identifiers::RoomId,
Client, ClientConfig, SyncSettings, Client, ClientConfig, SyncSettings,
@ -17,9 +17,9 @@ impl WasmBot {
async fn on_room_message( async fn on_room_message(
&self, &self,
room_id: &RoomId, room_id: &RoomId,
event: MessageEventStub<MessageEventContent>, event: SyncMessageEvent<MessageEventContent>,
) { ) {
let msg_body = if let MessageEventStub { let msg_body = if let SyncMessageEvent {
content: MessageEventContent::Text(TextMessageEventContent { body: msg_body, .. }), content: MessageEventContent::Text(TextMessageEventContent { body: msg_body, .. }),
.. ..
} = event } = event
@ -45,7 +45,7 @@ impl WasmBot {
for (room_id, room) in response.rooms.join { for (room_id, room) in response.rooms.join {
for event in room.timeline.events { for event in room.timeline.events {
if let Ok(event) = event.deserialize() { if let Ok(event) = event.deserialize() {
if let AnyRoomEventStub::Message(AnyMessageEventStub::RoomMessage(ev)) = event { if let AnySyncRoomEvent::Message(AnySyncMessageEvent::RoomMessage(ev)) = event {
self.on_room_message(&room_id, ev).await self.on_room_message(&room_id, ev).await
} }
} }

View File

@ -31,7 +31,7 @@ use futures_timer::Delay as sleep;
use std::future::Future; use std::future::Future;
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
use tracing::{debug, warn}; use tracing::{debug, warn};
use tracing::{info, instrument, trace}; use tracing::{error, info, instrument, trace};
use http::Method as HttpMethod; use http::Method as HttpMethod;
use http::Response as HttpResponse; use http::Response as HttpResponse;
@ -105,6 +105,7 @@ pub struct ClientConfig {
user_agent: Option<HeaderValue>, user_agent: Option<HeaderValue>,
disable_ssl_verification: bool, disable_ssl_verification: bool,
base_config: BaseClientConfig, base_config: BaseClientConfig,
timeout: Option<Duration>,
} }
// #[cfg_attr(tarpaulin, skip)] // #[cfg_attr(tarpaulin, skip)]
@ -198,11 +199,18 @@ impl ClientConfig {
self.base_config = self.base_config.passphrase(passphrase); self.base_config = self.base_config.passphrase(passphrase);
self self
} }
/// Set a timeout duration for all HTTP requests. The default is no timeout.
pub fn timeout(mut self, timeout: Duration) -> Self {
self.timeout = Some(timeout);
self
}
} }
#[derive(Debug, Default, Clone)] #[derive(Debug, Default, Clone)]
/// Settings for a sync call. /// Settings for a sync call.
pub struct SyncSettings { pub struct SyncSettings {
pub(crate) filter: Option<sync_events::Filter>,
pub(crate) timeout: Option<Duration>, pub(crate) timeout: Option<Duration>,
pub(crate) token: Option<String>, pub(crate) token: Option<String>,
pub(crate) full_state: bool, pub(crate) full_state: bool,
@ -235,6 +243,17 @@ impl SyncSettings {
self self
} }
/// Set the sync filter.
/// It can be either the filter ID, or the definition for the filter.
///
/// # Arguments
///
/// * `filter` - The filter configuration that should be used for the sync call.
pub fn filter(mut self, filter: sync_events::Filter) -> Self {
self.filter = Some(filter);
self
}
/// Should the server return the full state from the start of the timeline. /// Should the server return the full state from the start of the timeline.
/// ///
/// This does nothing if no sync token is set. /// This does nothing if no sync token is set.
@ -302,6 +321,11 @@ impl Client {
#[cfg(not(target_arch = "wasm32"))] #[cfg(not(target_arch = "wasm32"))]
let http_client = { let http_client = {
let http_client = match config.timeout {
Some(x) => http_client.timeout(x),
None => http_client,
};
let http_client = if config.disable_ssl_verification { let http_client = if config.disable_ssl_verification {
http_client.danger_accept_invalid_certs(true) http_client.danger_accept_invalid_certs(true)
} else { } else {
@ -448,7 +472,7 @@ impl Client {
login_info: login::LoginInfo::Password { login_info: login::LoginInfo::Password {
password: password.into(), password: password.into(),
}, },
device_id: device_id.map(|d| d.into()), device_id: device_id.map(|d| d.into().into_boxed_str()),
initial_device_display_name: initial_device_display_name.map(|d| d.into()), initial_device_display_name: initial_device_display_name.map(|d| d.into()),
}; };
@ -1222,7 +1246,7 @@ impl Client {
#[instrument] #[instrument]
pub async fn sync(&self, sync_settings: SyncSettings) -> Result<sync_events::Response> { pub async fn sync(&self, sync_settings: SyncSettings) -> Result<sync_events::Response> {
let request = sync_events::Request { let request = sync_events::Request {
filter: None, filter: sync_settings.filter,
since: sync_settings.token, since: sync_settings.token,
full_state: sync_settings.full_state, full_state: sync_settings.full_state,
set_presence: sync_events::SetPresence::Online, set_presence: sync_events::SetPresence::Online,
@ -1302,6 +1326,7 @@ impl Client {
C: Future<Output = ()>, C: Future<Output = ()>,
{ {
let mut sync_settings = sync_settings; let mut sync_settings = sync_settings;
let filter = sync_settings.filter.clone();
let mut last_sync_time: Option<Instant> = None; let mut last_sync_time: Option<Instant> = None;
if sync_settings.token.is_none() { if sync_settings.token.is_none() {
@ -1311,12 +1336,13 @@ impl Client {
loop { loop {
let response = self.sync(sync_settings.clone()).await; let response = self.sync(sync_settings.clone()).await;
let response = if let Ok(r) = response { let response = match response {
r Ok(r) => r,
} else { Err(e) => {
error!("Received an invalid response: {}", e);
sleep::new(Duration::from_secs(1)).await; sleep::new(Duration::from_secs(1)).await;
continue; continue;
}
}; };
// TODO send out to-device messages here // TODO send out to-device messages here
@ -1360,6 +1386,9 @@ impl Client {
.await .await
.expect("No sync token found after initial sync"), .expect("No sync token found after initial sync"),
); );
if let Some(f) = filter.as_ref() {
sync_settings = sync_settings.filter(f.clone());
}
} }
} }
@ -1378,7 +1407,7 @@ impl Client {
#[instrument] #[instrument]
async fn claim_one_time_keys( async fn claim_one_time_keys(
&self, &self,
one_time_keys: BTreeMap<UserId, BTreeMap<DeviceId, KeyAlgorithm>>, one_time_keys: BTreeMap<UserId, BTreeMap<Box<DeviceId>, KeyAlgorithm>>,
) -> Result<claim_keys::Response> { ) -> Result<claim_keys::Response> {
let request = claim_keys::Request { let request = claim_keys::Request {
timeout: None, timeout: None,
@ -1482,7 +1511,7 @@ impl Client {
users_for_query users_for_query
); );
let mut device_keys: BTreeMap<UserId, Vec<DeviceId>> = BTreeMap::new(); let mut device_keys: BTreeMap<UserId, Vec<Box<DeviceId>>> = BTreeMap::new();
for user in users_for_query.drain() { for user in users_for_query.drain() {
device_keys.insert(user, Vec::new()); device_keys.insert(user, Vec::new());

View File

@ -40,6 +40,8 @@ pub use matrix_sdk_base::Error as BaseError;
#[cfg(not(target_arch = "wasm32"))] #[cfg(not(target_arch = "wasm32"))]
pub use matrix_sdk_base::JsonStore; pub use matrix_sdk_base::JsonStore;
pub use matrix_sdk_base::{CustomOrRawEvent, EventEmitter, Room, Session, SyncRoom}; pub use matrix_sdk_base::{CustomOrRawEvent, EventEmitter, Room, Session, SyncRoom};
#[cfg(feature = "messages")]
pub use matrix_sdk_base::{MessageQueue, MessageWrapper, PossiblyRedactedExt};
pub use matrix_sdk_base::{RoomState, StateStore}; pub use matrix_sdk_base::{RoomState, StateStore};
pub use matrix_sdk_common::*; pub use matrix_sdk_common::*;
pub use reqwest::header::InvalidHeaderValue; pub use reqwest::header::InvalidHeaderValue;

View File

@ -152,6 +152,12 @@ impl RoomBuilder {
} }
} }
impl Default for RoomBuilder {
fn default() -> Self {
Self::new()
}
}
impl Into<create_room::Request> for RoomBuilder { impl Into<create_room::Request> for RoomBuilder {
fn into(mut self) -> create_room::Request { fn into(mut self) -> create_room::Request {
self.req.creation_content = Some(self.creation_content); self.req.creation_content = Some(self.creation_content);
@ -269,7 +275,7 @@ impl Into<get_message_events::Request> for MessagesRequestBuilder {
pub struct RegistrationBuilder { pub struct RegistrationBuilder {
password: Option<String>, password: Option<String>,
username: Option<String>, username: Option<String>,
device_id: Option<DeviceId>, device_id: Option<Box<DeviceId>>,
initial_device_display_name: Option<String>, initial_device_display_name: Option<String>,
auth: Option<AuthData>, auth: Option<AuthData>,
kind: Option<RegistrationKind>, kind: Option<RegistrationKind>,
@ -303,7 +309,7 @@ impl RegistrationBuilder {
/// ///
/// If this does not correspond to a known client device, a new device will be created. /// If this does not correspond to a known client device, a new device will be created.
/// The server will auto-generate a device_id if this is not specified. /// The server will auto-generate a device_id if this is not specified.
pub fn device_id<S: Into<String>>(&mut self, device_id: S) -> &mut Self { pub fn device_id<S: Into<Box<str>>>(&mut self, device_id: S) -> &mut Self {
self.device_id = Some(device_id.into()); self.device_id = Some(device_id.into());
self self
} }

View File

@ -21,6 +21,7 @@ async-trait = "0.1.36"
serde = "1.0.114" serde = "1.0.114"
serde_json = "1.0.56" serde_json = "1.0.56"
zeroize = "1.1.0" zeroize = "1.1.0"
tracing = "0.1.16"
matrix-sdk-common-macros = { version = "0.1.0", path = "../matrix_sdk_common_macros" } matrix-sdk-common-macros = { version = "0.1.0", path = "../matrix_sdk_common_macros" }
matrix-sdk-common = { version = "0.1.0", path = "../matrix_sdk_common" } matrix-sdk-common = { version = "0.1.0", path = "../matrix_sdk_common" }
@ -45,4 +46,4 @@ mockito = "0.26.0"
tokio = { version = "0.2.21", features = ["rt-threaded", "macros"] } tokio = { version = "0.2.21", features = ["rt-threaded", "macros"] }
[target.'cfg(target_arch = "wasm32")'.dev-dependencies] [target.'cfg(target_arch = "wasm32")'.dev-dependencies]
wasm-bindgen-test = "0.3.14" wasm-bindgen-test = "0.3.15"

View File

@ -38,8 +38,8 @@ use crate::session::Session;
use crate::state::{AllRooms, ClientState, StateStore}; use crate::state::{AllRooms, ClientState, StateStore};
use crate::EventEmitter; use crate::EventEmitter;
use matrix_sdk_common::events::{ use matrix_sdk_common::events::{
AnyBasicEvent, AnyEphemeralRoomEventStub, AnyMessageEventStub, AnyRoomEventStub, AnyBasicEvent, AnyStrippedStateEvent, AnySyncEphemeralRoomEvent, AnySyncMessageEvent,
AnyStateEventStub, AnyStrippedStateEventStub, EventJson, AnySyncRoomEvent, AnySyncStateEvent, EventJson,
}; };
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
@ -94,8 +94,8 @@ pub struct AdditionalUnsignedData {
/// [synapse-bug]: <https://github.com/matrix-org/matrix-doc/issues/684#issuecomment-641182668> /// [synapse-bug]: <https://github.com/matrix-org/matrix-doc/issues/684#issuecomment-641182668>
/// [discussion]: <https://github.com/matrix-org/matrix-doc/issues/684#issuecomment-641182668> /// [discussion]: <https://github.com/matrix-org/matrix-doc/issues/684#issuecomment-641182668>
fn hoist_room_event_prev_content( fn hoist_room_event_prev_content(
event: &EventJson<AnyRoomEventStub>, event: &EventJson<AnySyncRoomEvent>,
) -> Option<EventJson<AnyRoomEventStub>> { ) -> Option<EventJson<AnySyncRoomEvent>> {
let prev_content = serde_json::from_str::<AdditionalEventData>(event.json().get()) let prev_content = serde_json::from_str::<AdditionalEventData>(event.json().get())
.map(|more_unsigned| more_unsigned.unsigned) .map(|more_unsigned| more_unsigned.unsigned)
.map(|additional| additional.prev_content) .map(|additional| additional.prev_content)
@ -105,7 +105,7 @@ fn hoist_room_event_prev_content(
let mut ev = event.deserialize().ok()?; let mut ev = event.deserialize().ok()?;
match &mut ev { match &mut ev {
AnyRoomEventStub::State(AnyStateEventStub::RoomMember(ref mut member)) AnySyncRoomEvent::State(AnySyncStateEvent::RoomMember(ref mut member))
if member.prev_content.is_none() => if member.prev_content.is_none() =>
{ {
if let Ok(prev) = prev_content.deserialize() { if let Ok(prev) = prev_content.deserialize() {
@ -122,8 +122,8 @@ fn hoist_room_event_prev_content(
/// ///
/// See comment of `hoist_room_event_prev_content`. /// See comment of `hoist_room_event_prev_content`.
fn hoist_state_event_prev_content( fn hoist_state_event_prev_content(
event: &EventJson<AnyStateEventStub>, event: &EventJson<AnySyncStateEvent>,
) -> Option<EventJson<AnyStateEventStub>> { ) -> Option<EventJson<AnySyncStateEvent>> {
let prev_content = serde_json::from_str::<AdditionalEventData>(event.json().get()) let prev_content = serde_json::from_str::<AdditionalEventData>(event.json().get())
.map(|more_unsigned| more_unsigned.unsigned) .map(|more_unsigned| more_unsigned.unsigned)
.map(|additional| additional.prev_content) .map(|additional| additional.prev_content)
@ -132,7 +132,7 @@ fn hoist_state_event_prev_content(
let mut ev = event.deserialize().ok()?; let mut ev = event.deserialize().ok()?;
match &mut ev { match &mut ev {
AnyStateEventStub::RoomMember(ref mut member) if member.prev_content.is_none() => { AnySyncStateEvent::RoomMember(ref mut member) if member.prev_content.is_none() => {
member.prev_content = Some(prev_content.deserialize().ok()?); member.prev_content = Some(prev_content.deserialize().ok()?);
Some(EventJson::from(ev)) Some(EventJson::from(ev))
} }
@ -141,7 +141,7 @@ fn hoist_state_event_prev_content(
} }
fn stripped_deserialize_prev_content( fn stripped_deserialize_prev_content(
event: &EventJson<AnyStrippedStateEventStub>, event: &EventJson<AnyStrippedStateEvent>,
) -> Option<AdditionalUnsignedData> { ) -> Option<AdditionalUnsignedData> {
serde_json::from_str::<AdditionalEventData>(event.json().get()) serde_json::from_str::<AdditionalEventData>(event.json().get())
.map(|more_unsigned| more_unsigned.unsigned) .map(|more_unsigned| more_unsigned.unsigned)
@ -488,7 +488,7 @@ impl BaseClient {
*olm = Some( *olm = Some(
OlmMachine::new_with_store( OlmMachine::new_with_store(
session.user_id.to_owned(), session.user_id.to_owned(),
session.device_id.to_owned(), session.device_id.as_str().into(),
store, store,
) )
.await .await
@ -713,14 +713,14 @@ impl BaseClient {
pub async fn receive_joined_timeline_event( pub async fn receive_joined_timeline_event(
&self, &self,
room_id: &RoomId, room_id: &RoomId,
event: &mut EventJson<AnyRoomEventStub>, event: &mut EventJson<AnySyncRoomEvent>,
) -> Result<bool> { ) -> Result<bool> {
match event.deserialize() { match event.deserialize() {
#[allow(unused_mut)] #[allow(unused_mut)]
Ok(mut e) => { Ok(mut e) => {
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
{ {
if let AnyRoomEventStub::Message(AnyMessageEventStub::RoomEncrypted( if let AnySyncRoomEvent::Message(AnySyncMessageEvent::RoomEncrypted(
ref mut encrypted_event, ref mut encrypted_event,
)) = e )) = e
{ {
@ -742,8 +742,8 @@ impl BaseClient {
let room_lock = self.get_or_create_joined_room(&room_id).await?; let room_lock = self.get_or_create_joined_room(&room_id).await?;
let mut room = room_lock.write().await; let mut room = room_lock.write().await;
if let AnyRoomEventStub::State(AnyStateEventStub::RoomMember(mem_event)) = &mut e { if let AnySyncRoomEvent::State(AnySyncStateEvent::RoomMember(mem_event)) = &mut e {
let changed = room.handle_membership(mem_event); let (changed, _) = room.handle_membership(mem_event, false);
// The memberlist of the room changed, invalidate the group session // The memberlist of the room changed, invalidate the group session
// of the room. // of the room.
@ -774,13 +774,13 @@ impl BaseClient {
pub async fn receive_joined_state_event( pub async fn receive_joined_state_event(
&self, &self,
room_id: &RoomId, room_id: &RoomId,
event: &AnyStateEventStub, event: &AnySyncStateEvent,
) -> Result<bool> { ) -> Result<bool> {
let room_lock = self.get_or_create_joined_room(room_id).await?; let room_lock = self.get_or_create_joined_room(room_id).await?;
let mut room = room_lock.write().await; let mut room = room_lock.write().await;
if let AnyStateEventStub::RoomMember(e) = event { if let AnySyncStateEvent::RoomMember(e) = event {
let changed = room.handle_membership(e); let (changed, _) = room.handle_membership(e, true);
// The memberlist of the room changed, invalidate the group session // The memberlist of the room changed, invalidate the group session
// of the room. // of the room.
@ -808,7 +808,7 @@ impl BaseClient {
pub async fn receive_invite_state_event( pub async fn receive_invite_state_event(
&self, &self,
room_id: &RoomId, room_id: &RoomId,
event: &AnyStrippedStateEventStub, event: &AnyStrippedStateEvent,
) -> Result<bool> { ) -> Result<bool> {
let room_lock = self.get_or_create_invited_room(room_id).await?; let room_lock = self.get_or_create_invited_room(room_id).await?;
let mut room = room_lock.write().await; let mut room = room_lock.write().await;
@ -828,7 +828,7 @@ impl BaseClient {
pub async fn receive_left_timeline_event( pub async fn receive_left_timeline_event(
&self, &self,
room_id: &RoomId, room_id: &RoomId,
event: &EventJson<AnyRoomEventStub>, event: &EventJson<AnySyncRoomEvent>,
) -> Result<bool> { ) -> Result<bool> {
match event.deserialize() { match event.deserialize() {
Ok(e) => { Ok(e) => {
@ -853,7 +853,7 @@ impl BaseClient {
pub async fn receive_left_state_event( pub async fn receive_left_state_event(
&self, &self,
room_id: &RoomId, room_id: &RoomId,
event: &AnyStateEventStub, event: &AnySyncStateEvent,
) -> Result<bool> { ) -> Result<bool> {
let room_lock = self.get_or_create_left_room(room_id).await?; let room_lock = self.get_or_create_left_room(room_id).await?;
let mut room = room_lock.write().await; let mut room = room_lock.write().await;
@ -906,11 +906,11 @@ impl BaseClient {
/// * `room_id` - The unique id of the room the event belongs to. /// * `room_id` - The unique id of the room the event belongs to.
/// ///
/// * `event` - The presence event for a specified room member. /// * `event` - The presence event for a specified room member.
pub async fn receive_ephemeral_event(&self, event: &AnyEphemeralRoomEventStub) -> bool { pub async fn receive_ephemeral_event(&self, event: &AnySyncEphemeralRoomEvent) -> bool {
match event { match event {
AnyEphemeralRoomEventStub::FullyRead(_) => {} AnySyncEphemeralRoomEvent::FullyRead(_) => {}
AnyEphemeralRoomEventStub::Receipt(_) => {} AnySyncEphemeralRoomEvent::Receipt(_) => {}
AnyEphemeralRoomEventStub::Typing(_) => {} AnySyncEphemeralRoomEvent::Typing(_) => {}
_ => {} _ => {}
}; };
false false
@ -1197,7 +1197,7 @@ impl BaseClient {
if let Ok(mut e) = event.deserialize() { if let Ok(mut e) = event.deserialize() {
// if the event is a m.room.member event the server will sometimes // if the event is a m.room.member event the server will sometimes
// send the `prev_content` field as part of the unsigned field. // send the `prev_content` field as part of the unsigned field.
if let AnyStrippedStateEventStub::RoomMember(_) = &mut e { if let AnyStrippedStateEvent::RoomMember(_) = &mut e {
if let Some(raw_content) = stripped_deserialize_prev_content(event) { if let Some(raw_content) = stripped_deserialize_prev_content(event) {
let prev_content = match raw_content.prev_content { let prev_content = match raw_content.prev_content {
Some(json) => json.deserialize().ok(), Some(json) => json.deserialize().ok(),
@ -1280,7 +1280,7 @@ impl BaseClient {
pub async fn get_missing_sessions( pub async fn get_missing_sessions(
&self, &self,
users: impl Iterator<Item = &UserId>, users: impl Iterator<Item = &UserId>,
) -> Result<BTreeMap<UserId, BTreeMap<DeviceId, KeyAlgorithm>>> { ) -> Result<BTreeMap<UserId, BTreeMap<Box<DeviceId>, KeyAlgorithm>>> {
let mut olm = self.olm.lock().await; let mut olm = self.olm.lock().await;
match &mut *olm { match &mut *olm {
@ -1437,7 +1437,7 @@ impl BaseClient {
pub(crate) async fn emit_timeline_event( pub(crate) async fn emit_timeline_event(
&self, &self,
room_id: &RoomId, room_id: &RoomId,
event: &AnyRoomEventStub, event: &AnySyncRoomEvent,
room_state: RoomStateType, room_state: RoomStateType,
) { ) {
let lock = self.event_emitter.read().await; let lock = self.event_emitter.read().await;
@ -1472,52 +1472,54 @@ impl BaseClient {
}; };
match event { match event {
AnyRoomEventStub::State(event) => match event { AnySyncRoomEvent::State(event) => match event {
AnyStateEventStub::RoomMember(e) => event_emitter.on_room_member(room, e).await, AnySyncStateEvent::RoomMember(e) => event_emitter.on_room_member(room, e).await,
AnyStateEventStub::RoomName(e) => event_emitter.on_room_name(room, e).await, AnySyncStateEvent::RoomName(e) => event_emitter.on_room_name(room, e).await,
AnyStateEventStub::RoomCanonicalAlias(e) => { AnySyncStateEvent::RoomCanonicalAlias(e) => {
event_emitter.on_room_canonical_alias(room, e).await event_emitter.on_room_canonical_alias(room, e).await
} }
AnyStateEventStub::RoomAliases(e) => event_emitter.on_room_aliases(room, e).await, AnySyncStateEvent::RoomAliases(e) => event_emitter.on_room_aliases(room, e).await,
AnyStateEventStub::RoomAvatar(e) => event_emitter.on_room_avatar(room, e).await, AnySyncStateEvent::RoomAvatar(e) => event_emitter.on_room_avatar(room, e).await,
AnyStateEventStub::RoomPowerLevels(e) => { AnySyncStateEvent::RoomPowerLevels(e) => {
event_emitter.on_room_power_levels(room, e).await event_emitter.on_room_power_levels(room, e).await
} }
AnyStateEventStub::RoomTombstone(e) => { AnySyncStateEvent::RoomTombstone(e) => {
event_emitter.on_room_tombstone(room, e).await event_emitter.on_room_tombstone(room, e).await
} }
AnyStateEventStub::RoomJoinRules(e) => { AnySyncStateEvent::RoomJoinRules(e) => {
event_emitter.on_room_join_rules(room, e).await event_emitter.on_room_join_rules(room, e).await
} }
AnyStateEventStub::Custom(e) => { AnySyncStateEvent::Custom(e) => {
event_emitter event_emitter
.on_unrecognized_event(room, &CustomOrRawEvent::State(e)) .on_unrecognized_event(room, &CustomOrRawEvent::State(e))
.await .await
} }
_ => {} _ => {}
}, },
AnyRoomEventStub::Message(event) => match event { AnySyncRoomEvent::Message(event) => match event {
AnyMessageEventStub::RoomMessage(e) => event_emitter.on_room_message(room, e).await, AnySyncMessageEvent::RoomMessage(e) => event_emitter.on_room_message(room, e).await,
AnyMessageEventStub::RoomMessageFeedback(e) => { AnySyncMessageEvent::RoomMessageFeedback(e) => {
event_emitter.on_room_message_feedback(room, e).await event_emitter.on_room_message_feedback(room, e).await
} }
AnyMessageEventStub::RoomRedaction(e) => { AnySyncMessageEvent::RoomRedaction(e) => {
event_emitter.on_room_redaction(room, e).await event_emitter.on_room_redaction(room, e).await
} }
AnyMessageEventStub::Custom(e) => { AnySyncMessageEvent::Custom(e) => {
event_emitter event_emitter
.on_unrecognized_event(room, &CustomOrRawEvent::Message(e)) .on_unrecognized_event(room, &CustomOrRawEvent::Message(e))
.await .await
} }
_ => {} _ => {}
}, },
AnySyncRoomEvent::RedactedState(_event) => {}
AnySyncRoomEvent::RedactedMessage(_event) => {}
} }
} }
pub(crate) async fn emit_state_event( pub(crate) async fn emit_state_event(
&self, &self,
room_id: &RoomId, room_id: &RoomId,
event: &AnyStateEventStub, event: &AnySyncStateEvent,
room_state: RoomStateType, room_state: RoomStateType,
) { ) {
let lock = self.event_emitter.read().await; let lock = self.event_emitter.read().await;
@ -1552,32 +1554,32 @@ impl BaseClient {
}; };
match event { match event {
AnyStateEventStub::RoomMember(member) => { AnySyncStateEvent::RoomMember(member) => {
event_emitter.on_state_member(room, &member).await event_emitter.on_state_member(room, &member).await
} }
AnyStateEventStub::RoomName(name) => event_emitter.on_state_name(room, &name).await, AnySyncStateEvent::RoomName(name) => event_emitter.on_state_name(room, &name).await,
AnyStateEventStub::RoomCanonicalAlias(canonical) => { AnySyncStateEvent::RoomCanonicalAlias(canonical) => {
event_emitter event_emitter
.on_state_canonical_alias(room, &canonical) .on_state_canonical_alias(room, &canonical)
.await .await
} }
AnyStateEventStub::RoomAliases(aliases) => { AnySyncStateEvent::RoomAliases(aliases) => {
event_emitter.on_state_aliases(room, &aliases).await event_emitter.on_state_aliases(room, &aliases).await
} }
AnyStateEventStub::RoomAvatar(avatar) => { AnySyncStateEvent::RoomAvatar(avatar) => {
event_emitter.on_state_avatar(room, &avatar).await event_emitter.on_state_avatar(room, &avatar).await
} }
AnyStateEventStub::RoomPowerLevels(power) => { AnySyncStateEvent::RoomPowerLevels(power) => {
event_emitter.on_state_power_levels(room, &power).await event_emitter.on_state_power_levels(room, &power).await
} }
AnyStateEventStub::RoomJoinRules(rules) => { AnySyncStateEvent::RoomJoinRules(rules) => {
event_emitter.on_state_join_rules(room, &rules).await event_emitter.on_state_join_rules(room, &rules).await
} }
AnyStateEventStub::RoomTombstone(tomb) => { AnySyncStateEvent::RoomTombstone(tomb) => {
// TODO make `on_state_tombstone` method // TODO make `on_state_tombstone` method
event_emitter.on_room_tombstone(room, &tomb).await event_emitter.on_room_tombstone(room, &tomb).await
} }
AnyStateEventStub::Custom(custom) => { AnySyncStateEvent::Custom(custom) => {
event_emitter event_emitter
.on_unrecognized_event(room, &CustomOrRawEvent::State(custom)) .on_unrecognized_event(room, &CustomOrRawEvent::State(custom))
.await .await
@ -1589,7 +1591,7 @@ impl BaseClient {
pub(crate) async fn emit_stripped_state_event( pub(crate) async fn emit_stripped_state_event(
&self, &self,
room_id: &RoomId, room_id: &RoomId,
event: &AnyStrippedStateEventStub, event: &AnyStrippedStateEvent,
prev_content: Option<MemberEventContent>, prev_content: Option<MemberEventContent>,
room_state: RoomStateType, room_state: RoomStateType,
) { ) {
@ -1625,33 +1627,33 @@ impl BaseClient {
}; };
match event { match event {
AnyStrippedStateEventStub::RoomMember(member) => { AnyStrippedStateEvent::RoomMember(member) => {
event_emitter event_emitter
.on_stripped_state_member(room, &member, prev_content) .on_stripped_state_member(room, &member, prev_content)
.await .await
} }
AnyStrippedStateEventStub::RoomName(name) => { AnyStrippedStateEvent::RoomName(name) => {
event_emitter.on_stripped_state_name(room, &name).await event_emitter.on_stripped_state_name(room, &name).await
} }
AnyStrippedStateEventStub::RoomCanonicalAlias(canonical) => { AnyStrippedStateEvent::RoomCanonicalAlias(canonical) => {
event_emitter event_emitter
.on_stripped_state_canonical_alias(room, &canonical) .on_stripped_state_canonical_alias(room, &canonical)
.await .await
} }
AnyStrippedStateEventStub::RoomAliases(aliases) => { AnyStrippedStateEvent::RoomAliases(aliases) => {
event_emitter event_emitter
.on_stripped_state_aliases(room, &aliases) .on_stripped_state_aliases(room, &aliases)
.await .await
} }
AnyStrippedStateEventStub::RoomAvatar(avatar) => { AnyStrippedStateEvent::RoomAvatar(avatar) => {
event_emitter.on_stripped_state_avatar(room, &avatar).await event_emitter.on_stripped_state_avatar(room, &avatar).await
} }
AnyStrippedStateEventStub::RoomPowerLevels(power) => { AnyStrippedStateEvent::RoomPowerLevels(power) => {
event_emitter event_emitter
.on_stripped_state_power_levels(room, &power) .on_stripped_state_power_levels(room, &power)
.await .await
} }
AnyStrippedStateEventStub::RoomJoinRules(rules) => { AnyStrippedStateEvent::RoomJoinRules(rules) => {
event_emitter event_emitter
.on_stripped_state_join_rules(room, &rules) .on_stripped_state_join_rules(room, &rules)
.await .await
@ -1716,7 +1718,7 @@ impl BaseClient {
pub(crate) async fn emit_ephemeral_event( pub(crate) async fn emit_ephemeral_event(
&self, &self,
room_id: &RoomId, room_id: &RoomId,
event: &AnyEphemeralRoomEventStub, event: &AnySyncEphemeralRoomEvent,
room_state: RoomStateType, room_state: RoomStateType,
) { ) {
let lock = self.event_emitter.read().await; let lock = self.event_emitter.read().await;
@ -1751,13 +1753,13 @@ impl BaseClient {
}; };
match event { match event {
AnyEphemeralRoomEventStub::FullyRead(full_read) => { AnySyncEphemeralRoomEvent::FullyRead(full_read) => {
event_emitter.on_non_room_fully_read(room, &full_read).await event_emitter.on_non_room_fully_read(room, &full_read).await
} }
AnyEphemeralRoomEventStub::Typing(typing) => { AnySyncEphemeralRoomEvent::Typing(typing) => {
event_emitter.on_non_room_typing(room, &typing).await event_emitter.on_non_room_typing(room, &typing).await
} }
AnyEphemeralRoomEventStub::Receipt(receipt) => { AnySyncEphemeralRoomEvent::Receipt(receipt) => {
event_emitter.on_non_room_receipt(room, &receipt).await event_emitter.on_non_room_receipt(room, &receipt).await
} }
_ => {} _ => {}
@ -1837,17 +1839,19 @@ impl BaseClient {
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use crate::identifiers::{RoomId, UserId}; use crate::identifiers::{RoomId, UserId};
use crate::{BaseClient, BaseClientConfig, Session}; #[cfg(feature = "messages")]
use matrix_sdk_common::events::{AnyRoomEventStub, EventJson}; use crate::{
events::{AnySyncRoomEvent, EventJson},
identifiers::EventId,
BaseClientConfig, JsonStore,
};
use crate::{BaseClient, Session};
use matrix_sdk_common_macros::async_trait; use matrix_sdk_common_macros::async_trait;
use matrix_sdk_test::{async_test, test_json, EventBuilder, EventsJson}; use matrix_sdk_test::{async_test, test_json, EventBuilder, EventsJson};
use serde_json::json; use serde_json::json;
use std::convert::TryFrom; use std::convert::TryFrom;
use tempfile::tempdir; use tempfile::tempdir;
#[cfg(feature = "messages")]
use crate::JsonStore;
#[cfg(target_arch = "wasm32")] #[cfg(target_arch = "wasm32")]
use wasm_bindgen_test::*; use wasm_bindgen_test::*;
@ -2006,7 +2010,7 @@ mod test {
use crate::{EventEmitter, SyncRoom}; use crate::{EventEmitter, SyncRoom};
use matrix_sdk_common::events::{ use matrix_sdk_common::events::{
room::member::{MemberEventContent, MembershipChange}, room::member::{MemberEventContent, MembershipChange},
StateEventStub, SyncStateEvent,
}; };
use matrix_sdk_common::locks::RwLock; use matrix_sdk_common::locks::RwLock;
use std::sync::{ use std::sync::{
@ -2020,7 +2024,7 @@ mod test {
async fn on_room_member( async fn on_room_member(
&self, &self,
room: SyncRoom, room: SyncRoom,
event: &StateEventStub<MemberEventContent>, event: &SyncStateEvent<MemberEventContent>,
) { ) {
if let SyncRoom::Joined(_) = room { if let SyncRoom::Joined(_) = room {
if let MembershipChange::Joined = event.membership_change() { if let MembershipChange::Joined = event.membership_change() {
@ -2157,8 +2161,8 @@ mod test {
let member = room.joined_members.get(&user_id).unwrap(); let member = room.joined_members.get(&user_id).unwrap();
assert_eq!(*member.display_name.as_ref().unwrap(), "changed"); assert_eq!(*member.display_name.as_ref().unwrap(), "changed");
// The second part tests that the event is emitted correctly. If `prev_content` was // The second part tests that the event is emitted correctly. If `prev_content` were
// missing, this bool is reset to false. // missing, this bool would had been flipped.
assert!(passed.load(Ordering::SeqCst)) assert!(passed.load(Ordering::SeqCst))
} }
@ -2393,7 +2397,7 @@ mod test {
"type": "m.room.redaction", "type": "m.room.redaction",
"redacts": "$152037280074GZeOm:localhost" "redacts": "$152037280074GZeOm:localhost"
}); });
let mut event: EventJson<AnyRoomEventStub> = serde_json::from_value(json).unwrap(); let mut event: EventJson<AnySyncRoomEvent> = serde_json::from_value(json).unwrap();
client client
.receive_joined_timeline_event(&room_id, &mut event) .receive_joined_timeline_event(&room_id, &mut event)
.await .await
@ -2402,16 +2406,22 @@ mod test {
// check that the message has actually been redacted // check that the message has actually been redacted
for room in client.joined_rooms().read().await.values() { for room in client.joined_rooms().read().await.values() {
let queue = &room.read().await.messages; let queue = &room.read().await.messages;
if let crate::events::AnyMessageEventContent::RoomRedaction(content) = if let crate::events::AnyPossiblyRedactedSyncMessageEvent::Redacted(
&queue.msgs[0].content crate::events::AnyRedactedSyncMessageEvent::RoomMessage(event),
) = &queue.msgs[0].deref()
{ {
assert_eq!(content.reason, Some("😀".to_string())); // this is the id from the message event in the sync response
assert_eq!(
event.event_id,
EventId::try_from("$152037280074GZeOm:localhost").unwrap()
)
} else { } else {
panic!("[pre store sync] message event in message queue should be redacted") panic!("message event in message queue should be redacted")
} }
} }
// `receive_joined_timeline_event` does not save the state to the store so we must // `receive_joined_timeline_event` does not save the state to the store
// so we must do it ourselves
client.store_room_state(&room_id).await.unwrap(); client.store_room_state(&room_id).await.unwrap();
// we load state from the store only // we load state from the store only
@ -2424,10 +2434,15 @@ mod test {
// properly // properly
for room in client.joined_rooms().read().await.values() { for room in client.joined_rooms().read().await.values() {
let queue = &room.read().await.messages; let queue = &room.read().await.messages;
if let crate::events::AnyMessageEventContent::RoomRedaction(content) = if let crate::events::AnyPossiblyRedactedSyncMessageEvent::Redacted(
&queue.msgs[0].content crate::events::AnyRedactedSyncMessageEvent::RoomMessage(event),
) = &queue.msgs[0].deref()
{ {
assert_eq!(content.reason, Some("😀".to_string())); // this is the id from the message event in the sync response
assert_eq!(
event.event_id,
EventId::try_from("$152037280074GZeOm:localhost").unwrap()
)
} else { } else {
panic!("[post store sync] message event in message queue should be redacted") panic!("[post store sync] message event in message queue should be redacted")
} }

View File

@ -33,11 +33,11 @@ use crate::events::{
message::{feedback::FeedbackEventContent, MessageEventContent as MsgEventContent}, message::{feedback::FeedbackEventContent, MessageEventContent as MsgEventContent},
name::NameEventContent, name::NameEventContent,
power_levels::PowerLevelsEventContent, power_levels::PowerLevelsEventContent,
redaction::RedactionEventStub, redaction::SyncRedactionEvent,
tombstone::TombstoneEventContent, tombstone::TombstoneEventContent,
}, },
typing::TypingEventContent, typing::TypingEventContent,
BasicEvent, EphemeralRoomEvent, MessageEventStub, StateEventStub, StrippedStateEventStub, BasicEvent, EphemeralRoomEvent, StrippedStateEvent, SyncMessageEvent, SyncStateEvent,
}; };
use crate::{Room, RoomState}; use crate::{Room, RoomState};
use matrix_sdk_common_macros::async_trait; use matrix_sdk_common_macros::async_trait;
@ -55,11 +55,11 @@ pub enum CustomOrRawEvent<'c> {
/// A custom basic event. /// A custom basic event.
EphemeralRoom(&'c EphemeralRoomEvent<CustomEventContent>), EphemeralRoom(&'c EphemeralRoomEvent<CustomEventContent>),
/// A custom room event. /// A custom room event.
Message(&'c MessageEventStub<CustomEventContent>), Message(&'c SyncMessageEvent<CustomEventContent>),
/// A custom state event. /// A custom state event.
State(&'c StateEventStub<CustomEventContent>), State(&'c SyncStateEvent<CustomEventContent>),
/// A custom stripped state event. /// A custom stripped state event.
StrippedState(&'c StrippedStateEventStub<CustomEventContent>), StrippedState(&'c StrippedStateEvent<CustomEventContent>),
} }
/// This trait allows any type implementing `EventEmitter` to specify event callbacks for each event. /// This trait allows any type implementing `EventEmitter` to specify event callbacks for each event.
@ -74,7 +74,7 @@ pub enum CustomOrRawEvent<'c> {
/// # self, /// # self,
/// # events::{ /// # events::{
/// # room::message::{MessageEventContent, TextMessageEventContent}, /// # room::message::{MessageEventContent, TextMessageEventContent},
/// # MessageEventStub /// # SyncMessageEvent
/// # }, /// # },
/// # EventEmitter, SyncRoom /// # EventEmitter, SyncRoom
/// # }; /// # };
@ -85,9 +85,9 @@ pub enum CustomOrRawEvent<'c> {
/// ///
/// #[async_trait] /// #[async_trait]
/// impl EventEmitter for EventCallback { /// impl EventEmitter for EventCallback {
/// async fn on_room_message(&self, room: SyncRoom, event: &MessageEventStub<MessageEventContent>) { /// async fn on_room_message(&self, room: SyncRoom, event: &SyncMessageEvent<MessageEventContent>) {
/// if let SyncRoom::Joined(room) = room { /// if let SyncRoom::Joined(room) = room {
/// if let MessageEventStub { /// if let SyncMessageEvent {
/// content: MessageEventContent::Text(TextMessageEventContent { body: msg_body, .. }), /// content: MessageEventContent::Text(TextMessageEventContent { body: msg_body, .. }),
/// sender, /// sender,
/// .. /// ..
@ -112,114 +112,109 @@ pub enum CustomOrRawEvent<'c> {
pub trait EventEmitter: Send + Sync { pub trait EventEmitter: Send + Sync {
// ROOM EVENTS from `IncomingTimeline` // ROOM EVENTS from `IncomingTimeline`
/// Fires when `Client` receives a `RoomEvent::RoomMember` event. /// Fires when `Client` receives a `RoomEvent::RoomMember` event.
async fn on_room_member(&self, _: SyncRoom, _: &StateEventStub<MemberEventContent>) {} async fn on_room_member(&self, _: SyncRoom, _: &SyncStateEvent<MemberEventContent>) {}
/// Fires when `Client` receives a `RoomEvent::RoomName` event. /// Fires when `Client` receives a `RoomEvent::RoomName` event.
async fn on_room_name(&self, _: SyncRoom, _: &StateEventStub<NameEventContent>) {} async fn on_room_name(&self, _: SyncRoom, _: &SyncStateEvent<NameEventContent>) {}
/// Fires when `Client` receives a `RoomEvent::RoomCanonicalAlias` event. /// Fires when `Client` receives a `RoomEvent::RoomCanonicalAlias` event.
async fn on_room_canonical_alias( async fn on_room_canonical_alias(
&self, &self,
_: SyncRoom, _: SyncRoom,
_: &StateEventStub<CanonicalAliasEventContent>, _: &SyncStateEvent<CanonicalAliasEventContent>,
) { ) {
} }
/// Fires when `Client` receives a `RoomEvent::RoomAliases` event. /// Fires when `Client` receives a `RoomEvent::RoomAliases` event.
async fn on_room_aliases(&self, _: SyncRoom, _: &StateEventStub<AliasesEventContent>) {} async fn on_room_aliases(&self, _: SyncRoom, _: &SyncStateEvent<AliasesEventContent>) {}
/// Fires when `Client` receives a `RoomEvent::RoomAvatar` event. /// Fires when `Client` receives a `RoomEvent::RoomAvatar` event.
async fn on_room_avatar(&self, _: SyncRoom, _: &StateEventStub<AvatarEventContent>) {} async fn on_room_avatar(&self, _: SyncRoom, _: &SyncStateEvent<AvatarEventContent>) {}
/// Fires when `Client` receives a `RoomEvent::RoomMessage` event. /// Fires when `Client` receives a `RoomEvent::RoomMessage` event.
async fn on_room_message(&self, _: SyncRoom, _: &MessageEventStub<MsgEventContent>) {} async fn on_room_message(&self, _: SyncRoom, _: &SyncMessageEvent<MsgEventContent>) {}
/// Fires when `Client` receives a `RoomEvent::RoomMessageFeedback` event. /// Fires when `Client` receives a `RoomEvent::RoomMessageFeedback` event.
async fn on_room_message_feedback( async fn on_room_message_feedback(
&self, &self,
_: SyncRoom, _: SyncRoom,
_: &MessageEventStub<FeedbackEventContent>, _: &SyncMessageEvent<FeedbackEventContent>,
) { ) {
} }
/// Fires when `Client` receives a `RoomEvent::RoomRedaction` event. /// Fires when `Client` receives a `RoomEvent::RoomRedaction` event.
async fn on_room_redaction(&self, _: SyncRoom, _: &RedactionEventStub) {} async fn on_room_redaction(&self, _: SyncRoom, _: &SyncRedactionEvent) {}
/// Fires when `Client` receives a `RoomEvent::RoomPowerLevels` event. /// Fires when `Client` receives a `RoomEvent::RoomPowerLevels` event.
async fn on_room_power_levels(&self, _: SyncRoom, _: &StateEventStub<PowerLevelsEventContent>) { async fn on_room_power_levels(&self, _: SyncRoom, _: &SyncStateEvent<PowerLevelsEventContent>) {
} }
/// Fires when `Client` receives a `RoomEvent::Tombstone` event. /// Fires when `Client` receives a `RoomEvent::Tombstone` event.
async fn on_room_join_rules(&self, _: SyncRoom, _: &StateEventStub<JoinRulesEventContent>) {} async fn on_room_join_rules(&self, _: SyncRoom, _: &SyncStateEvent<JoinRulesEventContent>) {}
/// Fires when `Client` receives a `RoomEvent::Tombstone` event. /// Fires when `Client` receives a `RoomEvent::Tombstone` event.
async fn on_room_tombstone(&self, _: SyncRoom, _: &StateEventStub<TombstoneEventContent>) {} async fn on_room_tombstone(&self, _: SyncRoom, _: &SyncStateEvent<TombstoneEventContent>) {}
// `RoomEvent`s from `IncomingState` // `RoomEvent`s from `IncomingState`
/// Fires when `Client` receives a `StateEvent::RoomMember` event. /// Fires when `Client` receives a `StateEvent::RoomMember` event.
async fn on_state_member(&self, _: SyncRoom, _: &StateEventStub<MemberEventContent>) {} async fn on_state_member(&self, _: SyncRoom, _: &SyncStateEvent<MemberEventContent>) {}
/// Fires when `Client` receives a `StateEvent::RoomName` event. /// Fires when `Client` receives a `StateEvent::RoomName` event.
async fn on_state_name(&self, _: SyncRoom, _: &StateEventStub<NameEventContent>) {} async fn on_state_name(&self, _: SyncRoom, _: &SyncStateEvent<NameEventContent>) {}
/// Fires when `Client` receives a `StateEvent::RoomCanonicalAlias` event. /// Fires when `Client` receives a `StateEvent::RoomCanonicalAlias` event.
async fn on_state_canonical_alias( async fn on_state_canonical_alias(
&self, &self,
_: SyncRoom, _: SyncRoom,
_: &StateEventStub<CanonicalAliasEventContent>, _: &SyncStateEvent<CanonicalAliasEventContent>,
) { ) {
} }
/// Fires when `Client` receives a `StateEvent::RoomAliases` event. /// Fires when `Client` receives a `StateEvent::RoomAliases` event.
async fn on_state_aliases(&self, _: SyncRoom, _: &StateEventStub<AliasesEventContent>) {} async fn on_state_aliases(&self, _: SyncRoom, _: &SyncStateEvent<AliasesEventContent>) {}
/// Fires when `Client` receives a `StateEvent::RoomAvatar` event. /// Fires when `Client` receives a `StateEvent::RoomAvatar` event.
async fn on_state_avatar(&self, _: SyncRoom, _: &StateEventStub<AvatarEventContent>) {} async fn on_state_avatar(&self, _: SyncRoom, _: &SyncStateEvent<AvatarEventContent>) {}
/// Fires when `Client` receives a `StateEvent::RoomPowerLevels` event. /// Fires when `Client` receives a `StateEvent::RoomPowerLevels` event.
async fn on_state_power_levels( async fn on_state_power_levels(
&self, &self,
_: SyncRoom, _: SyncRoom,
_: &StateEventStub<PowerLevelsEventContent>, _: &SyncStateEvent<PowerLevelsEventContent>,
) { ) {
} }
/// Fires when `Client` receives a `StateEvent::RoomJoinRules` event. /// Fires when `Client` receives a `StateEvent::RoomJoinRules` event.
async fn on_state_join_rules(&self, _: SyncRoom, _: &StateEventStub<JoinRulesEventContent>) {} async fn on_state_join_rules(&self, _: SyncRoom, _: &SyncStateEvent<JoinRulesEventContent>) {}
// `AnyStrippedStateEvent`s // `AnyStrippedStateEvent`s
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomMember` event. /// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomMember` event.
async fn on_stripped_state_member( async fn on_stripped_state_member(
&self, &self,
_: SyncRoom, _: SyncRoom,
_: &StrippedStateEventStub<MemberEventContent>, _: &StrippedStateEvent<MemberEventContent>,
_: Option<MemberEventContent>, _: Option<MemberEventContent>,
) { ) {
} }
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomName` event. /// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomName` event.
async fn on_stripped_state_name( async fn on_stripped_state_name(&self, _: SyncRoom, _: &StrippedStateEvent<NameEventContent>) {}
&self,
_: SyncRoom,
_: &StrippedStateEventStub<NameEventContent>,
) {
}
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomCanonicalAlias` event. /// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomCanonicalAlias` event.
async fn on_stripped_state_canonical_alias( async fn on_stripped_state_canonical_alias(
&self, &self,
_: SyncRoom, _: SyncRoom,
_: &StrippedStateEventStub<CanonicalAliasEventContent>, _: &StrippedStateEvent<CanonicalAliasEventContent>,
) { ) {
} }
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomAliases` event. /// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomAliases` event.
async fn on_stripped_state_aliases( async fn on_stripped_state_aliases(
&self, &self,
_: SyncRoom, _: SyncRoom,
_: &StrippedStateEventStub<AliasesEventContent>, _: &StrippedStateEvent<AliasesEventContent>,
) { ) {
} }
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomAvatar` event. /// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomAvatar` event.
async fn on_stripped_state_avatar( async fn on_stripped_state_avatar(
&self, &self,
_: SyncRoom, _: SyncRoom,
_: &StrippedStateEventStub<AvatarEventContent>, _: &StrippedStateEvent<AvatarEventContent>,
) { ) {
} }
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomPowerLevels` event. /// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomPowerLevels` event.
async fn on_stripped_state_power_levels( async fn on_stripped_state_power_levels(
&self, &self,
_: SyncRoom, _: SyncRoom,
_: &StrippedStateEventStub<PowerLevelsEventContent>, _: &StrippedStateEvent<PowerLevelsEventContent>,
) { ) {
} }
/// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomJoinRules` event. /// Fires when `Client` receives a `AnyStrippedStateEvent::StrippedRoomJoinRules` event.
async fn on_stripped_state_join_rules( async fn on_stripped_state_join_rules(
&self, &self,
_: SyncRoom, _: SyncRoom,
_: &StrippedStateEventStub<JoinRulesEventContent>, _: &StrippedStateEvent<JoinRulesEventContent>,
) { ) {
} }
@ -276,79 +271,79 @@ mod test {
#[async_trait] #[async_trait]
impl EventEmitter for EvEmitterTest { impl EventEmitter for EvEmitterTest {
async fn on_room_member(&self, _: SyncRoom, _: &StateEventStub<MemberEventContent>) { async fn on_room_member(&self, _: SyncRoom, _: &SyncStateEvent<MemberEventContent>) {
self.0.lock().await.push("member".to_string()) self.0.lock().await.push("member".to_string())
} }
async fn on_room_name(&self, _: SyncRoom, _: &StateEventStub<NameEventContent>) { async fn on_room_name(&self, _: SyncRoom, _: &SyncStateEvent<NameEventContent>) {
self.0.lock().await.push("name".to_string()) self.0.lock().await.push("name".to_string())
} }
async fn on_room_canonical_alias( async fn on_room_canonical_alias(
&self, &self,
_: SyncRoom, _: SyncRoom,
_: &StateEventStub<CanonicalAliasEventContent>, _: &SyncStateEvent<CanonicalAliasEventContent>,
) { ) {
self.0.lock().await.push("canonical".to_string()) self.0.lock().await.push("canonical".to_string())
} }
async fn on_room_aliases(&self, _: SyncRoom, _: &StateEventStub<AliasesEventContent>) { async fn on_room_aliases(&self, _: SyncRoom, _: &SyncStateEvent<AliasesEventContent>) {
self.0.lock().await.push("aliases".to_string()) self.0.lock().await.push("aliases".to_string())
} }
async fn on_room_avatar(&self, _: SyncRoom, _: &StateEventStub<AvatarEventContent>) { async fn on_room_avatar(&self, _: SyncRoom, _: &SyncStateEvent<AvatarEventContent>) {
self.0.lock().await.push("avatar".to_string()) self.0.lock().await.push("avatar".to_string())
} }
async fn on_room_message(&self, _: SyncRoom, _: &MessageEventStub<MsgEventContent>) { async fn on_room_message(&self, _: SyncRoom, _: &SyncMessageEvent<MsgEventContent>) {
self.0.lock().await.push("message".to_string()) self.0.lock().await.push("message".to_string())
} }
async fn on_room_message_feedback( async fn on_room_message_feedback(
&self, &self,
_: SyncRoom, _: SyncRoom,
_: &MessageEventStub<FeedbackEventContent>, _: &SyncMessageEvent<FeedbackEventContent>,
) { ) {
self.0.lock().await.push("feedback".to_string()) self.0.lock().await.push("feedback".to_string())
} }
async fn on_room_redaction(&self, _: SyncRoom, _: &RedactionEventStub) { async fn on_room_redaction(&self, _: SyncRoom, _: &SyncRedactionEvent) {
self.0.lock().await.push("redaction".to_string()) self.0.lock().await.push("redaction".to_string())
} }
async fn on_room_power_levels( async fn on_room_power_levels(
&self, &self,
_: SyncRoom, _: SyncRoom,
_: &StateEventStub<PowerLevelsEventContent>, _: &SyncStateEvent<PowerLevelsEventContent>,
) { ) {
self.0.lock().await.push("power".to_string()) self.0.lock().await.push("power".to_string())
} }
async fn on_room_tombstone(&self, _: SyncRoom, _: &StateEventStub<TombstoneEventContent>) { async fn on_room_tombstone(&self, _: SyncRoom, _: &SyncStateEvent<TombstoneEventContent>) {
self.0.lock().await.push("tombstone".to_string()) self.0.lock().await.push("tombstone".to_string())
} }
async fn on_state_member(&self, _: SyncRoom, _: &StateEventStub<MemberEventContent>) { async fn on_state_member(&self, _: SyncRoom, _: &SyncStateEvent<MemberEventContent>) {
self.0.lock().await.push("state member".to_string()) self.0.lock().await.push("state member".to_string())
} }
async fn on_state_name(&self, _: SyncRoom, _: &StateEventStub<NameEventContent>) { async fn on_state_name(&self, _: SyncRoom, _: &SyncStateEvent<NameEventContent>) {
self.0.lock().await.push("state name".to_string()) self.0.lock().await.push("state name".to_string())
} }
async fn on_state_canonical_alias( async fn on_state_canonical_alias(
&self, &self,
_: SyncRoom, _: SyncRoom,
_: &StateEventStub<CanonicalAliasEventContent>, _: &SyncStateEvent<CanonicalAliasEventContent>,
) { ) {
self.0.lock().await.push("state canonical".to_string()) self.0.lock().await.push("state canonical".to_string())
} }
async fn on_state_aliases(&self, _: SyncRoom, _: &StateEventStub<AliasesEventContent>) { async fn on_state_aliases(&self, _: SyncRoom, _: &SyncStateEvent<AliasesEventContent>) {
self.0.lock().await.push("state aliases".to_string()) self.0.lock().await.push("state aliases".to_string())
} }
async fn on_state_avatar(&self, _: SyncRoom, _: &StateEventStub<AvatarEventContent>) { async fn on_state_avatar(&self, _: SyncRoom, _: &SyncStateEvent<AvatarEventContent>) {
self.0.lock().await.push("state avatar".to_string()) self.0.lock().await.push("state avatar".to_string())
} }
async fn on_state_power_levels( async fn on_state_power_levels(
&self, &self,
_: SyncRoom, _: SyncRoom,
_: &StateEventStub<PowerLevelsEventContent>, _: &SyncStateEvent<PowerLevelsEventContent>,
) { ) {
self.0.lock().await.push("state power".to_string()) self.0.lock().await.push("state power".to_string())
} }
async fn on_state_join_rules( async fn on_state_join_rules(
&self, &self,
_: SyncRoom, _: SyncRoom,
_: &StateEventStub<JoinRulesEventContent>, _: &SyncStateEvent<JoinRulesEventContent>,
) { ) {
self.0.lock().await.push("state rules".to_string()) self.0.lock().await.push("state rules".to_string())
} }
@ -358,7 +353,7 @@ mod test {
async fn on_stripped_state_member( async fn on_stripped_state_member(
&self, &self,
_: SyncRoom, _: SyncRoom,
_: &StrippedStateEventStub<MemberEventContent>, _: &StrippedStateEvent<MemberEventContent>,
_: Option<MemberEventContent>, _: Option<MemberEventContent>,
) { ) {
self.0 self.0
@ -370,7 +365,7 @@ mod test {
async fn on_stripped_state_name( async fn on_stripped_state_name(
&self, &self,
_: SyncRoom, _: SyncRoom,
_: &StrippedStateEventStub<NameEventContent>, _: &StrippedStateEvent<NameEventContent>,
) { ) {
self.0.lock().await.push("stripped state name".to_string()) self.0.lock().await.push("stripped state name".to_string())
} }
@ -378,7 +373,7 @@ mod test {
async fn on_stripped_state_canonical_alias( async fn on_stripped_state_canonical_alias(
&self, &self,
_: SyncRoom, _: SyncRoom,
_: &StrippedStateEventStub<CanonicalAliasEventContent>, _: &StrippedStateEvent<CanonicalAliasEventContent>,
) { ) {
self.0 self.0
.lock() .lock()
@ -389,7 +384,7 @@ mod test {
async fn on_stripped_state_aliases( async fn on_stripped_state_aliases(
&self, &self,
_: SyncRoom, _: SyncRoom,
_: &StrippedStateEventStub<AliasesEventContent>, _: &StrippedStateEvent<AliasesEventContent>,
) { ) {
self.0 self.0
.lock() .lock()
@ -400,7 +395,7 @@ mod test {
async fn on_stripped_state_avatar( async fn on_stripped_state_avatar(
&self, &self,
_: SyncRoom, _: SyncRoom,
_: &StrippedStateEventStub<AvatarEventContent>, _: &StrippedStateEvent<AvatarEventContent>,
) { ) {
self.0 self.0
.lock() .lock()
@ -411,7 +406,7 @@ mod test {
async fn on_stripped_state_power_levels( async fn on_stripped_state_power_levels(
&self, &self,
_: SyncRoom, _: SyncRoom,
_: &StrippedStateEventStub<PowerLevelsEventContent>, _: &StrippedStateEvent<PowerLevelsEventContent>,
) { ) {
self.0.lock().await.push("stripped state power".to_string()) self.0.lock().await.push("stripped state power".to_string())
} }
@ -419,7 +414,7 @@ mod test {
async fn on_stripped_state_join_rules( async fn on_stripped_state_join_rules(
&self, &self,
_: SyncRoom, _: SyncRoom,
_: &StrippedStateEventStub<JoinRulesEventContent>, _: &StrippedStateEvent<JoinRulesEventContent>,
) { ) {
self.0.lock().await.push("stripped state rules".to_string()) self.0.lock().await.push("stripped state rules".to_string())
} }
@ -581,7 +576,7 @@ mod test {
"unrecognized event", "unrecognized event",
"redaction", "redaction",
"unrecognized event", "unrecognized event",
"unrecognized event", // "unrecognized event", this is actually a redacted "m.room.messages" event
"receipt event", "receipt event",
"typing event" "typing event"
], ],

View File

@ -47,11 +47,16 @@ mod state;
pub use client::{BaseClient, BaseClientConfig, RoomState, RoomStateType}; pub use client::{BaseClient, BaseClientConfig, RoomState, RoomStateType};
pub use event_emitter::{CustomOrRawEvent, EventEmitter, SyncRoom}; pub use event_emitter::{CustomOrRawEvent, EventEmitter, SyncRoom};
pub use models::Room;
pub use state::{AllRooms, ClientState};
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
pub use matrix_sdk_crypto::{Device, TrustState}; pub use matrix_sdk_crypto::{Device, TrustState};
pub use models::Room;
pub use state::AllRooms; #[cfg(feature = "messages")]
pub use state::ClientState; #[cfg_attr(docsrs, doc(cfg(feature = "messages")))]
pub use models::{MessageQueue, MessageWrapper, PossiblyRedactedExt};
#[cfg(not(target_arch = "wasm32"))] #[cfg(not(target_arch = "wasm32"))]
pub use state::JsonStore; pub use state::JsonStore;
pub use state::StateStore; pub use state::StateStore;

View File

@ -3,17 +3,50 @@
//! The `Room` struct optionally holds a `MessageQueue` if the "messages" //! The `Room` struct optionally holds a `MessageQueue` if the "messages"
//! feature is enabled. //! feature is enabled.
use std::cmp::Ordering; use std::{
use std::ops::{Deref, DerefMut}; cmp::Ordering,
use std::vec::IntoIter; ops::{Deref, DerefMut},
time::SystemTime,
use crate::events::{AnyMessageEventContent, AnyMessageEventStub, MessageEventStub}; vec::IntoIter,
};
use matrix_sdk_common::identifiers::EventId;
use serde::{de, ser, Serialize}; use serde::{de, ser, Serialize};
use crate::events::AnyPossiblyRedactedSyncMessageEvent;
/// Exposes some of the field access methods found in the event held by
/// `AnyPossiblyRedacted*` enums.
///
/// This is just an extension trait to aid the ease of use of certain event enums.
pub trait PossiblyRedactedExt {
/// Access the redacted or full events `event_id` field.
fn event_id(&self) -> &EventId;
/// Access the redacted or full events `origin_server_ts` field.
fn origin_server_ts(&self) -> &SystemTime;
}
impl PossiblyRedactedExt for AnyPossiblyRedactedSyncMessageEvent {
/// Access the underlying events `event_id`.
fn event_id(&self) -> &EventId {
match self {
Self::Regular(e) => e.event_id(),
Self::Redacted(e) => e.event_id(),
}
}
/// Access the underlying events `origin_server_ts`.
fn origin_server_ts(&self) -> &SystemTime {
match self {
Self::Regular(e) => e.origin_server_ts(),
Self::Redacted(e) => e.origin_server_ts(),
}
}
}
const MESSAGE_QUEUE_CAP: usize = 35; const MESSAGE_QUEUE_CAP: usize = 35;
pub type SyncMessageEvent = MessageEventStub<AnyMessageEventContent>; pub type SyncMessageEvent = AnyPossiblyRedactedSyncMessageEvent;
/// A queue that holds the 35 most recent messages received from the server. /// A queue that holds the 35 most recent messages received from the server.
#[derive(Clone, Debug, Default)] #[derive(Clone, Debug, Default)]
@ -29,18 +62,6 @@ pub struct MessageQueue {
#[derive(Clone, Debug, Serialize)] #[derive(Clone, Debug, Serialize)]
pub struct MessageWrapper(pub SyncMessageEvent); pub struct MessageWrapper(pub SyncMessageEvent);
impl MessageWrapper {
pub fn clone_into_any_content(event: &AnyMessageEventStub) -> SyncMessageEvent {
MessageEventStub {
content: event.content(),
sender: event.sender().clone(),
origin_server_ts: *event.origin_server_ts(),
event_id: event.event_id().clone(),
unsigned: event.unsigned().clone(),
}
}
}
impl Deref for MessageWrapper { impl Deref for MessageWrapper {
type Target = SyncMessageEvent; type Target = SyncMessageEvent;
@ -57,7 +78,7 @@ impl DerefMut for MessageWrapper {
impl PartialEq for MessageWrapper { impl PartialEq for MessageWrapper {
fn eq(&self, other: &MessageWrapper) -> bool { fn eq(&self, other: &MessageWrapper) -> bool {
self.0.event_id == other.0.event_id self.0.event_id() == other.0.event_id()
} }
} }
@ -65,7 +86,7 @@ impl Eq for MessageWrapper {}
impl PartialOrd for MessageWrapper { impl PartialOrd for MessageWrapper {
fn partial_cmp(&self, other: &MessageWrapper) -> Option<Ordering> { fn partial_cmp(&self, other: &MessageWrapper) -> Option<Ordering> {
Some(self.0.origin_server_ts.cmp(&other.0.origin_server_ts)) Some(self.0.origin_server_ts().cmp(&other.0.origin_server_ts()))
} }
} }
@ -82,7 +103,7 @@ impl PartialEq for MessageQueue {
.msgs .msgs
.iter() .iter()
.zip(other.msgs.iter()) .zip(other.msgs.iter())
.all(|(msg_a, msg_b)| msg_a.event_id == msg_b.event_id) .all(|(msg_a, msg_b)| msg_a.event_id() == msg_b.event_id())
} }
} }
@ -100,7 +121,7 @@ impl MessageQueue {
pub fn push(&mut self, msg: SyncMessageEvent) -> bool { pub fn push(&mut self, msg: SyncMessageEvent) -> bool {
// only push new messages into the queue // only push new messages into the queue
if let Some(latest) = self.msgs.last() { if let Some(latest) = self.msgs.last() {
if msg.origin_server_ts < latest.origin_server_ts && self.msgs.len() >= 10 { if msg.origin_server_ts() < latest.origin_server_ts() && self.msgs.len() >= 10 {
return false; return false;
} }
} }
@ -120,10 +141,12 @@ impl MessageQueue {
true true
} }
/// Iterate over the messages in the queue.
pub fn iter(&self) -> impl Iterator<Item = &MessageWrapper> { pub fn iter(&self) -> impl Iterator<Item = &MessageWrapper> {
self.msgs.iter() self.msgs.iter()
} }
/// Iterate over each message mutably.
pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut MessageWrapper> { pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut MessageWrapper> {
self.msgs.iter_mut() self.msgs.iter_mut()
} }
@ -183,17 +206,18 @@ pub(crate) mod ser_deser {
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use super::*;
use std::collections::HashMap; use std::collections::HashMap;
use std::convert::TryFrom; use std::convert::TryFrom;
use matrix_sdk_common::{
events::{AnyPossiblyRedactedSyncMessageEvent, AnySyncMessageEvent},
identifiers::{RoomId, UserId},
};
use matrix_sdk_test::test_json;
#[cfg(target_arch = "wasm32")] #[cfg(target_arch = "wasm32")]
use wasm_bindgen_test::*; use wasm_bindgen_test::*;
use matrix_sdk_test::test_json; use super::*;
use crate::identifiers::{RoomId, UserId};
use crate::Room; use crate::Room;
#[test] #[test]
@ -204,7 +228,9 @@ mod test {
let mut room = Room::new(&id, &user); let mut room = Room::new(&id, &user);
let json: &serde_json::Value = &test_json::MESSAGE_TEXT; let json: &serde_json::Value = &test_json::MESSAGE_TEXT;
let msg = serde_json::from_value::<SyncMessageEvent>(json.clone()).unwrap(); let msg = AnyPossiblyRedactedSyncMessageEvent::Regular(
serde_json::from_value::<AnySyncMessageEvent>(json.clone()).unwrap(),
);
let mut msgs = MessageQueue::new(); let mut msgs = MessageQueue::new();
msgs.push(msg.clone()); msgs.push(msg.clone());
@ -216,7 +242,6 @@ mod test {
serde_json::json!({ serde_json::json!({
"!roomid:example.com": { "!roomid:example.com": {
"room_id": "!roomid:example.com", "room_id": "!roomid:example.com",
"disambiguated_display_names": {},
"room_name": { "room_name": {
"name": null, "name": null,
"canonical_alias": null, "canonical_alias": null,
@ -250,7 +275,9 @@ mod test {
let mut room = Room::new(&id, &user); let mut room = Room::new(&id, &user);
let json: &serde_json::Value = &test_json::MESSAGE_TEXT; let json: &serde_json::Value = &test_json::MESSAGE_TEXT;
let msg = serde_json::from_value::<SyncMessageEvent>(json.clone()).unwrap(); let msg = AnyPossiblyRedactedSyncMessageEvent::Regular(
serde_json::from_value::<AnySyncMessageEvent>(json.clone()).unwrap(),
);
let mut msgs = MessageQueue::new(); let mut msgs = MessageQueue::new();
msgs.push(msg.clone()); msgs.push(msg.clone());
@ -262,7 +289,6 @@ mod test {
let json = serde_json::json!({ let json = serde_json::json!({
"!roomid:example.com": { "!roomid:example.com": {
"room_id": "!roomid:example.com", "room_id": "!roomid:example.com",
"disambiguated_display_names": {},
"room_name": { "room_name": {
"name": null, "name": null,
"canonical_alias": null, "canonical_alias": null,

View File

@ -4,5 +4,8 @@ mod message;
mod room; mod room;
mod room_member; mod room_member;
#[cfg(feature = "messages")]
#[cfg_attr(docsrs, doc(cfg(feature = "messages")))]
pub use message::{MessageQueue, MessageWrapper, PossiblyRedactedExt};
pub use room::{Room, RoomName}; pub use room::{Room, RoomName};
pub use room_member::RoomMember; pub use room_member::RoomMember;

File diff suppressed because it is too large Load Diff

View File

@ -15,16 +15,17 @@
use std::convert::TryFrom; use std::convert::TryFrom;
use crate::events::presence::{PresenceEvent, PresenceEventContent, PresenceState}; use matrix_sdk_common::{
use crate::events::room::{ events::{
member::{MemberEventContent, MembershipChange, MembershipState}, presence::{PresenceEvent, PresenceState},
power_levels::PowerLevelsEventContent, room::member::MemberEventContent,
SyncStateEvent,
},
identifiers::{RoomId, UserId},
js_int::{Int, UInt},
}; };
use crate::events::StateEventStub;
use crate::identifiers::{RoomId, UserId};
use crate::js_int::{int, Int, UInt};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
// Notes: if Alice invites Bob into a room we will get an event with the sender as Alice and the state key as Bob. // Notes: if Alice invites Bob into a room we will get an event with the sender as Alice and the state key as Bob.
#[derive(Debug, Serialize, Deserialize, Clone)] #[derive(Debug, Serialize, Deserialize, Clone)]
@ -34,6 +35,9 @@ pub struct RoomMember {
pub user_id: UserId, pub user_id: UserId,
/// The human readable name of the user. /// The human readable name of the user.
pub display_name: Option<String>, pub display_name: Option<String>,
/// Whether the member's display name is ambiguous due to being shared with
/// other members.
pub display_name_ambiguous: bool,
/// The matrix url of the users avatar. /// The matrix url of the users avatar.
pub avatar_url: Option<String>, pub avatar_url: Option<String>,
/// The time, in ms, since the user interacted with the server. /// The time, in ms, since the user interacted with the server.
@ -52,10 +56,18 @@ pub struct RoomMember {
pub power_level: Option<Int>, pub power_level: Option<Int>,
/// The normalized power level of this `RoomMember` (0-100). /// The normalized power level of this `RoomMember` (0-100).
pub power_level_norm: Option<Int>, pub power_level_norm: Option<Int>,
/// The `MembershipState` of this `RoomMember`.
pub membership: MembershipState,
/// The human readable name of this room member. /// The human readable name of this room member.
pub name: String, pub name: String,
// FIXME: The docstring below is currently a lie since we only store the initial event that
// creates the member (the one we pass to RoomMember::new).
//
// The intent of this field is to keep the last (or last few?) state events related to the room
// member cached so we can quickly go back to the previous one in case some of them get
// redacted. Keeping all state for each room member is probably too much.
//
// Needs design.
/// The events that created the state of this room member.
pub events: Vec<SyncStateEvent<MemberEventContent>>,
/// The `PresenceEvent`s connected to this user. /// The `PresenceEvent`s connected to this user.
pub presence_events: Vec<PresenceEvent>, pub presence_events: Vec<PresenceEvent>,
} }
@ -67,18 +79,20 @@ impl PartialEq for RoomMember {
&& self.user_id == other.user_id && self.user_id == other.user_id
&& self.name == other.name && self.name == other.name
&& self.display_name == other.display_name && self.display_name == other.display_name
&& self.display_name_ambiguous == other.display_name_ambiguous
&& self.avatar_url == other.avatar_url && self.avatar_url == other.avatar_url
&& self.last_active_ago == other.last_active_ago && self.last_active_ago == other.last_active_ago
} }
} }
impl RoomMember { impl RoomMember {
pub fn new(event: &StateEventStub<MemberEventContent>, room_id: &RoomId) -> Self { pub fn new(event: &SyncStateEvent<MemberEventContent>, room_id: &RoomId) -> Self {
Self { Self {
name: event.state_key.clone(), name: event.state_key.clone(),
room_id: room_id.clone(), room_id: room_id.clone(),
user_id: UserId::try_from(event.state_key.as_str()).unwrap(), user_id: UserId::try_from(event.state_key.as_str()).unwrap(),
display_name: event.content.displayname.clone(), display_name: event.content.displayname.clone(),
display_name_ambiguous: false,
avatar_url: event.content.avatar_url.clone(), avatar_url: event.content.avatar_url.clone(),
presence: None, presence: None,
status_msg: None, status_msg: None,
@ -87,12 +101,13 @@ impl RoomMember {
typing: None, typing: None,
power_level: None, power_level: None,
power_level_norm: None, power_level_norm: None,
membership: event.content.membership, presence_events: Vec::default(),
presence_events: vec![], events: vec![event.clone()],
} }
} }
/// Returns the most ergonomic name available for the member. /// Returns the most ergonomic (but potentially ambiguous/non-unique) name
/// available for the member.
/// ///
/// This is the member's display name if it is set, otherwise their MXID. /// This is the member's display name if it is set, otherwise their MXID.
pub fn name(&self) -> String { pub fn name(&self) -> String {
@ -101,10 +116,11 @@ impl RoomMember {
.unwrap_or_else(|| format!("{}", self.user_id)) .unwrap_or_else(|| format!("{}", self.user_id))
} }
/// Returns a name for the member which is guaranteed to be unique. /// Returns a name for the member which is guaranteed to be unique, but not
/// necessarily the most ergonomic.
/// ///
/// This is either of the format "DISPLAY_NAME (MXID)" if the display name is set for the /// This is either a name in the format "DISPLAY_NAME (MXID)" if the
/// member, or simply "MXID" if not. /// member's display name is set, or simply "MXID" if not.
pub fn unique_name(&self) -> String { pub fn unique_name(&self) -> String {
self.display_name self.display_name
.clone() .clone()
@ -112,100 +128,28 @@ impl RoomMember {
.unwrap_or_else(|| format!("{}", self.user_id)) .unwrap_or_else(|| format!("{}", self.user_id))
} }
/// Handle profile updates. /// Get the disambiguated display name for the member which is as ergonomic
pub(crate) fn update_profile(&mut self, event: &StateEventStub<MemberEventContent>) -> bool { /// as possible while still guaranteeing it is unique.
use MembershipChange::*; ///
/// If the member's display name is currently ambiguous (i.e. shared by
match event.membership_change() { /// other room members), this method will return the same result as
// we assume that the profile has changed /// `RoomMember::unique_name`. Otherwise, this method will return the same
ProfileChanged { .. } => { /// result as `RoomMember::name`.
self.display_name = event.content.displayname.clone(); ///
self.avatar_url = event.content.avatar_url.clone(); /// This is usually the name you want when showing room messages from the
true /// member or when showing the member in the member list.
} ///
/// **Warning**: When displaying a room member's display name, clients
// We're only interested in profile changes here. /// *must* use a disambiguated name, so they *must not* use
_ => false, /// `RoomMember::display_name` directly. Clients *should* use this method to
} /// obtain the name, but an acceptable alternative is to use
} /// `RoomMember::unique_name` in certain situations.
pub fn disambiguated_name(&self) -> String {
pub fn update_power( if self.display_name_ambiguous {
&mut self, self.unique_name()
event: &StateEventStub<PowerLevelsEventContent>,
max_power: Int,
) -> bool {
let changed;
if let Some(user_power) = event.content.users.get(&self.user_id) {
changed = self.power_level != Some(*user_power);
self.power_level = Some(*user_power);
} else { } else {
changed = self.power_level != Some(event.content.users_default); self.name()
self.power_level = Some(event.content.users_default);
} }
if max_power > int!(0) {
self.power_level_norm = Some((self.power_level.unwrap() * int!(100)) / max_power);
}
changed
}
/// If the current `PresenceEvent` updated the state of this `User`.
///
/// Returns true if the specific users presence has changed, false otherwise.
///
/// # Arguments
///
/// * `presence` - The presence event for a this room member.
pub fn did_update_presence(&self, presence: &PresenceEvent) -> bool {
let PresenceEvent {
content:
PresenceEventContent {
avatar_url,
currently_active,
displayname,
last_active_ago,
presence,
status_msg,
},
..
} = presence;
self.display_name == *displayname
&& self.avatar_url == *avatar_url
&& self.presence.as_ref() == Some(presence)
&& self.status_msg == *status_msg
&& self.last_active_ago == *last_active_ago
&& self.currently_active == *currently_active
}
/// Updates the `User`s presence.
///
/// This should only be used if `did_update_presence` was true.
///
/// # Arguments
///
/// * `presence` - The presence event for a this room member.
pub fn update_presence(&mut self, presence_ev: &PresenceEvent) {
let PresenceEvent {
content:
PresenceEventContent {
avatar_url,
currently_active,
displayname,
last_active_ago,
presence,
status_msg,
},
..
} = presence_ev;
self.presence_events.push(presence_ev.clone());
self.avatar_url = avatar_url.clone();
self.currently_active = *currently_active;
self.display_name = displayname.clone();
self.last_active_ago = *last_active_ago;
self.presence = Some(*presence);
self.status_msg = status_msg.clone();
} }
} }
@ -234,7 +178,9 @@ mod test {
client client
} }
fn get_room_id() -> RoomId { // TODO: Move this to EventBuilder since it's a magic room ID used in EventBuilder's example
// events.
fn test_room_id() -> RoomId {
RoomId::try_from("!SVkFJHzfwvuaIEawgC:localhost").unwrap() RoomId::try_from("!SVkFJHzfwvuaIEawgC:localhost").unwrap()
} }
@ -242,11 +188,11 @@ mod test {
async fn room_member_events() { async fn room_member_events() {
let client = get_client().await; let client = get_client().await;
let room_id = get_room_id(); let room_id = test_room_id();
let mut response = EventBuilder::default() let mut response = EventBuilder::default()
.add_state_event(EventsJson::Member) .add_room_event(EventsJson::Member)
.add_state_event(EventsJson::PowerLevels) .add_room_event(EventsJson::PowerLevels)
.build_sync_response(); .build_sync_response();
client.receive_sync_response(&mut response).await.unwrap(); client.receive_sync_response(&mut response).await.unwrap();
@ -261,15 +207,65 @@ mod test {
assert_eq!(member.power_level, Some(int!(100))); assert_eq!(member.power_level, Some(int!(100)));
} }
#[async_test]
async fn room_member_display_name_change() {
let client = get_client().await;
let room_id = test_room_id();
let mut builder = EventBuilder::default();
let mut initial_response = builder
.add_room_event(EventsJson::Member)
.build_sync_response();
let mut name_change_response = builder
.add_room_event(EventsJson::MemberNameChange)
.build_sync_response();
client
.receive_sync_response(&mut initial_response)
.await
.unwrap();
let room = client.get_joined_room(&room_id).await.unwrap();
// Initially, the display name is "example".
{
let room = room.read().await;
let member = room
.joined_members
.get(&UserId::try_from("@example:localhost").unwrap())
.unwrap();
assert_eq!(member.display_name.as_ref().unwrap(), "example");
}
client
.receive_sync_response(&mut name_change_response)
.await
.unwrap();
// Afterwards, the display name is "changed".
{
let room = room.read().await;
let member = room
.joined_members
.get(&UserId::try_from("@example:localhost").unwrap())
.unwrap();
assert_eq!(member.display_name.as_ref().unwrap(), "changed");
}
}
#[async_test] #[async_test]
async fn member_presence_events() { async fn member_presence_events() {
let client = get_client().await; let client = get_client().await;
let room_id = get_room_id(); let room_id = test_room_id();
let mut response = EventBuilder::default() let mut response = EventBuilder::default()
.add_state_event(EventsJson::Member) .add_room_event(EventsJson::Member)
.add_state_event(EventsJson::PowerLevels) .add_room_event(EventsJson::PowerLevels)
.add_presence_event(EventsJson::Presence) .add_presence_event(EventsJson::Presence)
.build_sync_response(); .build_sync_response();

View File

@ -39,6 +39,36 @@ impl JsonStore {
user_path_set: AtomicBool::new(false), user_path_set: AtomicBool::new(false),
}) })
} }
/// Build a path for a file where the Room state to be stored in.
async fn build_room_path(&self, room_state: &str, room_id: &RoomId) -> PathBuf {
let mut path = self.path.read().await.clone();
path.push("rooms");
path.push(room_state);
path.push(JsonStore::sanitize_room_id(room_id));
path.set_extension("json");
path
}
/// Build a path for the file where the Client state to be stored in.
async fn build_client_path(&self) -> PathBuf {
let mut path = self.path.read().await.clone();
path.push("client");
path.set_extension("json");
path
}
/// Replace common characters that can't be used in a file name with an
/// underscore.
fn sanitize_room_id(room_id: &RoomId) -> String {
room_id.as_str().replace(
&['.', ':', '<', '>', '"', '/', '\\', '|', '?', '*'][..],
"_",
)
}
} }
impl fmt::Debug for JsonStore { impl fmt::Debug for JsonStore {
@ -57,8 +87,7 @@ impl StateStore for JsonStore {
self.path.write().await.push(sess.user_id.localpart()) self.path.write().await.push(sess.user_id.localpart())
} }
let mut path = self.path.read().await.clone(); let path = self.build_client_path().await;
path.push("client.json");
let json = async_fs::read_to_string(path) let json = async_fs::read_to_string(path)
.await .await
@ -114,8 +143,7 @@ impl StateStore for JsonStore {
} }
async fn store_client_state(&self, state: ClientState) -> Result<()> { async fn store_client_state(&self, state: ClientState) -> Result<()> {
let mut path = self.path.read().await.clone(); let path = self.build_client_path().await;
path.push("client.json");
if !path.exists() { if !path.exists() {
let mut dir = path.clone(); let mut dir = path.clone();
@ -146,9 +174,7 @@ impl StateStore for JsonStore {
self.path.write().await.push(room.own_user_id.localpart()) self.path.write().await.push(room.own_user_id.localpart())
} }
let mut path = self.path.read().await.clone(); let path = self.build_room_path(room_state, &room.room_id).await;
path.push("rooms");
path.push(&format!("{}/{}.json", room_state, room.room_id));
if !path.exists() { if !path.exists() {
let mut dir = path.clone(); let mut dir = path.clone();
@ -178,15 +204,13 @@ impl StateStore for JsonStore {
return Err(Error::StateStore("path for JsonStore not set".into())); return Err(Error::StateStore("path for JsonStore not set".into()));
} }
let mut to_del = self.path.read().await.clone(); let path = self.build_room_path(room_state, room_id).await;
to_del.push("rooms");
to_del.push(&format!("{}/{}.json", room_state, room_id));
if !to_del.exists() { if !path.exists() {
return Err(Error::StateStore(format!("file {:?} not found", to_del))); return Err(Error::StateStore(format!("file {:?} not found", path)));
} }
tokio::fs::remove_file(to_del).await.map_err(Error::from) tokio::fs::remove_file(path).await.map_err(Error::from)
} }
} }

View File

@ -159,7 +159,6 @@ mod test {
"creator": null, "creator": null,
"joined_members": {}, "joined_members": {},
"invited_members": {}, "invited_members": {},
"disambiguated_display_names": {},
"typing_users": [], "typing_users": [],
"power_levels": null, "power_levels": null,
"encrypted": null, "encrypted": null,
@ -176,7 +175,6 @@ mod test {
serde_json::json!({ serde_json::json!({
"!roomid:example.com": { "!roomid:example.com": {
"room_id": "!roomid:example.com", "room_id": "!roomid:example.com",
"disambiguated_display_names": {},
"room_name": { "room_name": {
"name": null, "name": null,
"canonical_alias": null, "canonical_alias": null,

View File

@ -11,13 +11,13 @@ repository = "https://github.com/matrix-org/matrix-rust-sdk"
version = "0.1.0" version = "0.1.0"
[dependencies] [dependencies]
instant = { version = "0.1.4", features = ["wasm-bindgen", "now"] } instant = { version = "0.1.6", features = ["wasm-bindgen", "now"] }
js_int = "0.1.8" js_int = "0.1.8"
[dependencies.ruma] [dependencies.ruma]
path = "/home/poljar/werk/priv/ruma/ruma" git = "https://github.com/ruma/ruma"
features = ["client-api"] features = ["client-api"]
rev = "c19bcaab" rev = "848b22568106d05c5444f3fe46070d5aa16e422b"
[target.'cfg(not(target_arch = "wasm32"))'.dependencies] [target.'cfg(not(target_arch = "wasm32"))'.dependencies]
uuid = { version = "0.8.1", features = ["v4"] } uuid = { version = "0.8.1", features = ["v4"] }

View File

@ -14,5 +14,5 @@ version = "0.1.0"
proc-macro = true proc-macro = true
[dependencies] [dependencies]
syn = "1.0.33" syn = "1.0.34"
quote = "1.0.7" quote = "1.0.7"

View File

@ -29,7 +29,7 @@ url = "2.1.1"
# Misc dependencies # Misc dependencies
thiserror = "1.0.20" thiserror = "1.0.20"
tracing = "0.1.15" tracing = "0.1.16"
atomic = "0.4.6" atomic = "0.4.6"
dashmap = "3.11.7" dashmap = "3.11.7"

View File

@ -33,9 +33,10 @@ use crate::verify_json;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Device { pub struct Device {
user_id: Arc<UserId>, user_id: Arc<UserId>,
device_id: Arc<DeviceId>, device_id: Arc<Box<DeviceId>>,
algorithms: Arc<Vec<Algorithm>>, algorithms: Arc<Vec<Algorithm>>,
keys: Arc<BTreeMap<AlgorithmAndDeviceId, String>>, keys: Arc<BTreeMap<AlgorithmAndDeviceId, String>>,
signatures: Arc<BTreeMap<UserId, BTreeMap<AlgorithmAndDeviceId, String>>>,
display_name: Arc<Option<String>>, display_name: Arc<Option<String>>,
deleted: Arc<AtomicBool>, deleted: Arc<AtomicBool>,
trust_state: Arc<Atomic<TrustState>>, trust_state: Arc<Atomic<TrustState>>,
@ -70,17 +71,19 @@ impl Device {
/// Create a new Device. /// Create a new Device.
pub fn new( pub fn new(
user_id: UserId, user_id: UserId,
device_id: DeviceId, device_id: Box<DeviceId>,
display_name: Option<String>, display_name: Option<String>,
trust_state: TrustState, trust_state: TrustState,
algorithms: Vec<Algorithm>, algorithms: Vec<Algorithm>,
keys: BTreeMap<AlgorithmAndDeviceId, String>, keys: BTreeMap<AlgorithmAndDeviceId, String>,
signatures: BTreeMap<UserId, BTreeMap<AlgorithmAndDeviceId, String>>,
) -> Self { ) -> Self {
Device { Device {
user_id: Arc::new(user_id), user_id: Arc::new(user_id),
device_id: Arc::new(device_id), device_id: Arc::new(device_id),
display_name: Arc::new(display_name), display_name: Arc::new(display_name),
trust_state: Arc::new(Atomic::new(trust_state)), trust_state: Arc::new(Atomic::new(trust_state)),
signatures: Arc::new(signatures),
algorithms: Arc::new(algorithms), algorithms: Arc::new(algorithms),
keys: Arc::new(keys), keys: Arc::new(keys),
deleted: Arc::new(AtomicBool::new(false)), deleted: Arc::new(AtomicBool::new(false)),
@ -104,8 +107,10 @@ impl Device {
/// Get the key of the given key algorithm belonging to this device. /// Get the key of the given key algorithm belonging to this device.
pub fn get_key(&self, algorithm: KeyAlgorithm) -> Option<&String> { pub fn get_key(&self, algorithm: KeyAlgorithm) -> Option<&String> {
self.keys self.keys.get(&AlgorithmAndDeviceId(
.get(&AlgorithmAndDeviceId(algorithm, self.device_id.to_string())) algorithm,
self.device_id.as_ref().clone(),
))
} }
/// Get a map containing all the device keys. /// Get a map containing all the device keys.
@ -113,6 +118,11 @@ impl Device {
&self.keys &self.keys
} }
/// Get a map containing all the device signatures.
pub fn signatures(&self) -> &BTreeMap<UserId, BTreeMap<AlgorithmAndDeviceId, String>> {
&self.signatures
}
/// Get the trust state of the device. /// Get the trust state of the device.
pub fn trust_state(&self) -> TrustState { pub fn trust_state(&self) -> TrustState {
self.trust_state.load(Ordering::Relaxed) self.trust_state.load(Ordering::Relaxed)
@ -142,6 +152,7 @@ impl Device {
self.algorithms = Arc::new(device_keys.algorithms.clone()); self.algorithms = Arc::new(device_keys.algorithms.clone());
self.keys = Arc::new(device_keys.keys.clone()); self.keys = Arc::new(device_keys.keys.clone());
self.signatures = Arc::new(device_keys.signatures.clone());
self.display_name = display_name; self.display_name = display_name;
Ok(()) Ok(())
@ -173,37 +184,11 @@ impl Device {
pub(crate) fn mark_as_deleted(&self) { pub(crate) fn mark_as_deleted(&self) {
self.deleted.store(true, Ordering::Relaxed); self.deleted.store(true, Ordering::Relaxed);
} }
}
#[cfg(test)] #[cfg(test)]
impl From<&OlmMachine> for Device { pub async fn from_machine(machine: &OlmMachine) -> Device {
fn from(machine: &OlmMachine) -> Self { let device_keys = machine.account.device_keys().await;
Device { Device::try_from(&device_keys).unwrap()
user_id: Arc::new(machine.user_id().clone()),
device_id: Arc::new(machine.device_id().clone()),
algorithms: Arc::new(vec![
Algorithm::MegolmV1AesSha2,
Algorithm::OlmV1Curve25519AesSha2,
]),
keys: Arc::new(
machine
.identity_keys()
.iter()
.map(|(key, value)| {
(
AlgorithmAndDeviceId(
KeyAlgorithm::try_from(key.as_ref()).unwrap(),
machine.device_id().clone(),
),
value.to_owned(),
)
})
.collect(),
),
display_name: Arc::new(None),
deleted: Arc::new(AtomicBool::new(false)),
trust_state: Arc::new(Atomic::new(TrustState::Unset)),
}
} }
} }
@ -215,6 +200,7 @@ impl TryFrom<&DeviceKeys> for Device {
user_id: Arc::new(device_keys.user_id.clone()), user_id: Arc::new(device_keys.user_id.clone()),
device_id: Arc::new(device_keys.device_id.clone()), device_id: Arc::new(device_keys.device_id.clone()),
algorithms: Arc::new(device_keys.algorithms.clone()), algorithms: Arc::new(device_keys.algorithms.clone()),
signatures: Arc::new(device_keys.signatures.clone()),
keys: Arc::new(device_keys.keys.clone()), keys: Arc::new(device_keys.keys.clone()),
display_name: Arc::new( display_name: Arc::new(
device_keys device_keys

View File

@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
use cjson::Error as CjsonError; use cjson::Error as CjsonError;
use matrix_sdk_common::identifiers::{DeviceId, UserId};
use olm_rs::errors::{OlmGroupSessionError, OlmSessionError}; use olm_rs::errors::{OlmGroupSessionError, OlmSessionError};
use serde_json::Error as SerdeError; use serde_json::Error as SerdeError;
use thiserror::Error; use thiserror::Error;
@ -49,6 +50,14 @@ pub enum OlmError {
/// The session with a device has become corrupted. /// The session with a device has become corrupted.
#[error("decryption failed likely because a Olm session was wedged")] #[error("decryption failed likely because a Olm session was wedged")]
SessionWedged, SessionWedged,
/// Encryption failed because the device does not have a valid Olm session
/// with us.
#[error(
"encryption failed because the device does not \
have a valid Olm session with us"
)]
MissingSession,
} }
/// Error representing a failure during a group encryption operation. /// Error representing a failure during a group encryption operation.
@ -93,6 +102,9 @@ pub enum EventError {
#[error("the Encrypted message is missing the signing key of the sender")] #[error("the Encrypted message is missing the signing key of the sender")]
MissingSigningKey, MissingSigningKey,
#[error("the Encrypted message is missing the sender key")]
MissingSenderKey,
#[error("the Encrypted message is missing the field {0}")] #[error("the Encrypted message is missing the field {0}")]
MissingField(String), MissingField(String),
@ -121,6 +133,29 @@ pub enum SignatureError {
VerificationError, VerificationError,
} }
#[derive(Error, Debug)]
pub(crate) enum SessionCreationError {
#[error(
"Failed to create a new Olm session for {0} {1}, the requested \
one-time key isn't a signed curve key"
)]
OneTimeKeyNotSigned(UserId, Box<DeviceId>),
#[error(
"Tried to create a new Olm session for {0} {1}, but the signed \
one-time key is missing"
)]
OneTimeKeyMissing(UserId, Box<DeviceId>),
#[error("Failed to verify the one-time key signatures for {0} {1}: {2:?}")]
InvalidSignature(UserId, Box<DeviceId>, SignatureError),
#[error(
"Tried to create an Olm session for {0} {1}, but the device is missing \
a curve25519 key"
)]
DeviceMissingCurveKey(UserId, Box<DeviceId>),
#[error("Error creating new Olm session for {0} {1}: {2:?}")]
OlmError(UserId, Box<DeviceId>, OlmSessionError),
}
impl From<CjsonError> for SignatureError { impl From<CjsonError> for SignatureError {
fn from(error: CjsonError) -> Self { fn from(error: CjsonError) -> Self {
Self::CanonicalJsonError(error) Self::CanonicalJsonError(error)

View File

@ -38,7 +38,7 @@ pub use device::{Device, TrustState};
pub use error::{MegolmError, OlmError}; pub use error::{MegolmError, OlmError};
pub use machine::{OlmMachine, OneTimeKeys}; pub use machine::{OlmMachine, OneTimeKeys};
pub use memory_stores::{DeviceStore, GroupSessionStore, SessionStore, UserDevices}; pub use memory_stores::{DeviceStore, GroupSessionStore, SessionStore, UserDevices};
pub use olm::{Account, InboundGroupSession, OutboundGroupSession, Session}; pub use olm::{Account, IdentityKeys, InboundGroupSession, OutboundGroupSession, Session};
#[cfg(feature = "sqlite-cryptostore")] #[cfg(feature = "sqlite-cryptostore")]
pub use store::sqlite::SqliteStore; pub use store::sqlite::SqliteStore;
pub use store::{CryptoStore, CryptoStoreError}; pub use store::{CryptoStore, CryptoStoreError};
@ -83,7 +83,7 @@ pub(crate) fn verify_json(
json_object.insert("unsigned".to_string(), u); json_object.insert("unsigned".to_string(), u);
} }
let key_id = AlgorithmAndDeviceId(KeyAlgorithm::Ed25519, key_id.to_string()); let key_id = AlgorithmAndDeviceId(KeyAlgorithm::Ed25519, key_id.into());
let signatures = signatures.ok_or(SignatureError::NoSignatureFound)?; let signatures = signatures.ok_or(SignatureError::NoSignatureFound)?;
let signature_object = signatures let signature_object = signatures

View File

@ -23,7 +23,6 @@ use std::result::Result as StdResult;
use super::error::{EventError, MegolmError, MegolmResult, OlmError, OlmResult}; use super::error::{EventError, MegolmError, MegolmResult, OlmError, OlmResult};
use super::olm::{ use super::olm::{
Account, GroupSessionKey, IdentityKeys, InboundGroupSession, OlmMessage, OutboundGroupSession, Account, GroupSessionKey, IdentityKeys, InboundGroupSession, OlmMessage, OutboundGroupSession,
Session,
}; };
use super::store::memorystore::MemoryStore; use super::store::memorystore::MemoryStore;
#[cfg(feature = "sqlite-cryptostore")] #[cfg(feature = "sqlite-cryptostore")]
@ -32,13 +31,10 @@ use super::{device::Device, store::Result as StoreResult, CryptoStore};
use matrix_sdk_common::api; use matrix_sdk_common::api;
use matrix_sdk_common::events::{ use matrix_sdk_common::events::{
forwarded_room_key::ForwardedRoomKeyEventContent, forwarded_room_key::ForwardedRoomKeyEventContent, room::encrypted::EncryptedEventContent,
room::encrypted::{CiphertextInfo, EncryptedEventContent, OlmV1Curve25519AesSha2Content}, room::message::MessageEventContent, room_key::RoomKeyEventContent,
room::message::MessageEventContent, room_key_request::RoomKeyRequestEventContent, Algorithm, AnySyncRoomEvent, AnyToDeviceEvent,
room_key::RoomKeyEventContent, EventJson, EventType, SyncMessageEvent, ToDeviceEvent,
room_key_request::RoomKeyRequestEventContent,
Algorithm, AnyRoomEventStub, AnyToDeviceEvent, EventJson, EventType, MessageEventStub,
ToDeviceEvent,
}; };
use matrix_sdk_common::identifiers::{DeviceId, RoomId, UserId}; use matrix_sdk_common::identifiers::{DeviceId, RoomId, UserId};
use matrix_sdk_common::uuid::Uuid; use matrix_sdk_common::uuid::Uuid;
@ -50,7 +46,7 @@ use api::r0::{
to_device::{send_event_to_device::Request as ToDeviceRequest, DeviceIdOrAllDevices}, to_device::{send_event_to_device::Request as ToDeviceRequest, DeviceIdOrAllDevices},
}; };
use serde_json::{json, Value}; use serde_json::Value;
use tracing::{debug, error, info, instrument, trace, warn}; use tracing::{debug, error, info, instrument, trace, warn};
/// A map from the algorithm and device id to a one-time key. /// A map from the algorithm and device id to a one-time key.
@ -64,9 +60,9 @@ pub struct OlmMachine {
/// The unique user id that owns this account. /// The unique user id that owns this account.
user_id: UserId, user_id: UserId,
/// The unique device id of the device that holds this account. /// The unique device id of the device that holds this account.
device_id: DeviceId, device_id: Box<DeviceId>,
/// Our underlying Olm Account holding our identity keys. /// Our underlying Olm Account holding our identity keys.
account: Account, pub(crate) account: Account,
/// Store for the encryption keys. /// Store for the encryption keys.
/// Persists all the encryption keys so a client can resume the session /// Persists all the encryption keys so a client can resume the session
/// without the need to create new keys. /// without the need to create new keys.
@ -98,10 +94,11 @@ impl OlmMachine {
/// * `user_id` - The unique id of the user that owns this machine. /// * `user_id` - The unique id of the user that owns this machine.
/// ///
/// * `device_id` - The unique id of the device that owns this machine. /// * `device_id` - The unique id of the device that owns this machine.
#[allow(clippy::ptr_arg)]
pub fn new(user_id: &UserId, device_id: &DeviceId) -> Self { pub fn new(user_id: &UserId, device_id: &DeviceId) -> Self {
OlmMachine { OlmMachine {
user_id: user_id.clone(), user_id: user_id.clone(),
device_id: device_id.to_owned(), device_id: device_id.into(),
account: Account::new(user_id, &device_id), account: Account::new(user_id, &device_id),
store: Box::new(MemoryStore::new()), store: Box::new(MemoryStore::new()),
outbound_group_sessions: HashMap::new(), outbound_group_sessions: HashMap::new(),
@ -127,7 +124,7 @@ impl OlmMachine {
/// the encryption keys. /// the encryption keys.
pub async fn new_with_store( pub async fn new_with_store(
user_id: UserId, user_id: UserId,
device_id: String, device_id: Box<DeviceId>,
mut store: Box<dyn CryptoStore>, mut store: Box<dyn CryptoStore>,
) -> StoreResult<Self> { ) -> StoreResult<Self> {
let account = match store.load_account().await? { let account = match store.load_account().await? {
@ -163,14 +160,14 @@ impl OlmMachine {
/// * `device_id` - The unique id of the device that owns this machine. /// * `device_id` - The unique id of the device that owns this machine.
pub async fn new_with_default_store<P: AsRef<Path>>( pub async fn new_with_default_store<P: AsRef<Path>>(
user_id: &UserId, user_id: &UserId,
device_id: &str, device_id: &DeviceId,
path: P, path: P,
passphrase: &str, passphrase: &str,
) -> StoreResult<Self> { ) -> StoreResult<Self> {
let store = let store =
SqliteStore::open_with_passphrase(&user_id, device_id, path, passphrase).await?; SqliteStore::open_with_passphrase(&user_id, device_id, path, passphrase).await?;
OlmMachine::new_with_store(user_id.to_owned(), device_id.to_owned(), Box::new(store)).await OlmMachine::new_with_store(user_id.to_owned(), device_id.into(), Box::new(store)).await
} }
/// The unique user id that owns this identity. /// The unique user id that owns this identity.
@ -254,7 +251,7 @@ impl OlmMachine {
pub async fn get_missing_sessions( pub async fn get_missing_sessions(
&mut self, &mut self,
users: impl Iterator<Item = &UserId>, users: impl Iterator<Item = &UserId>,
) -> OlmResult<BTreeMap<UserId, BTreeMap<DeviceId, KeyAlgorithm>>> { ) -> OlmResult<BTreeMap<UserId, BTreeMap<Box<DeviceId>, KeyAlgorithm>>> {
let mut missing = BTreeMap::new(); let mut missing = BTreeMap::new();
for user_id in users { for user_id in users {
@ -281,10 +278,8 @@ impl OlmMachine {
} }
let user_map = missing.get_mut(user_id).unwrap(); let user_map = missing.get_mut(user_id).unwrap();
let _ = user_map.insert( let _ =
device.device_id().to_owned(), user_map.insert(device.device_id().into(), KeyAlgorithm::SignedCurve25519);
KeyAlgorithm::SignedCurve25519,
);
} }
} }
} }
@ -306,76 +301,35 @@ impl OlmMachine {
for (user_id, user_devices) in &response.one_time_keys { for (user_id, user_devices) in &response.one_time_keys {
for (device_id, key_map) in user_devices { for (device_id, key_map) in user_devices {
let device = if let Some(d) = self let device: Device = match self.store.get_device(&user_id, device_id).await {
.store Ok(d) => {
.get_device(&user_id, device_id) if let Some(d) = d {
.await
.expect("Can't get devices")
{
d d
} else { } else {
warn!( warn!(
"Tried to create an Olm session for {} {}, but the device is unknown", "Tried to create an Olm session for {} {}, but \
user_id, device_id the device is unknown",
);
continue;
};
// TODO move this logic into the account, pass the device to the
// account when creating an outbound session.
let one_time_key = if let Some(k) = key_map.values().next() {
match k {
OneTimeKey::SignedKey(k) => k,
OneTimeKey::Key(_) => {
warn!(
"Tried to create an Olm session for {} {}, but
the requested key isn't a signed curve key",
user_id, device_id user_id, device_id
); );
continue; continue;
} }
} }
} else { Err(e) => {
warn!( warn!(
"Tried to create an Olm session for {} {}, but the "Tried to create an Olm session for {} {}, but \
signed one-time key is missing", can't fetch the device from the store {:?}",
user_id, device_id
);
continue;
};
if let Err(e) = device.verify_one_time_key(&one_time_key) {
warn!(
"Failed to verify the one-time key signatures for {} {}: {:?}",
user_id, device_id, e user_id, device_id, e
); );
continue; continue;
} }
let curve_key = if let Some(k) = device.get_key(KeyAlgorithm::Curve25519) {
k
} else {
warn!(
"Tried to create an Olm session for {} {}, but the
device is missing the curve key",
user_id, device_id
);
continue;
}; };
info!("Creating outbound Session for {} {}", user_id, device_id); info!("Creating outbound Session for {} {}", user_id, device_id);
let session = match self let session = match self.account.create_outbound_session(device, &key_map).await {
.account
.create_outbound_session(curve_key, &one_time_key)
.await
{
Ok(s) => s, Ok(s) => s,
Err(e) => { Err(e) => {
warn!( warn!("{:?}", e);
"Error creating new Olm session for {} {}: {}",
user_id, device_id, e
);
continue; continue;
} }
}; };
@ -396,7 +350,7 @@ impl OlmMachine {
async fn handle_devices_from_key_query( async fn handle_devices_from_key_query(
&mut self, &mut self,
device_keys_map: &BTreeMap<UserId, BTreeMap<DeviceId, DeviceKeys>>, device_keys_map: &BTreeMap<UserId, BTreeMap<Box<DeviceId>, DeviceKeys>>,
) -> StoreResult<Vec<Device>> { ) -> StoreResult<Vec<Device>> {
let mut changed_devices = Vec::new(); let mut changed_devices = Vec::new();
@ -446,7 +400,8 @@ impl OlmMachine {
changed_devices.push(device); changed_devices.push(device);
} }
let current_devices: HashSet<&DeviceId> = device_map.keys().collect(); let current_devices: HashSet<&DeviceId> =
device_map.keys().map(|id| id.as_ref()).collect();
let stored_devices = self.store.get_user_devices(&user_id).await.unwrap(); let stored_devices = self.store.get_user_devices(&user_id).await.unwrap();
let stored_devices_set: HashSet<&DeviceId> = stored_devices.keys().collect(); let stored_devices_set: HashSet<&DeviceId> = stored_devices.keys().collect();
@ -840,66 +795,51 @@ impl OlmMachine {
Ok(session.encrypt(content).await) Ok(session.encrypt(content).await)
} }
/// Encrypt some JSON content using the given Olm session. /// Encrypt the given event for the given Device
///
/// # Arguments
///
/// * `reciepient_device` - The device that the event should be encrypted
/// for.
///
/// * `event_type` - The type of the event.
///
/// * `content` - The content of the event that should be encrypted.
async fn olm_encrypt( async fn olm_encrypt(
&mut self, &mut self,
mut session: Session,
recipient_device: &Device, recipient_device: &Device,
event_type: EventType, event_type: EventType,
content: Value, content: Value,
) -> OlmResult<EncryptedEventContent> { ) -> OlmResult<EncryptedEventContent> {
let identity_keys = self.account.identity_keys(); let sender_key = if let Some(k) = recipient_device.get_key(KeyAlgorithm::Curve25519) {
k
// TODO most of this could go into the session, the session already } else {
// stores the curve key of the device, if we also store the ed25519 key warn!(
// with the session we'll only need to pass in the account to the "Trying to encrypt a Megolm session for user {} on device {}, \
// session and all of this can live in the session. but the device doesn't have a curve25519 key",
recipient_device.user_id(),
let recipient_signing_key = recipient_device recipient_device.device_id()
.get_key(KeyAlgorithm::Ed25519) );
.ok_or(EventError::MissingSigningKey)?; return Err(EventError::MissingSenderKey.into());
let recipient_sender_key = recipient_device
.get_key(KeyAlgorithm::Curve25519)
.ok_or(EventError::MissingSigningKey)?;
let payload = json!({
"sender": self.user_id,
"sender_device": self.device_id,
"keys": {
"ed25519": identity_keys.ed25519(),
},
"recipient": recipient_device.user_id(),
"recipient_keys": {
"ed25519": recipient_signing_key,
},
"type": event_type,
"content": content,
});
let plaintext = cjson::to_string(&payload)
.unwrap_or_else(|_| panic!(format!("Can't serialize {} to canonical JSON", payload)));
let ciphertext = session.encrypt(&plaintext).await.to_tuple();
let message_type: usize = ciphertext.0.into();
let ciphertext = CiphertextInfo {
body: ciphertext.1,
message_type: (message_type as u32).into(),
}; };
let mut content = BTreeMap::new(); let mut session = if let Some(s) = self.store.get_sessions(sender_key).await? {
let session = &s.lock().await[0];
content.insert(recipient_sender_key.to_owned(), ciphertext); session.clone()
} else {
warn!(
"Trying to encrypt a Megolm session for user {} on device {}, \
but no Olm session is found",
recipient_device.user_id(),
recipient_device.device_id()
);
return Err(OlmError::MissingSession);
};
let message = session.encrypt(recipient_device, event_type, content).await;
self.store.save_sessions(&[session]).await?; self.store.save_sessions(&[session]).await?;
Ok(EncryptedEventContent::OlmV1Curve25519AesSha2( message
OlmV1Curve25519AesSha2Content {
sender_key: identity_keys.curve25519().to_owned(),
ciphertext: content,
},
))
} }
/// Should the client share a group session for the given room. /// Should the client share a group session for the given room.
@ -946,102 +886,67 @@ impl OlmMachine {
I: IntoIterator<Item = &'a UserId>, I: IntoIterator<Item = &'a UserId>,
{ {
self.create_outbound_group_session(room_id).await?; self.create_outbound_group_session(room_id).await?;
let megolm_session = self.outbound_group_sessions.get(room_id).unwrap(); let session = self.outbound_group_sessions.get(room_id).unwrap();
if megolm_session.shared() { if session.shared() {
panic!("Session is already shared"); panic!("Session is already shared");
} }
let session_id = megolm_session.session_id().to_owned();
// TODO don't mark the session as shared automatically only, when all // TODO don't mark the session as shared automatically only, when all
// the requests are done, failure to send these requests will likely end // the requests are done, failure to send these requests will likely end
// up in wedged sessions. We'll need to store the requests and let the // up in wedged sessions. We'll need to store the requests and let the
// caller mark them as sent using an UUID. // caller mark them as sent using an UUID.
megolm_session.mark_as_shared(); session.mark_as_shared();
// TODO the key content creation can go into the OutboundGroupSession let mut devices = Vec::new();
// struct.
let key_content = json!({
"algorithm": Algorithm::MegolmV1AesSha2,
"room_id": room_id,
"session_id": session_id.clone(),
"session_key": megolm_session.session_key().await,
"chain_index": megolm_session.message_index().await,
});
let mut user_map = Vec::new();
for user_id in users { for user_id in users {
for device in self.store.get_user_devices(user_id).await?.devices() { for device in self.store.get_user_devices(user_id).await?.devices() {
let sender_key = if let Some(k) = device.get_key(KeyAlgorithm::Curve25519) {
k
} else {
warn!(
"The device {} of user {} doesn't have a curve 25519 key",
user_id,
device.device_id()
);
// TODO mark the user for a key query.
continue;
};
// TODO abort if the device isn't verified // TODO abort if the device isn't verified
let sessions = self.store.get_sessions(sender_key).await?; devices.push(device.clone());
if let Some(s) = sessions {
let session = &s.lock().await[0];
// TODO once the session has the all the device info, we
// won't need the device anymore to encrypt stuff with the
// session.
user_map.push((session.clone(), device.clone()));
} else {
warn!(
"Trying to encrypt a Megolm session for user
{} on device {}, but no Olm session is found",
user_id,
device.device_id()
);
}
} }
} }
let mut message_vec = Vec::new(); let mut requests = Vec::new();
let key_content = session.as_json().await;
for user_map_chunk in user_map.chunks(OlmMachine::MAX_TO_DEVICE_MESSAGES) { for device_map_chunk in devices.chunks(OlmMachine::MAX_TO_DEVICE_MESSAGES) {
let mut messages = BTreeMap::new(); let mut messages = BTreeMap::new();
for (session, device) in user_map_chunk { for device in device_map_chunk {
let encrypted = self
.olm_encrypt(&device, EventType::RoomKey, key_content.clone())
.await;
let encrypted = match encrypted {
Ok(c) => c,
Err(OlmError::MissingSession)
| Err(OlmError::EventError(EventError::MissingSenderKey)) => {
continue;
}
Err(e) => return Err(e),
};
if !messages.contains_key(device.user_id()) { if !messages.contains_key(device.user_id()) {
messages.insert(device.user_id().clone(), BTreeMap::new()); messages.insert(device.user_id().clone(), BTreeMap::new());
}; };
let user_messages = messages.get_mut(device.user_id()).unwrap(); let user_messages = messages.get_mut(device.user_id()).unwrap();
let encrypted_content = self
.olm_encrypt(
session.clone(),
&device,
EventType::RoomKey,
key_content.clone(),
)
.await?;
user_messages.insert( user_messages.insert(
DeviceIdOrAllDevices::DeviceId(device.device_id().clone()), DeviceIdOrAllDevices::DeviceId(device.device_id().into()),
serde_json::value::to_raw_value(&encrypted_content)?, serde_json::value::to_raw_value(&encrypted)?,
); );
} }
message_vec.push(ToDeviceRequest { requests.push(ToDeviceRequest {
event_type: EventType::RoomEncrypted, event_type: EventType::RoomEncrypted,
txn_id: Uuid::new_v4().to_string(), txn_id: Uuid::new_v4().to_string(),
messages, messages,
}); });
} }
Ok(message_vec) Ok(requests)
} }
fn add_forwarded_room_key( fn add_forwarded_room_key(
@ -1150,8 +1055,6 @@ impl OlmMachine {
} }
}; };
// TODO make sure private keys are cleared from the event
// before we replace the result.
*event_result = decrypted_event; *event_result = decrypted_event;
} }
AnyToDeviceEvent::RoomKeyRequest(e) => self.handle_room_key_request(e), AnyToDeviceEvent::RoomKeyRequest(e) => self.handle_room_key_request(e),
@ -1177,9 +1080,9 @@ impl OlmMachine {
/// * `room_id` - The ID of the room where the event was sent to. /// * `room_id` - The ID of the room where the event was sent to.
pub async fn decrypt_room_event( pub async fn decrypt_room_event(
&mut self, &mut self,
event: &MessageEventStub<EncryptedEventContent>, event: &SyncMessageEvent<EncryptedEventContent>,
room_id: &RoomId, room_id: &RoomId,
) -> MegolmResult<EventJson<AnyRoomEventStub>> { ) -> MegolmResult<EventJson<AnySyncRoomEvent>> {
let content = match &event.content { let content = match &event.content {
EncryptedEventContent::MegolmV1AesSha2(c) => c, EncryptedEventContent::MegolmV1AesSha2(c) => c,
_ => return Err(EventError::UnsupportedAlgorithm.into()), _ => return Err(EventError::UnsupportedAlgorithm.into()),
@ -1286,8 +1189,8 @@ mod test {
encrypted::EncryptedEventContent, encrypted::EncryptedEventContent,
message::{MessageEventContent, TextMessageEventContent}, message::{MessageEventContent, TextMessageEventContent},
}, },
AnyMessageEventStub, AnyRoomEventStub, AnyToDeviceEvent, EventJson, EventType, AnySyncMessageEvent, AnySyncRoomEvent, AnyToDeviceEvent, EventJson, EventType,
MessageEventStub, ToDeviceEvent, UnsignedData, SyncMessageEvent, ToDeviceEvent, Unsigned,
}; };
use matrix_sdk_common::identifiers::{DeviceId, EventId, RoomId, UserId}; use matrix_sdk_common::identifiers::{DeviceId, EventId, RoomId, UserId};
use matrix_sdk_test::test_json; use matrix_sdk_test::test_json;
@ -1296,8 +1199,8 @@ mod test {
UserId::try_from("@alice:example.org").unwrap() UserId::try_from("@alice:example.org").unwrap()
} }
fn alice_device_id() -> DeviceId { fn alice_device_id() -> Box<DeviceId> {
"JLAFKJWSCS".to_string() "JLAFKJWSCS".into()
} }
fn user_id() -> UserId { fn user_id() -> UserId {
@ -1375,8 +1278,8 @@ mod test {
let alice_device = alice_device_id(); let alice_device = alice_device_id();
let alice = OlmMachine::new(&alice_id, &alice_device); let alice = OlmMachine::new(&alice_id, &alice_device);
let alice_deivce = Device::from(&alice); let alice_deivce = Device::from_machine(&alice).await;
let bob_device = Device::from(&bob); let bob_device = Device::from_machine(&bob).await;
alice.store.save_devices(&[bob_device]).await.unwrap(); alice.store.save_devices(&[bob_device]).await.unwrap();
bob.store.save_devices(&[alice_deivce]).await.unwrap(); bob.store.save_devices(&[alice_deivce]).await.unwrap();
@ -1409,16 +1312,6 @@ mod test {
async fn get_machine_pair_with_setup_sessions() -> (OlmMachine, OlmMachine) { async fn get_machine_pair_with_setup_sessions() -> (OlmMachine, OlmMachine) {
let (mut alice, mut bob) = get_machine_pair_with_session().await; let (mut alice, mut bob) = get_machine_pair_with_session().await;
let session = alice
.store
.get_sessions(bob.account.identity_keys().curve25519())
.await
.unwrap()
.unwrap()
.lock()
.await[0]
.clone();
let bob_device = alice let bob_device = alice
.store .store
.get_device(&bob.user_id, &bob.device_id) .get_device(&bob.user_id, &bob.device_id)
@ -1429,7 +1322,7 @@ mod test {
let event = ToDeviceEvent { let event = ToDeviceEvent {
sender: alice.user_id.clone(), sender: alice.user_id.clone(),
content: alice content: alice
.olm_encrypt(session, &bob_device, EventType::Dummy, json!({})) .olm_encrypt(&bob_device, EventType::Dummy, json!({}))
.await .await
.unwrap(), .unwrap(),
}; };
@ -1698,16 +1591,6 @@ mod test {
async fn test_olm_encryption() { async fn test_olm_encryption() {
let (mut alice, mut bob) = get_machine_pair_with_session().await; let (mut alice, mut bob) = get_machine_pair_with_session().await;
let session = alice
.store
.get_sessions(bob.account.identity_keys().curve25519())
.await
.unwrap()
.unwrap()
.lock()
.await[0]
.clone();
let bob_device = alice let bob_device = alice
.store .store
.get_device(&bob.user_id, &bob.device_id) .get_device(&bob.user_id, &bob.device_id)
@ -1718,7 +1601,7 @@ mod test {
let event = ToDeviceEvent { let event = ToDeviceEvent {
sender: alice.user_id.clone(), sender: alice.user_id.clone(),
content: alice content: alice
.olm_encrypt(session, &bob_device, EventType::Dummy, json!({})) .olm_encrypt(&bob_device, EventType::Dummy, json!({}))
.await .await
.unwrap(), .unwrap(),
}; };
@ -1804,12 +1687,12 @@ mod test {
let encrypted_content = alice.encrypt(&room_id, content.clone()).await.unwrap(); let encrypted_content = alice.encrypt(&room_id, content.clone()).await.unwrap();
let event = MessageEventStub { let event = SyncMessageEvent {
event_id: EventId::try_from("$xxxxx:example.org").unwrap(), event_id: EventId::try_from("$xxxxx:example.org").unwrap(),
origin_server_ts: SystemTime::now(), origin_server_ts: SystemTime::now(),
sender: alice.user_id().clone(), sender: alice.user_id().clone(),
content: encrypted_content, content: encrypted_content,
unsigned: UnsignedData::default(), unsigned: Unsigned::default(),
}; };
let decrypted_event = bob let decrypted_event = bob
@ -1820,7 +1703,7 @@ mod test {
.unwrap(); .unwrap();
match decrypted_event { match decrypted_event {
AnyRoomEventStub::Message(AnyMessageEventStub::RoomMessage(MessageEventStub { AnySyncRoomEvent::Message(AnySyncMessageEvent::RoomMessage(SyncMessageEvent {
sender, sender,
content, content,
.. ..

View File

@ -129,24 +129,24 @@ impl GroupSessionStore {
/// In-memory store holding the devices of users. /// In-memory store holding the devices of users.
#[derive(Clone, Debug, Default)] #[derive(Clone, Debug, Default)]
pub struct DeviceStore { pub struct DeviceStore {
entries: Arc<DashMap<UserId, DashMap<String, Device>>>, entries: Arc<DashMap<UserId, DashMap<Box<DeviceId>, Device>>>,
} }
/// A read only view over all devices belonging to a user. /// A read only view over all devices belonging to a user.
#[derive(Debug)] #[derive(Debug)]
pub struct UserDevices { pub struct UserDevices {
entries: ReadOnlyView<DeviceId, Device>, entries: ReadOnlyView<Box<DeviceId>, Device>,
} }
impl UserDevices { impl UserDevices {
/// Get the specific device with the given device id. /// Get the specific device with the given device id.
pub fn get(&self, device_id: &str) -> Option<Device> { pub fn get(&self, device_id: &DeviceId) -> Option<Device> {
self.entries.get(device_id).cloned() self.entries.get(device_id).cloned()
} }
/// Iterator over all the device ids of the user devices. /// Iterator over all the device ids of the user devices.
pub fn keys(&self) -> impl Iterator<Item = &DeviceId> { pub fn keys(&self) -> impl Iterator<Item = &DeviceId> {
self.entries.keys() self.entries.keys().map(|id| id.as_ref())
} }
/// Iterator over all the devices of the user devices. /// Iterator over all the devices of the user devices.
@ -175,12 +175,12 @@ impl DeviceStore {
let device_map = self.entries.get_mut(&user_id).unwrap(); let device_map = self.entries.get_mut(&user_id).unwrap();
device_map device_map
.insert(device.device_id().to_owned(), device) .insert(device.device_id().into(), device)
.is_none() .is_none()
} }
/// Get the device with the given device_id and belonging to the given user. /// Get the device with the given device_id and belonging to the given user.
pub fn get(&self, user_id: &UserId, device_id: &str) -> Option<Device> { pub fn get(&self, user_id: &UserId, device_id: &DeviceId) -> Option<Device> {
self.entries self.entries
.get(user_id) .get(user_id)
.and_then(|m| m.get(device_id).map(|d| d.value().clone())) .and_then(|m| m.get(device_id).map(|d| d.value().clone()))
@ -189,7 +189,7 @@ impl DeviceStore {
/// Remove the device with the given device_id and belonging to the given user. /// Remove the device with the given device_id and belonging to the given user.
/// ///
/// Returns the device if it was removed, None if it wasn't in the store. /// Returns the device if it was removed, None if it wasn't in the store.
pub fn remove(&self, user_id: &UserId, device_id: &str) -> Option<Device> { pub fn remove(&self, user_id: &UserId, device_id: &DeviceId) -> Option<Device> {
self.entries self.entries
.get(user_id) .get(user_id)
.and_then(|m| m.remove(device_id)) .and_then(|m| m.remove(device_id))
@ -292,8 +292,8 @@ mod test {
let user_devices = store.user_devices(device.user_id()); let user_devices = store.user_devices(device.user_id());
assert_eq!(user_devices.keys().nth(0).unwrap(), device.device_id()); assert_eq!(user_devices.keys().next().unwrap(), device.device_id());
assert_eq!(user_devices.devices().nth(0).unwrap(), &device); assert_eq!(user_devices.devices().next().unwrap(), &device);
let loaded_device = user_devices.get(device.device_id()).unwrap(); let loaded_device = user_devices.get(device.device_id()).unwrap();

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,550 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use matrix_sdk_common::instant::Instant;
use std::convert::TryFrom;
use std::convert::TryInto;
use std::fmt;
use std::sync::atomic::{AtomicBool, AtomicI64, Ordering};
use std::sync::Arc;
use matrix_sdk_common::locks::Mutex;
use serde_json::{json, Value};
use std::collections::BTreeMap;
pub use olm_rs::account::IdentityKeys;
use olm_rs::account::{OlmAccount, OneTimeKeys};
use olm_rs::errors::{OlmAccountError, OlmSessionError};
use olm_rs::PicklingMode;
use crate::device::Device;
use crate::error::SessionCreationError;
pub use olm_rs::{
session::{OlmMessage, PreKeyMessage},
utility::OlmUtility,
};
use matrix_sdk_common::{
api::r0::keys::{AlgorithmAndDeviceId, DeviceKeys, KeyAlgorithm, OneTimeKey, SignedKey},
events::Algorithm,
identifiers::{DeviceId, RoomId, UserId},
};
use super::{InboundGroupSession, OutboundGroupSession, Session};
/// Account holding identity keys for which sessions can be created.
///
/// An account is the central identity for encrypted communication between two
/// devices.
#[derive(Clone)]
pub struct Account {
pub(crate) user_id: Arc<UserId>,
pub(crate) device_id: Arc<Box<DeviceId>>,
inner: Arc<Mutex<OlmAccount>>,
pub(crate) identity_keys: Arc<IdentityKeys>,
shared: Arc<AtomicBool>,
/// The number of signed one-time keys we have uploaded to the server. If
/// this is None, no action will be taken. After a sync request the client
/// needs to set this for us, depending on the count we will suggest the
/// client to upload new keys.
uploaded_signed_key_count: Arc<AtomicI64>,
}
// #[cfg_attr(tarpaulin, skip)]
impl fmt::Debug for Account {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Account")
.field("identity_keys", self.identity_keys())
.field("shared", &self.shared())
.finish()
}
}
impl Account {
const ALGORITHMS: &'static [&'static Algorithm] = &[
&Algorithm::OlmV1Curve25519AesSha2,
&Algorithm::MegolmV1AesSha2,
];
/// Create a fresh new account, this will generate the identity key-pair.
#[allow(clippy::ptr_arg)]
pub fn new(user_id: &UserId, device_id: &DeviceId) -> Self {
let account = OlmAccount::new();
let identity_keys = account.parsed_identity_keys();
Account {
user_id: Arc::new(user_id.to_owned()),
device_id: Arc::new(device_id.into()),
inner: Arc::new(Mutex::new(account)),
identity_keys: Arc::new(identity_keys),
shared: Arc::new(AtomicBool::new(false)),
uploaded_signed_key_count: Arc::new(AtomicI64::new(0)),
}
}
/// Get the public parts of the identity keys for the account.
pub fn identity_keys(&self) -> &IdentityKeys {
&self.identity_keys
}
/// Update the uploaded key count.
///
/// # Arguments
///
/// * `new_count` - The new count that was reported by the server.
pub(crate) fn update_uploaded_key_count(&self, new_count: u64) {
let key_count = i64::try_from(new_count).unwrap_or(i64::MAX);
self.uploaded_signed_key_count
.store(key_count, Ordering::Relaxed);
}
/// Get the currently known uploaded key count.
pub fn uploaded_key_count(&self) -> i64 {
self.uploaded_signed_key_count.load(Ordering::Relaxed)
}
/// Has the account been shared with the server.
pub fn shared(&self) -> bool {
self.shared.load(Ordering::Relaxed)
}
/// Mark the account as shared.
///
/// Messages shouldn't be encrypted with the session before it has been
/// shared.
pub(crate) fn mark_as_shared(&self) {
self.shared.store(true, Ordering::Relaxed);
}
/// Get the one-time keys of the account.
///
/// This can be empty, keys need to be generated first.
pub(crate) async fn one_time_keys(&self) -> OneTimeKeys {
self.inner.lock().await.parsed_one_time_keys()
}
/// Generate count number of one-time keys.
pub(crate) async fn generate_one_time_keys_helper(&self, count: usize) {
self.inner.lock().await.generate_one_time_keys(count);
}
/// Get the maximum number of one-time keys the account can hold.
pub(crate) async fn max_one_time_keys(&self) -> usize {
self.inner.lock().await.max_number_of_one_time_keys()
}
/// Get a tuple of device and one-time keys that need to be uploaded.
///
/// Returns an empty error if no keys need to be uploaded.
pub(crate) async fn generate_one_time_keys(&self) -> Result<u64, ()> {
let count = self.uploaded_key_count() as u64;
let max_keys = self.max_one_time_keys().await;
let max_on_server = (max_keys as u64) / 2;
if count >= (max_on_server) {
return Err(());
}
let key_count = (max_on_server) - count;
let key_count: usize = key_count.try_into().unwrap_or(max_keys);
self.generate_one_time_keys_helper(key_count).await;
Ok(key_count as u64)
}
/// Should account or one-time keys be uploaded to the server.
pub(crate) async fn should_upload_keys(&self) -> bool {
if !self.shared() {
return true;
}
let count = self.uploaded_key_count() as u64;
// If we have a known key count, check that we have more than
// max_one_time_Keys() / 2, otherwise tell the client to upload more.
let max_keys = self.max_one_time_keys().await as u64;
// If there are more keys already uploaded than max_key / 2
// bail out returning false, this also avoids overflow.
if count > (max_keys / 2) {
return false;
}
let key_count = (max_keys / 2) - count;
key_count > 0
}
/// Get a tuple of device and one-time keys that need to be uploaded.
///
/// Returns an empty error if no keys need to be uploaded.
pub(crate) async fn keys_for_upload(
&self,
) -> Result<
(
Option<DeviceKeys>,
Option<BTreeMap<AlgorithmAndDeviceId, OneTimeKey>>,
),
(),
> {
if !self.should_upload_keys().await {
return Err(());
}
let device_keys = if !self.shared() {
Some(self.device_keys().await)
} else {
None
};
let one_time_keys = self.signed_one_time_keys().await.ok();
Ok((device_keys, one_time_keys))
}
/// Mark the current set of one-time keys as being published.
pub(crate) async fn mark_keys_as_published(&self) {
self.inner.lock().await.mark_keys_as_published();
}
/// Sign the given string using the accounts signing key.
///
/// Returns the signature as a base64 encoded string.
pub async fn sign(&self, string: &str) -> String {
self.inner.lock().await.sign(string)
}
/// Store the account as a base64 encoded string.
///
/// # Arguments
///
/// * `pickle_mode` - The mode that was used to pickle the account, either an
/// unencrypted mode or an encrypted using passphrase.
pub async fn pickle(&self, pickle_mode: PicklingMode) -> String {
self.inner.lock().await.pickle(pickle_mode)
}
/// Restore an account from a previously pickled string.
///
/// # Arguments
///
/// * `pickle` - The pickled string of the account.
///
/// * `pickle_mode` - The mode that was used to pickle the account, either an
/// unencrypted mode or an encrypted using passphrase.
///
/// * `shared` - Boolean determining if the account was uploaded to the
/// server.
#[allow(clippy::ptr_arg)]
pub fn from_pickle(
pickle: String,
pickle_mode: PicklingMode,
shared: bool,
uploaded_signed_key_count: i64,
user_id: &UserId,
device_id: &DeviceId,
) -> Result<Self, OlmAccountError> {
let account = OlmAccount::unpickle(pickle, pickle_mode)?;
let identity_keys = account.parsed_identity_keys();
Ok(Account {
user_id: Arc::new(user_id.to_owned()),
device_id: Arc::new(device_id.into()),
inner: Arc::new(Mutex::new(account)),
identity_keys: Arc::new(identity_keys),
shared: Arc::new(AtomicBool::from(shared)),
uploaded_signed_key_count: Arc::new(AtomicI64::new(uploaded_signed_key_count)),
})
}
/// Sign the device keys of the account and return them so they can be
/// uploaded.
pub(crate) async fn device_keys(&self) -> DeviceKeys {
let identity_keys = self.identity_keys();
let mut keys = BTreeMap::new();
keys.insert(
AlgorithmAndDeviceId(KeyAlgorithm::Curve25519, (*self.device_id).clone()),
identity_keys.curve25519().to_owned(),
);
keys.insert(
AlgorithmAndDeviceId(KeyAlgorithm::Ed25519, (*self.device_id).clone()),
identity_keys.ed25519().to_owned(),
);
let device_keys = json!({
"user_id": (*self.user_id).clone(),
"device_id": (*self.device_id).clone(),
"algorithms": Account::ALGORITHMS,
"keys": keys,
});
let mut signatures = BTreeMap::new();
let mut signature = BTreeMap::new();
signature.insert(
AlgorithmAndDeviceId(KeyAlgorithm::Ed25519, (*self.device_id).clone()),
self.sign_json(&device_keys).await,
);
signatures.insert((*self.user_id).clone(), signature);
DeviceKeys {
user_id: (*self.user_id).clone(),
device_id: (*self.device_id).clone(),
algorithms: vec![
Algorithm::OlmV1Curve25519AesSha2,
Algorithm::MegolmV1AesSha2,
],
keys,
signatures,
unsigned: None,
}
}
/// Convert a JSON value to the canonical representation and sign the JSON
/// string.
///
/// # Arguments
///
/// * `json` - The value that should be converted into a canonical JSON
/// string.
///
/// # Panic
///
/// Panics if the json value can't be serialized.
pub async fn sign_json(&self, json: &Value) -> String {
let canonical_json = cjson::to_string(json)
.unwrap_or_else(|_| panic!(format!("Can't serialize {} to canonical JSON", json)));
self.sign(&canonical_json).await
}
/// Generate, sign and prepare one-time keys to be uploaded.
///
/// If no one-time keys need to be uploaded returns an empty error.
pub(crate) async fn signed_one_time_keys(
&self,
) -> Result<BTreeMap<AlgorithmAndDeviceId, OneTimeKey>, ()> {
let _ = self.generate_one_time_keys().await?;
let one_time_keys = self.one_time_keys().await;
let mut one_time_key_map = BTreeMap::new();
for (key_id, key) in one_time_keys.curve25519().iter() {
let key_json = json!({
"key": key,
});
let signature = self.sign_json(&key_json).await;
let mut signature_map = BTreeMap::new();
signature_map.insert(
AlgorithmAndDeviceId(KeyAlgorithm::Ed25519, (*self.device_id).clone()),
signature,
);
let mut signatures = BTreeMap::new();
signatures.insert((*self.user_id).clone(), signature_map);
let signed_key = SignedKey {
key: key.to_owned(),
signatures,
};
one_time_key_map.insert(
AlgorithmAndDeviceId(KeyAlgorithm::SignedCurve25519, key_id.as_str().into()),
OneTimeKey::SignedKey(signed_key),
);
}
Ok(one_time_key_map)
}
/// Create a new session with another account given a one-time key.
///
/// Returns the newly created session or a `OlmSessionError` if creating a
/// session failed.
///
/// # Arguments
/// * `their_identity_key` - The other account's identity/curve25519 key.
///
/// * `their_one_time_key` - A signed one-time key that the other account
/// created and shared with us.
pub(crate) async fn create_outbound_session_helper(
&self,
their_identity_key: &str,
their_one_time_key: &SignedKey,
) -> Result<Session, OlmSessionError> {
let session = self
.inner
.lock()
.await
.create_outbound_session(their_identity_key, &their_one_time_key.key)?;
let now = Instant::now();
let session_id = session.session_id();
Ok(Session {
user_id: self.user_id.clone(),
device_id: self.device_id.clone(),
our_identity_keys: self.identity_keys.clone(),
inner: Arc::new(Mutex::new(session)),
session_id: Arc::new(session_id),
sender_key: Arc::new(their_identity_key.to_owned()),
creation_time: Arc::new(now),
last_use_time: Arc::new(now),
})
}
/// Create a new session with another account given a one-time key and a
/// device.
///
/// Returns the newly created session or a `OlmSessionError` if creating a
/// session failed.
///
/// # Arguments
/// * `device` - The other account's device.
///
/// * `key_map` - A map from the algorithm and device id to the one-time
/// key that the other account created and shared with us.
pub(crate) async fn create_outbound_session(
&self,
device: Device,
key_map: &BTreeMap<AlgorithmAndDeviceId, OneTimeKey>,
) -> Result<Session, SessionCreationError> {
let one_time_key = key_map.values().next().ok_or_else(|| {
SessionCreationError::OneTimeKeyMissing(
device.user_id().to_owned(),
device.device_id().into(),
)
})?;
let one_time_key = match one_time_key {
OneTimeKey::SignedKey(k) => k,
OneTimeKey::Key(_) => {
return Err(SessionCreationError::OneTimeKeyNotSigned(
device.user_id().to_owned(),
device.device_id().into(),
));
}
};
device.verify_one_time_key(&one_time_key).map_err(|e| {
SessionCreationError::InvalidSignature(
device.user_id().to_owned(),
device.device_id().into(),
e,
)
})?;
let curve_key = device.get_key(KeyAlgorithm::Curve25519).ok_or_else(|| {
SessionCreationError::DeviceMissingCurveKey(
device.user_id().to_owned(),
device.device_id().into(),
)
})?;
self.create_outbound_session_helper(curve_key, &one_time_key)
.await
.map_err(|e| {
SessionCreationError::OlmError(
device.user_id().to_owned(),
device.device_id().into(),
e,
)
})
}
/// Create a new session with another account given a pre-key Olm message.
///
/// Returns the newly created session or a `OlmSessionError` if creating a
/// session failed.
///
/// # Arguments
/// * `their_identity_key` - The other account's identitiy/curve25519 key.
///
/// * `message` - A pre-key Olm message that was sent to us by the other
/// account.
pub(crate) async fn create_inbound_session(
&self,
their_identity_key: &str,
message: PreKeyMessage,
) -> Result<Session, OlmSessionError> {
let session = self
.inner
.lock()
.await
.create_inbound_session_from(their_identity_key, message)?;
self.inner
.lock()
.await
.remove_one_time_keys(&session)
.expect(
"Session was successfully created but the account doesn't hold a matching one-time key",
);
let now = Instant::now();
let session_id = session.session_id();
Ok(Session {
user_id: self.user_id.clone(),
device_id: self.device_id.clone(),
our_identity_keys: self.identity_keys.clone(),
inner: Arc::new(Mutex::new(session)),
session_id: Arc::new(session_id),
sender_key: Arc::new(their_identity_key.to_owned()),
creation_time: Arc::new(now),
last_use_time: Arc::new(now),
})
}
/// Create a group session pair.
///
/// This session pair can be used to encrypt and decrypt messages meant for
/// a large group of participants.
///
/// The outbound session is used to encrypt messages while the inbound one
/// is used to decrypt messages encrypted by the outbound one.
///
/// # Arguments
///
/// * `room_id` - The ID of the room where the group session will be used.
pub(crate) async fn create_group_session_pair(
&self,
room_id: &RoomId,
) -> (OutboundGroupSession, InboundGroupSession) {
let outbound =
OutboundGroupSession::new(self.device_id.clone(), self.identity_keys.clone(), room_id);
let identity_keys = self.identity_keys();
let sender_key = identity_keys.curve25519();
let signing_key = identity_keys.ed25519();
let inbound = InboundGroupSession::new(
sender_key,
signing_key,
&room_id,
outbound.session_key().await,
)
.expect("Can't create inbound group session from a newly created outbound group session");
(outbound, inbound)
}
}
impl PartialEq for Account {
fn eq(&self, other: &Self) -> bool {
self.identity_keys() == other.identity_keys() && self.shared() == other.shared()
}
}

View File

@ -0,0 +1,412 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use matrix_sdk_common::instant::Instant;
use std::convert::TryInto;
use std::fmt;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Arc;
use matrix_sdk_common::locks::Mutex;
use serde::Serialize;
use serde_json::{json, Value};
use zeroize::Zeroize;
pub use olm_rs::account::IdentityKeys;
use olm_rs::errors::OlmGroupSessionError;
use olm_rs::inbound_group_session::OlmInboundGroupSession;
use olm_rs::outbound_group_session::OlmOutboundGroupSession;
use olm_rs::PicklingMode;
use crate::error::{EventError, MegolmResult};
pub use olm_rs::{
session::{OlmMessage, PreKeyMessage},
utility::OlmUtility,
};
use matrix_sdk_common::{
events::{
room::{
encrypted::{EncryptedEventContent, MegolmV1AesSha2Content},
message::MessageEventContent,
},
Algorithm, AnySyncRoomEvent, EventJson, EventType, SyncMessageEvent,
},
identifiers::{DeviceId, RoomId},
};
/// The private session key of a group session.
/// Can be used to create a new inbound group session.
#[derive(Clone, Debug, Serialize, Zeroize)]
#[zeroize(drop)]
pub struct GroupSessionKey(pub String);
/// Inbound group session.
///
/// Inbound group sessions are used to exchange room messages between a group of
/// participants. Inbound group sessions are used to decrypt the room messages.
#[derive(Clone)]
pub struct InboundGroupSession {
inner: Arc<Mutex<OlmInboundGroupSession>>,
session_id: Arc<String>,
pub(crate) sender_key: Arc<String>,
pub(crate) signing_key: Arc<String>,
pub(crate) room_id: Arc<RoomId>,
forwarding_chains: Arc<Mutex<Option<Vec<String>>>>,
}
impl InboundGroupSession {
/// Create a new inbound group session for the given room.
///
/// These sessions are used to decrypt room messages.
///
/// # Arguments
///
/// * `sender_key` - The public curve25519 key of the account that
/// sent us the session
///
/// * `signing_key` - The public ed25519 key of the account that
/// sent us the session.
///
/// * `room_id` - The id of the room that the session is used in.
///
/// * `session_key` - The private session key that is used to decrypt
/// messages.
pub fn new(
sender_key: &str,
signing_key: &str,
room_id: &RoomId,
session_key: GroupSessionKey,
) -> Result<Self, OlmGroupSessionError> {
let session = OlmInboundGroupSession::new(&session_key.0)?;
let session_id = session.session_id();
Ok(InboundGroupSession {
inner: Arc::new(Mutex::new(session)),
session_id: Arc::new(session_id),
sender_key: Arc::new(sender_key.to_owned()),
signing_key: Arc::new(signing_key.to_owned()),
room_id: Arc::new(room_id.clone()),
forwarding_chains: Arc::new(Mutex::new(None)),
})
}
/// Store the group session as a base64 encoded string.
///
/// # Arguments
///
/// * `pickle_mode` - The mode that was used to pickle the group session,
/// either an unencrypted mode or an encrypted using passphrase.
pub async fn pickle(&self, pickle_mode: PicklingMode) -> String {
self.inner.lock().await.pickle(pickle_mode)
}
/// Restore a Session from a previously pickled string.
///
/// Returns the restored group session or a `OlmGroupSessionError` if there
/// was an error.
///
/// # Arguments
///
/// * `pickle` - The pickled string of the group session session.
///
/// * `pickle_mode` - The mode that was used to pickle the session, either
/// an unencrypted mode or an encrypted using passphrase.
///
/// * `sender_key` - The public curve25519 key of the account that
/// sent us the session
///
/// * `signing_key` - The public ed25519 key of the account that
/// sent us the session.
///
/// * `room_id` - The id of the room that the session is used in.
pub fn from_pickle(
pickle: String,
pickle_mode: PicklingMode,
sender_key: String,
signing_key: String,
room_id: RoomId,
) -> Result<Self, OlmGroupSessionError> {
let session = OlmInboundGroupSession::unpickle(pickle, pickle_mode)?;
let session_id = session.session_id();
Ok(InboundGroupSession {
inner: Arc::new(Mutex::new(session)),
session_id: Arc::new(session_id),
sender_key: Arc::new(sender_key),
signing_key: Arc::new(signing_key),
room_id: Arc::new(room_id),
forwarding_chains: Arc::new(Mutex::new(None)),
})
}
/// Returns the unique identifier for this session.
pub fn session_id(&self) -> &str {
&self.session_id
}
/// Get the first message index we know how to decrypt.
pub async fn first_known_index(&self) -> u32 {
self.inner.lock().await.first_known_index()
}
/// Decrypt the given ciphertext.
///
/// Returns the decrypted plaintext or an `OlmGroupSessionError` if
/// decryption failed.
///
/// # Arguments
///
/// * `message` - The message that should be decrypted.
pub async fn decrypt_helper(
&self,
message: String,
) -> Result<(String, u32), OlmGroupSessionError> {
self.inner.lock().await.decrypt(message)
}
/// Decrypt an event from a room timeline.
///
/// # Arguments
///
/// * `event` - The event that should be decrypted.
pub async fn decrypt(
&self,
event: &SyncMessageEvent<EncryptedEventContent>,
) -> MegolmResult<(EventJson<AnySyncRoomEvent>, u32)> {
let content = match &event.content {
EncryptedEventContent::MegolmV1AesSha2(c) => c,
_ => return Err(EventError::UnsupportedAlgorithm.into()),
};
let (plaintext, message_index) = self.decrypt_helper(content.ciphertext.clone()).await?;
let mut decrypted_value = serde_json::from_str::<Value>(&plaintext)?;
let decrypted_object = decrypted_value
.as_object_mut()
.ok_or(EventError::NotAnObject)?;
// TODO better number conversion here.
let server_ts = event
.origin_server_ts
.duration_since(std::time::SystemTime::UNIX_EPOCH)
.unwrap_or_default()
.as_millis();
let server_ts: i64 = server_ts.try_into().unwrap_or_default();
decrypted_object.insert("sender".to_owned(), event.sender.to_string().into());
decrypted_object.insert("event_id".to_owned(), event.event_id.to_string().into());
decrypted_object.insert("origin_server_ts".to_owned(), server_ts.into());
decrypted_object.insert(
"unsigned".to_owned(),
serde_json::to_value(&event.unsigned).unwrap_or_default(),
);
Ok((
serde_json::from_value::<EventJson<AnySyncRoomEvent>>(decrypted_value)?,
message_index,
))
}
}
// #[cfg_attr(tarpaulin, skip)]
impl fmt::Debug for InboundGroupSession {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("InboundGroupSession")
.field("session_id", &self.session_id())
.finish()
}
}
impl PartialEq for InboundGroupSession {
fn eq(&self, other: &Self) -> bool {
self.session_id() == other.session_id()
}
}
/// Outbound group session.
///
/// Outbound group sessions are used to exchange room messages between a group
/// of participants. Outbound group sessions are used to encrypt the room
/// messages.
#[derive(Clone)]
pub struct OutboundGroupSession {
inner: Arc<Mutex<OlmOutboundGroupSession>>,
device_id: Arc<Box<DeviceId>>,
account_identity_keys: Arc<IdentityKeys>,
session_id: Arc<String>,
room_id: Arc<RoomId>,
creation_time: Arc<Instant>,
message_count: Arc<AtomicUsize>,
shared: Arc<AtomicBool>,
}
impl OutboundGroupSession {
/// Create a new outbound group session for the given room.
///
/// Outbound group sessions are used to encrypt room messages.
///
/// # Arguments
///
/// * `device_id` - The id of the device that created this session.
///
/// * `identity_keys` - The identity keys of the account that created this
/// session.
///
/// * `room_id` - The id of the room that the session is used in.
pub fn new(
device_id: Arc<Box<DeviceId>>,
identity_keys: Arc<IdentityKeys>,
room_id: &RoomId,
) -> Self {
let session = OlmOutboundGroupSession::new();
let session_id = session.session_id();
OutboundGroupSession {
inner: Arc::new(Mutex::new(session)),
room_id: Arc::new(room_id.to_owned()),
device_id,
account_identity_keys: identity_keys,
session_id: Arc::new(session_id),
creation_time: Arc::new(Instant::now()),
message_count: Arc::new(AtomicUsize::new(0)),
shared: Arc::new(AtomicBool::new(false)),
}
}
/// Encrypt the given plaintext using this session.
///
/// Returns the encrypted ciphertext.
///
/// # Arguments
///
/// * `plaintext` - The plaintext that should be encrypted.
pub(crate) async fn encrypt_helper(&self, plaintext: String) -> String {
let session = self.inner.lock().await;
session.encrypt(plaintext)
}
/// Encrypt a room message for the given room.
///
/// Beware that a group session needs to be shared before this method can be
/// called using the `share_group_session()` method.
///
/// Since group sessions can expire or become invalid if the room membership
/// changes client authors should check with the
/// `should_share_group_session()` method if a new group session needs to
/// be shared.
///
/// # Arguments
///
/// * `content` - The plaintext content of the message that should be
/// encrypted.
///
/// # Panics
///
/// Panics if the content can't be serialized.
pub async fn encrypt(&self, content: MessageEventContent) -> EncryptedEventContent {
let json_content = json!({
"content": content,
"room_id": &*self.room_id,
"type": EventType::RoomMessage,
});
let plaintext = cjson::to_string(&json_content).unwrap_or_else(|_| {
panic!(format!(
"Can't serialize {} to canonical JSON",
json_content
))
});
let ciphertext = self.encrypt_helper(plaintext).await;
EncryptedEventContent::MegolmV1AesSha2(MegolmV1AesSha2Content::new(
matrix_sdk_common::events::room::encrypted::MegolmV1AesSha2ContentInit {
ciphertext,
sender_key: self.account_identity_keys.curve25519().to_owned(),
session_id: self.session_id().to_owned(),
device_id: (&*self.device_id).to_owned(),
},
))
}
/// Check if the session has expired and if it should be rotated.
///
/// A session will expire after some time or if enough messages have been
/// encrypted using it.
pub fn expired(&self) -> bool {
// TODO implement this.
false
}
/// Mark the session as shared.
///
/// Messages shouldn't be encrypted with the session before it has been
/// shared.
pub fn mark_as_shared(&self) {
self.shared.store(true, Ordering::Relaxed);
}
/// Check if the session has been marked as shared.
pub fn shared(&self) -> bool {
self.shared.load(Ordering::Relaxed)
}
/// Get the session key of this session.
///
/// A session key can be used to to create an `InboundGroupSession`.
pub async fn session_key(&self) -> GroupSessionKey {
let session = self.inner.lock().await;
GroupSessionKey(session.session_key())
}
/// Returns the unique identifier for this session.
pub fn session_id(&self) -> &str {
&self.session_id
}
/// Get the current message index for this session.
///
/// Each message is sent with an increasing index. This returns the
/// message index that will be used for the next encrypted message.
pub async fn message_index(&self) -> u32 {
let session = self.inner.lock().await;
session.session_message_index()
}
/// Get the outbound group session key as a json value that can be sent as a
/// m.room_key.
pub async fn as_json(&self) -> Value {
json!({
"algorithm": Algorithm::MegolmV1AesSha2,
"room_id": &*self.room_id,
"session_id": &*self.session_id,
"session_key": self.session_key().await,
"chain_index": self.message_index().await,
})
}
}
// #[cfg_attr(tarpaulin, skip)]
impl std::fmt::Debug for OutboundGroupSession {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("OutboundGroupSession")
.field("session_id", &self.session_id)
.field("room_id", &self.room_id)
.field("creation_time", &self.creation_time)
.field("message_count", &self.message_count)
.finish()
}
}

View File

@ -0,0 +1,208 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
mod account;
mod group_sessions;
mod session;
pub use account::{Account, IdentityKeys};
pub use group_sessions::{GroupSessionKey, InboundGroupSession, OutboundGroupSession};
pub use session::{OlmMessage, Session};
#[cfg(test)]
pub(crate) mod test {
use crate::olm::{Account, InboundGroupSession, Session};
use matrix_sdk_common::api::r0::keys::SignedKey;
use matrix_sdk_common::identifiers::{DeviceId, RoomId, UserId};
use olm_rs::session::OlmMessage;
use std::collections::BTreeMap;
use std::convert::TryFrom;
fn alice_id() -> UserId {
UserId::try_from("@alice:example.org").unwrap()
}
fn alice_device_id() -> Box<DeviceId> {
"ALICEDEVICE".into()
}
fn bob_id() -> UserId {
UserId::try_from("@bob:example.org").unwrap()
}
fn bob_device_id() -> Box<DeviceId> {
"BOBDEVICE".into()
}
pub(crate) async fn get_account_and_session() -> (Account, Session) {
let alice = Account::new(&alice_id(), &alice_device_id());
let bob = Account::new(&bob_id(), &bob_device_id());
bob.generate_one_time_keys_helper(1).await;
let one_time_key = bob
.one_time_keys()
.await
.curve25519()
.iter()
.next()
.unwrap()
.1
.to_owned();
let one_time_key = SignedKey {
key: one_time_key,
signatures: BTreeMap::new(),
};
let sender_key = bob.identity_keys().curve25519().to_owned();
let session = alice
.create_outbound_session_helper(&sender_key, &one_time_key)
.await
.unwrap();
(alice, session)
}
#[test]
fn account_creation() {
let account = Account::new(&alice_id(), &alice_device_id());
let identyty_keys = account.identity_keys();
assert!(!account.shared());
assert!(!identyty_keys.ed25519().is_empty());
assert_ne!(identyty_keys.values().len(), 0);
assert_ne!(identyty_keys.keys().len(), 0);
assert_ne!(identyty_keys.iter().len(), 0);
assert!(identyty_keys.contains_key("ed25519"));
assert_eq!(
identyty_keys.ed25519(),
identyty_keys.get("ed25519").unwrap()
);
assert!(!identyty_keys.curve25519().is_empty());
account.mark_as_shared();
assert!(account.shared());
}
#[tokio::test]
async fn one_time_keys_creation() {
let account = Account::new(&alice_id(), &alice_device_id());
let one_time_keys = account.one_time_keys().await;
assert!(one_time_keys.curve25519().is_empty());
assert_ne!(account.max_one_time_keys().await, 0);
account.generate_one_time_keys_helper(10).await;
let one_time_keys = account.one_time_keys().await;
assert!(!one_time_keys.curve25519().is_empty());
assert_ne!(one_time_keys.values().len(), 0);
assert_ne!(one_time_keys.keys().len(), 0);
assert_ne!(one_time_keys.iter().len(), 0);
assert!(one_time_keys.contains_key("curve25519"));
assert_eq!(one_time_keys.curve25519().keys().len(), 10);
assert_eq!(
one_time_keys.curve25519(),
one_time_keys.get("curve25519").unwrap()
);
account.mark_keys_as_published().await;
let one_time_keys = account.one_time_keys().await;
assert!(one_time_keys.curve25519().is_empty());
}
#[tokio::test]
async fn session_creation() {
let alice = Account::new(&alice_id(), &alice_device_id());
let bob = Account::new(&bob_id(), &bob_device_id());
let alice_keys = alice.identity_keys();
alice.generate_one_time_keys_helper(1).await;
let one_time_keys = alice.one_time_keys().await;
alice.mark_keys_as_published().await;
let one_time_key = one_time_keys
.curve25519()
.iter()
.next()
.unwrap()
.1
.to_owned();
let one_time_key = SignedKey {
key: one_time_key,
signatures: BTreeMap::new(),
};
let mut bob_session = bob
.create_outbound_session_helper(alice_keys.curve25519(), &one_time_key)
.await
.unwrap();
let plaintext = "Hello world";
let message = bob_session.encrypt_helper(plaintext).await;
let prekey_message = match message.clone() {
OlmMessage::PreKey(m) => m,
OlmMessage::Message(_) => panic!("Incorrect message type"),
};
let bob_keys = bob.identity_keys();
let mut alice_session = alice
.create_inbound_session(bob_keys.curve25519(), prekey_message.clone())
.await
.unwrap();
assert!(alice_session
.matches(bob_keys.curve25519(), prekey_message)
.await
.unwrap());
assert_eq!(bob_session.session_id(), alice_session.session_id());
let decyrpted = alice_session.decrypt(message).await.unwrap();
assert_eq!(plaintext, decyrpted);
}
#[tokio::test]
async fn group_session_creation() {
let alice = Account::new(&alice_id(), &alice_device_id());
let room_id = RoomId::try_from("!test:localhost").unwrap();
let (outbound, _) = alice.create_group_session_pair(&room_id).await;
assert_eq!(0, outbound.message_index().await);
assert!(!outbound.shared());
outbound.mark_as_shared();
assert!(outbound.shared());
let inbound = InboundGroupSession::new(
"test_key",
"test_key",
&room_id,
outbound.session_key().await,
)
.unwrap();
assert_eq!(0, inbound.first_known_index().await);
assert_eq!(outbound.session_id(), inbound.session_id());
let plaintext = "This is a secret to everybody".to_owned();
let ciphertext = outbound.encrypt_helper(plaintext.clone()).await;
assert_eq!(
plaintext,
inbound.decrypt_helper(ciphertext).await.unwrap().0
);
}
}

View File

@ -0,0 +1,246 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::BTreeMap;
use std::fmt;
use std::sync::Arc;
use olm_rs::errors::OlmSessionError;
use olm_rs::session::OlmSession;
use olm_rs::PicklingMode;
use serde_json::{json, Value};
pub use olm_rs::{
session::{OlmMessage, PreKeyMessage},
utility::OlmUtility,
};
use super::IdentityKeys;
use crate::error::{EventError, OlmResult};
use crate::Device;
use matrix_sdk_common::{
api::r0::keys::KeyAlgorithm,
events::{
room::encrypted::{CiphertextInfo, EncryptedEventContent, OlmV1Curve25519AesSha2Content},
EventType,
},
identifiers::{DeviceId, UserId},
instant::Instant,
locks::Mutex,
};
/// Cryptographic session that enables secure communication between two
/// `Account`s
#[derive(Clone)]
pub struct Session {
pub(crate) user_id: Arc<UserId>,
pub(crate) device_id: Arc<Box<DeviceId>>,
pub(crate) our_identity_keys: Arc<IdentityKeys>,
pub(crate) inner: Arc<Mutex<OlmSession>>,
pub(crate) session_id: Arc<String>,
pub(crate) sender_key: Arc<String>,
pub(crate) creation_time: Arc<Instant>,
pub(crate) last_use_time: Arc<Instant>,
}
// #[cfg_attr(tarpaulin, skip)]
impl fmt::Debug for Session {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Session")
.field("session_id", &self.session_id())
.field("sender_key", &self.sender_key)
.finish()
}
}
impl Session {
/// Decrypt the given Olm message.
///
/// Returns the decrypted plaintext or an `OlmSessionError` if decryption
/// failed.
///
/// # Arguments
///
/// * `message` - The Olm message that should be decrypted.
pub async fn decrypt(&mut self, message: OlmMessage) -> Result<String, OlmSessionError> {
let plaintext = self.inner.lock().await.decrypt(message)?;
self.last_use_time = Arc::new(Instant::now());
Ok(plaintext)
}
/// Encrypt the given plaintext as a OlmMessage.
///
/// Returns the encrypted Olm message.
///
/// # Arguments
///
/// * `plaintext` - The plaintext that should be encrypted.
pub(crate) async fn encrypt_helper(&mut self, plaintext: &str) -> OlmMessage {
let message = self.inner.lock().await.encrypt(plaintext);
self.last_use_time = Arc::new(Instant::now());
message
}
/// Encrypt the given event content content as an m.room.encrypted event
/// content.
///
/// # Arguments
///
/// * `recipient_device` - The device for which this message is going to be
/// encrypted, this needs to be the device that was used to create this
/// session with.
///
/// * `event_type` - The type of the event.
///
/// * `content` - The content of the event.
pub async fn encrypt(
&mut self,
recipient_device: &Device,
event_type: EventType,
content: Value,
) -> OlmResult<EncryptedEventContent> {
let recipient_signing_key = recipient_device
.get_key(KeyAlgorithm::Ed25519)
.ok_or(EventError::MissingSigningKey)?;
let payload = json!({
"sender": self.user_id.as_str(),
"sender_device": self.device_id.as_ref(),
"keys": {
"ed25519": self.our_identity_keys.ed25519(),
},
"recipient": recipient_device.user_id(),
"recipient_keys": {
"ed25519": recipient_signing_key,
},
"type": event_type,
"content": content,
});
let plaintext = cjson::to_string(&payload)
.unwrap_or_else(|_| panic!(format!("Can't serialize {} to canonical JSON", payload)));
let ciphertext = self.encrypt_helper(&plaintext).await.to_tuple();
let message_type = ciphertext.0;
let ciphertext = CiphertextInfo::new(ciphertext.1, (message_type as u32).into());
let mut content = BTreeMap::new();
content.insert((&*self.sender_key).to_owned(), ciphertext);
Ok(EncryptedEventContent::OlmV1Curve25519AesSha2(
OlmV1Curve25519AesSha2Content::new(
content,
self.our_identity_keys.curve25519().to_string(),
),
))
}
/// Check if a pre-key Olm message was encrypted for this session.
///
/// Returns true if it matches, false if not and a OlmSessionError if there
/// was an error checking if it matches.
///
/// # Arguments
///
/// * `their_identity_key` - The identity/curve25519 key of the account
/// that encrypted this Olm message.
///
/// * `message` - The pre-key Olm message that should be checked.
pub async fn matches(
&self,
their_identity_key: &str,
message: PreKeyMessage,
) -> Result<bool, OlmSessionError> {
self.inner
.lock()
.await
.matches_inbound_session_from(their_identity_key, message)
}
/// Returns the unique identifier for this session.
pub fn session_id(&self) -> &str {
&self.session_id
}
/// Store the session as a base64 encoded string.
///
/// # Arguments
///
/// * `pickle_mode` - The mode that was used to pickle the session, either
/// an unencrypted mode or an encrypted using passphrase.
pub async fn pickle(&self, pickle_mode: PicklingMode) -> String {
self.inner.lock().await.pickle(pickle_mode)
}
/// Restore a Session from a previously pickled string.
///
/// Returns the restored Olm Session or a `OlmSessionError` if there was an
/// error.
///
/// # Arguments
///
/// * `user_id` - Our own user id that the session belongs to.
///
/// * `device_id` - Our own device id that the session belongs to.
///
/// * `our_idenity_keys` - An clone of the Arc to our own identity keys.
///
/// * `pickle` - The pickled string of the session.
///
/// * `pickle_mode` - The mode that was used to pickle the session, either
/// an unencrypted mode or an encrypted using passphrase.
///
/// * `sender_key` - The public curve25519 key of the account that
/// established the session with us.
///
/// * `creation_time` - The timestamp that marks when the session was
/// created.
///
/// * `last_use_time` - The timestamp that marks when the session was
/// last used to encrypt or decrypt an Olm message.
#[allow(clippy::too_many_arguments)]
pub fn from_pickle(
user_id: Arc<UserId>,
device_id: Arc<Box<DeviceId>>,
our_identity_keys: Arc<IdentityKeys>,
pickle: String,
pickle_mode: PicklingMode,
sender_key: String,
creation_time: Instant,
last_use_time: Instant,
) -> Result<Self, OlmSessionError> {
let session = OlmSession::unpickle(pickle, pickle_mode)?;
let session_id = session.session_id();
Ok(Session {
user_id,
device_id,
our_identity_keys,
inner: Arc::new(Mutex::new(session)),
session_id: Arc::new(session_id),
sender_key: Arc::new(sender_key),
creation_time: Arc::new(creation_time),
last_use_time: Arc::new(last_use_time),
})
}
}
impl PartialEq for Session {
fn eq(&self, other: &Self) -> bool {
self.session_id() == other.session_id()
}
}

View File

@ -200,8 +200,8 @@ mod test {
let user_devices = store.get_user_devices(device.user_id()).await.unwrap(); let user_devices = store.get_user_devices(device.user_id()).await.unwrap();
assert_eq!(user_devices.keys().nth(0).unwrap(), device.device_id()); assert_eq!(user_devices.keys().next().unwrap(), device.device_id());
assert_eq!(user_devices.devices().nth(0).unwrap(), &device); assert_eq!(user_devices.devices().next().unwrap(), &device);
let loaded_device = user_devices.get(device.device_id()).unwrap(); let loaded_device = user_devices.get(device.device_id()).unwrap();

View File

@ -87,7 +87,7 @@ pub enum CryptoStoreError {
pub type Result<T> = std::result::Result<T, CryptoStoreError>; pub type Result<T> = std::result::Result<T, CryptoStoreError>;
#[async_trait] #[async_trait]
#[warn(clippy::type_complexity)] #[allow(clippy::type_complexity)]
#[cfg_attr(not(target_arch = "wasm32"), send_sync)] #[cfg_attr(not(target_arch = "wasm32"), send_sync)]
/// Trait abstracting a store that the `OlmMachine` uses to store cryptographic /// Trait abstracting a store that the `OlmMachine` uses to store cryptographic
/// keys. /// keys.

View File

@ -26,9 +26,10 @@ use olm_rs::PicklingMode;
use sqlx::{query, query_as, sqlite::SqliteQueryAs, Connect, Executor, SqliteConnection}; use sqlx::{query, query_as, sqlite::SqliteQueryAs, Connect, Executor, SqliteConnection};
use zeroize::Zeroizing; use zeroize::Zeroizing;
use super::{Account, CryptoStore, CryptoStoreError, InboundGroupSession, Result, Session}; use super::{CryptoStore, CryptoStoreError, Result};
use crate::device::{Device, TrustState}; use crate::device::{Device, TrustState};
use crate::memory_stores::{DeviceStore, GroupSessionStore, SessionStore, UserDevices}; use crate::memory_stores::{DeviceStore, GroupSessionStore, SessionStore, UserDevices};
use crate::{Account, IdentityKeys, InboundGroupSession, Session};
use matrix_sdk_common::api::r0::keys::{AlgorithmAndDeviceId, KeyAlgorithm}; use matrix_sdk_common::api::r0::keys::{AlgorithmAndDeviceId, KeyAlgorithm};
use matrix_sdk_common::events::Algorithm; use matrix_sdk_common::events::Algorithm;
use matrix_sdk_common::identifiers::{DeviceId, RoomId, UserId}; use matrix_sdk_common::identifiers::{DeviceId, RoomId, UserId};
@ -36,8 +37,8 @@ use matrix_sdk_common::identifiers::{DeviceId, RoomId, UserId};
/// SQLite based implementation of a `CryptoStore`. /// SQLite based implementation of a `CryptoStore`.
pub struct SqliteStore { pub struct SqliteStore {
user_id: Arc<UserId>, user_id: Arc<UserId>,
device_id: Arc<String>, device_id: Arc<Box<DeviceId>>,
account_id: Option<i64>, account_info: Option<AccountInfo>,
path: PathBuf, path: PathBuf,
sessions: SessionStore, sessions: SessionStore,
@ -50,6 +51,11 @@ pub struct SqliteStore {
pickle_passphrase: Option<Zeroizing<String>>, pickle_passphrase: Option<Zeroizing<String>>,
} }
struct AccountInfo {
account_id: i64,
identity_keys: Arc<IdentityKeys>,
}
static DATABASE_NAME: &str = "matrix-sdk-crypto.db"; static DATABASE_NAME: &str = "matrix-sdk-crypto.db";
impl SqliteStore { impl SqliteStore {
@ -66,7 +72,7 @@ impl SqliteStore {
/// * `path` - The path where the database file should reside in. /// * `path` - The path where the database file should reside in.
pub async fn open<P: AsRef<Path>>( pub async fn open<P: AsRef<Path>>(
user_id: &UserId, user_id: &UserId,
device_id: &str, device_id: &DeviceId,
path: P, path: P,
) -> Result<SqliteStore> { ) -> Result<SqliteStore> {
SqliteStore::open_helper(user_id, device_id, path, None).await SqliteStore::open_helper(user_id, device_id, path, None).await
@ -88,7 +94,7 @@ impl SqliteStore {
/// the encryption keys. /// the encryption keys.
pub async fn open_with_passphrase<P: AsRef<Path>>( pub async fn open_with_passphrase<P: AsRef<Path>>(
user_id: &UserId, user_id: &UserId,
device_id: &str, device_id: &DeviceId,
path: P, path: P,
passphrase: &str, passphrase: &str,
) -> Result<SqliteStore> { ) -> Result<SqliteStore> {
@ -109,7 +115,7 @@ impl SqliteStore {
async fn open_helper<P: AsRef<Path>>( async fn open_helper<P: AsRef<Path>>(
user_id: &UserId, user_id: &UserId,
device_id: &str, device_id: &DeviceId,
path: P, path: P,
passphrase: Option<Zeroizing<String>>, passphrase: Option<Zeroizing<String>>,
) -> Result<SqliteStore> { ) -> Result<SqliteStore> {
@ -118,8 +124,8 @@ impl SqliteStore {
let connection = SqliteConnection::connect(url.as_ref()).await?; let connection = SqliteConnection::connect(url.as_ref()).await?;
let store = SqliteStore { let store = SqliteStore {
user_id: Arc::new(user_id.to_owned()), user_id: Arc::new(user_id.to_owned()),
device_id: Arc::new(device_id.to_owned()), device_id: Arc::new(device_id.into()),
account_id: None, account_info: None,
sessions: SessionStore::new(), sessions: SessionStore::new(),
inbound_group_sessions: GroupSessionStore::new(), inbound_group_sessions: GroupSessionStore::new(),
devices: DeviceStore::new(), devices: DeviceStore::new(),
@ -133,6 +139,10 @@ impl SqliteStore {
Ok(store) Ok(store)
} }
fn account_id(&self) -> Option<i64> {
self.account_info.as_ref().map(|i| i.account_id)
}
async fn create_tables(&self) -> Result<()> { async fn create_tables(&self) -> Result<()> {
let mut connection = self.connection.lock().await; let mut connection = self.connection.lock().await;
connection connection
@ -262,6 +272,25 @@ impl SqliteStore {
) )
.await?; .await?;
connection
.execute(
r#"
CREATE TABLE IF NOT EXISTS device_signatures (
"id" INTEGER NOT NULL PRIMARY KEY,
"device_id" INTEGER NOT NULL,
"user_id" TEXT NOT NULL,
"key_algorithm" TEXT NOT NULL,
"signature" TEXT NOT NULL,
FOREIGN KEY ("device_id") REFERENCES "devices" ("id")
ON DELETE CASCADE
UNIQUE(device_id, user_id, key_algorithm)
);
CREATE INDEX IF NOT EXISTS "device_keys_device_id" ON "device_keys" ("device_id");
"#,
)
.await?;
Ok(()) Ok(())
} }
@ -288,14 +317,17 @@ impl SqliteStore {
} }
async fn load_sessions_for(&mut self, sender_key: &str) -> Result<Vec<Session>> { async fn load_sessions_for(&mut self, sender_key: &str) -> Result<Vec<Session>> {
let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?; let account_info = self
.account_info
.as_ref()
.ok_or(CryptoStoreError::AccountUnset)?;
let mut connection = self.connection.lock().await; let mut connection = self.connection.lock().await;
let rows: Vec<(String, String, String, String)> = query_as( let rows: Vec<(String, String, String, String)> = query_as(
"SELECT pickle, sender_key, creation_time, last_use_time "SELECT pickle, sender_key, creation_time, last_use_time
FROM sessions WHERE account_id = ? and sender_key = ?", FROM sessions WHERE account_id = ? and sender_key = ?",
) )
.bind(account_id) .bind(account_info.account_id)
.bind(sender_key) .bind(sender_key)
.fetch_all(&mut *connection) .fetch_all(&mut *connection)
.await?; .await?;
@ -315,6 +347,9 @@ impl SqliteStore {
.ok_or(CryptoStoreError::SessionTimestampError)?; .ok_or(CryptoStoreError::SessionTimestampError)?;
Ok(Session::from_pickle( Ok(Session::from_pickle(
self.user_id.clone(),
self.device_id.clone(),
account_info.identity_keys.clone(),
pickle.to_string(), pickle.to_string(),
self.get_pickle_mode(), self.get_pickle_mode(),
sender_key.to_string(), sender_key.to_string(),
@ -326,7 +361,7 @@ impl SqliteStore {
} }
async fn load_inbound_group_sessions(&self) -> Result<Vec<InboundGroupSession>> { async fn load_inbound_group_sessions(&self) -> Result<Vec<InboundGroupSession>> {
let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?; let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?;
let mut connection = self.connection.lock().await; let mut connection = self.connection.lock().await;
let rows: Vec<(String, String, String, String)> = query_as( let rows: Vec<(String, String, String, String)> = query_as(
@ -357,7 +392,7 @@ impl SqliteStore {
} }
async fn save_tracked_user(&self, user: &UserId, dirty: bool) -> Result<()> { async fn save_tracked_user(&self, user: &UserId, dirty: bool) -> Result<()> {
let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?; let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?;
let mut connection = self.connection.lock().await; let mut connection = self.connection.lock().await;
query( query(
@ -378,7 +413,7 @@ impl SqliteStore {
} }
async fn load_tracked_users(&self) -> Result<(HashSet<UserId>, HashSet<UserId>)> { async fn load_tracked_users(&self) -> Result<(HashSet<UserId>, HashSet<UserId>)> {
let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?; let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?;
let mut connection = self.connection.lock().await; let mut connection = self.connection.lock().await;
let rows: Vec<(String, bool)> = query_as( let rows: Vec<(String, bool)> = query_as(
@ -410,7 +445,7 @@ impl SqliteStore {
} }
async fn load_devices(&self) -> Result<DeviceStore> { async fn load_devices(&self) -> Result<DeviceStore> {
let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?; let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?;
let mut connection = self.connection.lock().await; let mut connection = self.connection.lock().await;
let rows: Vec<(i64, String, String, Option<String>, i64)> = query_as( let rows: Vec<(i64, String, String, Option<String>, i64)> = query_as(
@ -456,31 +491,64 @@ impl SqliteStore {
.fetch_all(&mut *connection) .fetch_all(&mut *connection)
.await?; .await?;
let mut keys = BTreeMap::new(); let keys: BTreeMap<AlgorithmAndDeviceId, String> = key_rows
.into_iter()
.filter_map(|row| {
let algorithm = KeyAlgorithm::try_from(&*row.0).ok()?;
let key = row.1;
for row in key_rows { Some((
let algorithm: &str = &row.0; AlgorithmAndDeviceId(algorithm, device_id.as_str().into()),
let algorithm = if let Ok(a) = KeyAlgorithm::try_from(algorithm) { key,
a ))
})
.collect();
let signature_rows: Vec<(String, String, String)> = query_as(
"SELECT user_id, key_algorithm, signature
FROM device_signatures WHERE device_id = ?",
)
.bind(device_row_id)
.fetch_all(&mut *connection)
.await?;
let mut signatures: BTreeMap<UserId, BTreeMap<AlgorithmAndDeviceId, String>> =
BTreeMap::new();
for row in signature_rows {
let user_id = if let Ok(u) = UserId::try_from(&*row.0) {
u
} else { } else {
continue; continue;
}; };
let key = &row.1; let key_algorithm = if let Ok(k) = KeyAlgorithm::try_from(&*row.1) {
k
} else {
continue;
};
keys.insert( let signature = row.2;
AlgorithmAndDeviceId(algorithm, device_id.clone()),
key.to_owned(), if !signatures.contains_key(&user_id) {
let _ = signatures.insert(user_id.clone(), BTreeMap::new());
}
let user_map = signatures.get_mut(&user_id).unwrap();
user_map.insert(
AlgorithmAndDeviceId(key_algorithm, device_id.as_str().into()),
signature.to_owned(),
); );
} }
let device = Device::new( let device = Device::new(
user_id, user_id,
device_id.to_owned(), device_id.as_str().into(),
display_name.clone(), display_name.clone(),
trust_state, trust_state,
algorithms, algorithms,
keys, keys,
signatures,
); );
store.add(device); store.add(device);
@ -490,7 +558,7 @@ impl SqliteStore {
} }
async fn save_device_helper(&self, device: Device) -> Result<()> { async fn save_device_helper(&self, device: Device) -> Result<()> {
let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?; let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?;
let mut connection = self.connection.lock().await; let mut connection = self.connection.lock().await;
@ -550,6 +618,23 @@ impl SqliteStore {
.await?; .await?;
} }
for (user_id, signature_map) in device.signatures() {
for (key_id, signature) in signature_map {
query(
"INSERT OR IGNORE INTO device_signatures (
device_id, user_id, key_algorithm, signature
) VALUES (?1, ?2, ?3, ?4)
",
)
.bind(device_row_id)
.bind(user_id.as_str())
.bind(key_id.0.to_string())
.bind(signature)
.execute(&mut *connection)
.await?;
}
}
Ok(()) Ok(())
} }
@ -573,20 +658,26 @@ impl CryptoStore for SqliteStore {
WHERE user_id = ? and device_id = ?", WHERE user_id = ? and device_id = ?",
) )
.bind(self.user_id.as_str()) .bind(self.user_id.as_str())
.bind(&*self.device_id) .bind((&*self.device_id).as_ref())
.fetch_optional(&mut *connection) .fetch_optional(&mut *connection)
.await?; .await?;
let result = if let Some((id, pickle, shared, uploaded_key_count)) = row { let result = if let Some((id, pickle, shared, uploaded_key_count)) = row {
self.account_id = Some(id); let account = Account::from_pickle(
Some(Account::from_pickle(
pickle, pickle,
self.get_pickle_mode(), self.get_pickle_mode(),
shared, shared,
uploaded_key_count, uploaded_key_count,
&self.user_id, &self.user_id,
&self.device_id, &self.device_id,
)?) )?;
self.account_info = Some(AccountInfo {
account_id: id,
identity_keys: account.identity_keys.clone(),
});
Some(account)
} else { } else {
return Ok(None); return Ok(None);
}; };
@ -640,19 +731,22 @@ impl CryptoStore for SqliteStore {
.fetch_one(&mut *connection) .fetch_one(&mut *connection)
.await?; .await?;
self.account_id = Some(account_id.0); self.account_info = Some(AccountInfo {
account_id: account_id.0,
identity_keys: account.identity_keys.clone(),
});
Ok(()) Ok(())
} }
async fn save_sessions(&mut self, sessions: &[Session]) -> Result<()> { async fn save_sessions(&mut self, sessions: &[Session]) -> Result<()> {
let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?;
// TODO turn this into a transaction // TODO turn this into a transaction
for session in sessions { for session in sessions {
self.lazy_load_sessions(&session.sender_key).await?; self.lazy_load_sessions(&session.sender_key).await?;
self.sessions.add(session.clone()).await; self.sessions.add(session.clone()).await;
let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?;
let session_id = session.session_id(); let session_id = session.session_id();
let creation_time = serde_json::to_string(&session.creation_time.elapsed())?; let creation_time = serde_json::to_string(&session.creation_time.elapsed())?;
let last_use_time = serde_json::to_string(&session.last_use_time.elapsed())?; let last_use_time = serde_json::to_string(&session.last_use_time.elapsed())?;
@ -683,7 +777,7 @@ impl CryptoStore for SqliteStore {
} }
async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result<bool> { async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result<bool> {
let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?; let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?;
let pickle = session.pickle(self.get_pickle_mode()).await; let pickle = session.pickle(self.get_pickle_mode()).await;
let mut connection = self.connection.lock().await; let mut connection = self.connection.lock().await;
let session_id = session.session_id(); let session_id = session.session_id();
@ -753,7 +847,7 @@ impl CryptoStore for SqliteStore {
} }
async fn delete_device(&self, device: Device) -> Result<()> { async fn delete_device(&self, device: Device) -> Result<()> {
let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?; let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?;
let mut connection = self.connection.lock().await; let mut connection = self.connection.lock().await;
query( query(
@ -804,7 +898,7 @@ mod test {
use super::{Account, CryptoStore, InboundGroupSession, RoomId, Session, SqliteStore, TryFrom}; use super::{Account, CryptoStore, InboundGroupSession, RoomId, Session, SqliteStore, TryFrom};
static USER_ID: &str = "@example:localhost"; static USER_ID: &str = "@example:localhost";
static DEVICE_ID: &str = "DEVICEID"; static DEVICE_ID: &DeviceId = "DEVICEID";
async fn get_store(passphrase: Option<&str>) -> (SqliteStore, tempfile::TempDir) { async fn get_store(passphrase: Option<&str>) -> (SqliteStore, tempfile::TempDir) {
let tmpdir = tempdir().unwrap(); let tmpdir = tempdir().unwrap();
@ -840,16 +934,16 @@ mod test {
UserId::try_from("@alice:example.org").unwrap() UserId::try_from("@alice:example.org").unwrap()
} }
fn alice_device_id() -> DeviceId { fn alice_device_id() -> Box<DeviceId> {
"ALICEDEVICE".to_string() "ALICEDEVICE".into()
} }
fn bob_id() -> UserId { fn bob_id() -> UserId {
UserId::try_from("@bob:example.org").unwrap() UserId::try_from("@bob:example.org").unwrap()
} }
fn bob_device_id() -> DeviceId { fn bob_device_id() -> Box<DeviceId> {
"BOBDEVICE".to_string() "BOBDEVICE".into()
} }
fn get_account() -> Account { fn get_account() -> Account {
@ -866,7 +960,7 @@ mod test {
.await .await
.curve25519() .curve25519()
.iter() .iter()
.nth(0) .next()
.unwrap() .unwrap()
.1 .1
.to_owned(); .to_owned();
@ -876,7 +970,7 @@ mod test {
}; };
let sender_key = bob.identity_keys().curve25519().to_owned(); let sender_key = bob.identity_keys().curve25519().to_owned();
let session = alice let session = alice
.create_outbound_session(&sender_key, &one_time_key) .create_outbound_session_helper(&sender_key, &one_time_key)
.await .await
.unwrap(); .unwrap();
@ -1165,8 +1259,8 @@ mod test {
assert_eq!(device.keys(), loaded_device.keys()); assert_eq!(device.keys(), loaded_device.keys());
let user_devices = store.get_user_devices(device.user_id()).await.unwrap(); let user_devices = store.get_user_devices(device.user_id()).await.unwrap();
assert_eq!(user_devices.keys().nth(0).unwrap(), device.device_id()); assert_eq!(user_devices.keys().next().unwrap(), device.device_id());
assert_eq!(user_devices.devices().nth(0).unwrap(), &device); assert_eq!(user_devices.devices().next().unwrap(), &device);
} }
#[tokio::test] #[tokio::test]

View File

@ -1,8 +1,8 @@
use crate::Device; use crate::Device;
use matrix_sdk_common::events::key::verification::{ use matrix_sdk_common::events::key::verification::{
start::{StartEvent, StartEventContent},
accept::AcceptEvent, accept::AcceptEvent,
start::{StartEvent, StartEventContent},
HashAlgorithm, KeyAgreementProtocol, MessageAuthenticationCode, ShortAuthenticationString, HashAlgorithm, KeyAgreementProtocol, MessageAuthenticationCode, ShortAuthenticationString,
VerificationMethod, VerificationMethod,
}; };
@ -11,7 +11,7 @@ use matrix_sdk_common::uuid::Uuid;
struct SasIds { struct SasIds {
own_user_id: UserId, own_user_id: UserId,
own_device_id: DeviceId, own_device_id: Box<DeviceId>,
other_device: Device, other_device: Device,
} }
@ -27,7 +27,7 @@ struct AcceptedProtocols {
key_agreement_protocol: KeyAgreementProtocol, key_agreement_protocol: KeyAgreementProtocol,
hash: HashAlgorithm, hash: HashAlgorithm,
message_auth_code: MessageAuthenticationCode, message_auth_code: MessageAuthenticationCode,
short_auth_string: Vec<ShortAuthenticationString> short_auth_string: Vec<ShortAuthenticationString>,
} }
struct Sas<S> { struct Sas<S> {
@ -38,11 +38,11 @@ struct Sas<S> {
} }
impl Sas<Created> { impl Sas<Created> {
fn new(own_user_id: UserId, own_device_id: DeviceId, other_device: Device) -> Sas<Created> { fn new(own_user_id: UserId, own_device_id: &DeviceId, other_device: Device) -> Sas<Created> {
Sas { Sas {
ids: SasIds { ids: SasIds {
own_user_id, own_user_id,
own_device_id, own_device_id: own_device_id.into(),
other_device, other_device,
}, },
verification_flow_id: Uuid::new_v4(), verification_flow_id: Uuid::new_v4(),
@ -76,7 +76,7 @@ impl Sas<Created> {
key_agreement_protocol: content.key_agreement_protocol, key_agreement_protocol: content.key_agreement_protocol,
message_auth_code: content.message_authentication_code, message_auth_code: content.message_authentication_code,
short_auth_string: content.short_authentication_string.clone(), short_auth_string: content.short_authentication_string.clone(),
} },
}, },
} }
} }
@ -89,7 +89,7 @@ struct Started {}
impl Sas<Started> { impl Sas<Started> {
fn from_start_event( fn from_start_event(
own_user_id: UserId, own_user_id: UserId,
own_device_id: DeviceId, own_device_id: &DeviceId,
other_device: Device, other_device: Device,
event: &StartEvent, event: &StartEvent,
) -> Sas<Started> { ) -> Sas<Started> {
@ -102,7 +102,7 @@ impl Sas<Started> {
Sas { Sas {
ids: SasIds { ids: SasIds {
own_user_id, own_user_id,
own_device_id, own_device_id: own_device_id.into(),
other_device, other_device,
}, },
verification_flow_id: Uuid::new_v4(), verification_flow_id: Uuid::new_v4(),

View File

@ -16,4 +16,4 @@ http = "0.2.1"
matrix-sdk-common = { version = "0.1.0", path = "../matrix_sdk_common" } matrix-sdk-common = { version = "0.1.0", path = "../matrix_sdk_common" }
matrix-sdk-test-macros = { version = "0.1.0", path = "../matrix_sdk_test_macros" } matrix-sdk-test-macros = { version = "0.1.0", path = "../matrix_sdk_test_macros" }
lazy_static = "1.4.0" lazy_static = "1.4.0"
serde = "1.0.111" serde = "1.0.114"

View File

@ -6,8 +6,8 @@ use http::Response;
use matrix_sdk_common::api::r0::sync::sync_events::Response as SyncResponse; use matrix_sdk_common::api::r0::sync::sync_events::Response as SyncResponse;
use matrix_sdk_common::events::{ use matrix_sdk_common::events::{
presence::PresenceEvent, AnyBasicEvent, AnyEphemeralRoomEventStub, AnyRoomEventStub, presence::PresenceEvent, AnyBasicEvent, AnySyncEphemeralRoomEvent, AnySyncRoomEvent,
AnyStateEventStub, AnySyncStateEvent,
}; };
use matrix_sdk_common::identifiers::RoomId; use matrix_sdk_common::identifiers::RoomId;
use serde_json::Value as JsonValue; use serde_json::Value as JsonValue;
@ -26,6 +26,7 @@ pub enum EventsJson {
HistoryVisibility, HistoryVisibility,
JoinRules, JoinRules,
Member, Member,
MemberNameChange,
MessageEmote, MessageEmote,
MessageNotice, MessageNotice,
MessageText, MessageText,
@ -42,21 +43,54 @@ pub enum EventsJson {
Typing, Typing,
} }
/// Easily create events to stream into either a Client or a `Room` for testing. /// The `EventBuilder` struct can be used to easily generate valid sync responses for testing.
/// These can be then fed into either `Client` or `Room`.
///
/// It supports generated a number of canned events, such as a member entering a room, his power
/// level and display name changing and similar. It also supports insertion of custom events in the
/// form of `EventsJson` values.
///
/// **Important** You *must* use the *same* builder when sending multiple sync responses to
/// a single client. Otherwise, the subsequent responses will be *ignored* by the client because
/// the `next_batch` sync token will not be rotated properly.
///
/// # Example usage
///
/// ```rust
/// use matrix_sdk_test::{EventBuilder, EventsJson};
///
/// let mut builder = EventBuilder::new();
///
/// // response1 now contains events that add an example member to the room and change their power
/// // level
/// let response1 = builder
/// .add_room_event(EventsJson::Member)
/// .add_room_event(EventsJson::PowerLevels)
/// .build_sync_response();
///
/// // response2 is now empty (nothing changed)
/// let response2 = builder.build_sync_response();
///
/// // response3 contains a display name change for member example
/// let response3 = builder
/// .add_room_event(EventsJson::MemberNameChange)
/// .build_sync_response();
/// ```
#[derive(Default)] #[derive(Default)]
pub struct EventBuilder { pub struct EventBuilder {
/// The events that determine the state of a `Room`. /// The events that determine the state of a `Room`.
joined_room_events: HashMap<RoomId, Vec<AnyRoomEventStub>>, joined_room_events: HashMap<RoomId, Vec<AnySyncRoomEvent>>,
/// The events that determine the state of a `Room`. /// The events that determine the state of a `Room`.
invited_room_events: HashMap<RoomId, Vec<AnyStateEventStub>>, invited_room_events: HashMap<RoomId, Vec<AnySyncStateEvent>>,
/// The events that determine the state of a `Room`. /// The events that determine the state of a `Room`.
left_room_events: HashMap<RoomId, Vec<AnyRoomEventStub>>, left_room_events: HashMap<RoomId, Vec<AnySyncRoomEvent>>,
/// The presence events that determine the presence state of a `RoomMember`. /// The presence events that determine the presence state of a `RoomMember`.
presence_events: Vec<PresenceEvent>, presence_events: Vec<PresenceEvent>,
/// The state events that determine the state of a `Room`. /// The state events that determine the state of a `Room`.
state_events: Vec<AnyStateEventStub>, state_events: Vec<AnySyncStateEvent>,
/// The ephemeral room events that determine the state of a `Room`. /// The ephemeral room events that determine the state of a `Room`.
ephemeral: Vec<AnyEphemeralRoomEventStub>, ephemeral: Vec<AnySyncEphemeralRoomEvent>,
/// The account data events that determine the state of a `Room`. /// The account data events that determine the state of a `Room`.
account_data: Vec<AnyBasicEvent>, account_data: Vec<AnyBasicEvent>,
/// Internal counter to enable the `prev_batch` and `next_batch` of each sync response to vary. /// Internal counter to enable the `prev_batch` and `next_batch` of each sync response to vary.
@ -76,7 +110,7 @@ impl EventBuilder {
_ => panic!("unknown ephemeral event {:?}", json), _ => panic!("unknown ephemeral event {:?}", json),
}; };
let event = serde_json::from_value::<AnyEphemeralRoomEventStub>(val.clone()).unwrap(); let event = serde_json::from_value::<AnySyncEphemeralRoomEvent>(val.clone()).unwrap();
self.ephemeral.push(event); self.ephemeral.push(event);
self self
} }
@ -97,11 +131,12 @@ impl EventBuilder {
pub fn add_room_event(&mut self, json: EventsJson) -> &mut Self { pub fn add_room_event(&mut self, json: EventsJson) -> &mut Self {
let val: &JsonValue = match json { let val: &JsonValue = match json {
EventsJson::Member => &test_json::MEMBER, EventsJson::Member => &test_json::MEMBER,
EventsJson::MemberNameChange => &test_json::MEMBER_NAME_CHANGE,
EventsJson::PowerLevels => &test_json::POWER_LEVELS, EventsJson::PowerLevels => &test_json::POWER_LEVELS,
_ => panic!("unknown room event json {:?}", json), _ => panic!("unknown room event json {:?}", json),
}; };
let event = serde_json::from_value::<AnyRoomEventStub>(val.clone()).unwrap(); let event = serde_json::from_value::<AnySyncRoomEvent>(val.clone()).unwrap();
self.add_joined_event( self.add_joined_event(
&RoomId::try_from("!SVkFJHzfwvuaIEawgC:localhost").unwrap(), &RoomId::try_from("!SVkFJHzfwvuaIEawgC:localhost").unwrap(),
@ -115,12 +150,12 @@ impl EventBuilder {
room_id: &RoomId, room_id: &RoomId,
event: serde_json::Value, event: serde_json::Value,
) -> &mut Self { ) -> &mut Self {
let event = serde_json::from_value::<AnyRoomEventStub>(event).unwrap(); let event = serde_json::from_value::<AnySyncRoomEvent>(event).unwrap();
self.add_joined_event(room_id, event); self.add_joined_event(room_id, event);
self self
} }
fn add_joined_event(&mut self, room_id: &RoomId, event: AnyRoomEventStub) { fn add_joined_event(&mut self, room_id: &RoomId, event: AnySyncRoomEvent) {
self.joined_room_events self.joined_room_events
.entry(room_id.clone()) .entry(room_id.clone())
.or_insert_with(Vec::new) .or_insert_with(Vec::new)
@ -132,7 +167,7 @@ impl EventBuilder {
room_id: &RoomId, room_id: &RoomId,
event: serde_json::Value, event: serde_json::Value,
) -> &mut Self { ) -> &mut Self {
let event = serde_json::from_value::<AnyStateEventStub>(event).unwrap(); let event = serde_json::from_value::<AnySyncStateEvent>(event).unwrap();
self.invited_room_events self.invited_room_events
.entry(room_id.clone()) .entry(room_id.clone())
.or_insert_with(Vec::new) .or_insert_with(Vec::new)
@ -145,7 +180,7 @@ impl EventBuilder {
room_id: &RoomId, room_id: &RoomId,
event: serde_json::Value, event: serde_json::Value,
) -> &mut Self { ) -> &mut Self {
let event = serde_json::from_value::<AnyRoomEventStub>(event).unwrap(); let event = serde_json::from_value::<AnySyncRoomEvent>(event).unwrap();
self.left_room_events self.left_room_events
.entry(room_id.clone()) .entry(room_id.clone())
.or_insert_with(Vec::new) .or_insert_with(Vec::new)
@ -164,7 +199,7 @@ impl EventBuilder {
_ => panic!("unknown state event {:?}", json), _ => panic!("unknown state event {:?}", json),
}; };
let event = serde_json::from_value::<AnyStateEventStub>(val.clone()).unwrap(); let event = serde_json::from_value::<AnySyncStateEvent>(val.clone()).unwrap();
self.state_events.push(event); self.state_events.push(event);
self self
} }
@ -181,7 +216,8 @@ impl EventBuilder {
self self
} }
/// Consumes `ResponseBuilder` and returns `SyncResponse`. /// Builds a `SyncResponse` containing the events we queued so far. The next response returned
/// by `build_sync_response` will then be empty if no further events were queued.
pub fn build_sync_response(&mut self) -> SyncResponse { pub fn build_sync_response(&mut self) -> SyncResponse {
let main_room_id = RoomId::try_from("!SVkFJHzfwvuaIEawgC:localhost").unwrap(); let main_room_id = RoomId::try_from("!SVkFJHzfwvuaIEawgC:localhost").unwrap();
@ -293,12 +329,26 @@ impl EventBuilder {
let response = Response::builder() let response = Response::builder()
.body(serde_json::to_vec(&body).unwrap()) .body(serde_json::to_vec(&body).unwrap())
.unwrap(); .unwrap();
// Clear state so that the next sync response will be empty if nothing was added.
self.clear();
SyncResponse::try_from(response).unwrap() SyncResponse::try_from(response).unwrap()
} }
fn generate_sync_token(&self) -> String { fn generate_sync_token(&self) -> String {
format!("t392-516_47314_0_7_1_1_1_11444_{}", self.batch_counter) format!("t392-516_47314_0_7_1_1_1_11444_{}", self.batch_counter)
} }
pub fn clear(&mut self) {
self.account_data.clear();
self.ephemeral.clear();
self.invited_room_events.clear();
self.joined_room_events.clear();
self.left_room_events.clear();
self.presence_events.clear();
self.state_events.clear();
}
} }
/// Embedded sync reponse files /// Embedded sync reponse files

View File

@ -208,6 +208,7 @@ lazy_static! {
}); });
} }
// TODO: Move `prev_content` into `unsigned` once ruma supports it
lazy_static! { lazy_static! {
pub static ref MEMBER: JsonValue = json!({ pub static ref MEMBER: JsonValue = json!({
"content": { "content": {
@ -221,14 +222,40 @@ lazy_static! {
"sender": "@example:localhost", "sender": "@example:localhost",
"state_key": "@example:localhost", "state_key": "@example:localhost",
"type": "m.room.member", "type": "m.room.member",
"unsigned": {
"age": 297036,
"replaces_state": "$151800111315tsynI:localhost",
"prev_content": { "prev_content": {
"avatar_url": null, "avatar_url": null,
"displayname": "example", "displayname": "example",
"membership": "invite" "membership": "invite"
},
"unsigned": {
"age": 297036,
"replaces_state": "$151800111315tsynI:localhost"
} }
});
}
// TODO: Move `prev_content` into `unsigned` once ruma supports it
lazy_static! {
pub static ref MEMBER_NAME_CHANGE: JsonValue = json!({
"content": {
"avatar_url": null,
"displayname": "changed",
"membership": "join"
},
"event_id": "$151800234427abgho:localhost",
"membership": "join",
"origin_server_ts": 151800152,
"sender": "@example:localhost",
"state_key": "@example:localhost",
"type": "m.room.member",
"prev_content": {
"avatar_url": null,
"displayname": "example",
"membership": "join"
},
"unsigned": {
"age": 297032,
"replaces_state": "$151800140517rfvjc:localhost"
} }
}); });
} }
@ -552,6 +579,7 @@ lazy_static! {
}); });
} }
// TODO: Move `prev_content` into `unsigned` once ruma supports it
lazy_static! { lazy_static! {
pub static ref TOPIC: JsonValue = json!({ pub static ref TOPIC: JsonValue = json!({
"content": { "content": {
@ -562,11 +590,11 @@ lazy_static! {
"sender": "@example:localhost", "sender": "@example:localhost",
"state_key": "", "state_key": "",
"type": "m.room.topic", "type": "m.room.topic",
"unsigned": {
"age": 1392989,
"prev_content": { "prev_content": {
"topic": "test" "topic": "test"
}, },
"unsigned": {
"age": 1392989,
"prev_sender": "@example:localhost", "prev_sender": "@example:localhost",
"replaces_state": "$151957069225EVYKm:localhost" "replaces_state": "$151957069225EVYKm:localhost"
} }

View File

@ -9,8 +9,8 @@ pub mod sync;
pub use events::{ pub use events::{
ALIAS, ALIASES, EVENT_ID, KEYS_QUERY, KEYS_UPLOAD, LOGIN, LOGIN_RESPONSE_ERR, LOGOUT, MEMBER, ALIAS, ALIASES, EVENT_ID, KEYS_QUERY, KEYS_UPLOAD, LOGIN, LOGIN_RESPONSE_ERR, LOGOUT, MEMBER,
MESSAGE_EDIT, MESSAGE_TEXT, NAME, POWER_LEVELS, PRESENCE, PUBLIC_ROOMS, REACTION, REDACTED, MEMBER_NAME_CHANGE, MESSAGE_EDIT, MESSAGE_TEXT, NAME, POWER_LEVELS, PRESENCE, PUBLIC_ROOMS,
REDACTED_INVALID, REDACTED_STATE, REDACTION, REGISTRATION_RESPONSE_ERR, ROOM_ID, ROOM_MESSAGES, REACTION, REDACTED, REDACTED_INVALID, REDACTED_STATE, REDACTION, REGISTRATION_RESPONSE_ERR,
TYPING, ROOM_ID, ROOM_MESSAGES, TYPING,
}; };
pub use sync::{DEFAULT_SYNC_SUMMARY, INVITE_SYNC, LEAVE_SYNC, LEAVE_SYNC_EVENT, MORE_SYNC, SYNC}; pub use sync::{DEFAULT_SYNC_SUMMARY, INVITE_SYNC, LEAVE_SYNC, LEAVE_SYNC_EVENT, MORE_SYNC, SYNC};