diff --git a/matrix_sdk/Cargo.toml b/matrix_sdk/Cargo.toml index b15d84d7..cb1f0fb3 100644 --- a/matrix_sdk/Cargo.toml +++ b/matrix_sdk/Cargo.toml @@ -25,8 +25,9 @@ markdown = ["matrix-sdk-base/markdown"] native-tls = ["reqwest/native-tls"] rustls-tls = ["reqwest/rustls-tls"] socks = ["reqwest/socks"] +sso_login = ["warp", "rand", "tokio-stream"] -docs = ["encryption", "sled_cryptostore", "sled_state_store"] +docs = ["encryption", "sled_cryptostore", "sled_state_store", "sso_login"] [dependencies] dashmap = { version = "4.0.2", optional = true } @@ -38,6 +39,7 @@ tracing = "0.1.22" url = "2.2.0" zeroize = "1.2.0" mime = "0.3.16" +rand = { version = "0.8.2", optional = true } matrix-sdk-common = { version = "0.2.0", path = "../matrix_sdk_common" } @@ -50,6 +52,16 @@ default_features = false version = "0.11.0" default_features = false +[dependencies.tokio-stream] +version = "0.1.4" +features = ["net"] +optional = true + +[dependencies.warp] +version = "0.3.0" +default-features = false +optional = true + [target.'cfg(not(target_arch = "wasm32"))'.dependencies.backoff] version = "0.3.0" features = ["tokio"] diff --git a/matrix_sdk/src/client.rs b/matrix_sdk/src/client.rs index 050c6a6a..c8a135aa 100644 --- a/matrix_sdk/src/client.rs +++ b/matrix_sdk/src/client.rs @@ -15,6 +15,12 @@ #[cfg(feature = "encryption")] use std::{collections::BTreeMap, io::Write, path::PathBuf}; +#[cfg(feature = "sso_login")] +use std::{ + collections::HashMap, + io::{Error as IoError, ErrorKind as IoErrorKind}, + ops::Range, +}; use std::{ convert::TryInto, fmt::{self, Debug}, @@ -29,9 +35,19 @@ use std::{ use dashmap::DashMap; use futures_timer::Delay as sleep; use http::HeaderValue; +#[cfg(feature = "sso_login")] +use http::Response; use mime::{self, Mime}; +#[cfg(feature = "sso_login")] +use rand::{thread_rng, Rng}; use reqwest::header::InvalidHeaderValue; +#[cfg(feature = "sso_login")] +use tokio::{net::TcpListener, sync::oneshot}; +#[cfg(feature = "sso_login")] +use tokio_stream::wrappers::TcpListenerStream; use url::Url; +#[cfg(feature = "sso_login")] +use warp::Filter; #[cfg(feature = "encryption")] use zeroize::Zeroizing; @@ -120,6 +136,12 @@ const SYNC_REQUEST_TIMEOUT: Duration = Duration::from_secs(15); const DEFAULT_UPLOAD_SPEED: u64 = 125_000; /// 5 min minimal upload request timeout, used to clamp the request timeout. const MIN_UPLOAD_REQUEST_TIMEOUT: Duration = Duration::from_secs(60 * 5); +/// The range of ports the SSO server will try to bind to randomly +#[cfg(feature = "sso_login")] +const SSO_SERVER_BIND_RANGE: Range<u16> = 20000..30000; +/// The number of timesthe SSO server will try to bind to a random port +#[cfg(feature = "sso_login")] +const SSO_SERVER_BIND_TRIES: u8 = 10; /// An async/await enabled Matrix client. /// @@ -738,6 +760,200 @@ impl Client { Ok(response) } + /// Login to the server via Single Sign-On. + /// + /// This takes care of the whole SSO flow: + /// * Spawn a local http server + /// * Provide a callback to open the SSO login URL in a web browser + /// * Wait for the local http server to get the loginToken + /// * Call [`login_with_token`] + /// + /// If cancellation is needed the method should be wrapped in a cancellable + /// task. **Note** that users with root access to the system have the ability + /// to snoop in on the data/token that is passed to the local HTTP server + /// that will be spawned. + /// + /// If you need more control over the SSO login process, you should use + /// [`get_sso_login_url`] and [`login_with_token`] directly. + /// + /// This should only be used for the first login. + /// + /// The [`restore_login`] method should be used to restore a + /// logged in client after the first login. + /// + /// A device id should be provided to restore the correct stores, if the + /// device id isn't provided a new device will be created. + /// + /// # Arguments + /// + /// * `use_sso_login_url` - A callback that will receive the SSO Login URL. It + /// should usually be used to open the SSO URL in a browser and must return + /// `Ok(())` if the URL was successfully opened. If it returns `Err`, the + /// error will be forwarded. + /// + /// * `server_url` - The local URL the server is going to try to bind to, e.g. + /// `http://localhost:3030`. If `None`, the server will try to open a random + /// port on localhost. + /// + /// * `server_response` - The text that will be shown on the webpage at the end + /// of the login process. This can be an HTML page. If `None`, a default + /// text will be displayed. + /// + /// * `device_id` - A unique id that will be associated with this session. If + /// not given the homeserver will create one. Can be an existing device_id + /// from a previous login call. Note that this should be provided only + /// if the client also holds the encryption keys for this device. + /// + /// * `initial_device_display_name` - A public display name that will be + /// associated with the device_id. Only necessary the first time you + /// login with this device_id. It can be changed later. + /// + /// # Example + /// ```no_run + /// # use matrix_sdk::Client; + /// # use futures::executor::block_on; + /// # use url::Url; + /// # let homeserver = Url::parse("https://example.com").unwrap(); + /// # block_on(async { + /// let client = Client::new(homeserver).unwrap(); + /// + /// let response = client + /// .login_with_sso( + /// |sso_url| async move { + /// // Open sso_url + /// Ok(()) + /// }, + /// None, + /// None, + /// None, + /// Some("My app") + /// ) + /// .await + /// .unwrap(); + /// + /// println!("Logged in as {}, got device_id {} and access_token {}", + /// response.user_id, response.device_id, response.access_token); + /// # }) + /// ``` + /// + /// [`get_sso_login_url`]: #method.get_sso_login_url + /// [`login_with_token`]: #method.login_with_token + /// [`restore_login`]: #method.restore_login + #[cfg(all(feature = "sso_login", not(target_arch = "wasm32")))] + #[cfg_attr( + feature = "docs", + doc(cfg(all(sso_login, not(target_arch = "wasm32")))) + )] + pub async fn login_with_sso<C>( + &self, + use_sso_login_url: impl Fn(String) -> C, + server_url: Option<&str>, + server_response: Option<&str>, + device_id: Option<&str>, + initial_device_display_name: Option<&str>, + ) -> Result<login::Response> + where + C: Future<Output = Result<()>>, + { + info!("Logging in to {}", self.homeserver); + let (signal_tx, signal_rx) = oneshot::channel(); + let (data_tx, data_rx) = oneshot::channel(); + let data_tx_mutex = Arc::new(std::sync::Mutex::new(Some(data_tx))); + + let mut redirect_url = match server_url { + Some(s) => match Url::parse(s) { + Ok(url) => url, + Err(err) => return Err(IoError::new(IoErrorKind::InvalidData, err).into()), + }, + None => Url::parse("http://localhost:0/").unwrap(), + }; + + let response = match server_response { + Some(s) => s.to_string(), + None => String::from( + "The Single Sign-On login process is complete. You can close this page now.", + ), + }; + + let route = warp::get() + .and(warp::query::<HashMap<String, String>>()) + .map(move |p: HashMap<String, String>| { + if let Some(data_tx) = data_tx_mutex.lock().unwrap().take() { + if let Some(token) = p.get("loginToken") { + data_tx.send(Some(token.to_owned())).unwrap(); + } else { + data_tx.send(None).unwrap(); + } + } + Response::builder().body(response.clone()) + }); + + let listener = { + if redirect_url + .port() + .expect("The redirect URL doesn't include a port") + == 0 + { + let host = redirect_url + .host_str() + .expect("The redirect URL doesn't have a host"); + let mut n = 0u8; + let mut port = 0u16; + let mut res = Err(IoError::new(IoErrorKind::Other, "")); + let mut rng = thread_rng(); + + while res.is_err() && n < SSO_SERVER_BIND_TRIES { + port = rng.gen_range(SSO_SERVER_BIND_RANGE); + res = TcpListener::bind((host, port)).await; + n += 1; + } + match res { + Ok(s) => { + redirect_url + .set_port(Some(port)) + .expect("Could not set new port on redirect URL"); + s + } + Err(err) => return Err(err.into()), + } + } else { + match TcpListener::bind(redirect_url.as_str()).await { + Ok(s) => s, + Err(err) => return Err(err.into()), + } + } + }; + + let server = warp::serve(route).serve_incoming_with_graceful_shutdown( + TcpListenerStream::new(listener), + async { + signal_rx.await.ok(); + }, + ); + + tokio::spawn(server); + + let sso_url = self.get_sso_login_url(redirect_url.as_str()).unwrap(); + + match use_sso_login_url(sso_url).await { + Ok(t) => t, + Err(err) => return Err(err), + }; + + let token = match data_rx.await { + Ok(Some(t)) => t, + Ok(None) => { + return Err(IoError::new(IoErrorKind::Other, "Could not get the loginToken").into()) + } + Err(err) => return Err(IoError::new(IoErrorKind::Other, format!("{}", err)).into()), + }; + + let _ = signal_tx.send(()); + + self.login_with_token(token.as_str(), device_id, initial_device_display_name) + .await + } + /// Login to the server with a token. /// /// This token is usually received in the SSO flow after following the URL @@ -1990,6 +2206,46 @@ mod test { assert!(logged_in, "Client should be logged in"); } + #[cfg(feature = "sso_login")] + #[tokio::test] + async fn login_with_sso() { + let _m_login = mock("POST", "/_matrix/client/r0/login") + .with_status(200) + .with_body(test_json::LOGIN.to_string()) + .create(); + + let homeserver = Url::from_str(&mockito::server_url()).unwrap(); + let client = Client::new(homeserver).unwrap(); + + client + .login_with_sso( + |sso_url| async move { + let sso_url = Url::parse(sso_url.as_str()).unwrap(); + + let (_, redirect) = sso_url + .query_pairs() + .find(|(key, _)| key == "redirectUrl") + .unwrap(); + + let mut redirect_url = Url::parse(redirect.into_owned().as_str()).unwrap(); + redirect_url.set_query(Some("loginToken=tinytoken")); + + reqwest::get(redirect_url.to_string()).await.unwrap(); + + Ok(()) + }, + None, + None, + None, + None, + ) + .await + .unwrap(); + + let logged_in = client.logged_in().await; + assert!(logged_in, "Client should be logged in"); + } + #[tokio::test] async fn login_with_sso_token() { let homeserver = Url::from_str(&mockito::server_url()).unwrap(); diff --git a/matrix_sdk/src/lib.rs b/matrix_sdk/src/lib.rs index 467c9eaa..177c497a 100644 --- a/matrix_sdk/src/lib.rs +++ b/matrix_sdk/src/lib.rs @@ -45,6 +45,7 @@ //! of Synapse in compliance with the Matrix API specification. //! * `markdown`: Support for sending markdown formatted messages. //! * `socks`: Enables SOCKS support in reqwest, the default HTTP client. +//! * `sso_login`: Enables SSO login with a local http server. #![deny( missing_debug_implementations, @@ -64,6 +65,9 @@ compile_error!("one of 'native-tls' or 'rustls-tls' features must be enabled"); #[cfg(all(feature = "native-tls", feature = "rustls-tls",))] compile_error!("only one of 'native-tls' or 'rustls-tls' features can be enabled"); +#[cfg(all(feature = "sso_login", target_arch = "wasm32"))] +compile_error!("'sso_login' cannot be enabled on 'wasm32' arch"); + #[cfg(feature = "encryption")] #[cfg_attr(feature = "docs", doc(cfg(encryption)))] pub use matrix_sdk_base::crypto::{EncryptionInfo, LocalTrust};