diff --git a/contrib/mitmproxy/failures.py b/contrib/mitmproxy/failures.py new file mode 100644 index 00000000..10dc1898 --- /dev/null +++ b/contrib/mitmproxy/failures.py @@ -0,0 +1,47 @@ +""" +A mitmproxy script that introduces certain request failures in a deterministic +way. + +Used mainly for Matrix style requests. + +To run execute it with mitmproxy: + + >>> mitmproxy -s failures.py` + +""" +import time +import json + +from mitmproxy import http +from mitmproxy.script import concurrent + +REQUEST_COUNT = 0 + + +@concurrent +def request(flow): + global REQUEST_COUNT + + REQUEST_COUNT += 1 + + if REQUEST_COUNT % 2 == 0: + return + elif REQUEST_COUNT % 3 == 0: + flow.response = http.HTTPResponse.make( + 500, + b"Gateway error", + ) + elif REQUEST_COUNT % 7 == 0: + if "sync" in flow.request.pretty_url: + time.sleep(60) + else: + time.sleep(30) + else: + flow.response = http.HTTPResponse.make( + 429, + json.dumps({ + "errcode": "M_LIMIT_EXCEEDED", + "error": "Too many requests", + "retry_after_ms": 2000 + }) + ) diff --git a/matrix_sdk/Cargo.toml b/matrix_sdk/Cargo.toml index 648e358c..95ffc18b 100644 --- a/matrix_sdk/Cargo.toml +++ b/matrix_sdk/Cargo.toml @@ -50,6 +50,11 @@ default_features = false version = "0.11.0" default_features = false +[target.'cfg(not(target_arch = "wasm32"))'.dependencies.backoff] +git = "https://github.com/ihrwein/backoff" +features = ["tokio"] +rev = "fa3fb91431729ce871d29c62b93425b8aec740f4" + [dependencies.tracing-futures] version = "0.2.4" default-features = false diff --git a/matrix_sdk/examples/get_profiles.rs b/matrix_sdk/examples/get_profiles.rs index c75d25b4..4d20bb76 100644 --- a/matrix_sdk/examples/get_profiles.rs +++ b/matrix_sdk/examples/get_profiles.rs @@ -19,7 +19,7 @@ async fn get_profile(client: Client, mxid: &UserId) -> MatrixResult let request = profile::get_profile::Request::new(mxid); // Start the request using matrix_sdk::Client::send - let resp = client.send(request).await?; + let resp = client.send(request, None).await?; // Use the response and construct a UserProfile struct. // See https://docs.rs/ruma-client-api/0.9.0/ruma_client_api/r0/profile/get_profile/struct.Response.html diff --git a/matrix_sdk/src/client.rs b/matrix_sdk/src/client.rs index 2ea72ff1..016da325 100644 --- a/matrix_sdk/src/client.rs +++ b/matrix_sdk/src/client.rs @@ -118,6 +118,7 @@ use matrix_sdk_common::{ }; use crate::{ + error::HttpError, http_client::{client_with_config, HttpClient, HttpSend}, Error, OutgoingRequest, Result, }; @@ -131,6 +132,12 @@ use crate::{ }; const DEFAULT_SYNC_TIMEOUT: Duration = Duration::from_secs(30); +/// Give the sync a bit more time than the default request timeout does. +const SYNC_REQUEST_TIMEOUT: Duration = Duration::from_secs(15); +/// A conservative upload speed of 1Mbps +const DEFAULT_UPLOAD_SPEED: u64 = 125_000; +/// 5 min minimal upload request timeout, used to clamp the request timeout. +const MIN_UPLOAD_REQUEST_TIMEOUT: Duration = Duration::from_secs(60 * 5); /// An async/await enabled Matrix client. /// @@ -451,7 +458,7 @@ impl Client { pub async fn display_name(&self) -> Result> { let user_id = self.user_id().await.ok_or(Error::AuthenticationRequired)?; let request = get_display_name::Request::new(&user_id); - let response = self.send(request).await?; + let response = self.send(request, None).await?; Ok(response.displayname) } @@ -474,7 +481,7 @@ impl Client { pub async fn set_display_name(&self, name: Option<&str>) -> Result<()> { let user_id = self.user_id().await.ok_or(Error::AuthenticationRequired)?; let request = set_display_name::Request::new(&user_id, name); - self.send(request).await?; + self.send(request, None).await?; Ok(()) } @@ -499,7 +506,7 @@ impl Client { pub async fn avatar_url(&self) -> Result> { let user_id = self.user_id().await.ok_or(Error::AuthenticationRequired)?; let request = get_avatar_url::Request::new(&user_id); - let response = self.send(request).await?; + let response = self.send(request, None).await?; Ok(response.avatar_url) } @@ -512,7 +519,7 @@ impl Client { pub async fn set_avatar_url(&self, url: Option<&str>) -> Result<()> { let user_id = self.user_id().await.ok_or(Error::AuthenticationRequired)?; let request = set_avatar_url::Request::new(&user_id, url); - self.send(request).await?; + self.send(request, None).await?; Ok(()) } @@ -671,7 +678,7 @@ impl Client { } ); - let response = self.send(request).await?; + let response = self.send(request, None).await?; self.base_client.receive_login_response(&response).await?; Ok(response) @@ -733,7 +740,7 @@ impl Client { info!("Registering to {}", self.homeserver); let request = registration.into(); - self.send(request).await + self.send(request, None).await } /// Get or upload a sync filter. @@ -747,7 +754,7 @@ impl Client { } else { let user_id = self.user_id().await.ok_or(Error::AuthenticationRequired)?; let request = FilterUploadRequest::new(&user_id, definition); - let response = self.send(request).await?; + let response = self.send(request, None).await?; self.base_client .receive_filter_upload(filter_name, &response) @@ -767,7 +774,7 @@ impl Client { /// * `room_id` - The `RoomId` of the room to be joined. pub async fn join_room_by_id(&self, room_id: &RoomId) -> Result { let request = join_room_by_id::Request::new(room_id); - self.send(request).await + self.send(request, None).await } /// Join a room by `RoomId`. @@ -787,7 +794,7 @@ impl Client { let request = assign!(join_room_by_id_or_alias::Request::new(alias), { server_name: server_names, }); - self.send(request).await + self.send(request, None).await } /// Forget a room by `RoomId`. @@ -799,7 +806,7 @@ impl Client { /// * `room_id` - The `RoomId` of the room to be forget. pub async fn forget_room_by_id(&self, room_id: &RoomId) -> Result { let request = forget_room::Request::new(room_id); - self.send(request).await + self.send(request, None).await } /// Ban a user from a room by `RoomId` and `UserId`. @@ -820,7 +827,7 @@ impl Client { reason: Option<&str>, ) -> Result { let request = assign!(ban_user::Request::new(room_id, user_id), { reason }); - self.send(request).await + self.send(request, None).await } /// Kick a user out of the specified room. @@ -841,7 +848,7 @@ impl Client { reason: Option<&str>, ) -> Result { let request = assign!(kick_user::Request::new(room_id, user_id), { reason }); - self.send(request).await + self.send(request, None).await } /// Leave the specified room. @@ -853,7 +860,7 @@ impl Client { /// * `room_id` - The `RoomId` of the room to leave. pub async fn leave_room(&self, room_id: &RoomId) -> Result { let request = leave_room::Request::new(room_id); - self.send(request).await + self.send(request, None).await } /// Invite the specified user by `UserId` to the given room. @@ -873,7 +880,7 @@ impl Client { let recipient = InvitationRecipient::UserId { user_id }; let request = invite_user::Request::new(room_id, recipient); - self.send(request).await + self.send(request, None).await } /// Invite the specified user by third party id to the given room. @@ -892,7 +899,7 @@ impl Client { ) -> Result { let recipient = InvitationRecipient::ThirdPartyId(invite_id); let request = invite_user::Request::new(room_id, recipient); - self.send(request).await + self.send(request, None).await } /// Search the homeserver's directory of public rooms. @@ -938,7 +945,7 @@ impl Client { since, server, }); - self.send(request).await + self.send(request, None).await } /// Search the homeserver's directory of public rooms with a filter. @@ -976,7 +983,7 @@ impl Client { room_search: impl Into>, ) -> Result { let request = room_search.into(); - self.send(request).await + self.send(request, None).await } /// Create a room using the `RoomBuilder` and send the request. @@ -1008,7 +1015,7 @@ impl Client { room: impl Into>, ) -> Result { let request = room.into(); - self.send(request).await + self.send(request, None).await } /// Sends a request to `/_matrix/client/r0/rooms/{room_id}/messages` and returns @@ -1043,8 +1050,8 @@ impl Client { &self, request: impl Into>, ) -> Result { - let req = request.into(); - self.send(req).await + let request = request.into(); + self.send(request, None).await } /// Send a request to notify the room of a user typing. @@ -1087,7 +1094,7 @@ impl Client { let user_id = self.user_id().await.ok_or(Error::AuthenticationRequired)?; let request = TypingRequest::new(&user_id, room_id, typing.into()); - self.send(request).await + self.send(request, None).await } /// Send a request to notify the room the user has read specific event. @@ -1106,7 +1113,7 @@ impl Client { ) -> Result { let request = create_receipt::Request::new(room_id, create_receipt::ReceiptType::Read, event_id); - self.send(request).await + self.send(request, None).await } /// Send a request to notify the room user has read up to specific event. @@ -1129,7 +1136,7 @@ impl Client { let request = assign!(set_read_marker::Request::new(room_id, fully_read), { read_receipt }); - self.send(request).await + self.send(request, None).await } /// Share a group session for the given room. @@ -1260,7 +1267,7 @@ impl Client { let txn_id = txn_id.unwrap_or_else(Uuid::new_v4).to_string(); let request = send_message_event::Request::new(&room_id, &txn_id, &content); - let response = self.send(request).await?; + let response = self.send(request, None).await?; Ok(response) } @@ -1447,11 +1454,16 @@ impl Client { let mut data = Vec::new(); reader.read_to_end(&mut data)?; + let timeout = std::cmp::max( + Duration::from_secs(data.len() as u64 / DEFAULT_UPLOAD_SPEED), + MIN_UPLOAD_REQUEST_TIMEOUT, + ); + let request = assign!(create_content::Request::new(data), { content_type: Some(content_type.essence_str()), }); - self.http_client.upload(request).await + Ok(self.http_client.upload(request, Some(timeout)).await?) } /// Send an arbitrary request to the server, without updating client state. @@ -1465,6 +1477,9 @@ impl Client { /// /// * `request` - A filled out and valid request for the endpoint to be hit /// + /// * `timeout` - An optional request timeout setting, this overrides the + /// default request setting if one was set. + /// /// # Example /// /// ```no_run @@ -1485,18 +1500,22 @@ impl Client { /// let request = profile::get_profile::Request::new(&user_id); /// /// // Start the request using Client::send() - /// let response = client.send(request).await.unwrap(); + /// let response = client.send(request, None).await.unwrap(); /// /// // Check the corresponding Response struct to find out what types are /// // returned /// # }) /// ``` - pub async fn send(&self, request: Request) -> Result + pub async fn send( + &self, + request: Request, + timeout: Option, + ) -> Result where Request: OutgoingRequest + Debug, - Error: From>, + HttpError: From>, { - self.http_client.send(request).await + Ok(self.http_client.send(request, timeout).await?) } #[cfg(feature = "encryption")] @@ -1511,7 +1530,7 @@ impl Client { request.messages.clone(), ); - self.send(request).await + self.send(request, None).await } /// Get information of all our own devices. @@ -1540,7 +1559,7 @@ impl Client { pub async fn devices(&self) -> Result { let request = get_devices::Request::new(); - self.send(request).await + self.send(request, None).await } /// Delete the given devices from the server. @@ -1605,13 +1624,13 @@ impl Client { let mut request = delete_devices::Request::new(devices); request.auth = auth_data; - self.send(request).await + self.send(request, None).await } /// Get the room members for the given room. pub async fn room_members(&self, room_id: &RoomId) -> Result { let request = get_member_events::Request::new(room_id); - let response = self.send(request).await?; + let response = self.send(request, None).await?; Ok(self.base_client.receive_members(room_id, &response).await?) } @@ -1637,7 +1656,12 @@ impl Client { timeout: sync_settings.timeout, }); - let response = self.send(request).await?; + let timeout = sync_settings + .timeout + .unwrap_or_else(|| Duration::from_secs(0)) + + SYNC_REQUEST_TIMEOUT; + + let response = self.send(request, Some(timeout)).await?; Ok(self.base_client.receive_sync_response(response).await?) } @@ -1778,7 +1802,7 @@ impl Client { } OutgoingRequests::SignatureUpload(request) => { // TODO remove this unwrap. - if let Ok(resp) = self.send(request.clone()).await { + if let Ok(resp) = self.send(request.clone(), None).await { self.base_client .mark_request_as_sent(&r.request_id(), &resp) .await @@ -1838,7 +1862,7 @@ impl Client { let _lock = self.key_claim_lock.lock().await; if let Some((request_id, request)) = self.base_client.get_missing_sessions(users).await? { - let response = self.send(request).await?; + let response = self.send(request, None).await?; self.base_client .mark_request_as_sent(&request_id, &response) .await?; @@ -1897,7 +1921,7 @@ impl Client { request.one_time_keys.as_ref().map_or(0, |k| k.len()) ); - let response = self.send(request.clone()).await?; + let response = self.send(request.clone(), None).await?; self.base_client .mark_request_as_sent(request_id, &response) .await?; @@ -1926,7 +1950,7 @@ impl Client { ) -> Result { let request = assign!(get_keys::Request::new(), { device_keys }); - let response = self.send(request).await?; + let response = self.send(request, None).await?; self.base_client .mark_request_as_sent(request_id, &response) .await?; @@ -2079,8 +2103,8 @@ impl Client { user_signing_key: request.user_signing_key, }); - self.send(request).await?; - self.send(signature_request).await?; + self.send(request, None).await?; + self.send(signature_request, None).await?; Ok(()) } @@ -2276,7 +2300,7 @@ impl Client { #[cfg(test)] mod test { - use crate::ClientConfig; + use crate::{ClientConfig, HttpError}; use super::{ get_public_rooms, get_public_rooms_filtered, register::RegistrationKind, Client, @@ -2471,12 +2495,12 @@ mod test { .create(); if let Err(err) = client.login("example", "wordpass", None, None).await { - if let crate::Error::RumaResponse(crate::FromHttpResponseError::Http( - crate::ServerError::Known(crate::api::Error { + if let crate::Error::Http(HttpError::FromHttpResponse( + crate::FromHttpResponseError::Http(crate::ServerError::Known(crate::api::Error { kind, message, status_code, - }), + })), )) = err { if let crate::api::error::ErrorKind::Forbidden = kind { @@ -2517,10 +2541,10 @@ mod test { }); if let Err(err) = client.register(user).await { - if let crate::Error::UiaaError(crate::FromHttpResponseError::Http( + if let crate::Error::Http(HttpError::UiaaError(crate::FromHttpResponseError::Http( // TODO this should be a UiaaError need to investigate crate::ServerError::Unknown(e), - )) = err + ))) = err { assert!(e.to_string().starts_with("EOF while parsing")) } else { diff --git a/matrix_sdk/src/error.rs b/matrix_sdk/src/error.rs index 29b0d9b4..1c2f9e57 100644 --- a/matrix_sdk/src/error.rs +++ b/matrix_sdk/src/error.rs @@ -14,13 +14,14 @@ //! Error conditions. +use http::StatusCode; use matrix_sdk_base::{Error as MatrixError, StoreError}; use matrix_sdk_common::{ api::{ r0::uiaa::{UiaaInfo, UiaaResponse as UiaaError}, Error as RumaClientError, }, - FromHttpResponseError as RumaResponseError, IntoHttpError as RumaIntoHttpError, ServerError, + FromHttpResponseError, IntoHttpError, ServerError, }; use reqwest::Error as ReqwestError; use serde_json::Error as JsonError; @@ -33,9 +34,14 @@ use matrix_sdk_base::crypto::store::CryptoStoreError; /// Result type of the rust-sdk. pub type Result = std::result::Result; -/// Internal representation of errors. +/// An HTTP error, representing either a connection error or an error while +/// converting the raw HTTP response into a Matrix response. #[derive(Error, Debug)] -pub enum Error { +pub enum HttpError { + /// An error at the HTTP layer. + #[error(transparent)] + Reqwest(#[from] ReqwestError), + /// Queried endpoint requires authentication but was called on an anonymous client. #[error("the queried endpoint requires authentication but was called before logging in")] AuthenticationRequired, @@ -44,9 +50,41 @@ pub enum Error { #[error("the queried endpoint is not meant for clients")] NotClientRequest, - /// An error at the HTTP layer. + /// An error converting between ruma_client_api types and Hyper types. #[error(transparent)] - Reqwest(#[from] ReqwestError), + FromHttpResponse(#[from] FromHttpResponseError), + + /// An error converting between ruma_client_api types and Hyper types. + #[error(transparent)] + IntoHttp(#[from] IntoHttpError), + + /// An error occurred while authenticating. + /// + /// When registering or authenticating the Matrix server can send a `UiaaResponse` + /// as the error type, this is a User-Interactive Authentication API response. This + /// represents an error with information about how to authenticate the user. + #[error(transparent)] + UiaaError(#[from] FromHttpResponseError), + + /// The server returned a status code that should be retried. + #[error("Server returned an error {0}")] + Server(StatusCode), + + /// The given request can't be cloned and thus can't be retried. + #[error("The request cannot be cloned")] + UnableToCloneRequest, +} + +/// Internal representation of errors. +#[derive(Error, Debug)] +pub enum Error { + /// Error doing an HTTP request. + #[error(transparent)] + Http(#[from] HttpError), + + /// Queried endpoint requires authentication but was called on an anonymous client. + #[error("the queried endpoint requires authentication but was called before logging in")] + AuthenticationRequired, /// An error de/serializing type for the `StateStore` #[error(transparent)] @@ -56,14 +94,6 @@ pub enum Error { #[error(transparent)] IO(#[from] IoError), - /// An error converting between ruma_client_api types and Hyper types. - #[error("can't parse the JSON response as a Matrix response")] - RumaResponse(RumaResponseError), - - /// An error converting between ruma_client_api types and Hyper types. - #[error("can't convert between ruma_client_api and hyper types.")] - IntoHttp(RumaIntoHttpError), - /// An error occurred in the Matrix client library. #[error(transparent)] MatrixError(#[from] MatrixError), @@ -76,14 +106,6 @@ pub enum Error { /// An error occured in the state store. #[error(transparent)] StateStore(#[from] StoreError), - - /// An error occurred while authenticating. - /// - /// When registering or authenticating the Matrix server can send a `UiaaResponse` - /// as the error type, this is a User-Interactive Authentication API response. This - /// represents an error with information about how to authenticate the user. - #[error("User-Interactive Authentication required.")] - UiaaError(RumaResponseError), } impl Error { @@ -99,9 +121,9 @@ impl Error { /// This method is an convenience method to get to the info the server /// returned on the first, failed request. pub fn uiaa_response(&self) -> Option<&UiaaInfo> { - if let Error::UiaaError(RumaResponseError::Http(ServerError::Known( + if let Error::Http(HttpError::UiaaError(FromHttpResponseError::Http(ServerError::Known( UiaaError::AuthResponse(i), - ))) = self + )))) = self { Some(i) } else { @@ -110,20 +132,8 @@ impl Error { } } -impl From> for Error { - fn from(error: RumaResponseError) -> Self { - Self::UiaaError(error) - } -} - -impl From> for Error { - fn from(error: RumaResponseError) -> Self { - Self::RumaResponse(error) - } -} - -impl From for Error { - fn from(error: RumaIntoHttpError) -> Self { - Self::IntoHttp(error) +impl From for Error { + fn from(e: ReqwestError) -> Self { + Error::Http(HttpError::Reqwest(e)) } } diff --git a/matrix_sdk/src/http_client.rs b/matrix_sdk/src/http_client.rs index fd84a424..cd7d6265 100644 --- a/matrix_sdk/src/http_client.rs +++ b/matrix_sdk/src/http_client.rs @@ -14,17 +14,26 @@ use std::{convert::TryFrom, fmt::Debug, sync::Arc}; +#[cfg(all(not(test), not(target_arch = "wasm32")))] +use backoff::{future::retry, Error as RetryError, ExponentialBackoff}; +#[cfg(all(not(test), not(target_arch = "wasm32")))] +use http::StatusCode; use http::{HeaderValue, Method as HttpMethod, Response as HttpResponse}; use reqwest::{Client, Response}; use tracing::trace; use url::Url; use matrix_sdk_common::{ - api::r0::media::create_content, async_trait, locks::RwLock, AsyncTraitDeps, AuthScheme, - FromHttpResponseError, + api::r0::media::create_content, async_trait, instant::Duration, locks::RwLock, AsyncTraitDeps, + AuthScheme, FromHttpResponseError, }; -use crate::{ClientConfig, Error, OutgoingRequest, Result, Session}; +use crate::{error::HttpError, ClientConfig, OutgoingRequest, Session}; + +#[cfg(not(target_arch = "wasm32"))] +const DEFAULT_CONNECTION_TIMEOUT: Duration = Duration::from_secs(5); +#[cfg(not(target_arch = "wasm32"))] +const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(10); /// Abstraction around the http layer. The allows implementors to use different /// http libraries. @@ -43,7 +52,8 @@ pub trait HttpSend: AsyncTraitDeps { /// /// ``` /// use std::convert::TryFrom; - /// use matrix_sdk::{HttpSend, Result, async_trait}; + /// use matrix_sdk::{HttpSend, async_trait, HttpError}; + /// # use std::time::Duration; /// /// #[derive(Debug)] /// struct Client(reqwest::Client); @@ -52,7 +62,7 @@ pub trait HttpSend: AsyncTraitDeps { /// async fn response_to_http_response( /// &self, /// mut response: reqwest::Response, - /// ) -> Result>> { + /// ) -> Result>, HttpError> { /// // Convert the reqwest response to a http one. /// todo!() /// } @@ -60,7 +70,11 @@ pub trait HttpSend: AsyncTraitDeps { /// /// #[async_trait] /// impl HttpSend for Client { - /// async fn send_request(&self, request: http::Request>) -> Result>> { + /// async fn send_request( + /// &self, + /// request: http::Request>, + /// timeout: Option, + /// ) -> Result>, HttpError> { /// Ok(self /// .response_to_http_response( /// self.0 @@ -74,7 +88,8 @@ pub trait HttpSend: AsyncTraitDeps { async fn send_request( &self, request: http::Request>, - ) -> Result>>; + timeout: Option, + ) -> Result>, HttpError>; } #[derive(Clone, Debug)] @@ -90,7 +105,8 @@ impl HttpClient { request: Request, session: Arc>>, content_type: Option, - ) -> Result>> { + timeout: Option, + ) -> Result>, HttpError> { let mut request = { let read_guard; let access_token = match Request::METADATA.authentication { @@ -100,11 +116,11 @@ impl HttpClient { if let Some(session) = read_guard.as_ref() { Some(session.access_token.as_str()) } else { - return Err(Error::AuthenticationRequired); + return Err(HttpError::AuthenticationRequired); } } AuthScheme::None => None, - _ => return Err(Error::NotClientRequest), + _ => return Err(HttpError::NotClientRequest), }; request.try_into_http_request(&self.homeserver.to_string(), access_token)? @@ -118,44 +134,51 @@ impl HttpClient { } } - self.inner.send_request(request).await + self.inner.send_request(request, timeout).await } pub async fn upload( &self, request: create_content::Request<'_>, - ) -> Result { + timeout: Option, + ) -> Result { let response = self - .send_request(request, self.session.clone(), None) + .send_request(request, self.session.clone(), None, timeout) .await?; Ok(create_content::Response::try_from(response)?) } - pub async fn send(&self, request: Request) -> Result + pub async fn send( + &self, + request: Request, + timeout: Option, + ) -> Result where - Request: OutgoingRequest, - Error: From>, + Request: OutgoingRequest + Debug, + HttpError: From>, { let content_type = HeaderValue::from_static("application/json"); let response = self - .send_request(request, self.session.clone(), Some(content_type)) + .send_request(request, self.session.clone(), Some(content_type), timeout) .await?; trace!("Got response: {:?}", response); - Ok(Request::IncomingResponse::try_from(response)?) + let response = Request::IncomingResponse::try_from(response)?; + + Ok(response) } } /// Build a client with the specified configuration. -pub(crate) fn client_with_config(config: &ClientConfig) -> Result { +pub(crate) fn client_with_config(config: &ClientConfig) -> Result { let http_client = reqwest::Client::builder(); #[cfg(not(target_arch = "wasm32"))] let http_client = { let http_client = match config.timeout { Some(x) => http_client.timeout(x), - None => http_client, + None => http_client.timeout(DEFAULT_REQUEST_TIMEOUT), }; let http_client = if config.disable_ssl_verification { @@ -173,12 +196,15 @@ pub(crate) fn client_with_config(config: &ClientConfig) -> Result { let user_agent = match &config.user_agent { Some(a) => a.clone(), - None => HeaderValue::from_str(&format!("matrix-rust-sdk {}", crate::VERSION)).unwrap(), + None => HeaderValue::from_str(&format!("matrix-rust-sdk {}", crate::VERSION)) + .expect("Can't construct the version header"), }; headers.insert(reqwest::header::USER_AGENT, user_agent); - http_client.default_headers(headers) + http_client + .default_headers(headers) + .connect_timeout(DEFAULT_CONNECTION_TIMEOUT) }; #[cfg(target_arch = "wasm32")] @@ -188,11 +214,15 @@ pub(crate) fn client_with_config(config: &ClientConfig) -> Result { Ok(http_client.build()?) } -async fn response_to_http_response(mut response: Response) -> Result>> { +async fn response_to_http_response( + mut response: Response, +) -> Result>, reqwest::Error> { let status = response.status(); let mut http_builder = HttpResponse::builder().status(status); - let headers = http_builder.headers_mut().unwrap(); + let headers = http_builder + .headers_mut() + .expect("Can't get the response builder headers"); for (k, v) in response.headers_mut().drain() { if let Some(key) = k { @@ -202,7 +232,63 @@ async fn response_to_http_response(mut response: Response) -> Result>, + _: Option, +) -> Result>, HttpError> { + let request = reqwest::Request::try_from(request)?; + let response = client.execute(request).await?; + + Ok(response_to_http_response(response).await?) +} + +#[cfg(all(not(test), not(target_arch = "wasm32")))] +async fn send_request( + client: &Client, + request: http::Request>, + timeout: Option, +) -> Result>, HttpError> { + let backoff = ExponentialBackoff::default(); + let mut request = reqwest::Request::try_from(request)?; + + if let Some(timeout) = timeout { + *request.timeout_mut() = Some(timeout); + } + + let request = &request; + + let request = || async move { + let request = request.try_clone().ok_or(HttpError::UnableToCloneRequest)?; + + let response = client + .execute(request) + .await + .map_err(|e| RetryError::Transient(HttpError::Reqwest(e)))?; + + let status_code = response.status(); + // TODO TOO_MANY_REQUESTS will have a retry timeout which we should + // use. + if status_code.is_server_error() || response.status() == StatusCode::TOO_MANY_REQUESTS { + return Err(RetryError::Transient(HttpError::Server(status_code))); + } + + let response = response_to_http_response(response) + .await + .map_err(|e| RetryError::Permanent(HttpError::Reqwest(e)))?; + + Ok(response) + }; + + let response = retry(backoff, request).await?; + + Ok(response) } #[cfg_attr(target_arch = "wasm32", async_trait(?Send))] @@ -211,10 +297,8 @@ impl HttpSend for Client { async fn send_request( &self, request: http::Request>, - ) -> Result>> { - Ok( - response_to_http_response(self.execute(reqwest::Request::try_from(request)?).await?) - .await?, - ) + timeout: Option, + ) -> Result>, HttpError> { + send_request(&self, request, timeout).await } } diff --git a/matrix_sdk/src/lib.rs b/matrix_sdk/src/lib.rs index 8528e920..23e14f8f 100644 --- a/matrix_sdk/src/lib.rs +++ b/matrix_sdk/src/lib.rs @@ -90,7 +90,7 @@ pub use client::{Client, ClientConfig, LoopCtrl, SyncSettings}; #[cfg(feature = "encryption")] #[cfg_attr(feature = "docs", doc(cfg(encryption)))] pub use device::Device; -pub use error::{Error, Result}; +pub use error::{Error, HttpError, Result}; pub use http_client::HttpSend; #[cfg(feature = "encryption")] #[cfg_attr(feature = "docs", doc(cfg(encryption)))] diff --git a/matrix_sdk/src/sas.rs b/matrix_sdk/src/sas.rs index ccbc4224..389633f7 100644 --- a/matrix_sdk/src/sas.rs +++ b/matrix_sdk/src/sas.rs @@ -54,7 +54,7 @@ impl Sas { } if let Some(s) = signature { - self.client.send(s).await?; + self.client.send(s, None).await?; } Ok(())