diff --git a/matrix_sdk/src/client.rs b/matrix_sdk/src/client.rs index 4fc63f81..499ba845 100644 --- a/matrix_sdk/src/client.rs +++ b/matrix_sdk/src/client.rs @@ -33,11 +33,9 @@ use futures_timer::Delay as sleep; use std::future::Future; #[cfg(feature = "encryption")] use tracing::{debug, warn}; -use tracing::{error, info, instrument, trace}; +use tracing::{error, info, instrument}; -use http::Method as HttpMethod; -use http::Response as HttpResponse; -use reqwest::header::{HeaderValue, InvalidHeaderValue, AUTHORIZATION}; +use reqwest::header::{HeaderValue, InvalidHeaderValue}; use url::Url; use crate::events::room::message::MessageEventContent; @@ -49,9 +47,10 @@ use crate::Endpoint; use crate::identifiers::DeviceId; use crate::api; +use crate::http_client::HttpClient; #[cfg(not(target_arch = "wasm32"))] use crate::VERSION; -use crate::{Error, EventEmitter, Result}; +use crate::{EventEmitter, Result}; use matrix_sdk_base::{BaseClient, BaseClientConfig, Room, Session, StateStore}; const DEFAULT_SYNC_TIMEOUT: Duration = Duration::from_secs(30); @@ -62,9 +61,9 @@ const DEFAULT_SYNC_TIMEOUT: Duration = Duration::from_secs(30); #[derive(Clone)] pub struct Client { /// The URL of the homeserver to connect to. - homeserver: Url, + homeserver: Arc, /// The underlying HTTP client. - http_client: reqwest::Client, + http_client: HttpClient, /// User session data. pub(crate) base_client: BaseClient, } @@ -351,7 +350,12 @@ impl Client { http_client.default_headers(headers) }; - let http_client = http_client.build()?; + let homeserver = Arc::new(homeserver); + + let http_client = HttpClient { + homeserver: homeserver.clone(), + inner: http_client.build()?, + }; let base_client = BaseClient::new_with_config(config.base_config)?; @@ -1048,78 +1052,6 @@ impl Client { Ok(response) } - async fn send_request( - &self, - requires_auth: bool, - method: HttpMethod, - request: http::Request>, - ) -> Result { - let url = request.uri(); - let path_and_query = url.path_and_query().unwrap(); - let mut url = self.homeserver.clone(); - - url.set_path(path_and_query.path()); - url.set_query(path_and_query.query()); - - let request_builder = match method { - HttpMethod::GET => self.http_client.get(url), - HttpMethod::POST => { - let body = request.body().clone(); - self.http_client - .post(url) - .body(body) - .header(reqwest::header::CONTENT_TYPE, "application/json") - } - HttpMethod::PUT => { - let body = request.body().clone(); - self.http_client - .put(url) - .body(body) - .header(reqwest::header::CONTENT_TYPE, "application/json") - } - HttpMethod::DELETE => { - let body = request.body().clone(); - self.http_client - .delete(url) - .body(body) - .header(reqwest::header::CONTENT_TYPE, "application/json") - } - method => panic!("Unsupported method {}", method), - }; - - let request_builder = if requires_auth { - let session = self.base_client.session().read().await; - - if let Some(session) = session.as_ref() { - let header_value = format!("Bearer {}", &session.access_token); - request_builder.header(AUTHORIZATION, header_value) - } else { - return Err(Error::AuthenticationRequired); - } - } else { - request_builder - }; - - Ok(request_builder.send().await?) - } - - async fn response_to_http_response( - &self, - mut response: reqwest::Response, - ) -> Result>> { - let status = response.status(); - let mut http_builder = HttpResponse::builder().status(status); - let headers = http_builder.headers_mut().unwrap(); - - for (k, v) in response.headers_mut().drain() { - if let Some(key) = k { - headers.insert(key, v); - } - } - let body = response.bytes().await?.as_ref().to_owned(); - Ok(http_builder.body(body).unwrap()) - } - /// Send an arbitrary request to the server, without updating client state. /// /// **Warning:** Because this method *does not* update the client state, it is @@ -1162,20 +1094,9 @@ impl Client { &self, request: Request, ) -> Result { - let request: http::Request> = request.try_into()?; - let response = self - .send_request( - Request::METADATA.requires_authentication, - Request::METADATA.method, - request, - ) - .await?; - - trace!("Got response: {:?}", response); - - let response = self.response_to_http_response(response).await?; - - Ok(::try_from(response)?) + self.http_client + .send(request, self.base_client.session().read().await.as_ref()) + .await } /// Send an arbitrary request to the server, without updating client state. @@ -1215,22 +1136,9 @@ impl Client { &self, request: Request, ) -> Result { - let request: http::Request> = request.try_into()?; - let response = self - .send_request( - Request::METADATA.requires_authentication, - Request::METADATA.method, - request, - ) - .await?; - - trace!("Got response: {:?}", response); - - let response = self.response_to_http_response(response).await?; - - let uiaa: Result<_> = ::try_from(response).map_err(Into::into); - - Ok(uiaa?) + self.http_client + .send_uiaa(request, self.base_client.session().read().await.as_ref()) + .await } /// Synchronize the client's state with the latest state on the server. diff --git a/matrix_sdk/src/http_client.rs b/matrix_sdk/src/http_client.rs new file mode 100644 index 00000000..deb97ac9 --- /dev/null +++ b/matrix_sdk/src/http_client.rs @@ -0,0 +1,148 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::convert::TryFrom; +use std::sync::Arc; + +use http::{Method as HttpMethod, Response as HttpResponse}; +use reqwest::{header::AUTHORIZATION, Client, Response}; +use tracing::trace; +use url::Url; + +use crate::{api::r0::uiaa::UiaaResponse, Endpoint, Error, Result, Session}; + +#[derive(Clone, Debug)] +pub(crate) struct HttpClient { + pub(crate) inner: Client, + pub(crate) homeserver: Arc, +} + +impl HttpClient { + async fn send_request( + &self, + requires_auth: bool, + method: HttpMethod, + request: http::Request>, + session: Option<&Session>, + ) -> Result { + let url = request.uri(); + let path_and_query = url.path_and_query().unwrap(); + let mut url = (&*self.homeserver).clone(); + + url.set_path(path_and_query.path()); + url.set_query(path_and_query.query()); + + let request_builder = match method { + HttpMethod::GET => self.inner.get(url), + HttpMethod::POST => { + let body = request.body().clone(); + self.inner + .post(url) + .body(body) + .header(reqwest::header::CONTENT_TYPE, "application/json") + } + HttpMethod::PUT => { + let body = request.body().clone(); + self.inner + .put(url) + .body(body) + .header(reqwest::header::CONTENT_TYPE, "application/json") + } + HttpMethod::DELETE => { + let body = request.body().clone(); + self.inner + .delete(url) + .body(body) + .header(reqwest::header::CONTENT_TYPE, "application/json") + } + method => panic!("Unsupported method {}", method), + }; + + let request_builder = if requires_auth { + if let Some(session) = session { + let header_value = format!("Bearer {}", &session.access_token); + request_builder.header(AUTHORIZATION, header_value) + } else { + return Err(Error::AuthenticationRequired); + } + } else { + request_builder + }; + + Ok(request_builder.send().await?) + } + + async fn response_to_http_response( + &self, + mut response: Response, + ) -> Result>> { + let status = response.status(); + let mut http_builder = HttpResponse::builder().status(status); + let headers = http_builder.headers_mut().unwrap(); + + for (k, v) in response.headers_mut().drain() { + if let Some(key) = k { + headers.insert(key, v); + } + } + let body = response.bytes().await?.as_ref().to_owned(); + Ok(http_builder.body(body).unwrap()) + } + + pub async fn send + std::fmt::Debug>( + &self, + request: Request, + session: Option<&Session>, + ) -> Result { + let request: http::Request> = request.try_into()?; + let response = self + .send_request( + Request::METADATA.requires_authentication, + Request::METADATA.method, + request, + session, + ) + .await?; + + trace!("Got response: {:?}", response); + + let response = self.response_to_http_response(response).await?; + + Ok(::try_from(response)?) + } + + pub async fn send_uiaa + std::fmt::Debug>( + &self, + request: Request, + session: Option<&Session>, + ) -> Result { + let request: http::Request> = request.try_into()?; + let response = self + .send_request( + Request::METADATA.requires_authentication, + Request::METADATA.method, + request, + session, + ) + .await?; + + trace!("Got response: {:?}", response); + + let response = self.response_to_http_response(response).await?; + + let uiaa: Result<_> = ::try_from(response).map_err(Into::into); + + Ok(uiaa?) + } +} diff --git a/matrix_sdk/src/lib.rs b/matrix_sdk/src/lib.rs index 9798c67e..a5177891 100644 --- a/matrix_sdk/src/lib.rs +++ b/matrix_sdk/src/lib.rs @@ -51,6 +51,7 @@ pub use matrix_sdk_base::{Device, TrustState}; mod client; mod error; +mod http_client; mod request_builder; pub use client::{Client, ClientConfig, SyncSettings}; pub use error::{Error, Result};