diff --git a/src/ruma_wrapper.rs b/src/ruma_wrapper.rs index c60c04e..5685ac6 100644 --- a/src/ruma_wrapper.rs +++ b/src/ruma_wrapper.rs @@ -4,11 +4,12 @@ use ruma::{ identifiers::{DeviceId, UserId}, Outgoing, }; +use std::collections::BTreeMap; use std::ops::Deref; #[cfg(feature = "conduit_bin")] use { - crate::utils, + crate::{server_server, utils}, log::{debug, warn}, rocket::{ data::{ @@ -21,7 +22,11 @@ use { tokio::io::AsyncReadExt, Request, State, }, - ruma::api::{AuthScheme, IncomingRequest}, + ruma::{ + api::{AuthScheme, IncomingRequest}, + signatures::CanonicalJsonValue, + ServerName, + }, std::convert::TryFrom, std::io::Cursor, }; @@ -72,6 +77,11 @@ where .map(|s| s[7..].to_owned()) // Split off "Bearer " .or_else(|| request.get_query_value("access_token").and_then(|r| r.ok())); + let limit = db.globals.max_request_size(); + let mut handle = data.open(ByteUnit::Byte(limit.into())); + let mut body = Vec::new(); + handle.read_to_end(&mut body).await.unwrap(); + let (sender_user, sender_device, from_appservice) = if let Some((_id, registration)) = db.appservice .iter_all() @@ -129,7 +139,139 @@ where return Failure((Status::raw(582), ())); } } - AuthScheme::ServerSignatures => (None, None, false), + AuthScheme::ServerSignatures => { + // Get origin from header + let x_matrix = match request + .headers() + .get_one("Authorization") + .map(|s| { + s[9..] + .split_terminator(',').map(|field| {let mut splits = field.splitn(2, '='); (splits.next(), splits.next().map(|s| s.trim_matches('"')))}).collect::>() + }) // Split off "X-Matrix " and parse the rest + { + Some(t) => t, + None => { + warn!("No Authorization header"); + + // Forbidden + return Failure((Status::raw(580), ())); + } + }; + + let origin_str = match x_matrix.get(&Some("origin")) { + Some(Some(o)) => *o, + _ => { + warn!("Invalid X-Matrix header origin field: {:?}", x_matrix); + + // Forbidden + return Failure((Status::raw(580), ())); + } + }; + + let origin = match Box::::try_from(origin_str) { + Ok(s) => s, + _ => { + warn!( + "Invalid server name in X-Matrix header origin field: {:?}", + x_matrix + ); + + // Forbidden + return Failure((Status::raw(580), ())); + } + }; + + let key = match x_matrix.get(&Some("key")) { + Some(Some(k)) => *k, + _ => { + warn!("Invalid X-Matrix header key field: {:?}", x_matrix); + + // Forbidden + return Failure((Status::raw(580), ())); + } + }; + + let sig = match x_matrix.get(&Some("sig")) { + Some(Some(s)) => *s, + _ => { + warn!("Invalid X-Matrix header sig field: {:?}", x_matrix); + + // Forbidden + return Failure((Status::raw(580), ())); + } + }; + + let json_body = serde_json::from_slice::(&body); + + let mut request_map = BTreeMap::::new(); + + if let Ok(json_body) = json_body { + request_map.insert("content".to_owned(), json_body); + }; + + request_map.insert( + "method".to_owned(), + CanonicalJsonValue::String(request.method().to_string()), + ); + request_map.insert( + "uri".to_owned(), + CanonicalJsonValue::String(request.uri().to_string()), + ); + request_map.insert( + "origin".to_owned(), + CanonicalJsonValue::String(origin.as_str().to_owned()), + ); + request_map.insert( + "destination".to_owned(), + CanonicalJsonValue::String( + db.globals.server_name().as_str().to_owned(), + ), + ); + + let mut origin_signatures = BTreeMap::new(); + origin_signatures + .insert(key.to_owned(), CanonicalJsonValue::String(sig.to_owned())); + + let mut signatures = BTreeMap::new(); + signatures.insert( + origin.as_str().to_owned(), + CanonicalJsonValue::Object(origin_signatures), + ); + + request_map.insert( + "signatures".to_owned(), + CanonicalJsonValue::Object(signatures), + ); + + let keys = match server_server::fetch_signing_keys( + &db, + &origin, + vec![&key.to_owned()], + ) + .await + { + Ok(b) => b, + Err(e) => { + warn!("Failed to fetch signing keys: {}", e); + + // Forbidden + return Failure((Status::raw(580), ())); + } + }; + + let mut pub_key_map = BTreeMap::new(); + pub_key_map.insert(origin.as_str().to_owned(), keys); + + match ruma::signatures::verify_json(&pub_key_map, &request_map) { + Ok(()) => (None, None, false), + Err(e) => { + warn!("Failed to verify json request: {}: {:?} {:?}", e, pub_key_map, request_map); + + // Forbidden + return Failure((Status::raw(580), ())); + } + } + } AuthScheme::None => (None, None, false), } }; @@ -141,11 +283,6 @@ where http_request = http_request.header(header.name.as_str(), &*header.value); } - let limit = db.globals.max_request_size(); - let mut handle = data.open(ByteUnit::Byte(limit.into())); - let mut body = Vec::new(); - handle.read_to_end(&mut body).await.unwrap(); - let http_request = http_request.body(&*body).unwrap(); debug!("{:?}", http_request); match ::try_from_http_request(http_request) { diff --git a/src/server_server.rs b/src/server_server.rs index ad198fc..90b5099 100644 --- a/src/server_server.rs +++ b/src/server_server.rs @@ -2106,7 +2106,10 @@ pub fn get_room_information_route<'a>( let room_id = db .rooms .id_from_alias(&body.room_alias)? - .ok_or(Error::BadRequest(ErrorKind::NotFound, "Room alias not found."))?; + .ok_or(Error::BadRequest( + ErrorKind::NotFound, + "Room alias not found.", + ))?; Ok(get_room_information::v1::Response { room_id,