diff --git a/src/database/rooms.rs b/src/database/rooms.rs index d62d4b0..325a2e2 100644 --- a/src/database/rooms.rs +++ b/src/database/rooms.rs @@ -150,7 +150,7 @@ impl Rooms { } } - /// Returns the last state hash key added to the db. + /// Returns the state hash for this pdu. pub fn pdu_state_hash(&self, pdu_id: &[u8]) -> Result> { Ok(self.pduid_statehash.get(pdu_id)?) } diff --git a/src/server_server.rs b/src/server_server.rs index 77f0fa8..0eb7d6f 100644 --- a/src/server_server.rs +++ b/src/server_server.rs @@ -5,6 +5,7 @@ use log::{error, info, warn}; use rocket::{get, post, put, response::content::Json, State}; use ruma::{ api::{ + client::r0::state, federation::{ directory::{get_public_rooms, get_public_rooms_filtered}, discovery::{ @@ -590,6 +591,8 @@ pub async fn send_transaction_message_route<'a>( continue; } }; + + // 1. check the server is in the room (optional) if !db.rooms.exists(&room_id)? { resolved_map.insert(event_id, Err("Room is unknown to this server".to_string())); continue; @@ -634,14 +637,13 @@ pub async fn send_transaction_message_route<'a>( // the auth events that it references. let mut auth_cache = EventMap::new(); - // 1. check the server is in the room (optional) // 2. check content hash, redact if doesn't match // 3. fetch any missing auth events doing all checks listed here starting at 1. These are not timeline events // 4. reject "due to auth events" if can't get all the auth events or some of the auth events are also rejected "due to auth events" // 5. reject "due to auth events" if the event doesn't pass auth based on the auth events - // 6. persist this event as an outlier // 7. if not timeline event: stop - let pdu = match validate_event( + // 8. fetch any missing prev events doing all checks listed here starting at 1. These are timeline events + let (pdu, previous) = match validate_event( &db, value, event_id.clone(), @@ -659,59 +661,16 @@ pub async fn send_transaction_message_route<'a>( } }; - let pdu = Arc::new(pdu.clone()); - // Fetch any unknown prev_events or retrieve them from the DB - let previous = match fetch_events( - &db, - server_name, - &pub_key_map, - &pdu.prev_events, - &mut auth_cache, - ) - .await - { - Ok(mut evs) if evs.len() == 1 => Some(evs.remove(0)), - _ => None, + let single_prev = if previous.len() == 1 { + previous.first().cloned() + } else { + None }; - // [auth_cache] At this point we have the auth chain of the incoming event. - let mut event_map: state_res::EventMap> = auth_cache - .iter() - .map(|(k, v)| (k.clone(), v.clone())) - .collect(); + // 6. persist the event as an outlier. + db.rooms.append_pdu_outlier(pdu.event_id(), &pdu)?; - // Check that the event passes auth based on the auth_events - let is_authed = state_res::event_auth::auth_check( - &RoomVersionId::Version6, - &pdu, - previous.clone(), - &pdu.auth_events - .iter() - .map(|id| { - auth_cache - .get(id) - .map(|pdu| ((pdu.kind(), pdu.state_key()), pdu.clone())) - .ok_or_else(|| { - Error::Conflict( - "Auth event not found, event failed recursive auth checks.", - ) - }) - }) - .collect::>>()?, - None, // TODO: third party invite - ) - .map_err(|_e| Error::Conflict("Auth check failed"))?; - - if !is_authed { - resolved_map.insert( - pdu.event_id().clone(), - Err("Event has failed auth check with auth events".into()), - ); - continue; - } - // End of step 4. - - // Step 5. event passes auth based on state at the event + // Step 10. check the auth of the event passes based on the calculated state of the event let (state_at_event, incoming_auth_events): (StateMap>, Vec>) = match db .sending @@ -757,9 +716,7 @@ pub async fn send_transaction_message_route<'a>( &res.auth_chain_ids, &mut auth_cache, ) - .await? - .into_iter() - .collect(), + .await?, ) } Err(_) => { @@ -771,10 +728,11 @@ pub async fn send_transaction_message_route<'a>( } }; + // 10. This is the actual auth check for state at the event if !state_res::event_auth::auth_check( &RoomVersionId::Version6, &pdu, - previous.clone(), + single_prev.clone(), &state_at_event, None, // TODO: third party invite ) @@ -787,10 +745,34 @@ pub async fn send_transaction_message_route<'a>( ); continue; } - // End of step 5. + // End of step 10. + + // 12. check if the event passes auth based on the "current state" of the room, if not "soft fail" it + let current_state = db + .rooms + .room_state_full(pdu.room_id())? + .into_iter() + .map(|(k, v)| ((k.0, Some(k.1)), Arc::new(v))) + .collect(); + + if !state_res::event_auth::auth_check( + &RoomVersionId::Version6, + &pdu, + single_prev.clone(), + ¤t_state, + None, + ) + .map_err(|_e| Error::Conflict("Auth check failed"))? + { + // Soft fail, we add the event as an outlier. + resolved_map.insert( + pdu.event_id().clone(), + Err("Event has been soft failed".into()), + ); + }; // Gather the forward extremities and resolve - let fork_states = match forward_extremity_ids( + let fork_states = match forward_extremities( &db, &pdu, server_name, @@ -806,7 +788,9 @@ pub async fn send_transaction_message_route<'a>( } }; - // Step 6. event passes auth based on state of all forks and current room state + // 13. start state-res with all previous forward extremities minus the ones that are in + // the prev_events of this event plus the new one created by this event and use + // the result as the new room state let state_at_forks = if fork_states.is_empty() { // State is empty Default::default() @@ -852,6 +836,7 @@ pub async fn send_transaction_message_route<'a>( } info!("{} event's were not in the auth_cache", number_fetches); + let mut event_map = EventMap::new(); // Add everything we will need to event_map event_map.extend( auth_events @@ -904,7 +889,7 @@ pub async fn send_transaction_message_route<'a>( if !state_res::event_auth::auth_check( &RoomVersionId::Version6, &pdu, - previous, + single_prev, &state_at_forks, None, ) @@ -925,14 +910,19 @@ pub async fn send_transaction_message_route<'a>( Ok(dbg!(send_transaction_message::v1::Response { pdus: resolved_map }).into()) } +/// An async function that can recursively calls itself. +type AsyncRecursiveResult<'a, T> = Pin> + 'a + Send>>; + /// TODO: don't add as outlier if event is fetched as a result of gathering auth_events /// Validate any event that is given to us by another server. /// /// 1. Is a valid event, otherwise it is dropped (PduEvent deserialization satisfies this). -/// 2. Passes signature checks, otherwise event is dropped. -/// 3. Passes hash checks, otherwise it is redacted before being processed further. -/// 4. Passes auth_chain collection (we can gather the events that auth this event recursively). -/// 5. Once the event has passed all checks it can be added as an outlier to the DB. +/// 2. check content hash, redact if doesn't match +/// 3. fetch any missing auth events doing all checks listed here starting at 1. These are not timeline events +/// 4. reject "due to auth events" if can't get all the auth events or some of the auth events are also rejected "due to auth events" +/// 5. reject "due to auth events" if the event doesn't pass auth based on the auth events +/// 7. if not timeline event: stop +/// 8. fetch any missing prev events doing all checks listed here starting at 1. These are timeline events fn validate_event<'a>( db: &'a Database, value: CanonicalJsonObject, @@ -940,9 +930,24 @@ fn validate_event<'a>( pub_key_map: &'a PublicKeyMap, origin: &'a ServerName, auth_cache: &'a mut EventMap>, -) -> Pin> + 'a + Send>> { +) -> AsyncRecursiveResult<'a, (Arc, Vec>)> { Box::pin(async move { - let mut val = signature_and_hash_check(&pub_key_map, value)?; + let mut val = + match ruma::signatures::verify_event(pub_key_map, &value, &RoomVersionId::Version6) { + Ok(ver) => { + if let ruma::signatures::Verified::Signatures = ver { + match ruma::signatures::redact(&value, &RoomVersionId::Version6) { + Ok(obj) => obj, + Err(_) => return Err("Redaction failed".to_string()), + } + } else { + value + } + } + Err(_e) => { + return Err("Signature verification failed".to_string()); + } + }; // Now that we have checked the signature and hashes we can add the eventID and convert // to our PduEvent type also finally verifying the first step listed above @@ -959,11 +964,42 @@ fn validate_event<'a>( .await .map_err(|_| "Event failed auth chain check".to_string())?; - db.rooms - .append_pdu_outlier(pdu.event_id(), &pdu) + let pdu = Arc::new(pdu.clone()); + + // 8. fetch any missing prev events doing all checks listed here starting at 1. These are timeline events + let previous = fetch_events(&db, origin, &pub_key_map, &pdu.prev_events, auth_cache) + .await .map_err(|e| e.to_string())?; - Ok(pdu) + // Check that the event passes auth based on the auth_events + let is_authed = state_res::event_auth::auth_check( + &RoomVersionId::Version6, + &pdu, + if previous.len() == 1 { + previous.first().cloned() + } else { + None + }, + &pdu.auth_events + .iter() + .map(|id| { + auth_cache + .get(id) + .map(|pdu| ((pdu.kind(), pdu.state_key()), pdu.clone())) + .ok_or_else(|| { + "Auth event not found, event failed recursive auth checks.".to_string() + }) + }) + .collect::, _>>()?, + None, // TODO: third party invite + ) + .map_err(|_e| "Auth check failed".to_string())?; + + if !is_authed { + return Err("Event has failed auth check with auth events".to_string()); + } + + Ok((pdu, previous)) }) } @@ -990,7 +1026,10 @@ async fn fetch_check_auth_events( let ev = fetch_events(db, origin, key_map, &[ev_id.clone()], auth_cache) .await - .map(|mut vec| vec.remove(0))?; + .map(|mut vec| { + vec.pop() + .ok_or_else(|| Error::Conflict("Event was not found in fetch_events")) + })??; stack.extend(ev.auth_events()); auth_cache.insert(ev.event_id().clone(), ev); @@ -1028,11 +1067,12 @@ async fn fetch_events( { Ok(res) => { let (event_id, value) = crate::pdu::gen_event_id_canonical_json(&res.pdu); - let pdu = validate_event(db, value, event_id, key_map, origin, auth_cache) - .await - .map_err(|_| Error::Conflict("Authentication of event failed"))?; + let (pdu, _) = + validate_event(db, value, event_id, key_map, origin, auth_cache) + .await + .map_err(|_| Error::Conflict("Authentication of event failed"))?; - Arc::new(pdu) + pdu } Err(_) => return Err(Error::BadServerResponse("Failed to fetch event")), }, @@ -1063,31 +1103,11 @@ async fn fetch_signing_keys( } } -fn signature_and_hash_check( - pub_key_map: &ruma::signatures::PublicKeyMap, - value: CanonicalJsonObject, -) -> std::result::Result { - Ok( - match ruma::signatures::verify_event(pub_key_map, &value, &RoomVersionId::Version6) { - Ok(ver) => { - if let ruma::signatures::Verified::Signatures = ver { - error!("CONTENT HASH FAILED"); - match ruma::signatures::redact(&value, &RoomVersionId::Version6) { - Ok(obj) => obj, - Err(_) => return Err("Redaction failed".to_string()), - } - } else { - value - } - } - Err(_e) => { - return Err("Signature verification failed".to_string()); - } - }, - ) -} - -async fn forward_extremity_ids( +/// Gather all state snapshots needed to resolve the current state of the room. +/// +/// Step 11. ensure that the state is derived from the previous current state (i.e. we calculated by doing state res +/// where one of the inputs was a previously trusted set of state, don't just trust a set of state we got from a remote) +async fn forward_extremities( db: &Database, pdu: &PduEvent, origin: &ServerName, @@ -1102,6 +1122,8 @@ async fn forward_extremity_ids( } } + let current_hash = db.rooms.current_state_hash(pdu.room_id())?; + let mut includes_current_state = false; let mut fork_states = vec![]; for id in ¤t_leaves { if let Some(id) = db.rooms.get_pdu_id(id)? { @@ -1109,6 +1131,10 @@ async fn forward_extremity_ids( .rooms .pdu_state_hash(&id)? .expect("found pdu with no statehash"); + + if current_hash.as_ref() == Some(&state_hash) { + includes_current_state = true; + } let state = db .rooms .state_full(&pdu.room_id, &state_hash)? @@ -1144,6 +1170,17 @@ async fn forward_extremity_ids( } } + // This guarantees that our current room state is included + if !includes_current_state && current_hash.is_some() { + fork_states.push( + db.rooms + .state_full(pdu.room_id(), current_hash.as_ref().unwrap())? + .into_iter() + .map(|(k, v)| ((k.0, Some(k.1)), Arc::new(v))) + .collect(), + ) + } + Ok(fork_states) }