diff --git a/matrix_sdk/src/client.rs b/matrix_sdk/src/client.rs index c88e097f..73562cb7 100644 --- a/matrix_sdk/src/client.rs +++ b/matrix_sdk/src/client.rs @@ -17,6 +17,7 @@ use std::collections::BTreeMap; use std::collections::HashMap; use std::convert::{TryFrom, TryInto}; +use std::fmt::{self, Debug}; use std::path::Path; use std::result::Result as StdResult; use std::sync::Arc; @@ -66,8 +67,8 @@ pub struct Client { } #[cfg_attr(tarpaulin, skip)] -impl std::fmt::Debug for Client { - fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> StdResult<(), std::fmt::Error> { +impl Debug for Client { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> StdResult<(), fmt::Error> { write!(fmt, "Client {{ homeserver: {} }}", self.homeserver) } } @@ -106,8 +107,8 @@ pub struct ClientConfig { } #[cfg_attr(tarpaulin, skip)] -impl std::fmt::Debug for ClientConfig { - fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> StdResult<(), std::fmt::Error> { +impl Debug for ClientConfig { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { let mut res = fmt.debug_struct("ClientConfig"); #[cfg(not(target_arch = "wasm32"))] @@ -246,6 +247,7 @@ impl SyncSettings { } } +use api::r0::account::register; #[cfg(feature = "encryption")] use api::r0::keys::{claim_keys, get_keys, upload_keys, KeyAlgorithm}; use api::r0::membership::{ @@ -263,6 +265,7 @@ use api::r0::sync::sync_events; #[cfg(feature = "encryption")] use api::r0::to_device::send_event_to_device; use api::r0::typing::create_typing_event; +use api::r0::uiaa::UiaaResponse; impl Client { /// Creates a new client for making HTTP requests to the given homeserver. @@ -413,7 +416,7 @@ impl Client { /// device_id from a previous login call. Note that this should be done /// only if the client also holds the encryption keys for this device. #[instrument(skip(password))] - pub async fn login + std::fmt::Debug>( + pub async fn login + Debug>( &self, user: S, password: S, @@ -447,6 +450,43 @@ impl Client { Ok(self.base_client.restore_login(session).await?) } + /// Register a user to the server. + /// + /// # Arguments + /// + /// * `registration` - The easiest way to create this request is using the `RegistrationBuilder`. + /// + /// + /// # Examples + /// ``` + /// # use std::convert::TryFrom; + /// # use matrix_sdk::{Client, RegistrationBuilder}; + /// # use matrix_sdk::api::r0::account::register::RegistrationKind; + /// # use matrix_sdk::identifiers::DeviceId; + /// # use url::Url; + /// # let homeserver = Url::parse("http://example.com").unwrap(); + /// # let mut rt = tokio::runtime::Runtime::new().unwrap(); + /// # rt.block_on(async { + /// let mut builder = RegistrationBuilder::default(); + /// builder.password("pass") + /// .username("user") + /// .kind(RegistrationKind::User); + /// let mut client = Client::new(homeserver).unwrap(); + /// client.register_user(builder).await; + /// # }) + /// ``` + #[instrument(skip(registration))] + pub async fn register_user>( + &self, + registration: R, + ) -> Result { + info!("Registering to {}", self.homeserver); + + let request = registration.into(); + println!("{:#?}", request); + self.send_uiaa(request).await + } + /// Join a room by `RoomId`. /// /// Returns a `join_room_by_id::Response` consisting of the @@ -873,7 +913,79 @@ impl Client { Ok(response) } - /// Send an arbitrary request to the server, without updating client state + async fn send_request( + &self, + requires_auth: bool, + method: HttpMethod, + request: http::Request>, + ) -> Result { + let url = request.uri(); + let path_and_query = url.path_and_query().unwrap(); + let mut url = self.homeserver.clone(); + + url.set_path(path_and_query.path()); + url.set_query(path_and_query.query()); + + let request_builder = match method { + HttpMethod::GET => self.http_client.get(url), + HttpMethod::POST => { + let body = request.body().clone(); + self.http_client + .post(url) + .body(body) + .header(reqwest::header::CONTENT_TYPE, "application/json") + } + HttpMethod::PUT => { + let body = request.body().clone(); + self.http_client + .put(url) + .body(body) + .header(reqwest::header::CONTENT_TYPE, "application/json") + } + HttpMethod::DELETE => { + let body = request.body().clone(); + self.http_client + .delete(url) + .body(body) + .header(reqwest::header::CONTENT_TYPE, "application/json") + } + method => panic!("Unsuported method {}", method), + }; + + let request_builder = if requires_auth { + let session = self.base_client.session().read().await; + + if let Some(session) = session.as_ref() { + let header_value = format!("Bearer {}", &session.access_token); + request_builder.header(AUTHORIZATION, header_value) + } else { + return Err(Error::AuthenticationRequired); + } + } else { + request_builder + }; + + Ok(request_builder.send().await?) + } + + async fn response_to_http_response( + &self, + mut response: reqwest::Response, + ) -> Result>> { + let status = response.status(); + let mut http_builder = HttpResponse::builder().status(status); + let headers = http_builder.headers_mut().unwrap(); + + for (k, v) in response.headers_mut().drain() { + if let Some(key) = k { + headers.insert(key, v); + } + } + let body = response.bytes().await?.as_ref().to_owned(); + Ok(http_builder.body(body).unwrap()) + } + + /// Send an arbitrary request to the server, without updating client state. /// /// **Warning:** Because this method *does not* update the client state, it is /// important to make sure than you account for this yourself, and use wrapper methods @@ -911,69 +1023,79 @@ impl Client { /// // returned /// # }) /// ``` - pub async fn send + std::fmt::Debug>( + pub async fn send + Debug>( &self, request: Request, ) -> Result { let request: http::Request> = request.try_into()?; - let url = request.uri(); - let path_and_query = url.path_and_query().unwrap(); - let mut url = self.homeserver.clone(); - - url.set_path(path_and_query.path()); - url.set_query(path_and_query.query()); - - trace!("Doing request {:?}", url); - - let request_builder = match Request::METADATA.method { - HttpMethod::GET => self.http_client.get(url), - HttpMethod::POST => { - let body = request.body().clone(); - self.http_client - .post(url) - .body(body) - .header(reqwest::header::CONTENT_TYPE, "application/json") - } - HttpMethod::PUT => { - let body = request.body().clone(); - self.http_client - .put(url) - .body(body) - .header(reqwest::header::CONTENT_TYPE, "application/json") - } - HttpMethod::DELETE => unimplemented!(), - _ => panic!("Unsuported method"), - }; - - let request_builder = if Request::METADATA.requires_authentication { - let session = self.base_client.session().read().await; - - if let Some(session) = session.as_ref() { - let header_value = format!("Bearer {}", &session.access_token); - request_builder.header(AUTHORIZATION, header_value) - } else { - return Err(Error::AuthenticationRequired); - } - } else { - request_builder - }; - let mut response = request_builder.send().await?; + let response = self + .send_request( + Request::METADATA.requires_authentication, + Request::METADATA.method, + request, + ) + .await?; trace!("Got response: {:?}", response); - let status = response.status(); - let mut http_builder = HttpResponse::builder().status(status); - let headers = http_builder.headers_mut().unwrap(); + let response = self.response_to_http_response(response).await?; - for (k, v) in response.headers_mut().drain() { - if let Some(key) = k { - headers.insert(key, v); - } - } - let body = response.bytes().await?.as_ref().to_owned(); - let http_response = http_builder.body(body).unwrap(); + Ok(::try_from(response)?) + } - Ok(::try_from(http_response)?) + /// Send an arbitrary request to the server, without updating client state. + /// + /// This version allows the client to make registration requests. + /// + /// **Warning:** Because this method *does not* update the client state, it is + /// important to make sure than you account for this yourself, and use wrapper methods + /// where available. This method should *only* be used if a wrapper method for the + /// endpoint you'd like to use is not available. + /// + /// # Arguments + /// + /// * `request` - This version of send is for dealing with types that return + /// a `UiaaResponse` as the `Endpoint` associated type. + /// + /// # Examples + /// ``` + /// # use std::convert::TryFrom; + /// # use matrix_sdk::{Client, RegistrationBuilder}; + /// # use matrix_sdk::api::r0::account::register::{RegistrationKind, Request}; + /// # use matrix_sdk::identifiers::DeviceId; + /// # use url::Url; + /// # let homeserver = Url::parse("http://example.com").unwrap(); + /// # let mut rt = tokio::runtime::Runtime::new().unwrap(); + /// # rt.block_on(async { + /// let mut builder = RegistrationBuilder::default(); + /// builder.password("pass") + /// .username("user") + /// .kind(RegistrationKind::User); + /// let mut client = Client::new(homeserver).unwrap(); + /// let req: Request = builder.into(); + /// client.send_uiaa(req).await; + /// # }) + /// ``` + pub async fn send_uiaa + Debug>( + &self, + request: Request, + ) -> Result { + let request: http::Request> = request.try_into()?; + let response = self + .send_request( + Request::METADATA.requires_authentication, + Request::METADATA.method, + request, + ) + .await?; + + trace!("Got response: {:?}", response); + + let response = self.response_to_http_response(response).await?; + + let uiaa: Result<_> = ::try_from(response).map_err(Into::into); + + Ok(uiaa?) } /// Synchronize the client's state with the latest state on the server. @@ -1271,14 +1393,16 @@ impl Client { #[cfg(test)] mod test { use super::{ - ban_user, create_receipt, create_typing_event, forget_room, invite_user, kick_user, - leave_room, set_read_marker, Invite3pid, MessageEventContent, + api::r0::uiaa::AuthData, ban_user, create_receipt, create_typing_event, forget_room, + invite_user, kick_user, leave_room, register::RegistrationKind, set_read_marker, + Invite3pid, MessageEventContent, }; use super::{Client, ClientConfig, Session, SyncSettings, Url}; use crate::events::collections::all::RoomEvent; use crate::events::room::member::MembershipState; use crate::events::room::message::TextMessageEventContent; use crate::identifiers::{EventId, RoomId, RoomIdOrAliasId, UserId}; + use crate::RegistrationBuilder; use matrix_sdk_base::JsonStore; use matrix_sdk_test::{EventBuilder, EventsFile}; @@ -1459,6 +1583,44 @@ mod test { } } + #[tokio::test] + async fn register_error() { + let homeserver = Url::from_str(&mockito::server_url()).unwrap(); + + let _m = mock("POST", "/_matrix/client/r0/register") + .with_status(403) + .with_body_from_file("../test_data/registration_response_error.json") + .create(); + + let mut user = RegistrationBuilder::default(); + + user.username("user") + .password("password") + .auth(AuthData::FallbackAcknowledgement { + session: "foobar".to_string(), + }) + .kind(RegistrationKind::User); + + let client = Client::new(homeserver).unwrap(); + + if let Err(err) = client.register_user(user).await { + if let crate::Error::UiaaError(crate::FromHttpResponseError::Http( + // TODO this should be a UiaaError need to investigate + crate::ServerError::Unknown(e), + )) = err + { + assert!(e.to_string().starts_with("EOF while parsing")) + } else { + panic!( + "found the wrong `Error` type {:#?}, expected `ServerError::Unknown", + err + ); + } + } else { + panic!("this request should return an `Err` variant") + } + } + #[tokio::test] async fn join_room_by_id() { let homeserver = Url::from_str(&mockito::server_url()).unwrap(); diff --git a/matrix_sdk/src/error.rs b/matrix_sdk/src/error.rs index 2670a989..c293797a 100644 --- a/matrix_sdk/src/error.rs +++ b/matrix_sdk/src/error.rs @@ -20,6 +20,7 @@ use thiserror::Error; use matrix_sdk_base::Error as MatrixError; +use crate::api::r0::uiaa::UiaaResponse as UiaaError; use crate::api::Error as RumaClientError; use crate::FromHttpResponseError as RumaResponseError; use crate::IntoHttpError as RumaIntoHttpError; @@ -50,9 +51,23 @@ pub enum Error { #[error("can't convert between ruma_client_api and hyper types.")] IntoHttp(RumaIntoHttpError), - /// An error occured in the Matrix client library. + /// An error occurred in the Matrix client library. #[error(transparent)] MatrixError(#[from] MatrixError), + + /// An error occurred while authenticating. + /// + /// When registering or authenticating the Matrix server can send a `UiaaResponse` + /// as the error type, this is a User-Interactive Authentication API response. This + /// represents an error with information about how to authenticate the user. + #[error("User-Interactive Authentication required.")] + UiaaError(RumaResponseError), +} + +impl From> for Error { + fn from(error: RumaResponseError) -> Self { + Self::UiaaError(error) + } } impl From> for Error { diff --git a/matrix_sdk/src/lib.rs b/matrix_sdk/src/lib.rs index 3834a7f4..579a74d7 100644 --- a/matrix_sdk/src/lib.rs +++ b/matrix_sdk/src/lib.rs @@ -51,7 +51,7 @@ mod error; mod request_builder; pub use client::{Client, ClientConfig, SyncSettings}; pub use error::{Error, Result}; -pub use request_builder::{MessagesRequestBuilder, RoomBuilder}; +pub use request_builder::{MessagesRequestBuilder, RegistrationBuilder, RoomBuilder}; #[cfg(not(target_arch = "wasm32"))] pub(crate) const VERSION: &str = env!("CARGO_PKG_VERSION"); diff --git a/matrix_sdk/src/request_builder.rs b/matrix_sdk/src/request_builder.rs index d7c52615..e76b2644 100644 --- a/matrix_sdk/src/request_builder.rs +++ b/matrix_sdk/src/request_builder.rs @@ -1,7 +1,9 @@ use crate::api; use crate::events::room::power_levels::PowerLevelsEventContent; use crate::events::EventJson; -use crate::identifiers::{RoomId, UserId}; +use crate::identifiers::{DeviceId, RoomId, UserId}; +use api::r0::account::register; +use api::r0::account::register::RegistrationKind; use api::r0::filter::RoomEventFilter; use api::r0::membership::Invite3pid; use api::r0::message::get_message_events::{self, Direction}; @@ -9,6 +11,7 @@ use api::r0::room::{ create_room::{self, CreationContent, InitialStateEvent, RoomPreset}, Visibility, }; +use api::r0::uiaa::AuthData; use crate::js_int::UInt; @@ -288,6 +291,120 @@ impl Into for MessagesRequestBuilder { } } +/// A builder used to register users. +/// +/// # Examples +/// ``` +/// # use std::convert::TryFrom; +/// # use matrix_sdk::{Client, RegistrationBuilder}; +/// # use matrix_sdk::api::r0::account::register::RegistrationKind; +/// # use matrix_sdk::identifiers::DeviceId; +/// # use url::Url; +/// # let homeserver = Url::parse("http://example.com").unwrap(); +/// # let mut rt = tokio::runtime::Runtime::new().unwrap(); +/// # rt.block_on(async { +/// let mut builder = RegistrationBuilder::default(); +/// builder.password("pass") +/// .username("user") +/// .kind(RegistrationKind::User); +/// let mut client = Client::new(homeserver).unwrap(); +/// client.register_user(builder).await; +/// # }) +/// ``` +#[derive(Clone, Debug, Default)] +pub struct RegistrationBuilder { + password: Option, + username: Option, + device_id: Option, + initial_device_display_name: Option, + auth: Option, + kind: Option, + inhibit_login: bool, +} + +impl RegistrationBuilder { + /// Create a `RegistrationBuilder` builder to make a `register::Request`. + /// + /// The `room_id` and `from`` fields **need to be set** to create the request. + pub fn new() -> Self { + Self::default() + } + + /// The desired password for the account. + /// + /// May be empty for accounts that should not be able to log in again + /// with a password, e.g., for guest or application service accounts. + pub fn password(&mut self, password: &str) -> &mut Self { + self.password = Some(password.to_string()); + self + } + + /// local part of the desired Matrix ID. + /// + /// If omitted, the homeserver MUST generate a Matrix ID local part. + pub fn username(&mut self, username: &str) -> &mut Self { + self.username = Some(username.to_string()); + self + } + + /// ID of the client device. + /// + /// If this does not correspond to a known client device, a new device will be created. + /// The server will auto-generate a device_id if this is not specified. + pub fn device_id(&mut self, device_id: &str) -> &mut Self { + self.device_id = Some(device_id.to_string()); + self + } + + /// A display name to assign to the newly-created device. + /// + /// Ignored if `device_id` corresponds to a known device. + pub fn initial_device_display_name(&mut self, initial_device_display_name: &str) -> &mut Self { + self.initial_device_display_name = Some(initial_device_display_name.to_string()); + self + } + + /// Additional authentication information for the user-interactive authentication API. + /// + /// Note that this information is not used to define how the registered user should be + /// authenticated, but is instead used to authenticate the register call itself. + /// It should be left empty, or omitted, unless an earlier call returned an response + /// with status code 401. + pub fn auth(&mut self, auth: AuthData) -> &mut Self { + self.auth = Some(auth); + self + } + + /// Kind of account to register + /// + /// Defaults to `User` if omitted. + pub fn kind(&mut self, kind: RegistrationKind) -> &mut Self { + self.kind = Some(kind); + self + } + + /// If `true`, an `access_token` and `device_id` should not be returned + /// from this call, therefore preventing an automatic login. + pub fn inhibit_login(&mut self, inhibit_login: bool) -> &mut Self { + self.inhibit_login = inhibit_login; + self + } +} + +impl Into for RegistrationBuilder { + fn into(self) -> register::Request { + register::Request { + password: self.password, + username: self.username, + device_id: self.device_id, + initial_device_display_name: self.initial_device_display_name, + auth: self.auth, + kind: self.kind, + inhibit_login: self.inhibit_login, + } + } +} + #[cfg(test)] mod test { use std::collections::BTreeMap; diff --git a/test_data/registration_response_error.json b/test_data/registration_response_error.json new file mode 100644 index 00000000..b34244c1 --- /dev/null +++ b/test_data/registration_response_error.json @@ -0,0 +1,19 @@ +{ + "errcode": "M_FORBIDDEN", + "error": "Invalid password", + "completed": ["example.type.foo"], + "flows": [ + { + "stages": ["example.type.foo", "example.type.bar"] + }, + { + "stages": ["example.type.foo", "example.type.baz"] + } + ], + "params": { + "example.type.baz": { + "example_key": "foobar" + } + }, + "session": "xxxxxx" +}