From 8679e81555c2e71313ab5bfe4b1bd167fca783cf Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?K=C3=A9vin=20Commaille?= <zecakeh@pm.me>
Date: Tue, 23 Mar 2021 15:30:40 +0100
Subject: [PATCH] client: Add login_with_sso

---
 matrix_sdk/Cargo.toml    |  14 ++-
 matrix_sdk/src/client.rs | 256 +++++++++++++++++++++++++++++++++++++++
 matrix_sdk/src/lib.rs    |   4 +
 3 files changed, 273 insertions(+), 1 deletion(-)

diff --git a/matrix_sdk/Cargo.toml b/matrix_sdk/Cargo.toml
index b15d84d7..cb1f0fb3 100644
--- a/matrix_sdk/Cargo.toml
+++ b/matrix_sdk/Cargo.toml
@@ -25,8 +25,9 @@ markdown = ["matrix-sdk-base/markdown"]
 native-tls = ["reqwest/native-tls"]
 rustls-tls = ["reqwest/rustls-tls"]
 socks = ["reqwest/socks"]
+sso_login = ["warp", "rand", "tokio-stream"]
 
-docs = ["encryption", "sled_cryptostore", "sled_state_store"]
+docs = ["encryption", "sled_cryptostore", "sled_state_store", "sso_login"]
 
 [dependencies]
 dashmap = { version = "4.0.2", optional = true }
@@ -38,6 +39,7 @@ tracing = "0.1.22"
 url = "2.2.0"
 zeroize = "1.2.0"
 mime = "0.3.16"
+rand = { version = "0.8.2", optional = true }
 
 matrix-sdk-common = { version = "0.2.0", path = "../matrix_sdk_common" }
 
@@ -50,6 +52,16 @@ default_features = false
 version = "0.11.0"
 default_features = false
 
+[dependencies.tokio-stream]
+version = "0.1.4"
+features = ["net"]
+optional = true
+
+[dependencies.warp]
+version = "0.3.0"
+default-features = false
+optional = true
+
 [target.'cfg(not(target_arch = "wasm32"))'.dependencies.backoff]
 version = "0.3.0"
 features = ["tokio"]
diff --git a/matrix_sdk/src/client.rs b/matrix_sdk/src/client.rs
index 050c6a6a..c8a135aa 100644
--- a/matrix_sdk/src/client.rs
+++ b/matrix_sdk/src/client.rs
@@ -15,6 +15,12 @@
 
 #[cfg(feature = "encryption")]
 use std::{collections::BTreeMap, io::Write, path::PathBuf};
