chore(sdk): Move the sso related imports into the sso login method

master
Damir Jelić 2021-09-14 13:20:03 +02:00
parent 29d11db73a
commit cf26557cc2
1 changed files with 20 additions and 28 deletions

View File

@ -23,19 +23,11 @@ use std::{
result::Result as StdResult, result::Result as StdResult,
sync::Arc, sync::Arc,
}; };
#[cfg(feature = "sso_login")]
use std::{
collections::HashMap,
io::{Error as IoError, ErrorKind as IoErrorKind},
ops::Range,
};
use dashmap::DashMap; use dashmap::DashMap;
use futures::FutureExt; use futures::FutureExt;
use futures_timer::Delay as sleep; use futures_timer::Delay as sleep;
use http::HeaderValue; use http::HeaderValue;
#[cfg(feature = "sso_login")]
use http::Response;
use matrix_sdk_base::{ use matrix_sdk_base::{
deserialized_responses::{JoinedRoom, LeftRoom, SyncResponse}, deserialized_responses::{JoinedRoom, LeftRoom, SyncResponse},
media::{MediaEventContent, MediaFormat, MediaRequest, MediaThumbnailSize, MediaType}, media::{MediaEventContent, MediaFormat, MediaRequest, MediaThumbnailSize, MediaType},
@ -47,8 +39,6 @@ use matrix_sdk_common::{
uuid::Uuid, uuid::Uuid,
}; };
use mime::{self, Mime}; use mime::{self, Mime};
#[cfg(feature = "sso_login")]
use rand::{thread_rng, Rng};
use reqwest::header::InvalidHeaderValue; use reqwest::header::InvalidHeaderValue;
use ruma::{ use ruma::{
api::{ api::{
@ -79,14 +69,8 @@ use ruma::{
DeviceIdBox, MxcUri, RoomId, RoomIdOrAliasId, ServerName, UInt, UserId, DeviceIdBox, MxcUri, RoomId, RoomIdOrAliasId, ServerName, UInt, UserId,
}; };
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
#[cfg(feature = "sso_login")]
use tokio::{net::TcpListener, sync::oneshot};
#[cfg(feature = "sso_login")]
use tokio_stream::wrappers::TcpListenerStream;
use tracing::{error, info, instrument, warn}; use tracing::{error, info, instrument, warn};
use url::Url; use url::Url;
#[cfg(feature = "sso_login")]
use warp::Filter;
use crate::{ use crate::{
error::{HttpError, HttpResult}, error::{HttpError, HttpResult},
@ -101,12 +85,6 @@ const DEFAULT_SYNC_TIMEOUT: Duration = Duration::from_secs(30);
const DEFAULT_UPLOAD_SPEED: u64 = 125_000; const DEFAULT_UPLOAD_SPEED: u64 = 125_000;
/// 5 min minimal upload request timeout, used to clamp the request timeout. /// 5 min minimal upload request timeout, used to clamp the request timeout.
const MIN_UPLOAD_REQUEST_TIMEOUT: Duration = Duration::from_secs(60 * 5); const MIN_UPLOAD_REQUEST_TIMEOUT: Duration = Duration::from_secs(60 * 5);
/// The range of ports the SSO server will try to bind to randomly
#[cfg(feature = "sso_login")]
const SSO_SERVER_BIND_RANGE: Range<u16> = 20000..30000;
/// The number of times the SSO server will try to bind to a random port
#[cfg(feature = "sso_login")]
const SSO_SERVER_BIND_TRIES: u8 = 10;
type EventHandlerFut = Pin<Box<dyn Future<Output = ()> + Send>>; type EventHandlerFut = Pin<Box<dyn Future<Output = ()> + Send>>;
type EventHandlerFn = Box<dyn Fn(EventHandlerData<'_>) -> EventHandlerFut + Send + Sync>; type EventHandlerFn = Box<dyn Fn(EventHandlerData<'_>) -> EventHandlerFut + Send + Sync>;
@ -1252,9 +1230,23 @@ impl Client {
where where
C: Future<Output = Result<()>>, C: Future<Output = Result<()>>,
{ {
use std::{
collections::HashMap,
io::{Error as IoError, ErrorKind as IoErrorKind},
ops::Range,
};
use rand::{thread_rng, Rng};
use warp::Filter;
/// The range of ports the SSO server will try to bind to randomly
const SSO_SERVER_BIND_RANGE: Range<u16> = 20000..30000;
/// The number of times the SSO server will try to bind to a random port
const SSO_SERVER_BIND_TRIES: u8 = 10;
info!("Logging in to {}", self.homeserver().await); info!("Logging in to {}", self.homeserver().await);
let (signal_tx, signal_rx) = oneshot::channel(); let (signal_tx, signal_rx) = tokio::sync::oneshot::channel();
let (data_tx, data_rx) = oneshot::channel(); let (data_tx, data_rx) = tokio::sync::oneshot::channel();
let data_tx_mutex = Arc::new(std::sync::Mutex::new(Some(data_tx))); let data_tx_mutex = Arc::new(std::sync::Mutex::new(Some(data_tx)));
let mut redirect_url = match server_url { let mut redirect_url = match server_url {
@ -1281,7 +1273,7 @@ impl Client {
data_tx.send(None).unwrap(); data_tx.send(None).unwrap();
} }
} }
Response::builder().body(response.clone()) http::Response::builder().body(response.clone())
}, },
); );
@ -1295,7 +1287,7 @@ impl Client {
while res.is_err() && n < SSO_SERVER_BIND_TRIES { while res.is_err() && n < SSO_SERVER_BIND_TRIES {
port = rng.gen_range(SSO_SERVER_BIND_RANGE); port = rng.gen_range(SSO_SERVER_BIND_RANGE);
res = TcpListener::bind((host, port)).await; res = tokio::net::TcpListener::bind((host, port)).await;
n += 1; n += 1;
} }
match res { match res {
@ -1308,7 +1300,7 @@ impl Client {
Err(err) => return Err(err.into()), Err(err) => return Err(err.into()),
} }
} else { } else {
match TcpListener::bind(redirect_url.as_str()).await { match tokio::net::TcpListener::bind(redirect_url.as_str()).await {
Ok(s) => s, Ok(s) => s,
Err(err) => return Err(err.into()), Err(err) => return Err(err.into()),
} }
@ -1316,7 +1308,7 @@ impl Client {
}; };
let server = warp::serve(route).serve_incoming_with_graceful_shutdown( let server = warp::serve(route).serve_incoming_with_graceful_shutdown(
TcpListenerStream::new(listener), tokio_stream::wrappers::TcpListenerStream::new(listener),
async { async {
signal_rx.await.ok(); signal_rx.await.ok();
}, },