Merge branch 'master' into sas-longer-flow

master
Damir Jelić 2021-05-18 09:07:50 +02:00
commit 110b8eb8dd
20 changed files with 332 additions and 327 deletions

View File

@ -2393,7 +2393,7 @@ mod test {
api::r0::{ api::r0::{
account::register::Request as RegistrationRequest, account::register::Request as RegistrationRequest,
directory::get_public_rooms_filtered::Request as PublicRoomsFilterRequest, directory::get_public_rooms_filtered::Request as PublicRoomsFilterRequest,
membership::Invite3pid, session::get_login_types::LoginType, uiaa::AuthData, membership::Invite3pidInit, session::get_login_types::LoginType, uiaa::AuthData,
}, },
assign, assign,
directory::Filter, directory::Filter,
@ -2791,12 +2791,15 @@ mod test {
let room = client.get_joined_room(&room_id!("!SVkFJHzfwvuaIEawgC:localhost")).unwrap(); let room = client.get_joined_room(&room_id!("!SVkFJHzfwvuaIEawgC:localhost")).unwrap();
room.invite_user_by_3pid(Invite3pid { room.invite_user_by_3pid(
Invite3pidInit {
id_server: "example.org", id_server: "example.org",
id_access_token: "IdToken", id_access_token: "IdToken",
medium: thirdparty::Medium::Email, medium: thirdparty::Medium::Email,
address: "address", address: "address",
}) }
.into(),
)
.await .await
.unwrap(); .unwrap();
} }
@ -3052,13 +3055,9 @@ mod test {
let room = client.get_joined_room(&room_id).unwrap(); let room = client.get_joined_room(&room_id).unwrap();
let avatar_url = mxc_uri!("mxc://example.org/avA7ar"); let avatar_url = mxc_uri!("mxc://example.org/avA7ar");
let member_event = MemberEventContent { let member_event = assign!(MemberEventContent::new(MembershipState::Join), {
avatar_url: Some(avatar_url), avatar_url: Some(avatar_url)
membership: MembershipState::Join, });
is_direct: None,
displayname: None,
third_party_invite: None,
};
let content = AnyStateEventContent::RoomMember(member_event); let content = AnyStateEventContent::RoomMember(member_event);
let response = room.send_state_event(content, "").await.unwrap(); let response = room.send_state_event(content, "").await.unwrap();
assert_eq!(event_id!("$h29iv0s8:example.com"), response.event_id); assert_eq!(event_id!("$h29iv0s8:example.com"), response.event_id);

View File

@ -245,7 +245,7 @@ impl Common {
pub async fn joined_members_no_sync(&self) -> Result<Vec<RoomMember>> { pub async fn joined_members_no_sync(&self) -> Result<Vec<RoomMember>> {
Ok(self Ok(self
.inner .inner
.members() .joined_members()
.await? .await?
.into_iter() .into_iter()
.map(|member| RoomMember::new(self.client.clone(), member)) .map(|member| RoomMember::new(self.client.clone(), member))

View File

@ -4,8 +4,6 @@ use std::{io::Read, ops::Deref};
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
use matrix_sdk_base::crypto::AttachmentEncryptor; use matrix_sdk_base::crypto::AttachmentEncryptor;
#[cfg(feature = "encryption")]
use matrix_sdk_common::locks::Mutex;
use matrix_sdk_common::{ use matrix_sdk_common::{
api::r0::{ api::r0::{
membership::{ membership::{
@ -36,6 +34,8 @@ use matrix_sdk_common::{
receipt::ReceiptType, receipt::ReceiptType,
uuid::Uuid, uuid::Uuid,
}; };
#[cfg(feature = "encryption")]
use matrix_sdk_common::{events::room::EncryptedFileInit, locks::Mutex};
use mime::{self, Mime}; use mime::{self, Mime};
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
use tracing::instrument; use tracing::instrument;
@ -462,15 +462,18 @@ impl Joined {
let response = self.client.upload(&content_type, &mut reader).await?; let response = self.client.upload(&content_type, &mut reader).await?;
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
let keys = { let keys: Option<Box<EncryptedFile>> = {
let keys = reader.finish(); let keys = reader.finish();
Some(Box::new(EncryptedFile { Some(Box::new(
EncryptedFileInit {
url: response.content_uri.clone(), url: response.content_uri.clone(),
key: keys.web_key, key: keys.web_key,
iv: keys.iv, iv: keys.iv,
hashes: keys.hashes, hashes: keys.hashes,
v: keys.version, v: keys.version,
})) }
.into(),
))
}; };
#[cfg(not(feature = "encryption"))] #[cfg(not(feature = "encryption"))]
let keys: Option<Box<EncryptedFile>> = None; let keys: Option<Box<EncryptedFile>> = None;
@ -486,32 +489,23 @@ impl Joined {
let content = match content_type.type_() { let content = match content_type.type_() {
mime::IMAGE => { mime::IMAGE => {
// TODO create a thumbnail using the image crate?. // TODO create a thumbnail using the image crate?.
MessageType::Image(ImageMessageEventContent { MessageType::Image(assign!(
body: body.to_owned(), ImageMessageEventContent::plain(body.to_owned(), url, None),
info: None, { file: encrypted_file }
url: Some(url), ))
file: encrypted_file,
})
} }
mime::AUDIO => MessageType::Audio(AudioMessageEventContent { mime::AUDIO => MessageType::Audio(assign!(
body: body.to_owned(), AudioMessageEventContent::plain(body.to_owned(), url, None),
info: None, { file: encrypted_file }
url: Some(url), )),
file: encrypted_file, mime::VIDEO => MessageType::Video(assign!(
}), VideoMessageEventContent::plain(body.to_owned(), url, None),
mime::VIDEO => MessageType::Video(VideoMessageEventContent { { file: encrypted_file }
body: body.to_owned(), )),
info: None, _ => MessageType::File(assign!(
url: Some(url), FileMessageEventContent::plain(body.to_owned(), url, None),
file: encrypted_file, { file: encrypted_file }
}), )),
_ => MessageType::File(FileMessageEventContent {
filename: None,
body: body.to_owned(),
info: None,
url: Some(url),
file: encrypted_file,
}),
}; };
self.send(AnyMessageEventContent::RoomMessage(MessageEventContent::new(content)), txn_id) self.send(AnyMessageEventContent::RoomMessage(MessageEventContent::new(content)), txn_id)
@ -540,6 +534,7 @@ impl Joined {
/// room::member::{MemberEventContent, MembershipState}, /// room::member::{MemberEventContent, MembershipState},
/// }, /// },
/// identifiers::mxc_uri, /// identifiers::mxc_uri,
/// assign,
/// }; /// };
/// # futures::executor::block_on(async { /// # futures::executor::block_on(async {
/// # let homeserver = url::Url::parse("http://localhost:8080").unwrap(); /// # let homeserver = url::Url::parse("http://localhost:8080").unwrap();
@ -547,13 +542,9 @@ impl Joined {
/// # let room_id = matrix_sdk::identifiers::room_id!("!test:localhost"); /// # let room_id = matrix_sdk::identifiers::room_id!("!test:localhost");
/// ///
/// let avatar_url = mxc_uri!("mxc://example.org/avatar"); /// let avatar_url = mxc_uri!("mxc://example.org/avatar");
/// let member_event = MemberEventContent { /// let member_event = assign!(MemberEventContent::new(MembershipState::Join), {
/// avatar_url: Some(avatar_url), /// avatar_url: Some(avatar_url),
/// membership: MembershipState::Join, /// });
/// is_direct: None,
/// displayname: None,
/// third_party_invite: None,
/// };
/// # let room = client /// # let room = client
/// # .get_joined_room(&room_id) /// # .get_joined_room(&room_id)
/// # .unwrap(); /// # .unwrap();

View File

@ -34,7 +34,9 @@ impl EventHandler for AppserviceEventHandler {
if let MembershipState::Invite = event.content.membership { if let MembershipState::Invite = event.content.membership {
let user_id = UserId::try_from(event.state_key.clone()).unwrap(); let user_id = UserId::try_from(event.state_key.clone()).unwrap();
let client = self.appservice.client_with_localpart(user_id.localpart()).await.unwrap(); self.appservice.register(user_id.localpart()).await.unwrap();
let client = self.appservice.client(Some(user_id.localpart())).await.unwrap();
client.join_room_by_id(room.room_id()).await.unwrap(); client.join_room_by_id(room.room_id()).await.unwrap();
} }
@ -55,7 +57,7 @@ pub async fn main() -> std::io::Result<()> {
let event_handler = AppserviceEventHandler::new(appservice.clone()); let event_handler = AppserviceEventHandler::new(appservice.clone());
appservice.client().set_event_handler(Box::new(event_handler)).await; appservice.set_event_handler(Box::new(event_handler)).await.unwrap();
HttpServer::new(move || App::new().service(appservice.actix_service())) HttpServer::new(move || App::new().service(appservice.actix_service()))
.bind(("0.0.0.0", 8090))? .bind(("0.0.0.0", 8090))?

View File

@ -61,11 +61,11 @@ async fn push_transactions(
request: IncomingRequest<api::event::push_events::v1::IncomingRequest>, request: IncomingRequest<api::event::push_events::v1::IncomingRequest>,
appservice: Data<Appservice>, appservice: Data<Appservice>,
) -> Result<HttpResponse, Error> { ) -> Result<HttpResponse, Error> {
if !appservice.hs_token_matches(request.access_token) { if !appservice.compare_hs_token(request.access_token) {
return Ok(HttpResponse::Unauthorized().finish()); return Ok(HttpResponse::Unauthorized().finish());
} }
appservice.client().receive_transaction(request.incoming).await.unwrap(); appservice.client(None).await?.receive_transaction(request.incoming).await?;
Ok(HttpResponse::Ok().json("{}")) Ok(HttpResponse::Ok().json("{}"))
} }
@ -76,7 +76,7 @@ async fn query_user_id(
request: IncomingRequest<api::query::query_user_id::v1::IncomingRequest>, request: IncomingRequest<api::query::query_user_id::v1::IncomingRequest>,
appservice: Data<Appservice>, appservice: Data<Appservice>,
) -> Result<HttpResponse, Error> { ) -> Result<HttpResponse, Error> {
if !appservice.hs_token_matches(request.access_token) { if !appservice.compare_hs_token(request.access_token) {
return Ok(HttpResponse::Unauthorized().finish()); return Ok(HttpResponse::Unauthorized().finish());
} }
@ -89,7 +89,7 @@ async fn query_room_alias(
request: IncomingRequest<api::query::query_room_alias::v1::IncomingRequest>, request: IncomingRequest<api::query::query_room_alias::v1::IncomingRequest>,
appservice: Data<Appservice>, appservice: Data<Appservice>,
) -> Result<HttpResponse, Error> { ) -> Result<HttpResponse, Error> {
if !appservice.hs_token_matches(request.access_token) { if !appservice.compare_hs_token(request.access_token) {
return Ok(HttpResponse::Unauthorized().finish()); return Ok(HttpResponse::Unauthorized().finish());
} }

View File

@ -20,13 +20,26 @@
//! the webserver for you //! the webserver for you
//! * receive and validate requests from the homeserver correctly //! * receive and validate requests from the homeserver correctly
//! * allow calling the homeserver with proper virtual user identity assertion //! * allow calling the homeserver with proper virtual user identity assertion
//! * have the goal to have a consistent room state available by leveraging the //! * have consistent room state by leveraging matrix-sdk's state store
//! stores that the matrix-sdk provides //! * provide E2EE support by leveraging matrix-sdk's crypto store
//!
//! # Status
//!
//! The crate is in an experimental state. Follow
//! [matrix-org/matrix-rust-sdk#228] for progress.
//! //!
//! # Quickstart //! # Quickstart
//! //!
//! ```no_run //! ```no_run
//! # async { //! # async {
//! #
//! # use matrix_sdk::{async_trait, EventHandler};
//! #
//! # struct AppserviceEventHandler;
//! #
//! # #[async_trait]
//! # impl EventHandler for AppserviceEventHandler {}
//! #
//! use matrix_sdk_appservice::{Appservice, AppserviceRegistration}; //! use matrix_sdk_appservice::{Appservice, AppserviceRegistration};
//! //!
//! let homeserver_url = "http://127.0.0.1:8008"; //! let homeserver_url = "http://127.0.0.1:8008";
@ -42,17 +55,23 @@
//! users: //! users:
//! - exclusive: true //! - exclusive: true
//! regex: '@_appservice_.*' //! regex: '@_appservice_.*'
//! ") //! ")?;
//! .unwrap();
//! //!
//! let appservice = Appservice::new(homeserver_url, server_name, registration).await.unwrap(); //! let appservice = Appservice::new(homeserver_url, server_name, registration).await?;
//! // set event handler with `appservice.client().set_event_handler()` here //! appservice.set_event_handler(Box::new(AppserviceEventHandler)).await?;
//! let (host, port) = appservice.get_host_and_port_from_registration().unwrap(); //!
//! appservice.run(host, port).await.unwrap(); //! let (host, port) = appservice.registration().get_host_and_port()?;
//! appservice.run(host, port).await?;
//! #
//! # Ok::<(), Box<dyn std::error::Error + 'static>>(())
//! # }; //! # };
//! ``` //! ```
//! //!
//! Check the [examples directory] for fully working examples.
//!
//! [Application Service]: https://matrix.org/docs/spec/application_service/r0.1.2 //! [Application Service]: https://matrix.org/docs/spec/application_service/r0.1.2
//! [matrix-org/matrix-rust-sdk#228]: https://github.com/matrix-org/matrix-rust-sdk/issues/228
//! [examples directory]: https://github.com/matrix-org/matrix-rust-sdk/tree/master/matrix_sdk_appservice/examples
#[cfg(not(any(feature = "actix",)))] #[cfg(not(any(feature = "actix",)))]
compile_error!("one webserver feature must be enabled. available ones: `actix`"); compile_error!("one webserver feature must be enabled. available ones: `actix`");
@ -79,11 +98,10 @@ use matrix_sdk::{
assign, assign,
identifiers::{self, DeviceId, ServerNameBox, UserId}, identifiers::{self, DeviceId, ServerNameBox, UserId},
reqwest::Url, reqwest::Url,
Client, ClientConfig, FromHttpResponseError, HttpError, RequestConfig, ServerError, Session, Client, ClientConfig, EventHandler, FromHttpResponseError, HttpError, RequestConfig,
ServerError, Session,
}; };
use regex::Regex; use regex::Regex;
#[cfg(not(feature = "actix"))]
use tracing::error;
use tracing::warn; use tracing::warn;
#[cfg(feature = "actix")] #[cfg(feature = "actix")]
@ -96,6 +114,8 @@ pub type Host = String;
pub type Port = u16; pub type Port = u16;
/// Appservice Registration /// Appservice Registration
///
/// Wrapper around [`Registration`]
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct AppserviceRegistration { pub struct AppserviceRegistration {
inner: Registration, inner: Registration,
@ -117,6 +137,26 @@ impl AppserviceRegistration {
Ok(Self { inner: serde_yaml::from_reader(file)? }) Ok(Self { inner: serde_yaml::from_reader(file)? })
} }
/// Get the host and port from the registration URL
///
/// If no port is found it falls back to scheme defaults: 80 for http and
/// 443 for https
pub fn get_host_and_port(&self) -> Result<(Host, Port)> {
let uri = Uri::try_from(&self.inner.url)?;
let host = uri.host().ok_or(Error::MissingRegistrationHost)?.to_owned();
let port = match uri.port() {
Some(port) => Ok(port.as_u16()),
None => match uri.scheme_str() {
Some("http") => Ok(80),
Some("https") => Ok(443),
_ => Err(Error::MissingRegistrationPort),
},
}?;
Ok((host, port))
}
} }
impl From<Registration> for AppserviceRegistration { impl From<Registration> for AppserviceRegistration {
@ -133,31 +173,20 @@ impl Deref for AppserviceRegistration {
} }
} }
async fn create_client( async fn client_session_with_login_restore(
homeserver_url: &Url, client: &Client,
server_name: &ServerNameBox,
registration: &AppserviceRegistration, registration: &AppserviceRegistration,
localpart: Option<&str>, localpart: impl AsRef<str> + Into<Box<str>>,
) -> Result<Client> { server_name: &ServerNameBox,
let client = if localpart.is_some() { ) -> Result<()> {
let request_config = RequestConfig::default().assert_identity();
let config = ClientConfig::default().request_config(request_config);
Client::new_with_config(homeserver_url.clone(), config)?
} else {
Client::new(homeserver_url.clone())?
};
let session = Session { let session = Session {
access_token: registration.as_token.clone(), access_token: registration.as_token.clone(),
user_id: UserId::parse_with_server_name( user_id: UserId::parse_with_server_name(localpart, server_name)?,
localpart.unwrap_or(&registration.sender_localpart),
&server_name,
)?,
device_id: DeviceId::new(), device_id: DeviceId::new(),
}; };
client.restore_login(session).await?; client.restore_login(session).await?;
Ok(client) Ok(())
} }
/// Appservice /// Appservice
@ -189,60 +218,82 @@ impl Appservice {
let homeserver_url = homeserver_url.try_into()?; let homeserver_url = homeserver_url.try_into()?;
let server_name = server_name.try_into()?; let server_name = server_name.try_into()?;
let client = create_client(&homeserver_url, &server_name, &registration, None).await?; let client_sender_localpart = Client::new(homeserver_url.clone())?;
Ok(Appservice { client_session_with_login_restore(
homeserver_url, &client_sender_localpart,
server_name, &registration,
registration, registration.sender_localpart.as_ref(),
client_sender_localpart: client, &server_name,
})
}
/// Get `Client` for the user associated with the application service
/// (`sender_localpart` of the [registration])
///
/// [registration]: https://matrix.org/docs/spec/application_service/r0.1.2#registration
pub fn client(&self) -> Client {
self.client_sender_localpart.clone()
}
/// Get `Client` for the given `localpart`
///
/// If the `localpart` is covered by the `namespaces` in the [registration]
/// all requests to the homeserver will [assert the identity] to the
/// according virtual user.
///
/// [registration]: https://matrix.org/docs/spec/application_service/r0.1.2#registration
/// [assert the identity]:
/// https://matrix.org/docs/spec/application_service/r0.1.2#identity-assertion
pub async fn client_with_localpart(
&self,
localpart: impl AsRef<str> + Into<Box<str>>,
) -> Result<Client> {
let user_id = UserId::parse_with_server_name(localpart, &self.server_name)?;
let localpart = user_id.localpart().to_owned();
let client = create_client(
&self.homeserver_url,
&self.server_name,
&self.registration,
Some(&localpart),
) )
.await?; .await?;
self.ensure_registered(localpart).await?; Ok(Appservice { homeserver_url, server_name, registration, client_sender_localpart })
}
/// Get a [`Client`]
///
/// Will return a `Client` that's configured to [assert the identity] on all
/// outgoing homeserver requests if `localpart` is given. If not given
/// the `Client` will use the main user associated with this appservice,
/// that is the `sender_localpart` in the [`AppserviceRegistration`]
///
/// # Arguments
///
/// * `localpart` - The localpart of the user we want assert our identity to
///
/// [registration]: https://matrix.org/docs/spec/application_service/r0.1.2#registration
/// [assert the identity]: https://matrix.org/docs/spec/application_service/r0.1.2#identity-assertion
pub async fn client(&self, localpart: Option<&str>) -> Result<Client> {
let localpart = localpart.unwrap_or_else(|| self.registration.sender_localpart.as_ref());
// The `as_token` in the `Session` maps to the main appservice user
// (`sender_localpart`) by default, so we don't need to assert identity
// in that case
let client = if localpart == self.registration.sender_localpart {
self.client_sender_localpart.clone()
} else {
let request_config = RequestConfig::default().assert_identity();
let config = ClientConfig::default().request_config(request_config);
let client = Client::new_with_config(self.homeserver_url.clone(), config)?;
client_session_with_login_restore(
&client,
&self.registration,
localpart,
&self.server_name,
)
.await?;
client
};
Ok(client) Ok(client)
} }
async fn ensure_registered(&self, localpart: impl AsRef<str>) -> Result<()> { /// Convenience wrapper around [`Client::set_event_handler()`]
pub async fn set_event_handler(&self, handler: Box<dyn EventHandler>) -> Result<()> {
let client = self.client(None).await?;
client.set_event_handler(handler).await;
Ok(())
}
/// Register a virtual user by sending a [`RegistrationRequest`] to the
/// homeserver
///
/// # Arguments
///
/// * `localpart` - The localpart of the user to register. Must be covered
/// by the namespaces in the [`Registration`] in order to succeed.
pub async fn register(&self, localpart: impl AsRef<str>) -> Result<()> {
let request = assign!(RegistrationRequest::new(), { let request = assign!(RegistrationRequest::new(), {
username: Some(localpart.as_ref()), username: Some(localpart.as_ref()),
login_type: Some(&LoginType::ApplicationService), login_type: Some(&LoginType::ApplicationService),
}); });
match self.client().register(request).await { let client = self.client(None).await?;
match client.register(request).await {
Ok(_) => (), Ok(_) => (),
Err(error) => match error { Err(error) => match error {
matrix_sdk::Error::Http(HttpError::UiaaError(FromHttpResponseError::Http( matrix_sdk::Error::Http(HttpError::UiaaError(FromHttpResponseError::Http(
@ -266,14 +317,14 @@ impl Appservice {
/// Get the Appservice [registration] /// Get the Appservice [registration]
/// ///
/// [registration]: https://matrix.org/docs/spec/application_service/r0.1.2#registration /// [registration]: https://matrix.org/docs/spec/application_service/r0.1.2#registration
pub fn registration(&self) -> &Registration { pub fn registration(&self) -> &AppserviceRegistration {
&self.registration &self.registration
} }
/// Compare the given `hs_token` against `registration.hs_token` /// Compare the given `hs_token` against `registration.hs_token`
/// ///
/// Returns `true` if the tokens match, `false` otherwise. /// Returns `true` if the tokens match, `false` otherwise.
pub fn hs_token_matches(&self, hs_token: impl AsRef<str>) -> bool { pub fn compare_hs_token(&self, hs_token: impl AsRef<str>) -> bool {
self.registration.hs_token == hs_token.as_ref() self.registration.hs_token == hs_token.as_ref()
} }
@ -290,26 +341,6 @@ impl Appservice {
Ok(false) Ok(false)
} }
/// Get the host and port from the registration URL
///
/// If no port is found it falls back to scheme defaults: 80 for http and
/// 443 for https
pub fn get_host_and_port_from_registration(&self) -> Result<(Host, Port)> {
let uri = Uri::try_from(&self.registration.url)?;
let host = uri.host().ok_or(Error::MissingRegistrationHost)?.to_owned();
let port = match uri.port() {
Some(port) => Ok(port.as_u16()),
None => match uri.scheme_str() {
Some("http") => Ok(80),
Some("https") => Ok(443),
_ => Err(Error::MissingRegistrationPort),
},
}?;
Ok((host, port))
}
/// Service to register on an Actix `App` /// Service to register on an Actix `App`
#[cfg(feature = "actix")] #[cfg(feature = "actix")]
#[cfg_attr(docs, doc(cfg(feature = "actix")))] #[cfg_attr(docs, doc(cfg(feature = "actix")))]

View File

@ -76,7 +76,7 @@ async fn test_event_handler() -> Result<()> {
} }
} }
appservice.client().set_event_handler(Box::new(Example::new())).await; appservice.set_event_handler(Box::new(Example::new())).await?;
let event = serde_json::from_value::<AnyStateEvent>(member_json()).unwrap(); let event = serde_json::from_value::<AnyStateEvent>(member_json()).unwrap();
let event: Raw<AnyRoomEvent> = AnyRoomEvent::State(event).into(); let event: Raw<AnyRoomEvent> = AnyRoomEvent::State(event).into();
@ -87,7 +87,7 @@ async fn test_event_handler() -> Result<()> {
events, events,
); );
appservice.client().receive_transaction(incoming).await?; appservice.client(None).await?.receive_transaction(incoming).await?;
Ok(()) Ok(())
} }
@ -105,7 +105,7 @@ async fn test_transaction() -> Result<()> {
events, events,
); );
appservice.client().receive_transaction(incoming).await?; appservice.client(None).await?.receive_transaction(incoming).await?;
Ok(()) Ok(())
} }
@ -116,7 +116,7 @@ async fn test_verify_hs_token() -> Result<()> {
let registration = appservice.registration(); let registration = appservice.registration();
assert!(appservice.hs_token_matches(&registration.hs_token)); assert!(appservice.compare_hs_token(&registration.hs_token));
Ok(()) Ok(())
} }

View File

@ -20,7 +20,6 @@ use std::{
path::{Path, PathBuf}, path::{Path, PathBuf},
result::Result as StdResult, result::Result as StdResult,
sync::Arc, sync::Arc,
time::SystemTime,
}; };
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
@ -49,7 +48,7 @@ use matrix_sdk_common::{
instant::Instant, instant::Instant,
locks::RwLock, locks::RwLock,
push::{Action, PushConditionRoomCtx, Ruleset}, push::{Action, PushConditionRoomCtx, Ruleset},
Raw, UInt, MilliSecondsSinceUnixEpoch, Raw, UInt,
}; };
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
use matrix_sdk_crypto::{ use matrix_sdk_crypto::{
@ -100,8 +99,7 @@ pub struct AdditionalUnsignedData {
pub fn hoist_and_deserialize_state_event( pub fn hoist_and_deserialize_state_event(
event: &Raw<AnySyncStateEvent>, event: &Raw<AnySyncStateEvent>,
) -> StdResult<AnySyncStateEvent, serde_json::Error> { ) -> StdResult<AnySyncStateEvent, serde_json::Error> {
let prev_content = let prev_content = event.deserialize_as::<AdditionalEventData>()?.unsigned.prev_content;
serde_json::from_str::<AdditionalEventData>(event.json().get())?.unsigned.prev_content;
let mut ev = event.deserialize()?; let mut ev = event.deserialize()?;
@ -117,8 +115,7 @@ pub fn hoist_and_deserialize_state_event(
fn hoist_member_event( fn hoist_member_event(
event: &Raw<StateEvent<MemberEventContent>>, event: &Raw<StateEvent<MemberEventContent>>,
) -> StdResult<StateEvent<MemberEventContent>, serde_json::Error> { ) -> StdResult<StateEvent<MemberEventContent>, serde_json::Error> {
let prev_content = let prev_content = event.deserialize_as::<AdditionalEventData>()?.unsigned.prev_content;
serde_json::from_str::<AdditionalEventData>(event.json().get())?.unsigned.prev_content;
let mut e = event.deserialize()?; let mut e = event.deserialize()?;
@ -132,7 +129,8 @@ fn hoist_member_event(
fn hoist_room_event_prev_content( fn hoist_room_event_prev_content(
event: &Raw<AnySyncRoomEvent>, event: &Raw<AnySyncRoomEvent>,
) -> StdResult<AnySyncRoomEvent, serde_json::Error> { ) -> StdResult<AnySyncRoomEvent, serde_json::Error> {
let prev_content = serde_json::from_str::<AdditionalEventData>(event.json().get()) let prev_content = event
.deserialize_as::<AdditionalEventData>()
.map(|more_unsigned| more_unsigned.unsigned) .map(|more_unsigned| more_unsigned.unsigned)
.map(|additional| additional.prev_content)? .map(|additional| additional.prev_content)?
.and_then(|p| p.deserialize().ok()); .and_then(|p| p.deserialize().ok());
@ -515,7 +513,7 @@ impl BaseClient {
event.event.clone(), event.event.clone(),
false, false,
room_id.clone(), room_id.clone(),
SystemTime::now(), MilliSecondsSinceUnixEpoch::now(),
), ),
); );
} }

View File

@ -19,7 +19,7 @@ use std::{
convert::TryFrom, convert::TryFrom,
path::{Path, PathBuf}, path::{Path, PathBuf},
sync::Arc, sync::Arc,
time::SystemTime, time::Instant,
}; };
use futures::{ use futures::{
@ -83,8 +83,9 @@ impl From<SerializationError> for StoreError {
} }
} }
const ENCODE_SEPARATOR: u8 = 0xff;
trait EncodeKey { trait EncodeKey {
const SEPARATOR: u8 = 0xff;
fn encode(&self) -> Vec<u8>; fn encode(&self) -> Vec<u8>;
} }
@ -102,13 +103,13 @@ impl EncodeKey for &RoomId {
impl EncodeKey for &str { impl EncodeKey for &str {
fn encode(&self) -> Vec<u8> { fn encode(&self) -> Vec<u8> {
[self.as_bytes(), &[Self::SEPARATOR]].concat() [self.as_bytes(), &[ENCODE_SEPARATOR]].concat()
} }
} }
impl EncodeKey for (&str, &str) { impl EncodeKey for (&str, &str) {
fn encode(&self) -> Vec<u8> { fn encode(&self) -> Vec<u8> {
[self.0.as_bytes(), &[Self::SEPARATOR], self.1.as_bytes(), &[Self::SEPARATOR]].concat() [self.0.as_bytes(), &[ENCODE_SEPARATOR], self.1.as_bytes(), &[ENCODE_SEPARATOR]].concat()
} }
} }
@ -116,11 +117,11 @@ impl EncodeKey for (&str, &str, &str) {
fn encode(&self) -> Vec<u8> { fn encode(&self) -> Vec<u8> {
[ [
self.0.as_bytes(), self.0.as_bytes(),
&[Self::SEPARATOR], &[ENCODE_SEPARATOR],
self.1.as_bytes(), self.1.as_bytes(),
&[Self::SEPARATOR], &[ENCODE_SEPARATOR],
self.2.as_bytes(), self.2.as_bytes(),
&[Self::SEPARATOR], &[ENCODE_SEPARATOR],
] ]
.concat() .concat()
} }
@ -286,7 +287,7 @@ impl SledStore {
} }
pub async fn save_changes(&self, changes: &StateChanges) -> Result<()> { pub async fn save_changes(&self, changes: &StateChanges) -> Result<()> {
let now = SystemTime::now(); let now = Instant::now();
let ret: Result<(), TransactionError<SerializationError>> = ( let ret: Result<(), TransactionError<SerializationError>> = (
&self.session, &self.session,
@ -506,11 +507,22 @@ impl SledStore {
.transpose()?) .transpose()?)
} }
pub async fn get_user_ids(&self, room_id: &RoomId) -> impl Stream<Item = Result<UserId>> { pub async fn get_user_ids_stream(
stream::iter(self.members.scan_prefix(room_id.encode()).map(|u| { &self,
UserId::try_from(String::from_utf8_lossy(&u?.1).to_string()) room_id: &RoomId,
.map_err(StoreError::Identifier) ) -> impl Stream<Item = Result<UserId>> {
})) let decode = |key: &[u8]| -> Result<UserId> {
let mut iter = key.split(|c| c == &ENCODE_SEPARATOR);
// Our key is a the room id separated from the user id by a null
// byte, discard the first value of the split.
iter.next();
let user_id = iter.next().expect("User ids weren't properly encoded");
Ok(UserId::try_from(String::from_utf8_lossy(user_id).to_string())?)
};
stream::iter(self.members.scan_prefix(room_id.encode()).map(move |u| decode(&u?.0)))
} }
pub async fn get_invited_user_ids( pub async fn get_invited_user_ids(
@ -636,7 +648,7 @@ impl StateStore for SledStore {
} }
async fn get_user_ids(&self, room_id: &RoomId) -> Result<Vec<UserId>> { async fn get_user_ids(&self, room_id: &RoomId) -> Result<Vec<UserId>> {
self.get_user_ids(room_id).await.try_collect().await self.get_user_ids_stream(room_id).await.try_collect().await
} }
async fn get_invited_user_ids(&self, room_id: &RoomId) -> Result<Vec<UserId>> { async fn get_invited_user_ids(&self, room_id: &RoomId) -> Result<Vec<UserId>> {
@ -681,7 +693,7 @@ impl StateStore for SledStore {
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use std::{convert::TryFrom, time::SystemTime}; use std::convert::TryFrom;
use matrix_sdk_common::{ use matrix_sdk_common::{
events::{ events::{
@ -692,13 +704,13 @@ mod test {
AnySyncStateEvent, EventType, Unsigned, AnySyncStateEvent, EventType, Unsigned,
}, },
identifiers::{room_id, user_id, EventId, UserId}, identifiers::{room_id, user_id, EventId, UserId},
Raw, MilliSecondsSinceUnixEpoch, Raw,
}; };
use matrix_sdk_test::async_test; use matrix_sdk_test::async_test;
use serde_json::json; use serde_json::json;
use super::{SledStore, StateChanges}; use super::{SledStore, StateChanges};
use crate::deserialized_responses::MemberEvent; use crate::{deserialized_responses::MemberEvent, StateStore};
fn user_id() -> UserId { fn user_id() -> UserId {
user_id!("@example:localhost") user_id!("@example:localhost")
@ -721,19 +733,11 @@ mod test {
} }
fn membership_event() -> MemberEvent { fn membership_event() -> MemberEvent {
let content = MemberEventContent {
avatar_url: None,
displayname: None,
is_direct: None,
third_party_invite: None,
membership: MembershipState::Join,
};
MemberEvent { MemberEvent {
event_id: EventId::try_from("$h29iv0s8:example.com").unwrap(), event_id: EventId::try_from("$h29iv0s8:example.com").unwrap(),
content, content: MemberEventContent::new(MembershipState::Join),
sender: user_id(), sender: user_id(),
origin_server_ts: SystemTime::now(), origin_server_ts: MilliSecondsSinceUnixEpoch::now(),
state_key: user_id(), state_key: user_id(),
prev_content: None, prev_content: None,
unsigned: Unsigned::default(), unsigned: Unsigned::default(),
@ -756,6 +760,9 @@ mod test {
store.save_changes(&changes).await.unwrap(); store.save_changes(&changes).await.unwrap();
assert!(store.get_member_event(&room_id, &user_id).await.unwrap().is_some()); assert!(store.get_member_event(&room_id, &user_id).await.unwrap().is_some());
let members = store.get_user_ids(&room_id).await.unwrap();
assert!(!members.is_empty())
} }
#[async_test] #[async_test]

View File

@ -20,9 +20,7 @@ serde = "1.0.122"
async-trait = "0.1.42" async-trait = "0.1.42"
[dependencies.ruma] [dependencies.ruma]
version = "0.0.3" version = "0.1.0"
git = "https://github.com/ruma/ruma"
rev = "3bdead1cf207e3ab9c8fcbfc454c054c726ba6f5"
features = ["client-api-c", "compat", "unstable-pre-spec"] features = ["client-api-c", "compat", "unstable-pre-spec"]
[target.'cfg(not(target_arch = "wasm32"))'.dependencies] [target.'cfg(not(target_arch = "wasm32"))'.dependencies]

View File

@ -1,4 +1,4 @@
use std::{collections::BTreeMap, convert::TryFrom, time::SystemTime}; use std::{collections::BTreeMap, convert::TryFrom};
use ruma::{ use ruma::{
api::client::r0::sync::sync_events::{ api::client::r0::sync::sync_events::{
@ -22,6 +22,7 @@ use super::{
SyncStateEvent, Unsigned, SyncStateEvent, Unsigned,
}, },
identifiers::{DeviceKeyAlgorithm, EventId, RoomId, UserId}, identifiers::{DeviceKeyAlgorithm, EventId, RoomId, UserId},
MilliSecondsSinceUnixEpoch,
}; };
/// A change in ambiguity of room members that an `m.room.member` event /// A change in ambiguity of room members that an `m.room.member` event
@ -249,7 +250,7 @@ impl Timeline {
pub struct MemberEvent { pub struct MemberEvent {
pub content: MemberEventContent, pub content: MemberEventContent,
pub event_id: EventId, pub event_id: EventId,
pub origin_server_ts: SystemTime, pub origin_server_ts: MilliSecondsSinceUnixEpoch,
pub prev_content: Option<MemberEventContent>, pub prev_content: Option<MemberEventContent>,
pub sender: UserId, pub sender: UserId,
pub state_key: UserId, pub state_key: UserId,

View File

@ -15,7 +15,7 @@ pub use ruma::{
}, },
assign, directory, encryption, events, identifiers, int, presence, push, receipt, assign, directory, encryption, events, identifiers, int, presence, push, receipt,
serde::{CanonicalJsonValue, Raw}, serde::{CanonicalJsonValue, Raw},
thirdparty, uint, Int, Outgoing, UInt, thirdparty, uint, Int, MilliSecondsSinceUnixEpoch, Outgoing, SecondsSinceUnixEpoch, UInt,
}; };
pub use uuid; pub use uuid;

View File

@ -23,7 +23,7 @@ use aes_ctr::{
}; };
use base64::DecodeError; use base64::DecodeError;
use getrandom::getrandom; use getrandom::getrandom;
use matrix_sdk_common::events::room::JsonWebKey; use matrix_sdk_common::events::room::{JsonWebKey, JsonWebKeyInit};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256}; use sha2::{Digest, Sha256};
use thiserror::Error; use thiserror::Error;
@ -201,13 +201,13 @@ impl<'a, R: Read + 'a> AttachmentEncryptor<'a, R> {
// initialized. // initialized.
getrandom(&mut iv[0..8]).expect("Can't generate randomness"); getrandom(&mut iv[0..8]).expect("Can't generate randomness");
let web_key = JsonWebKey { let web_key = JsonWebKey::from(JsonWebKeyInit {
kty: "oct".to_owned(), kty: "oct".to_owned(),
key_ops: vec!["encrypt".to_owned(), "decrypt".to_owned()], key_ops: vec!["encrypt".to_owned(), "decrypt".to_owned()],
alg: "A256CTR".to_owned(), alg: "A256CTR".to_owned(),
k: encode_url_safe(&*key), k: encode_url_safe(&*key),
ext: true, ext: true,
}; });
let encoded_iv = encode(&*iv); let encoded_iv = encode(&*iv);
let aes = Aes256Ctr::new_var(&*key, &*iv).expect("Cannot create AES encryption object."); let aes = Aes256Ctr::new_var(&*key, &*iv).expect("Cannot create AES encryption object.");

View File

@ -151,12 +151,12 @@ pub struct OutgoingKeyRequest {
impl OutgoingKeyRequest { impl OutgoingKeyRequest {
fn to_request(&self, own_device_id: &DeviceId) -> Result<OutgoingRequest, serde_json::Error> { fn to_request(&self, own_device_id: &DeviceId) -> Result<OutgoingRequest, serde_json::Error> {
let content = RoomKeyRequestToDeviceEventContent { let content = RoomKeyRequestToDeviceEventContent::new(
action: Action::Request, Action::Request,
request_id: self.request_id.to_string(), Some(self.info.clone()),
requesting_device_id: own_device_id.to_owned(), own_device_id.to_owned(),
body: Some(self.info.clone()), self.request_id.to_string(),
}; );
wrap_key_request_content(self.request_recipient.clone(), self.request_id, &content) wrap_key_request_content(self.request_recipient.clone(), self.request_id, &content)
} }
@ -165,12 +165,12 @@ impl OutgoingKeyRequest {
&self, &self,
own_device_id: &DeviceId, own_device_id: &DeviceId,
) -> Result<OutgoingRequest, serde_json::Error> { ) -> Result<OutgoingRequest, serde_json::Error> {
let content = RoomKeyRequestToDeviceEventContent { let content = RoomKeyRequestToDeviceEventContent::new(
action: Action::CancelRequest, Action::CancelRequest,
request_id: self.request_id.to_string(), None,
requesting_device_id: own_device_id.to_owned(), own_device_id.to_owned(),
body: None, self.request_id.to_string(),
}; );
let id = Uuid::new_v4(); let id = Uuid::new_v4();
wrap_key_request_content(self.request_recipient.clone(), id, &content) wrap_key_request_content(self.request_recipient.clone(), id, &content)
@ -584,12 +584,12 @@ impl KeyRequestMachine {
sender_key: &str, sender_key: &str,
session_id: &str, session_id: &str,
) -> Result<(Option<OutgoingRequest>, OutgoingRequest), CryptoStoreError> { ) -> Result<(Option<OutgoingRequest>, OutgoingRequest), CryptoStoreError> {
let key_info = RequestedKeyInfo { let key_info = RequestedKeyInfo::new(
algorithm: EventEncryptionAlgorithm::MegolmV1AesSha2, EventEncryptionAlgorithm::MegolmV1AesSha2,
room_id: room_id.to_owned(), room_id.to_owned(),
sender_key: sender_key.to_owned(), sender_key.to_owned(),
session_id: session_id.to_owned(), session_id.to_owned(),
}; );
let request = self.store.get_key_request_by_info(&key_info).await?; let request = self.store.get_key_request_by_info(&key_info).await?;
@ -644,12 +644,12 @@ impl KeyRequestMachine {
sender_key: &str, sender_key: &str,
session_id: &str, session_id: &str,
) -> Result<(), CryptoStoreError> { ) -> Result<(), CryptoStoreError> {
let key_info = RequestedKeyInfo { let key_info = RequestedKeyInfo::new(
algorithm: EventEncryptionAlgorithm::MegolmV1AesSha2, EventEncryptionAlgorithm::MegolmV1AesSha2,
room_id: room_id.to_owned(), room_id.to_owned(),
sender_key: sender_key.to_owned(), sender_key.to_owned(),
session_id: session_id.to_owned(), session_id.to_owned(),
}; );
if self.should_request_key(&key_info).await? { if self.should_request_key(&key_info).await? {
self.request_key_helper(key_info).await?; self.request_key_helper(key_info).await?;
@ -675,12 +675,12 @@ impl KeyRequestMachine {
&self, &self,
content: &ForwardedRoomKeyToDeviceEventContent, content: &ForwardedRoomKeyToDeviceEventContent,
) -> Result<Option<OutgoingKeyRequest>, CryptoStoreError> { ) -> Result<Option<OutgoingKeyRequest>, CryptoStoreError> {
let info = RequestedKeyInfo { let info = RequestedKeyInfo::new(
algorithm: content.algorithm.clone(), content.algorithm.clone(),
room_id: content.room_id.clone(), content.room_id.clone(),
sender_key: content.sender_key.clone(), content.sender_key.clone(),
session_id: content.session_id.clone(), content.session_id.clone(),
}; );
self.store.get_key_request_by_info(&info).await self.store.get_key_request_by_info(&info).await
} }

View File

@ -1216,7 +1216,6 @@ pub(crate) mod test {
collections::BTreeMap, collections::BTreeMap,
convert::{TryFrom, TryInto}, convert::{TryFrom, TryInto},
sync::Arc, sync::Arc,
time::SystemTime,
}; };
use http::Response; use http::Response;
@ -1233,7 +1232,7 @@ pub(crate) mod test {
identifiers::{ identifiers::{
event_id, room_id, user_id, DeviceId, DeviceKeyAlgorithm, DeviceKeyId, UserId, event_id, room_id, user_id, DeviceId, DeviceKeyAlgorithm, DeviceKeyId, UserId,
}, },
IncomingResponse, Raw, IncomingResponse, MilliSecondsSinceUnixEpoch, Raw,
}; };
use matrix_sdk_test::test_json; use matrix_sdk_test::test_json;
use serde_json::json; use serde_json::json;
@ -1680,7 +1679,7 @@ pub(crate) mod test {
let event = SyncMessageEvent { let event = SyncMessageEvent {
event_id: event_id!("$xxxxx:example.org"), event_id: event_id!("$xxxxx:example.org"),
origin_server_ts: SystemTime::now(), origin_server_ts: MilliSecondsSinceUnixEpoch::now(),
sender: alice.user_id().clone(), sender: alice.user_id().clone(),
content: encrypted_content, content: encrypted_content,
unsigned: Unsigned::default(), unsigned: Unsigned::default(),

View File

@ -12,12 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::{ use std::{collections::BTreeMap, convert::TryFrom, fmt, mem, sync::Arc};
collections::BTreeMap,
convert::{TryFrom, TryInto},
fmt, mem,
sync::Arc,
};
use matrix_sdk_common::{ use matrix_sdk_common::{
events::{ events::{
@ -310,13 +305,7 @@ impl InboundGroupSession {
let mut decrypted_value = serde_json::from_str::<Value>(&plaintext)?; let mut decrypted_value = serde_json::from_str::<Value>(&plaintext)?;
let decrypted_object = decrypted_value.as_object_mut().ok_or(EventError::NotAnObject)?; let decrypted_object = decrypted_value.as_object_mut().ok_or(EventError::NotAnObject)?;
// TODO better number conversion here. let server_ts: i64 = event.origin_server_ts.0.into();
let server_ts = event
.origin_server_ts
.duration_since(std::time::SystemTime::UNIX_EPOCH)
.unwrap_or_default()
.as_millis();
let server_ts: i64 = server_ts.try_into().unwrap_or_default();
decrypted_object.insert("sender".to_owned(), event.sender.to_string().into()); decrypted_object.insert("sender".to_owned(), event.sender.to_string().into());
decrypted_object.insert("event_id".to_owned(), event.event_id.to_string().into()); decrypted_object.insert("event_id".to_owned(), event.event_id.to_string().into());

View File

@ -1205,12 +1205,12 @@ mod test {
let (account, store, _dir) = get_loaded_store().await; let (account, store, _dir) = get_loaded_store().await;
let id = Uuid::new_v4(); let id = Uuid::new_v4();
let info = RequestedKeyInfo { let info = RequestedKeyInfo::new(
algorithm: EventEncryptionAlgorithm::MegolmV1AesSha2, EventEncryptionAlgorithm::MegolmV1AesSha2,
room_id: room_id!("!test:localhost"), room_id!("!test:localhost"),
sender_key: "test_sender_key".to_string(), "test_sender_key".to_string(),
session_id: "test_session_id".to_string(), "test_session_id".to_string(),
}; );
let request = OutgoingKeyRequest { let request = OutgoingKeyRequest {
request_recipient: account.user_id().to_owned(), request_recipient: account.user_id().to_owned(),

View File

@ -184,17 +184,17 @@ impl VerificationRequest {
own_device_id: &DeviceId, own_device_id: &DeviceId,
other_user_id: &UserId, other_user_id: &UserId,
) -> KeyVerificationRequestEventContent { ) -> KeyVerificationRequestEventContent {
KeyVerificationRequestEventContent { KeyVerificationRequestEventContent::new(
body: format!( format!(
"{} is requesting to verify your key, but your client does not \ "{} is requesting to verify your key, but your client does not \
support in-chat key verification. You will need to use legacy \ support in-chat key verification. You will need to use legacy \
key verification to verify keys.", key verification to verify keys.",
own_user_id own_user_id
), ),
methods: SUPPORTED_METHODS.to_vec(), SUPPORTED_METHODS.to_vec(),
from_device: own_device_id.into(), own_device_id.into(),
to: other_user_id.to_owned(), other_user_id.to_owned(),
} )
} }
/// The id of the other user that is participating in this verification /// The id of the other user that is participating in this verification
@ -515,21 +515,17 @@ impl RequestState<Requested> {
}; };
let content = match self.state.flow_id { let content = match self.state.flow_id {
FlowId::ToDevice(i) => { FlowId::ToDevice(i) => AnyToDeviceEventContent::KeyVerificationReady(
AnyToDeviceEventContent::KeyVerificationReady(ReadyToDeviceEventContent { ReadyToDeviceEventContent::new(self.own_device_id, self.state.methods, i),
from_device: self.own_device_id, )
methods: self.state.methods, .into(),
transaction_id: i,
})
.into()
}
FlowId::InRoom(r, e) => ( FlowId::InRoom(r, e) => (
r, r,
AnyMessageEventContent::KeyVerificationReady(ReadyEventContent { AnyMessageEventContent::KeyVerificationReady(ReadyEventContent::new(
from_device: self.own_device_id, self.own_device_id,
methods: self.state.methods, self.state.methods,
relation: Relation { event_id: e }, Relation::new(e),
}), )),
) )
.into(), .into(),
}; };
@ -608,11 +604,12 @@ struct Passive {
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use std::{convert::TryFrom, time::SystemTime}; use std::convert::TryFrom;
use matrix_sdk_common::{ use matrix_sdk_common::{
events::{SyncMessageEvent, Unsigned}, events::{SyncMessageEvent, Unsigned},
identifiers::{event_id, room_id, DeviceIdBox, UserId}, identifiers::{event_id, room_id, DeviceIdBox, UserId},
MilliSecondsSinceUnixEpoch,
}; };
use matrix_sdk_test::async_test; use matrix_sdk_test::async_test;
@ -738,7 +735,7 @@ mod test {
content: c, content: c,
event_id: event_id.clone(), event_id: event_id.clone(),
sender: bob_id(), sender: bob_id(),
origin_server_ts: SystemTime::now(), origin_server_ts: MilliSecondsSinceUnixEpoch::now(),
unsigned: Unsigned::default(), unsigned: Unsigned::default(),
} }
} else { } else {

View File

@ -315,12 +315,9 @@ pub fn get_mac_content(sas: &OlmSas, ids: &SasIds, flow_id: &FlowId) -> MacConte
.expect("Can't calculate SAS MAC"); .expect("Can't calculate SAS MAC");
match flow_id { match flow_id {
FlowId::ToDevice(s) => { FlowId::ToDevice(s) => MacToDeviceEventContent::new(s.to_string(), mac, keys).into(),
MacToDeviceEventContent { transaction_id: s.to_string(), keys, mac }.into()
}
FlowId::InRoom(r, e) => { FlowId::InRoom(r, e) => {
(r.clone(), MacEventContent { mac, keys, relation: Relation { event_id: e.clone() } }) (r.clone(), MacEventContent::new(mac, keys, Relation::new(e.clone()))).into()
.into()
} }
} }
} }

View File

@ -23,13 +23,13 @@ use matrix_sdk_common::{
events::key::verification::{ events::key::verification::{
accept::{ accept::{
AcceptEventContent, AcceptMethod, AcceptToDeviceEventContent, AcceptEventContent, AcceptMethod, AcceptToDeviceEventContent,
MSasV1Content as AcceptV1Content, MSasV1ContentInit as AcceptV1ContentInit, SasV1Content as AcceptV1Content, SasV1ContentInit as AcceptV1ContentInit,
}, },
cancel::{CancelCode, CancelEventContent, CancelToDeviceEventContent}, cancel::{CancelCode, CancelEventContent, CancelToDeviceEventContent},
done::DoneEventContent, done::DoneEventContent,
key::{KeyEventContent, KeyToDeviceEventContent}, key::{KeyEventContent, KeyToDeviceEventContent},
start::{ start::{
MSasV1Content, MSasV1ContentInit, StartEventContent, StartMethod, SasV1Content, SasV1ContentInit, StartEventContent, StartMethod,
StartToDeviceEventContent, StartToDeviceEventContent,
}, },
HashAlgorithm, KeyAgreementProtocol, MessageAuthenticationCode, Relation, HashAlgorithm, KeyAgreementProtocol, MessageAuthenticationCode, Relation,
@ -105,10 +105,10 @@ impl TryFrom<AcceptV1Content> for AcceptedProtocols {
} }
} }
impl TryFrom<&MSasV1Content> for AcceptedProtocols { impl TryFrom<&SasV1Content> for AcceptedProtocols {
type Error = CancelCode; type Error = CancelCode;
fn try_from(method_content: &MSasV1Content) -> Result<Self, Self::Error> { fn try_from(method_content: &SasV1Content) -> Result<Self, Self::Error> {
if !method_content if !method_content
.key_agreement_protocols .key_agreement_protocols
.contains(&KeyAgreementProtocol::Curve25519HkdfSha256) .contains(&KeyAgreementProtocol::Curve25519HkdfSha256)
@ -212,7 +212,7 @@ impl<S: Clone + std::fmt::Debug> std::fmt::Debug for SasState<S> {
/// The initial SAS state. /// The initial SAS state.
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct Created { pub struct Created {
protocol_definitions: MSasV1ContentInit, protocol_definitions: SasV1ContentInit,
} }
/// The initial SAS state if the other side started the SAS verification. /// The initial SAS state if the other side started the SAS verification.
@ -403,7 +403,7 @@ impl SasState<Created> {
last_event_time: Arc::new(Instant::now()), last_event_time: Arc::new(Instant::now()),
state: Arc::new(Created { state: Arc::new(Created {
protocol_definitions: MSasV1ContentInit { protocol_definitions: SasV1ContentInit {
short_authentication_string: STRINGS.to_vec(), short_authentication_string: STRINGS.to_vec(),
key_agreement_protocols: KEY_AGREEMENT_PROTOCOLS.to_vec(), key_agreement_protocols: KEY_AGREEMENT_PROTOCOLS.to_vec(),
message_authentication_codes: MACS.to_vec(), message_authentication_codes: MACS.to_vec(),
@ -415,24 +415,24 @@ impl SasState<Created> {
pub fn as_content(&self) -> StartContent { pub fn as_content(&self) -> StartContent {
match self.verification_flow_id.as_ref() { match self.verification_flow_id.as_ref() {
FlowId::ToDevice(s) => StartContent::ToDevice(StartToDeviceEventContent { FlowId::ToDevice(s) => StartContent::ToDevice(StartToDeviceEventContent::new(
transaction_id: s.to_string(), self.device_id().into(),
from_device: self.device_id().into(), s.to_string(),
method: StartMethod::MSasV1( StartMethod::SasV1(
MSasV1Content::new(self.state.protocol_definitions.clone()) SasV1Content::new(self.state.protocol_definitions.clone())
.expect("Invalid initial protocol definitions."), .expect("Invalid initial protocol definitions."),
), ),
}), )),
FlowId::InRoom(r, e) => StartContent::Room( FlowId::InRoom(r, e) => StartContent::Room(
r.clone(), r.clone(),
StartEventContent { StartEventContent::new(
from_device: self.device_id().into(), self.device_id().into(),
method: StartMethod::MSasV1( StartMethod::SasV1(
MSasV1Content::new(self.state.protocol_definitions.clone()) SasV1Content::new(self.state.protocol_definitions.clone())
.expect("Invalid initial protocol definitions."), .expect("Invalid initial protocol definitions."),
), ),
relation: Relation { event_id: e.clone() }, Relation::new(e.clone()),
}, ),
), ),
} }
} }
@ -522,7 +522,7 @@ impl SasState<Started> {
state: Arc::new(Canceled::new(CancelCode::UnknownMethod)), state: Arc::new(Canceled::new(CancelCode::UnknownMethod)),
}; };
if let StartMethod::MSasV1(method_content) = content.method() { if let StartMethod::SasV1(method_content) = content.method() {
let sas = OlmSas::new(); let sas = OlmSas::new();
let pubkey = sas.public_key(); let pubkey = sas.public_key();
@ -589,14 +589,10 @@ impl SasState<Started> {
); );
match self.verification_flow_id.as_ref() { match self.verification_flow_id.as_ref() {
FlowId::ToDevice(s) => { FlowId::ToDevice(s) => AcceptToDeviceEventContent::new(s.to_string(), method).into(),
AcceptToDeviceEventContent { transaction_id: s.to_string(), method }.into() FlowId::InRoom(r, e) => {
(r.clone(), AcceptEventContent::new(method, Relation::new(e.clone()))).into()
} }
FlowId::InRoom(r, e) => (
r.clone(),
AcceptEventContent { method, relation: Relation { event_id: e.clone() } },
)
.into(),
} }
} }
@ -701,10 +697,10 @@ impl SasState<Accepted> {
.into(), .into(),
FlowId::InRoom(r, e) => ( FlowId::InRoom(r, e) => (
r.clone(), r.clone(),
KeyEventContent { KeyEventContent::new(
key: self.inner.lock().unwrap().public_key(), self.inner.lock().unwrap().public_key(),
relation: Relation { event_id: e.clone() }, Relation::new(e.clone()),
}, ),
) )
.into(), .into(),
} }
@ -725,10 +721,10 @@ impl SasState<KeyReceived> {
.into(), .into(),
FlowId::InRoom(r, e) => ( FlowId::InRoom(r, e) => (
r.clone(), r.clone(),
KeyEventContent { KeyEventContent::new(
key: self.inner.lock().unwrap().public_key(), self.inner.lock().unwrap().public_key(),
relation: Relation { event_id: e.clone() }, Relation::new(e.clone()),
}, ),
) )
.into(), .into(),
} }
@ -1024,7 +1020,7 @@ impl SasState<WaitingForDone> {
unreachable!("The done content isn't supported yet for to-device verifications") unreachable!("The done content isn't supported yet for to-device verifications")
} }
FlowId::InRoom(r, e) => { FlowId::InRoom(r, e) => {
(r.clone(), DoneEventContent { relation: Relation { event_id: e.clone() } }).into() (r.clone(), DoneEventContent::new(Relation::new(e.clone()))).into()
} }
} }
} }
@ -1076,7 +1072,7 @@ impl SasState<Done> {
unreachable!("The done content isn't supported yet for to-device verifications") unreachable!("The done content isn't supported yet for to-device verifications")
} }
FlowId::InRoom(r, e) => { FlowId::InRoom(r, e) => {
(r.clone(), DoneEventContent { relation: Relation { event_id: e.clone() } }).into() (r.clone(), DoneEventContent::new(Relation::new(e.clone()))).into()
} }
} }
} }
@ -1120,20 +1116,20 @@ impl Canceled {
impl SasState<Canceled> { impl SasState<Canceled> {
pub fn as_content(&self) -> CancelContent { pub fn as_content(&self) -> CancelContent {
match self.verification_flow_id.as_ref() { match self.verification_flow_id.as_ref() {
FlowId::ToDevice(s) => CancelToDeviceEventContent { FlowId::ToDevice(s) => CancelToDeviceEventContent::new(
transaction_id: s.clone(), s.clone(),
reason: self.state.reason.to_string(), self.state.reason.to_string(),
code: self.state.cancel_code.clone(), self.state.cancel_code.clone(),
} )
.into(), .into(),
FlowId::InRoom(r, e) => ( FlowId::InRoom(r, e) => (
r.clone(), r.clone(),
CancelEventContent { CancelEventContent::new(
reason: self.state.reason.to_string(), self.state.reason.to_string(),
code: self.state.cancel_code.clone(), self.state.cancel_code.clone(),
relation: Relation { event_id: e.clone() }, Relation::new(e.clone()),
}, ),
) )
.into(), .into(),
} }
@ -1360,7 +1356,7 @@ mod test {
}; };
match method { match method {
StartMethod::MSasV1(ref mut c) => { StartMethod::SasV1(ref mut c) => {
c.message_authentication_codes = vec![]; c.message_authentication_codes = vec![];
} }
_ => panic!("Unknown SAS start method"), _ => panic!("Unknown SAS start method"),