diff --git a/matrix_sdk/src/client.rs b/matrix_sdk/src/client.rs index d0bd4445..3f2c8e81 100644 --- a/matrix_sdk/src/client.rs +++ b/matrix_sdk/src/client.rs @@ -75,29 +75,6 @@ pub enum LoopCtrl { Break, } -use matrix_sdk_common::{ - api::r0::{ - account::register, - device::{delete_devices, get_devices}, - directory::{get_public_rooms, get_public_rooms_filtered}, - filter::{create_filter::Request as FilterUploadRequest, FilterDefinition}, - media::{create_content, get_content, get_content_thumbnail}, - membership::{join_room_by_id, join_room_by_id_or_alias}, - message::send_message_event, - profile::{get_avatar_url, get_display_name, set_avatar_url, set_display_name}, - room::create_room, - session::{get_login_types, login, sso_login}, - sync::sync_events, - uiaa::AuthData, - }, - assign, - identifiers::{DeviceIdBox, RoomId, RoomIdOrAliasId, ServerName, UserId}, - instant::{Duration, Instant}, - locks::{Mutex, RwLock}, - presence::PresenceState, - uuid::Uuid, - FromHttpResponseError, UInt, -}; #[cfg(feature = "encryption")] use matrix_sdk_common::{ api::r0::{ @@ -108,6 +85,32 @@ use matrix_sdk_common::{ }, identifiers::EventId, }; +use matrix_sdk_common::{ + api::{ + r0::{ + account::register, + device::{delete_devices, get_devices}, + directory::{get_public_rooms, get_public_rooms_filtered}, + filter::{create_filter::Request as FilterUploadRequest, FilterDefinition}, + media::{create_content, get_content, get_content_thumbnail}, + membership::{join_room_by_id, join_room_by_id_or_alias}, + message::send_message_event, + profile::{get_avatar_url, get_display_name, set_avatar_url, set_display_name}, + room::create_room, + session::{get_login_types, login, sso_login}, + sync::sync_events, + uiaa::AuthData, + }, + unversioned::{discover_homeserver, get_supported_versions}, + }, + assign, + identifiers::{DeviceIdBox, RoomId, RoomIdOrAliasId, ServerName, UserId}, + instant::{Duration, Instant}, + locks::{Mutex, RwLock}, + presence::PresenceState, + uuid::Uuid, + FromHttpResponseError, UInt, +}; #[cfg(feature = "encryption")] use crate::{ @@ -142,7 +145,7 @@ const SSO_SERVER_BIND_TRIES: u8 = 10; #[derive(Clone)] pub struct Client { /// The URL of the homeserver to connect to. - homeserver: Arc, + homeserver: Arc>, /// The underlying HTTP client. http_client: HttpClient, /// User session data. @@ -164,7 +167,7 @@ pub struct Client { #[cfg(not(tarpaulin_include))] impl Debug for Client { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> StdResult<(), fmt::Error> { - write!(fmt, "Client {{ homeserver: {} }}", self.homeserver) + write!(fmt, "Client") } } @@ -502,7 +505,7 @@ impl Client { /// /// * `config` - Configuration for the client. pub fn new_with_config(homeserver_url: Url, config: ClientConfig) -> Result { - let homeserver = Arc::new(homeserver_url); + let homeserver = Arc::new(RwLock::new(homeserver_url)); let client = if let Some(client) = config.client { client @@ -513,12 +516,8 @@ impl Client { let base_client = BaseClient::new_with_config(config.base_config)?; let session = base_client.session().clone(); - let http_client = HttpClient { - homeserver: homeserver.clone(), - inner: client, - session, - request_config: config.request_config, - }; + let http_client = + HttpClient::new(client, homeserver.clone(), session, config.request_config); Ok(Self { homeserver, @@ -534,6 +533,89 @@ impl Client { }) } + /// Creates a new client for making HTTP requests to the homeserver of the + /// given user. Follows homeserver discovery directions described + /// [here](https://spec.matrix.org/unstable/client-server-api/#well-known-uri). + /// + /// # Arguments + /// + /// * `user_id` - The id of the user whose homeserver the client should + /// connect to. + /// + /// # Example + /// ```no_run + /// # use std::convert::TryFrom; + /// # use matrix_sdk::{Client, identifiers::UserId}; + /// # use futures::executor::block_on; + /// let alice = UserId::try_from("@alice:example.org").unwrap(); + /// # block_on(async { + /// let client = Client::new_from_user_id(alice.clone()).await.unwrap(); + /// client.login(alice.localpart(), "password", None, None).await.unwrap(); + /// # }); + /// ``` + pub async fn new_from_user_id(user_id: UserId) -> Result { + let config = ClientConfig::new(); + Client::new_from_user_id_with_config(user_id, config).await + } + + /// Creates a new client for making HTTP requests to the homeserver of the + /// given user and configuration. Follows homeserver discovery directions + /// described [here](https://spec.matrix.org/unstable/client-server-api/#well-known-uri). + /// + /// # Arguments + /// + /// * `user_id` - The id of the user whose homeserver the client should + /// connect to. + /// + /// * `config` - Configuration for the client. + pub async fn new_from_user_id_with_config( + user_id: UserId, + config: ClientConfig, + ) -> Result { + let homeserver = Client::homeserver_from_user_id(user_id)?; + let mut client = Client::new_with_config(homeserver, config)?; + + let well_known = client.discover_homeserver().await?; + let well_known = Url::parse(well_known.homeserver.base_url.as_ref())?; + client.set_homeserver(well_known).await; + client.get_supported_versions().await?; + Ok(client) + } + + fn homeserver_from_user_id(user_id: UserId) -> Result { + let homeserver = format!("https://{}", user_id.server_name()); + #[allow(unused_mut)] + let mut result = Url::parse(homeserver.as_str())?; + // Mockito only knows how to test http endpoints: + // https://github.com/lipanski/mockito/issues/127 + #[cfg(test)] + let _ = result.set_scheme("http"); + Ok(result) + } + + async fn discover_homeserver(&self) -> Result { + self.send(discover_homeserver::Request::new(), Some(RequestConfig::new().disable_retry())) + .await + } + + /// Change the homeserver URL used by this client. + /// + /// # Arguments + /// + /// * `homeserver_url` - The new URL to use. + pub async fn set_homeserver(&mut self, homeserver_url: Url) { + let mut homeserver = self.homeserver.write().await; + *homeserver = homeserver_url; + } + + async fn get_supported_versions(&self) -> Result { + self.send( + get_supported_versions::Request::new(), + Some(RequestConfig::new().disable_retry()), + ) + .await + } + /// Process a [transaction] received from the homeserver /// /// # Arguments @@ -566,8 +648,8 @@ impl Client { } /// The Homeserver of the client. - pub fn homeserver(&self) -> &Url { - &self.homeserver + pub async fn homeserver(&self) -> Url { + self.homeserver.read().await.clone() } /// Get the user id of the current owner of the client. @@ -866,8 +948,8 @@ impl Client { /// successful SSO login. /// /// [`login_with_token`]: #method.login_with_token - pub fn get_sso_login_url(&self, redirect_url: &str) -> Result { - let homeserver = self.homeserver(); + pub async fn get_sso_login_url(&self, redirect_url: &str) -> Result { + let homeserver = self.homeserver().await; let request = sso_login::Request::new(redirect_url) .try_into_http_request::>(homeserver.as_str(), SendAccessToken::None); match request { @@ -928,7 +1010,7 @@ impl Client { device_id: Option<&str>, initial_device_display_name: Option<&str>, ) -> Result { - info!("Logging in to {} as {:?}", self.homeserver, user); + info!("Logging in to {} as {:?}", self.homeserver().await, user); let request = assign!( login::Request::new( @@ -1037,7 +1119,7 @@ impl Client { where C: Future>, { - info!("Logging in to {}", self.homeserver); + info!("Logging in to {}", self.homeserver().await); let (signal_tx, signal_rx) = oneshot::channel(); let (data_tx, data_rx) = oneshot::channel(); let data_tx_mutex = Arc::new(std::sync::Mutex::new(Some(data_tx))); @@ -1109,7 +1191,7 @@ impl Client { tokio::spawn(server); - let sso_url = self.get_sso_login_url(redirect_url.as_str()).unwrap(); + let sso_url = self.get_sso_login_url(redirect_url.as_str()).await.unwrap(); match use_sso_login_url(sso_url).await { Ok(t) => t, @@ -1193,7 +1275,7 @@ impl Client { device_id: Option<&str>, initial_device_display_name: Option<&str>, ) -> Result { - info!("Logging in to {}", self.homeserver); + info!("Logging in to {}", self.homeserver().await); let request = assign!( login::Request::new( @@ -1264,7 +1346,7 @@ impl Client { &self, registration: impl Into>, ) -> Result { - info!("Registering to {}", self.homeserver); + info!("Registering to {}", self.homeserver().await); let request = registration.into(); self.send(request, None).await @@ -2387,7 +2469,13 @@ impl Client { #[cfg(test)] mod test { - use std::{collections::BTreeMap, convert::TryInto, io::Cursor, str::FromStr, time::Duration}; + use std::{ + collections::BTreeMap, + convert::{TryFrom, TryInto}, + io::Cursor, + str::FromStr, + time::Duration, + }; use matrix_sdk_base::identifiers::mxc_uri; use matrix_sdk_common::{ @@ -2399,7 +2487,7 @@ mod test { assign, directory::Filter, events::{room::message::MessageEventContent, AnyMessageEventContent}, - identifiers::{event_id, room_id, user_id}, + identifiers::{event_id, room_id, user_id, UserId}, thirdparty, }; use matrix_sdk_test::{test_json, EventBuilder, EventsJson}; @@ -2425,6 +2513,59 @@ mod test { client } + #[tokio::test] + async fn set_homeserver() { + let homeserver = Url::from_str("http://example.com/").unwrap(); + + let mut client = Client::new(homeserver).unwrap(); + + let homeserver = Url::from_str(&mockito::server_url()).unwrap(); + + client.set_homeserver(homeserver.clone()).await; + + assert_eq!(client.homeserver().await, homeserver); + } + + #[tokio::test] + async fn successful_discovery() { + let server_url = mockito::server_url(); + let domain = server_url.strip_prefix("http://").unwrap(); + let alice = UserId::try_from("@alice:".to_string() + domain).unwrap(); + + let _m_well_known = mock("GET", "/.well-known/matrix/client") + .with_status(200) + .with_body( + test_json::WELL_KNOWN.to_string().replace("HOMESERVER_URL", server_url.as_ref()), + ) + .create(); + + let _m_versions = mock("GET", "/_matrix/client/versions") + .with_status(200) + .with_body(test_json::VERSIONS.to_string()) + .create(); + let client = Client::new_from_user_id(alice).await.unwrap(); + + assert_eq!(client.homeserver().await, Url::parse(server_url.as_ref()).unwrap()); + } + + #[tokio::test] + async fn discovery_broken_server() { + let server_url = mockito::server_url(); + let domain = server_url.strip_prefix("http://").unwrap(); + let alice = UserId::try_from("@alice:".to_string() + domain).unwrap(); + + let _m = mock("GET", "/.well-known/matrix/client") + .with_status(200) + .with_body( + test_json::WELL_KNOWN.to_string().replace("HOMESERVER_URL", server_url.as_ref()), + ) + .create(); + + if Client::new_from_user_id(alice).await.is_ok() { + panic!("Creating a client from a user ID should fail when the .well-known server returns no version infromation."); + } + } + #[tokio::test] async fn login() { let homeserver = Url::from_str(&mockito::server_url()).unwrap(); @@ -2514,7 +2655,7 @@ mod test { .any(|flow| matches!(flow, LoginType::Sso(_))); assert!(can_sso); - let sso_url = client.get_sso_login_url("http://127.0.0.1:3030"); + let sso_url = client.get_sso_login_url("http://127.0.0.1:3030").await; assert!(sso_url.is_ok()); let _m = mock("POST", "/_matrix/client/r0/login") @@ -2626,7 +2767,7 @@ mod test { client.base_client.receive_sync_response(response).await.unwrap(); let room_id = room_id!("!SVkFJHzfwvuaIEawgC:localhost"); - assert_eq!(client.homeserver(), &Url::parse(&mockito::server_url()).unwrap()); + assert_eq!(client.homeserver().await, Url::parse(&mockito::server_url()).unwrap()); let room = client.get_joined_room(&room_id); assert!(room.is_some()); diff --git a/matrix_sdk/src/error.rs b/matrix_sdk/src/error.rs index fc12b91a..e2b032ec 100644 --- a/matrix_sdk/src/error.rs +++ b/matrix_sdk/src/error.rs @@ -31,6 +31,7 @@ use matrix_sdk_common::{ use reqwest::Error as ReqwestError; use serde_json::Error as JsonError; use thiserror::Error; +use url::ParseError as UrlParseError; /// Result type of the rust-sdk. pub type Result = std::result::Result; @@ -128,6 +129,10 @@ pub enum Error { /// An error encountered when trying to parse an identifier. #[error(transparent)] Identifier(#[from] IdentifierError), + + /// An error encountered when trying to parse a url. + #[error(transparent)] + Url(#[from] UrlParseError), } impl Error { diff --git a/matrix_sdk/src/http_client.rs b/matrix_sdk/src/http_client.rs index aced4169..cfd55be7 100644 --- a/matrix_sdk/src/http_client.rs +++ b/matrix_sdk/src/http_client.rs @@ -97,7 +97,7 @@ pub trait HttpSend: AsyncTraitDeps { #[derive(Clone, Debug)] pub(crate) struct HttpClient { pub(crate) inner: Arc, - pub(crate) homeserver: Arc, + pub(crate) homeserver: Arc>, pub(crate) session: Arc>>, pub(crate) request_config: RequestConfig, } @@ -106,6 +106,15 @@ pub(crate) struct HttpClient { use crate::OutgoingRequestAppserviceExt; impl HttpClient { + pub(crate) fn new( + inner: Arc, + homeserver: Arc>, + session: Arc>>, + request_config: RequestConfig, + ) -> Self { + HttpClient { inner, homeserver, session, request_config } + } + async fn send_request( &self, request: Request, @@ -161,7 +170,10 @@ impl HttpClient { }; let http_request = request - .try_into_http_request::(&self.homeserver.to_string(), access_token)? + .try_into_http_request::( + &self.homeserver.read().await.to_string(), + access_token, + )? .map(|body| body.freeze()); Ok(http_request) @@ -189,7 +201,7 @@ impl HttpClient { let http_request = request .try_into_http_request_with_user_id::( - &self.homeserver.to_string(), + &self.homeserver.read().await.to_string(), access_token, user_id, )? diff --git a/matrix_sdk_test/src/test_json/mod.rs b/matrix_sdk_test/src/test_json/mod.rs index c743e395..64b628f6 100644 --- a/matrix_sdk_test/src/test_json/mod.rs +++ b/matrix_sdk_test/src/test_json/mod.rs @@ -42,3 +42,29 @@ lazy_static! { ] }); } + +lazy_static! { + pub static ref WELL_KNOWN: JsonValue = json!({ + "m.homeserver": { + "base_url": "HOMESERVER_URL" + } + }); +} + +lazy_static! { + pub static ref VERSIONS: JsonValue = json!({ + "versions": [ + "r0.0.1", + "r0.1.0", + "r0.2.0", + "r0.3.0", + "r0.4.0", + "r0.5.0", + "r0.6.0" + ], + "unstable_features": { + "org.matrix.label_based_filtering":true, + "org.matrix.e2e_cross_signing":true + } + }); +}