diff --git a/matrix_sdk/src/client.rs b/matrix_sdk/src/client.rs index 1b17dec1..3f4244ca 100644 --- a/matrix_sdk/src/client.rs +++ b/matrix_sdk/src/client.rs @@ -130,9 +130,8 @@ use crate::{ EventHandler, }; +const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(10); const DEFAULT_SYNC_TIMEOUT: Duration = Duration::from_secs(30); -/// Give the sync a bit more time than the default request timeout does. -const SYNC_REQUEST_TIMEOUT: Duration = Duration::from_secs(15); /// A conservative upload speed of 1Mbps const DEFAULT_UPLOAD_SPEED: u64 = 125_000; /// 5 min minimal upload request timeout, used to clamp the request timeout. @@ -199,7 +198,7 @@ pub struct ClientConfig { pub(crate) user_agent: Option, pub(crate) disable_ssl_verification: bool, pub(crate) base_config: BaseClientConfig, - pub(crate) timeout: Option, + pub(crate) request_config: RequestConfig, pub(crate) client: Option>, } @@ -213,6 +212,7 @@ impl Debug for ClientConfig { res.field("user_agent", &self.user_agent) .field("disable_ssl_verification", &self.disable_ssl_verification) + .field("request_config", &self.request_config) .finish() } } @@ -295,9 +295,9 @@ impl ClientConfig { self } - /// Set a timeout duration for all HTTP requests. The default is no timeout. - pub fn timeout(mut self, timeout: Duration) -> Self { - self.timeout = Some(timeout); + /// Set the default timeout, fail and retry behavior for all HTTP requests. + pub fn request_config(mut self, request_config: RequestConfig) -> Self { + self.request_config = request_config; self } @@ -311,7 +311,7 @@ impl ClientConfig { } } -#[derive(Debug, Default, Clone)] +#[derive(Debug, Clone)] /// Settings for a sync call. pub struct SyncSettings<'a> { pub(crate) filter: Option>, @@ -320,6 +320,17 @@ pub struct SyncSettings<'a> { pub(crate) full_state: bool, } +impl<'a> Default for SyncSettings<'a> { + fn default() -> Self { + Self { + filter: Default::default(), + timeout: Some(DEFAULT_SYNC_TIMEOUT), + token: Default::default(), + full_state: Default::default(), + } + } +} + impl<'a> SyncSettings<'a> { /// Create new default sync settings. pub fn new() -> Self { @@ -371,6 +382,84 @@ impl<'a> SyncSettings<'a> { } } +/// Configuration for requests the `Client` makes. +/// +/// This sets how often and for how long a request should be repeated. As well as how long a +/// successful request is allowed to take. +/// +/// By default requests are retried indefinitely and use no timeout. +/// +/// # Example +/// +/// ``` +/// # use matrix_sdk::RequestConfig; +/// # use std::time::Duration; +/// // This sets makes requests fail after a single send request and sets the timeout to 30s +/// let request_config = RequestConfig::new() +/// .disable_retry() +/// .timeout(Duration::from_secs(30)); +/// ``` +#[derive(Copy, Clone)] +pub struct RequestConfig { + pub(crate) timeout: Duration, + pub(crate) retry_limit: Option, + pub(crate) retry_timeout: Option, +} + +#[cfg(not(tarpaulin_include))] +impl Debug for RequestConfig { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut res = fmt.debug_struct("RequestConfig"); + + res.field("timeout", &self.timeout) + .field("retry_limit", &self.retry_limit) + .field("retry_timeout", &self.retry_timeout) + .finish() + } +} + +impl Default for RequestConfig { + fn default() -> Self { + Self { + timeout: DEFAULT_REQUEST_TIMEOUT, + retry_limit: Default::default(), + retry_timeout: Default::default(), + } + } +} + +impl RequestConfig { + /// Create a new default `RequestConfig`. + pub fn new() -> Self { + Default::default() + } + + /// This is a convince method to disable the retries of a request. Setting the `retry_limit` to `0` + /// has the same effect. + pub fn disable_retry(mut self) -> Self { + self.retry_limit = Some(0); + self + } + + /// The number of times a request should be retried. The default is no limit + pub fn retry_limit(mut self, retry_limit: u64) -> Self { + self.retry_limit = Some(retry_limit); + self + } + + /// Set the timeout duration for all HTTP requests. + pub fn timeout(mut self, timeout: Duration) -> Self { + self.timeout = timeout; + self + } + + /// Set a timeout for how long a request should be retried. The default is no timeout, meaning requests are retried forever. + pub fn retry_timeout(mut self, retry_timeout: Duration) -> Self { + self.retry_timeout = Some(retry_timeout); + self + } +} + impl Client { /// Creates a new client for making HTTP requests to the given homeserver. /// @@ -412,6 +501,7 @@ impl Client { homeserver: homeserver.clone(), inner: client, session, + request_config: config.request_config, }; Ok(Self { @@ -1374,7 +1464,11 @@ impl Client { content_type: Some(content_type.essence_str()), }); - Ok(self.http_client.upload(request, Some(timeout)).await?) + let request_config = self.http_client.request_config.timeout(timeout); + Ok(self + .http_client + .upload(request, Some(request_config)) + .await?) } /// Send a room message to a room. @@ -1486,13 +1580,13 @@ impl Client { pub async fn send( &self, request: Request, - timeout: Option, + config: Option, ) -> Result where Request: OutgoingRequest + Debug, HttpError: From>, { - Ok(self.http_client.send(request, timeout).await?) + Ok(self.http_client.send(request, config).await?) } #[cfg(feature = "encryption")] @@ -1625,12 +1719,14 @@ impl Client { timeout: sync_settings.timeout, }); - let timeout = sync_settings - .timeout - .unwrap_or_else(|| Duration::from_secs(0)) - + SYNC_REQUEST_TIMEOUT; + let request_config = self.http_client.request_config.timeout( + sync_settings + .timeout + .unwrap_or_else(|| Duration::from_secs(0)) + + self.http_client.request_config.timeout, + ); - let response = self.send(request, Some(timeout)).await?; + let response = self.send(request, Some(request_config)).await?; let sync_response = self.base_client.receive_sync_response(response).await?; if let Some(handler) = self.event_handler.read().await.as_ref() { @@ -1727,7 +1823,6 @@ impl Client { } loop { - let filter = sync_settings.filter.clone(); let response = self.sync_once(sync_settings.clone()).await; let response = match response { @@ -1809,14 +1904,11 @@ impl Client { last_sync_time = Some(now); - sync_settings = SyncSettings::new().timeout(DEFAULT_SYNC_TIMEOUT).token( + sync_settings.token = Some( self.sync_token() .await .expect("No sync token found after initial sync"), ); - if let Some(f) = filter { - sync_settings = sync_settings.filter(f); - } } } @@ -2247,7 +2339,7 @@ impl Client { #[cfg(test)] mod test { - use crate::{ClientConfig, HttpError, RoomMember}; + use crate::{ClientConfig, HttpError, RequestConfig, RoomMember}; use super::{ get_public_rooms, get_public_rooms_filtered, register::RegistrationKind, Client, Session, @@ -3407,4 +3499,75 @@ mod test { } } } + + #[tokio::test] + async fn retry_limit_http_requests() { + let homeserver = Url::from_str(&mockito::server_url()).unwrap(); + let config = ClientConfig::default().request_config(RequestConfig::new().retry_limit(3)); + assert!(config.request_config.retry_limit.unwrap() == 3); + let client = Client::new_with_config(homeserver, config).unwrap(); + + let m = mock("POST", "/_matrix/client/r0/login") + .with_status(501) + .expect(3) + .create(); + + if client + .login("example", "wordpass", None, None) + .await + .is_err() + { + m.assert(); + } else { + panic!("this request should return an `Err` variant") + } + } + + #[tokio::test] + async fn retry_timeout_http_requests() { + let homeserver = Url::from_str(&mockito::server_url()).unwrap(); + // Keep this timeout small so that the test doesn't take long + let retry_timeout = Duration::from_secs(5); + let config = ClientConfig::default() + .request_config(RequestConfig::new().retry_timeout(retry_timeout)); + assert!(config.request_config.retry_timeout.unwrap() == retry_timeout); + let client = Client::new_with_config(homeserver, config).unwrap(); + + let m = mock("POST", "/_matrix/client/r0/login") + .with_status(501) + .expect_at_least(2) + .create(); + + if client + .login("example", "wordpass", None, None) + .await + .is_err() + { + m.assert(); + } else { + panic!("this request should return an `Err` variant") + } + } + + #[tokio::test] + async fn no_retry_http_requests() { + let homeserver = Url::from_str(&mockito::server_url()).unwrap(); + let config = ClientConfig::default().request_config(RequestConfig::new().disable_retry()); + assert!(config.request_config.retry_limit.unwrap() == 0); + let client = Client::new_with_config(homeserver, config).unwrap(); + + let m = mock("POST", "/_matrix/client/r0/login") + .with_status(501) + .create(); + + if client + .login("example", "wordpass", None, None) + .await + .is_err() + { + m.assert(); + } else { + panic!("this request should return an `Err` variant") + } + } } diff --git a/matrix_sdk/src/http_client.rs b/matrix_sdk/src/http_client.rs index 3e1cb8f2..1a34aae4 100644 --- a/matrix_sdk/src/http_client.rs +++ b/matrix_sdk/src/http_client.rs @@ -14,26 +14,23 @@ use std::{convert::TryFrom, fmt::Debug, sync::Arc}; -#[cfg(all(not(test), not(target_arch = "wasm32")))] +#[cfg(all(not(target_arch = "wasm32")))] use backoff::{future::retry, Error as RetryError, ExponentialBackoff}; -#[cfg(all(not(test), not(target_arch = "wasm32")))] +#[cfg(all(not(target_arch = "wasm32")))] use http::StatusCode; use http::{HeaderValue, Response as HttpResponse}; use reqwest::{Client, Response}; +#[cfg(all(not(target_arch = "wasm32")))] +use std::sync::atomic::{AtomicU64, Ordering}; use tracing::trace; use url::Url; use matrix_sdk_common::{ - api::r0::media::create_content, async_trait, instant::Duration, locks::RwLock, AsyncTraitDeps, - AuthScheme, FromHttpResponseError, + api::r0::media::create_content, async_trait, locks::RwLock, AsyncTraitDeps, AuthScheme, + FromHttpResponseError, }; -use crate::{error::HttpError, ClientConfig, OutgoingRequest, Session}; - -#[cfg(not(target_arch = "wasm32"))] -const DEFAULT_CONNECTION_TIMEOUT: Duration = Duration::from_secs(5); -#[cfg(not(target_arch = "wasm32"))] -const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(10); +use crate::{error::HttpError, ClientConfig, OutgoingRequest, RequestConfig, Session}; /// Abstraction around the http layer. The allows implementors to use different /// http libraries. @@ -48,12 +45,13 @@ pub trait HttpSend: AsyncTraitDeps { /// /// * `request` - The http request that has been converted from a ruma `Request`. /// + /// * `request_config` - The config used for this request. + /// /// # Examples /// /// ``` /// use std::convert::TryFrom; - /// use matrix_sdk::{HttpSend, async_trait, HttpError}; - /// # use std::time::Duration; + /// use matrix_sdk::{HttpSend, async_trait, HttpError, RequestConfig}; /// /// #[derive(Debug)] /// struct Client(reqwest::Client); @@ -73,7 +71,7 @@ pub trait HttpSend: AsyncTraitDeps { /// async fn send_request( /// &self, /// request: http::Request>, - /// timeout: Option, + /// config: RequestConfig, /// ) -> Result>, HttpError> { /// Ok(self /// .response_to_http_response( @@ -88,7 +86,7 @@ pub trait HttpSend: AsyncTraitDeps { async fn send_request( &self, request: http::Request>, - timeout: Option, + config: RequestConfig, ) -> Result>, HttpError>; } @@ -97,6 +95,7 @@ pub(crate) struct HttpClient { pub(crate) inner: Arc, pub(crate) homeserver: Arc, pub(crate) session: Arc>>, + pub(crate) request_config: RequestConfig, } impl HttpClient { @@ -104,7 +103,7 @@ impl HttpClient { &self, request: Request, session: Arc>>, - timeout: Option, + config: Option, ) -> Result>, HttpError> { let request = { let read_guard; @@ -125,16 +124,21 @@ impl HttpClient { request.try_into_http_request(&self.homeserver.to_string(), access_token)? }; - self.inner.send_request(request, timeout).await + let config = match config { + Some(config) => config, + None => self.request_config, + }; + + self.inner.send_request(request, config).await } pub async fn upload( &self, request: create_content::Request<'_>, - timeout: Option, + config: Option, ) -> Result { let response = self - .send_request(request, self.session.clone(), timeout) + .send_request(request, self.session.clone(), config) .await?; Ok(create_content::Response::try_from(response)?) } @@ -142,14 +146,14 @@ impl HttpClient { pub async fn send( &self, request: Request, - timeout: Option, + config: Option, ) -> Result where Request: OutgoingRequest + Debug, HttpError: From>, { let response = self - .send_request(request, self.session.clone(), timeout) + .send_request(request, self.session.clone(), config) .await?; trace!("Got response: {:?}", response); @@ -166,11 +170,6 @@ pub(crate) fn client_with_config(config: &ClientConfig) -> Result http_client.timeout(x), - None => http_client.timeout(DEFAULT_REQUEST_TIMEOUT), - }; - let http_client = if config.disable_ssl_verification { http_client.danger_accept_invalid_certs(true) } else { @@ -194,7 +193,7 @@ pub(crate) fn client_with_config(config: &ClientConfig) -> Result>, - _: Option, + _: RequestConfig, ) -> Result>, HttpError> { let request = reqwest::Request::try_from(request)?; let response = client.execute(request).await?; @@ -239,34 +238,52 @@ async fn send_request( Ok(response_to_http_response(response).await?) } -#[cfg(all(not(test), not(target_arch = "wasm32")))] +#[cfg(all(not(target_arch = "wasm32")))] async fn send_request( client: &Client, request: http::Request>, - timeout: Option, + config: RequestConfig, ) -> Result>, HttpError> { - let backoff = ExponentialBackoff::default(); + let mut backoff = ExponentialBackoff::default(); let mut request = reqwest::Request::try_from(request)?; + let retry_limit = config.retry_limit; + let retry_count = AtomicU64::new(1); - if let Some(timeout) = timeout { - *request.timeout_mut() = Some(timeout); - } + *request.timeout_mut() = Some(config.timeout); + + backoff.max_elapsed_time = config.retry_timeout; let request = &request; + let retry_count = &retry_count; let request = || async move { + let stop = if let Some(retry_limit) = retry_limit { + retry_count.fetch_add(1, Ordering::Relaxed) >= retry_limit + } else { + false + }; + + // Turn errors into permanent errors when the retry limit is reached + let error_type = if stop { + RetryError::Permanent + } else { + RetryError::Transient + }; + let request = request.try_clone().ok_or(HttpError::UnableToCloneRequest)?; let response = client .execute(request) .await - .map_err(|e| RetryError::Transient(HttpError::Reqwest(e)))?; + .map_err(|e| error_type(HttpError::Reqwest(e)))?; let status_code = response.status(); // TODO TOO_MANY_REQUESTS will have a retry timeout which we should // use. - if status_code.is_server_error() || response.status() == StatusCode::TOO_MANY_REQUESTS { - return Err(RetryError::Transient(HttpError::Server(status_code))); + if !stop + && (status_code.is_server_error() || response.status() == StatusCode::TOO_MANY_REQUESTS) + { + return Err(error_type(HttpError::Server(status_code))); } let response = response_to_http_response(response) @@ -287,8 +304,8 @@ impl HttpSend for Client { async fn send_request( &self, request: http::Request>, - timeout: Option, + config: RequestConfig, ) -> Result>, HttpError> { - send_request(&self, request, timeout).await + send_request(&self, request, config).await } } diff --git a/matrix_sdk/src/lib.rs b/matrix_sdk/src/lib.rs index dbaea1e9..f7ac4928 100644 --- a/matrix_sdk/src/lib.rs +++ b/matrix_sdk/src/lib.rs @@ -95,7 +95,7 @@ mod sas; #[cfg(feature = "encryption")] mod verification_request; -pub use client::{Client, ClientConfig, LoopCtrl, SyncSettings}; +pub use client::{Client, ClientConfig, LoopCtrl, RequestConfig, SyncSettings}; #[cfg(feature = "encryption")] #[cfg_attr(feature = "docs", doc(cfg(encryption)))] pub use device::Device;