diff --git a/src/client_server/membership.rs b/src/client_server/membership.rs index d491ca0..f648978 100644 --- a/src/client_server/membership.rs +++ b/src/client_server/membership.rs @@ -2,7 +2,7 @@ use super::State; use crate::{ client_server, pdu::{PduBuilder, PduEvent}, - utils, ConduitResult, Database, Error, Result, Ruma, + server_server, utils, ConduitResult, Database, Error, Result, Ruma, }; use log::{error, warn}; use ruma::{ @@ -21,7 +21,7 @@ use ruma::{ serde::{to_canonical_value, CanonicalJsonObject, Raw}, EventId, RoomId, RoomVersionId, ServerName, UserId, }; -use std::{collections::BTreeMap, convert::TryFrom, sync::Arc}; +use std::{collections::BTreeMap, convert::TryFrom}; #[cfg(feature = "conduit_bin")] use rocket::{get, post}; @@ -515,27 +515,6 @@ async fn join_room_by_id_helper( ) .await?; - let add_event_id = |pdu: &Raw| -> Result<(EventId, CanonicalJsonObject)> { - let mut value = serde_json::from_str(pdu.json().get()).map_err(|e| { - error!("{:?}: {:?}", pdu, e); - Error::BadServerResponse("Invalid PDU in server response") - })?; - let event_id = EventId::try_from(&*format!( - "${}", - ruma::signatures::reference_hash(&value, &room_version) - .expect("ruma can calculate reference hashes") - )) - .expect("ruma's reference hashes are valid event ids"); - - value.insert( - "event_id".to_owned(), - to_canonical_value(&event_id) - .expect("a valid EventId can be converted to CanonicalJsonValue"), - ); - - Ok((event_id, value)) - }; - let count = db.globals.next_count()?; let mut pdu_id = room_id.as_bytes().to_vec(); @@ -546,23 +525,15 @@ async fn join_room_by_id_helper( .map_err(|_| Error::BadServerResponse("Invalid PDU in send_join response."))?; let mut state = BTreeMap::new(); + let mut pub_key_map = BTreeMap::new(); + + for pdu in send_join_response.room_state.state.iter() { + let (event_id, value) = validate_and_add_event_id(pdu, &room_version, &mut pub_key_map, &db).await?; + let pdu = PduEvent::from_id_val(&event_id, value.clone()).map_err(|e| { + warn!("{:?}: {}", value, e); + Error::BadServerResponse("Invalid PDU in send_join response.") + })?; - for pdu in send_join_response - .room_state - .state - .iter() - .map(add_event_id) - .map(|r| { - let (event_id, value) = r?; - PduEvent::from_id_val(&event_id, value.clone()) - .map(|ev| (event_id, Arc::new(ev))) - .map_err(|e| { - warn!("{:?}: {}", value, e); - Error::BadServerResponse("Invalid PDU in send_join response.") - }) - }) - { - let (_id, pdu) = pdu?; db.rooms.add_pdu_outlier(&pdu)?; if let Some(state_key) = &pdu.state_key { if pdu.kind == EventType::RoomMember { @@ -612,22 +583,12 @@ async fn join_room_by_id_helper( db.rooms.force_state(room_id, state, &db.globals)?; - for pdu in send_join_response - .room_state - .auth_chain - .iter() - .map(add_event_id) - .map(|r| { - let (event_id, value) = r?; - PduEvent::from_id_val(&event_id, value.clone()) - .map(|ev| (event_id, Arc::new(ev))) - .map_err(|e| { - warn!("{:?}: {}", value, e); - Error::BadServerResponse("Invalid PDU in send_join response.") - }) - }) - { - let (_id, pdu) = pdu?; + for pdu in send_join_response.room_state.auth_chain.iter() { + let (event_id, value) = validate_and_add_event_id(pdu, &room_version, &mut pub_key_map, &db).await?; + let pdu = PduEvent::from_id_val(&event_id, value.clone()).map_err(|e| { + warn!("{:?}: {}", value, e); + Error::BadServerResponse("Invalid PDU in send_join response.") + })?; db.rooms.add_pdu_outlier(&pdu)?; } @@ -674,3 +635,32 @@ async fn join_room_by_id_helper( Ok(join_room_by_id::Response::new(room_id.clone()).into()) } + +async fn validate_and_add_event_id( + pdu: &Raw, + room_version: &RoomVersionId, + pub_key_map: &mut BTreeMap>, + db: &Database, +) -> Result<(EventId, CanonicalJsonObject)> { + let mut value = serde_json::from_str::(pdu.json().get()).map_err(|e| { + error!("{:?}: {:?}", pdu, e); + Error::BadServerResponse("Invalid PDU in server response") + })?; + + server_server::fetch_required_signing_keys(&value, pub_key_map, db).await?; + + let event_id = EventId::try_from(&*format!( + "${}", + ruma::signatures::reference_hash(&value, &room_version) + .expect("ruma can calculate reference hashes") + )) + .expect("ruma's reference hashes are valid event ids"); + + value.insert( + "event_id".to_owned(), + to_canonical_value(&event_id) + .expect("a valid EventId can be converted to CanonicalJsonValue"), + ); + + Ok((event_id, value)) +} diff --git a/src/server_server.rs b/src/server_server.rs index 304bc19..39b626f 100644 --- a/src/server_server.rs +++ b/src/server_server.rs @@ -658,44 +658,7 @@ fn handle_incoming_pdu<'a>( // We go through all the signatures we see on the value and fetch the corresponding signing // keys - for (signature_server, signature) in match value - .get("signatures") - .ok_or_else(|| "No signatures in server response pdu.".to_string())? - { - CanonicalJsonValue::Object(map) => map, - _ => return Err("Invalid signatures object in server response pdu.".to_string()), - } { - let signature_object = match signature { - CanonicalJsonValue::Object(map) => map, - _ => { - return Err( - "Invalid signatures content object in server response pdu.".to_string() - ) - } - }; - - let signature_ids = signature_object.keys().collect::>(); - - debug!("Fetching signing keys for {}", signature_server); - let keys = match fetch_signing_keys( - &db, - &Box::::try_from(&**signature_server).map_err(|_| { - "Invalid servername in signatures of server response pdu.".to_string() - })?, - signature_ids, - ) - .await - { - Ok(keys) => keys, - Err(_) => { - return Err( - "Signature verification failed: Could not fetch signing key.".to_string(), - ); - } - }; - - pub_key_map.insert(signature_server.clone(), keys); - } + fetch_required_signing_keys(&value, pub_key_map, db).await.map_err(|e| e.to_string())?; // 2. Check signatures, otherwise drop // 3. check content hash, redact if doesn't match @@ -1639,38 +1602,58 @@ pub fn get_profile_information_route<'a>( .into()) } -/* -#[cfg_attr( - feature = "conduit_bin", - get("/_matrix/federation/v2/invite/<_>/<_>", data = "") -)] -pub fn get_user_devices_route<'a>( - db: State<'a, Database>, - body: Ruma>, -) -> ConduitResult { - if !db.globals.allow_federation() { - return Err(Error::bad_config("Federation is disabled.")); - } - - let mut displayname = None; - let mut avatar_url = None; - - match body.field { - Some(ProfileField::DisplayName) => displayname = db.users.displayname(&body.user_id)?, - Some(ProfileField::AvatarUrl) => avatar_url = db.users.avatar_url(&body.user_id)?, - None => { - displayname = db.users.displayname(&body.user_id)?; - avatar_url = db.users.avatar_url(&body.user_id)?; +pub async fn fetch_required_signing_keys( + event: &BTreeMap, + pub_key_map: &mut BTreeMap>, + db: &Database, +) -> Result<()> { + // We go through all the signatures we see on the value and fetch the corresponding signing + // keys + for (signature_server, signature) in match event + .get("signatures") + .ok_or_else(|| Error::BadServerResponse("No signatures in server response pdu."))? + { + CanonicalJsonValue::Object(map) => map, + _ => { + return Err(Error::BadServerResponse( + "Invalid signatures object in server response pdu.", + )) } + } { + let signature_object = match signature { + CanonicalJsonValue::Object(map) => map, + _ => { + return Err(Error::BadServerResponse( + "Invalid signatures content object in server response pdu.", + )) + } + }; + + let signature_ids = signature_object.keys().collect::>(); + + debug!("Fetching signing keys for {}", signature_server); + let keys = match fetch_signing_keys( + db, + &Box::::try_from(&**signature_server).map_err(|_| { + Error::BadServerResponse("Invalid servername in signatures of server response pdu.") + })?, + signature_ids, + ) + .await + { + Ok(keys) => keys, + Err(_) => { + return Err(Error::BadServerResponse( + "Signature verification failed: Could not fetch signing key.", + )); + } + }; + + pub_key_map.insert(signature_server.clone(), keys); } - Ok(get_profile_information::v1::Response { - displayname, - avatar_url, - } - .into()) + Ok(()) } -*/ #[cfg(test)] mod tests {