diff --git a/matrix_sdk/src/client.rs b/matrix_sdk/src/client.rs index 74f47a2f..3a8494eb 100644 --- a/matrix_sdk/src/client.rs +++ b/matrix_sdk/src/client.rs @@ -47,15 +47,14 @@ use url::Url; use crate::{ events::{room::message::MessageEventContent, EventType}, + http_client::DefaultHttpClient, identifiers::{EventId, RoomId, RoomIdOrAliasId, UserId}, - Endpoint, + Endpoint, HttpSend, }; #[cfg(feature = "encryption")] use crate::{identifiers::DeviceId, sas::Sas}; -#[cfg(not(target_arch = "wasm32"))] -use crate::VERSION; use crate::{api, http_client::HttpClient, EventEmitter, Result}; use matrix_sdk_base::{BaseClient, BaseClientConfig, Room, Session, StateStore}; @@ -108,11 +107,12 @@ impl Debug for Client { #[derive(Default)] pub struct ClientConfig { #[cfg(not(target_arch = "wasm32"))] - proxy: Option, - user_agent: Option, - disable_ssl_verification: bool, - base_config: BaseClientConfig, - timeout: Option, + pub(crate) proxy: Option, + pub(crate) user_agent: Option, + pub(crate) disable_ssl_verification: bool, + pub(crate) base_config: BaseClientConfig, + pub(crate) timeout: Option, + pub(crate) client: Option>, } // #[cfg_attr(tarpaulin, skip)] @@ -212,6 +212,15 @@ impl ClientConfig { self.timeout = Some(timeout); self } + + /// Specify a client to handle sending requests and receiving responses. + /// + /// Any type that implements the `HttpSend` trait can be used to send/receive + /// `http` types. + pub fn client(mut self, client: Arc) -> Self { + self.client = Some(client); + self + } } #[derive(Debug, Default, Clone)] @@ -322,48 +331,21 @@ impl Client { config: ClientConfig, ) -> Result { let homeserver = if let Ok(u) = homeserver_url.try_into() { - u + Arc::new(u) } else { panic!("Error parsing homeserver url") }; - let http_client = reqwest::Client::builder(); - - #[cfg(not(target_arch = "wasm32"))] - let http_client = { - let http_client = match config.timeout { - Some(x) => http_client.timeout(x), - None => http_client, - }; - - let http_client = if config.disable_ssl_verification { - http_client.danger_accept_invalid_certs(true) - } else { - http_client - }; - - let http_client = match config.proxy { - Some(p) => http_client.proxy(p), - None => http_client, - }; - - let mut headers = reqwest::header::HeaderMap::new(); - - let user_agent = match config.user_agent { - Some(a) => a, - None => HeaderValue::from_str(&format!("matrix-rust-sdk {}", VERSION)).unwrap(), - }; - - headers.insert(reqwest::header::USER_AGENT, user_agent); - - http_client.default_headers(headers) - }; - - let homeserver = Arc::new(homeserver); - - let http_client = HttpClient { - homeserver: homeserver.clone(), - inner: http_client.build()?, + let http_client = if let Some(client) = config.client { + HttpClient { + homeserver: homeserver.clone(), + inner: client, + } + } else { + HttpClient { + homeserver: homeserver.clone(), + inner: Arc::new(DefaultHttpClient::with_config(&config)?), + } }; let base_client = BaseClient::new_with_config(config.base_config)?; diff --git a/matrix_sdk/src/http_client.rs b/matrix_sdk/src/http_client.rs index 3e2367a1..000e8e36 100644 --- a/matrix_sdk/src/http_client.rs +++ b/matrix_sdk/src/http_client.rs @@ -12,10 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{ - convert::{TryFrom, TryInto}, - sync::Arc, -}; +use std::{convert::TryFrom, sync::Arc}; use http::{HeaderValue, Method as HttpMethod, Response as HttpResponse}; use reqwest::{Client, Response}; @@ -24,11 +21,11 @@ use url::Url; use matrix_sdk_common::{locks::RwLock, FromHttpResponseError}; -use crate::{Endpoint, Error, Result, Session}; +use crate::{ClientConfig, Endpoint, Error, HttpSend, Result, Session}; #[derive(Clone, Debug)] pub(crate) struct HttpClient { - pub(crate) inner: Client, + pub(crate) inner: Arc, pub(crate) homeserver: Arc, } @@ -62,7 +59,7 @@ impl HttpClient { ); } - Ok(self.inner.execute(request.try_into()?).await?) + self.inner.send_request(request).await } async fn response_to_http_response( @@ -100,3 +97,62 @@ impl HttpClient { Ok(Request::IncomingResponse::try_from(response)?) } } + +/// Default http client used if none is specified using `Client::with_client`. +#[derive(Clone, Debug)] +pub struct DefaultHttpClient { + inner: Client, +} + +impl DefaultHttpClient { + /// Build a client with the specified configuration. + pub fn with_config(config: &ClientConfig) -> Result { + let http_client = reqwest::Client::builder(); + + #[cfg(not(target_arch = "wasm32"))] + let http_client = { + let http_client = match config.timeout { + Some(x) => http_client.timeout(x), + None => http_client, + }; + + let http_client = if config.disable_ssl_verification { + http_client.danger_accept_invalid_certs(true) + } else { + http_client + }; + + let http_client = match &config.proxy { + Some(p) => http_client.proxy(p.clone()), + None => http_client, + }; + + let mut headers = reqwest::header::HeaderMap::new(); + + let user_agent = match &config.user_agent { + Some(a) => a.clone(), + None => { + HeaderValue::from_str(&format!("matrix-rust-sdk {}", crate::VERSION)).unwrap() + } + }; + + headers.insert(reqwest::header::USER_AGENT, user_agent); + + http_client.default_headers(headers) + }; + + Ok(Self { + inner: http_client.build()?, + }) + } +} + +#[async_trait::async_trait] +impl HttpSend for DefaultHttpClient { + async fn send_request(&self, request: http::Request>) -> Result { + Ok(self + .inner + .execute(reqwest::Request::try_from(request)?) + .await?) + } +}