+#[cfg(feature = "sso_login")]
+use std::{
+    collections::HashMap,
+    io::{Error as IoError, ErrorKind as IoErrorKind},
+    ops::Range,
+};
 use std::{
     convert::TryInto,
     fmt::{self, Debug},
@@ -29,9 +35,19 @@ use std::{
 use dashmap::DashMap;
 use futures_timer::Delay as sleep;
 use http::HeaderValue;
+#[cfg(feature = "sso_login")]
+use http::Response;
 use mime::{self, Mime};
+#[cfg(feature = "sso_login")]
+use rand::{thread_rng, Rng};
 use reqwest::header::InvalidHeaderValue;
+#[cfg(feature = "sso_login")]
+use tokio::{net::TcpListener, sync::oneshot};
+#[cfg(feature = "sso_login")]
+use tokio_stream::wrappers::TcpListenerStream;
 use url::Url;
+#[cfg(feature = "sso_login")]
+use warp::Filter;
 #[cfg(feature = "encryption")]
 use zeroize::Zeroizing;
 
@@ -120,6 +136,12 @@ const SYNC_REQUEST_TIMEOUT: Duration = Duration::from_secs(15);
 const DEFAULT_UPLOAD_SPEED: u64 = 125_000;
 /// 5 min minimal upload request timeout, used to clamp the request timeout.
 const MIN_UPLOAD_REQUEST_TIMEOUT: Duration = Duration::from_secs(60 * 5);
+/// The range of ports the SSO server will try to bind to randomly
+#[cfg(feature = "sso_login")]
+const SSO_SERVER_BIND_RANGE: Range<u16> = 20000..30000;
+/// The number of timesthe SSO server will try to bind to a random port
+#[cfg(feature = "sso_login")]
+const SSO_SERVER_BIND_TRIES: u8 = 10;
 
 /// An async/await enabled Matrix client.
 ///
@@ -738,6 +760,200 @@ impl Client {
         Ok(response)
     }
 
+    /// Login to the server via Single Sign-On.
+    ///
+    /// This takes care of the whole SSO flow:
+    ///   * Spawn a local http server
+    ///   * Provide a callback to open the SSO login URL in a web browser
+    ///   * Wait for the local http server to get the loginToken
+    ///   * Call [`login_with_token`]
+    ///
+    /// If cancellation is needed the method should be wrapped in a cancellable
+    /// task. **Note** that users with root access to the system have the ability
+    /// to snoop in on the data/token that is passed to the local HTTP server
+    /// that will be spawned.
+    ///
+    /// If you need more control over the SSO login process, you should use
+    /// [`get_sso_login_url`] and [`login_with_token`] directly.
+    ///
+    /// This should only be used for the first login.
+    ///
+    /// The [`restore_login`] method should be used to restore a
+    /// logged in client after the first login.
+    ///
+    /// A device id should be provided to restore the correct stores, if the
+    /// device id isn't provided a new device will be created.
+    ///
+    /// # Arguments
+    ///
+    /// * `use_sso_login_url` - A callback that will receive the SSO Login URL. It
+    ///     should usually be used to open the SSO URL in a browser and must return
+    ///     `Ok(())` if the URL was successfully opened. If it returns `Err`, the
+    ///     error will be forwarded.
+    ///
+    /// * `server_url` - The local URL the server is going to try to bind to, e.g.
+    ///     `http://localhost:3030`. If `None`, the server will try to open a random
+    ///     port on localhost.
+    ///
+    /// * `server_response` - The text that will be shown on the webpage at the end
+    ///     of the login process. This can be an HTML page. If `None`, a default
+    ///     text will be displayed.
+    ///
+    /// * `device_id` - A unique id that will be associated with this session. If
+    ///     not given the homeserver will create one. Can be an existing device_id
+    ///     from a previous login call. Note that this should be provided only
+    ///     if the client also holds the encryption keys for this device.
+    ///
+    /// * `initial_device_display_name` - A public display name that will be
+    ///     associated with the device_id. Only necessary the first time you
+    ///     login with this device_id. It can be changed later.
+    ///
+    /// # Example
+    /// ```no_run
+    /// # use matrix_sdk::Client;
+    /// # use futures::executor::block_on;
+    /// # use url::Url;
+    /// # let homeserver = Url::parse("https://example.com").unwrap();
+    /// # block_on(async {
+    /// let client = Client::new(homeserver).unwrap();
+    ///
+    /// let response = client
+    ///     .login_with_sso(
+    ///         |sso_url| async move {
+    ///             // Open sso_url
+    ///             Ok(())
+    ///         },
+    ///         None,
+    ///         None,
+    ///         None,
+    ///         Some("My app")
+    ///     )
+    ///     .await
+    ///     .unwrap();
+    ///
+    /// println!("Logged in as {}, got device_id {} and access_token {}",
+    ///          response.user_id, response.device_id, response.access_token);
+    /// # })
+    /// ```
+    ///
+    /// [`get_sso_login_url`]: #method.get_sso_login_url
+    /// [`login_with_token`]: #method.login_with_token
+    /// [`restore_login`]: #method.restore_login
+    #[cfg(all(feature = "sso_login", not(target_arch = "wasm32")))]
+    #[cfg_attr(
+        feature = "docs",
+        doc(cfg(all(sso_login, not(target_arch = "wasm32"))))
+    )]
+    pub async fn login_with_sso<C>(
+        &self,
+        use_sso_login_url: impl Fn(String) -> C,
+        server_url: Option<&str>,
+        server_response: Option<&str>,
+        device_id: Option<&str>,
+        initial_device_display_name: Option<&str>,
+    ) -> Result<login::Response>
+    where
+        C: Future<Output = Result<()>>,
+    {
+        info!("Logging in to {}", self.homeserver);
+        let (signal_tx, signal_rx) = oneshot::channel();
+        let (data_tx, data_rx) = oneshot::channel();
+        let data_tx_mutex = Arc::new(std::sync::Mutex::new(Some(data_tx)));
+
+        let mut redirect_url = match server_url {
+            Some(s) => match Url::parse(s) {
+                Ok(url) => url,
+                Err(err) => return Err(IoError::new(IoErrorKind::InvalidData, err).into()),
+            },
+            None => Url::parse("http://localhost:0/").unwrap(),
+        };
+
+        let response = match server_response {
+            Some(s) => s.to_string(),
+            None => String::from(
+                "The Single Sign-On login process is complete. You can close this page now.",
+            ),
+        };
+
+        let route = warp::get()
+            .and(warp::query::<HashMap<String, String>>())
+            .map(move |p: HashMap<String, String>| {
+                if let Some(data_tx) = data_tx_mutex.lock().unwrap().take() {
+                    if let Some(token) = p.get("loginToken") {
+                        data_tx.send(Some(token.to_owned())).unwrap();
+                    } else {
+                        data_tx.send(None).unwrap();
+                    }
+                }
+                Response::builder().body(response.clone())
+            });
+
+        let listener = {
+            if redirect_url
+                .port()
+                .expect("The redirect URL doesn't include a port")
+                == 0
+            {
+                let host = redirect_url
+                    .host_str()
+                    .expect("The redirect URL doesn't have a host");
+                let mut n = 0u8;
+                let mut port = 0u16;
+                let mut res = Err(IoError::new(IoErrorKind::Other, ""));
+                let mut rng = thread_rng();
+
+                while res.is_err() && n < SSO_SERVER_BIND_TRIES {
+                    port = rng.gen_range(SSO_SERVER_BIND_RANGE);
+                    res = TcpListener::bind((host, port)).await;
+                    n += 1;
+                }
+                match res {
+                    Ok(s) => {
+                        redirect_url
+                            .set_port(Some(port))
+                            .expect("Could not set new port on redirect URL");
+                        s
+                    }
+                    Err(err) => return Err(err.into()),
+                }
+            } else {
+                match TcpListener::bind(redirect_url.as_str()).await {
+                    Ok(s) => s,
+                    Err(err) => return Err(err.into()),
+                }
+            }
+        };
+
+        let server = warp::serve(route).serve_incoming_with_graceful_shutdown(
+            TcpListenerStream::new(listener),
+            async {
+                signal_rx.await.ok();
+            },
+        );
+
+        tokio::spawn(server);
+
+        let sso_url = self.get_sso_login_url(redirect_url.as_str()).unwrap();
+
+        match use_sso_login_url(sso_url).await {
+            Ok(t) => t,
+            Err(err) => return Err(err),
+        };
+
+        let token = match data_rx.await {
+            Ok(Some(t)) => t,
+            Ok(None) => {
+                return Err(IoError::new(IoErrorKind::Other, "Could not get the loginToken").into())
+            }
+            Err(err) => return Err(IoError::new(IoErrorKind::Other, format!("{}", err)).into()),
+        };
+
+        let _ = signal_tx.send(());
+
+        self.login_with_token(token.as_str(), device_id, initial_device_display_name)
+            .await
+    }
+
     /// Login to the server with a token.
     ///
     /// This token is usually received in the SSO flow after following the URL
