diff --git a/src/error.rs b/src/error.rs index 6c37bed..e2664e2 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,5 +1,11 @@ use log::{error, warn}; -use ruma::api::client::{error::ErrorKind, r0::uiaa::UiaaInfo}; +use ruma::{ + api::client::{ + error::{Error as RumaError, ErrorKind}, + r0::uiaa::UiaaInfo, + }, + ServerName, +}; use thiserror::Error; #[cfg(feature = "conduit_bin")] @@ -10,7 +16,7 @@ use { response::{self, Responder}, Request, }, - ruma::api::client::{error::Error as RumaError, r0::uiaa::UiaaResponse}, + ruma::api::client::r0::uiaa::UiaaResponse, }; pub type Result = std::result::Result; @@ -33,6 +39,8 @@ pub enum Error { source: reqwest::Error, }, #[error("{0}")] + FederationError(Box, RumaError), + #[error("{0}")] BadServerResponse(&'static str), #[error("{0}")] BadConfig(&'static str), @@ -66,8 +74,13 @@ where 'o: 'r, { fn respond_to(self, r: &'r Request<'_>) -> response::Result<'o> { - if let Self::Uiaa(uiaainfo) = &self { - return RumaResponse::from(UiaaResponse::AuthResponse(uiaainfo.clone())).respond_to(r); + if let Self::Uiaa(uiaainfo) = self { + return RumaResponse::from(UiaaResponse::AuthResponse(uiaainfo)).respond_to(r); + } + + if let Self::FederationError(origin, mut error) = self { + error.message = format!("Answer from {}: {}", origin, error.message); + return RumaResponse::from(error).respond_to(r); } let message = format!("{}", self); diff --git a/src/ruma_wrapper.rs b/src/ruma_wrapper.rs index f2b9b9f..147df3c 100644 --- a/src/ruma_wrapper.rs +++ b/src/ruma_wrapper.rs @@ -137,9 +137,7 @@ where let x_matrix = match request .headers() .get_one("Authorization") - .and_then(|s| - // Split off "X-Matrix " and parse the rest - s.get(9..)) + .and_then(|s| s.get(9..)) // Split off "X-Matrix " and parse the rest .map(|s| { s.split_terminator(',') .map(|field| { diff --git a/src/server_server.rs b/src/server_server.rs index 699cbbe..82e51fc 100644 --- a/src/server_server.rs +++ b/src/server_server.rs @@ -9,7 +9,7 @@ use regex::Regex; use rocket::{response::content::Json, State}; use ruma::{ api::{ - client::error::ErrorKind, + client::error::{Error as RumaError, ErrorKind}, federation::{ device::get_devices::{self, v1::UserDevice}, directory::{get_public_rooms, get_public_rooms_filtered}, @@ -27,7 +27,7 @@ use ruma::{ query::{get_profile_information, get_room_information}, transactions::{edu::Edu, send_transaction_message}, }, - IncomingResponse, OutgoingRequest, OutgoingResponse, SendAccessToken, + EndpointError, IncomingResponse, OutgoingRequest, OutgoingResponse, SendAccessToken, }, directory::{IncomingFilter, IncomingRoomNetwork}, events::{ @@ -261,12 +261,21 @@ where ); } - let response = T::IncomingResponse::try_from_http_response( - http_response_builder - .body(body) - .expect("reqwest body is valid http body"), - ); - response.map_err(|_| Error::BadServerResponse("Server returned bad response.")) + let http_response = http_response_builder + .body(body) + .expect("reqwest body is valid http body"); + + if status == 200 { + let response = T::IncomingResponse::try_from_http_response(http_response); + response.map_err(|_| Error::BadServerResponse("Server returned bad 200 response.")) + } else { + Err(Error::FederationError( + destination.to_owned(), + RumaError::try_from_http_response(http_response).map_err(|_| { + Error::BadServerResponse("Server returned bad error response.") + })?, + )) + } } Err(e) => Err(e.into()), }