@@ -1990,6 +2206,46 @@ mod test {
         assert!(logged_in, "Client should be logged in");
     }
 
+    #[cfg(feature = "sso_login")]
+    #[tokio::test]
+    async fn login_with_sso() {
+        let _m_login = mock("POST", "/_matrix/client/r0/login")
+            .with_status(200)
+            .with_body(test_json::LOGIN.to_string())
+            .create();
+
+        let homeserver = Url::from_str(&mockito::server_url()).unwrap();
+        let client = Client::new(homeserver).unwrap();
+
+        client
+            .login_with_sso(
+                |sso_url| async move {
+                    let sso_url = Url::parse(sso_url.as_str()).unwrap();
+
+                    let (_, redirect) = sso_url
+                        .query_pairs()
+                        .find(|(key, _)| key == "redirectUrl")
+                        .unwrap();
+
+                    let mut redirect_url = Url::parse(redirect.into_owned().as_str()).unwrap();
+                    redirect_url.set_query(Some("loginToken=tinytoken"));
+
+                    reqwest::get(redirect_url.to_string()).await.unwrap();
+
+                    Ok(())
+                },
+                None,
+                None,
+                None,
+                None,
+            )
+            .await
+            .unwrap();
+
+        let logged_in = client.logged_in().await;
+        assert!(logged_in, "Client should be logged in");
+    }
+
     #[tokio::test]
     async fn login_with_sso_token() {
         let homeserver = Url::from_str(&mockito::server_url()).unwrap();
diff --git a/matrix_sdk/src/lib.rs b/matrix_sdk/src/lib.rs
index 467c9eaa..177c497a 100644
--- a/matrix_sdk/src/lib.rs
+++ b/matrix_sdk/src/lib.rs
@@ -45,6 +45,7 @@
 //! of Synapse in compliance with the Matrix API specification.
 //! * `markdown`: Support for sending markdown formatted messages.
 //! * `socks`: Enables SOCKS support in reqwest, the default HTTP client.
+//! * `sso_login`: Enables SSO login with a local http server.
 
 #![deny(
     missing_debug_implementations,
@@ -64,6 +65,9 @@ compile_error!("one of 'native-tls' or 'rustls-tls' features must be enabled");
 #[cfg(all(feature = "native-tls", feature = "rustls-tls",))]
 compile_error!("only one of 'native-tls' or 'rustls-tls' features can be enabled");
 
+#[cfg(all(feature = "sso_login", target_arch = "wasm32"))]
+compile_error!("'sso_login' cannot be enabled on 'wasm32' arch");
+
 #[cfg(feature = "encryption")]
 #[cfg_attr(feature = "docs", doc(cfg(encryption)))]
 pub use matrix_sdk_base::crypto::{EncryptionInfo, LocalTrust};