diff --git a/conduit-example.toml b/conduit-example.toml index b82da2c..bb3ae33 100644 --- a/conduit-example.toml +++ b/conduit-example.toml @@ -23,11 +23,11 @@ port = 6167 max_request_size = 20_000_000 # in bytes # Disable registration. No new users will be able to register on this server -#allow_registration = true +#allow_registration = false # Disable encryption, so no new encrypted rooms can be created # Note: existing rooms will continue to work -#allow_encryption = true +#allow_encryption = false #allow_federation = false #cache_capacity = 1073741824 # in bytes, 1024 * 1024 * 1024 diff --git a/src/database.rs b/src/database.rs index 8fcffd9..ea65d6f 100644 --- a/src/database.rs +++ b/src/database.rs @@ -22,7 +22,7 @@ use std::fs::remove_dir_all; use std::sync::{Arc, RwLock}; use tokio::sync::Semaphore; -#[derive(Clone, Deserialize)] +#[derive(Clone, Debug, Deserialize)] pub struct Config { server_name: Box, database_path: String, @@ -102,7 +102,12 @@ impl Database { let (admin_sender, admin_receiver) = mpsc::unbounded(); let db = Self { - globals: globals::Globals::load(db.open_tree("global")?, config).await?, + globals: globals::Globals::load( + db.open_tree("global")?, + db.open_tree("servertimeout_signingkey")?, + config, + ) + .await?, users: users::Users { userid_password: db.open_tree("userid_password")?, userid_displayname: db.open_tree("userid_displayname")?, @@ -155,6 +160,7 @@ impl Database { stateid_pduid: db.open_tree("stateid_pduid")?, pduid_statehash: db.open_tree("pduid_statehash")?, roomid_statehash: db.open_tree("roomid_statehash")?, + eventid_outlierpdu: db.open_tree("eventid_outlierpdu")?, }, account_data: account_data::AccountData { roomuserdataid_accountdata: db.open_tree("roomuserdataid_accountdata")?, diff --git a/src/database/globals.rs b/src/database/globals.rs index beb7de5..7eb162b 100644 --- a/src/database/globals.rs +++ b/src/database/globals.rs @@ -1,7 +1,10 @@ use crate::{database::Config, utils, Error, Result}; use log::error; -use ruma::ServerName; -use std::collections::HashMap; +use ruma::{ + api::federation::discovery::{ServerSigningKeys, VerifyKey}, + ServerName, ServerSigningKeyId, +}; +use std::collections::{BTreeMap, HashMap}; use std::sync::Arc; use std::sync::RwLock; use std::time::Duration; @@ -20,10 +23,15 @@ pub struct Globals { reqwest_client: reqwest::Client, dns_resolver: TokioAsyncResolver, jwt_decoding_key: Option>, + pub(super) servertimeout_signingkey: sled::Tree, // ServerName -> algorithm:key + pubkey } impl Globals { - pub async fn load(globals: sled::Tree, config: Config) -> Result { + pub async fn load( + globals: sled::Tree, + server_keys: sled::Tree, + config: Config, + ) -> Result { let bytes = &*globals .update_and_fetch("keypair", utils::generate_keypair)? .expect("utils::generate_keypair always returns Some"); @@ -82,6 +90,7 @@ impl Globals { })?, actual_destination_cache: Arc::new(RwLock::new(HashMap::new())), jwt_decoding_key, + servertimeout_signingkey: server_keys, }) } @@ -139,4 +148,66 @@ impl Globals { pub fn jwt_decoding_key(&self) -> Option<&jsonwebtoken::DecodingKey<'_>> { self.jwt_decoding_key.as_ref() } + + /// TODO: the key valid until timestamp is only honored in room version > 4 + /// Remove the outdated keys and insert the new ones. + /// + /// This doesn't actually check that the keys provided are newer than the old set. + pub fn add_signing_key(&self, origin: &ServerName, keys: &ServerSigningKeys) -> Result<()> { + // Remove outdated keys + let now = crate::utils::millis_since_unix_epoch(); + for item in self.servertimeout_signingkey.scan_prefix(origin.as_bytes()) { + let (k, _) = item?; + let valid_until = k + .splitn(2, |&b| b == 0xff) + .nth(1) + .map(crate::utils::u64_from_bytes) + .ok_or_else(|| Error::bad_database("Invalid signing keys."))? + .map_err(|_| Error::bad_database("Invalid signing key valid until bytes"))?; + + if now > valid_until { + self.servertimeout_signingkey.remove(k)?; + } + } + + let mut key = origin.as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice( + &(keys + .valid_until_ts + .duration_since(std::time::UNIX_EPOCH) + .expect("time is valid") + .as_millis() as u64) + .to_be_bytes(), + ); + + self.servertimeout_signingkey.insert( + key, + serde_json::to_vec(&keys.verify_keys).expect("ServerSigningKeys are a valid string"), + )?; + Ok(()) + } + + /// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found for the server. + pub fn signing_keys_for( + &self, + origin: &ServerName, + ) -> Result> { + let now = crate::utils::millis_since_unix_epoch(); + for item in self.servertimeout_signingkey.scan_prefix(origin.as_bytes()) { + let (k, bytes) = item?; + let valid_until = k + .splitn(2, |&b| b == 0xff) + .nth(1) + .map(crate::utils::u64_from_bytes) + .ok_or_else(|| Error::bad_database("Invalid signing keys."))? + .map_err(|_| Error::bad_database("Invalid signing key valid until bytes"))?; + // If these keys are still valid use em! + if valid_until > now { + return serde_json::from_slice(&bytes) + .map_err(|_| Error::bad_database("Invalid BTreeMap<> of signing keys")); + } + } + Ok(BTreeMap::default()) + } } diff --git a/src/database/rooms.rs b/src/database/rooms.rs index 88a772b..81abd62 100644 --- a/src/database/rooms.rs +++ b/src/database/rooms.rs @@ -65,6 +65,9 @@ pub struct Rooms { /// The state for a given state hash. pub(super) statekey_short: sled::Tree, // StateKey = EventType + StateKey, Short = Count pub(super) stateid_pduid: sled::Tree, // StateId = StateHash + Short, PduId = Count (without roomid) + + /// Any pdu that has passed the steps up to auth with auth_events. + pub(super) eventid_outlierpdu: sled::Tree, } impl Rooms { @@ -188,72 +191,6 @@ impl Rooms { Ok(events) } - /// Returns a Vec of the related auth events to the given `event`. - /// - /// A recursive list of all the auth_events going back to `RoomCreate` for each event in `event_ids`. - pub fn auth_events_full( - &self, - _room_id: &RoomId, - event_ids: &[EventId], - ) -> Result> { - let mut result = BTreeMap::new(); - let mut stack = event_ids.to_vec(); - - // DFS for auth event chain - while !stack.is_empty() { - let ev_id = stack.pop().unwrap(); - if result.contains_key(&ev_id) { - continue; - } - - if let Some(ev) = self.get_pdu(&ev_id)? { - stack.extend(ev.auth_events()); - result.insert(ev.event_id().clone(), ev); - } - } - - Ok(result.into_iter().map(|(_, v)| v).collect()) - } - - /// Returns a Vec representing the difference in auth chains of the given `events`. - /// - /// Each inner `Vec` of `event_ids` represents a state set (state at each forward extremity). - pub fn auth_chain_diff( - &self, - room_id: &RoomId, - event_ids: Vec>, - ) -> Result> { - use std::collections::BTreeSet; - - let mut chains = vec![]; - for ids in event_ids { - // TODO state store `auth_event_ids` returns self in the event ids list - // when an event returns `auth_event_ids` self is not contained - let chain = self - .auth_events_full(room_id, &ids)? - .into_iter() - .map(|pdu| pdu.event_id) - .collect::>(); - chains.push(chain); - } - - if let Some(chain) = chains.first() { - let rest = chains.iter().skip(1).flatten().cloned().collect(); - let common = chain.intersection(&rest).collect::>(); - - Ok(chains - .iter() - .flatten() - .filter(|id| !common.contains(&id)) - .cloned() - .collect::>() - .into_iter() - .collect()) - } else { - Ok(vec![]) - } - } - /// Generate a new StateHash. /// /// A unique hash made from hashing all PDU ids of the state joined with 0xff. @@ -475,6 +412,31 @@ impl Rooms { Ok(()) } + /// Returns the pdu from the outlier tree. + pub fn get_pdu_outlier(&self, event_id: &EventId) -> Result> { + self.eventid_outlierpdu + .get(event_id.as_bytes())? + .map_or(Ok(None), |pdu| { + Ok(Some( + serde_json::from_slice(&pdu) + .map_err(|_| Error::bad_database("Invalid PDU in db."))?, + )) + }) + } + + /// Returns true if the event_id was previously inserted. + pub fn append_pdu_outlier(&self, event_id: &EventId, pdu: &PduEvent) -> Result { + log::info!("Number of outlier pdu's {}", self.eventid_outlierpdu.len()); + let res = self + .eventid_outlierpdu + .insert( + event_id.as_bytes(), + &*serde_json::to_string(&pdu).expect("PduEvent is always a valid String"), + ) + .map(|op| op.is_some())?; + Ok(res) + } + /// Creates a new persisted data unit and adds it to a room. /// /// By this point the incoming event should be fully authenticated, no auth happens @@ -516,6 +478,9 @@ impl Rooms { } } + // We no longer keep this pdu as an outlier + self.eventid_outlierpdu.remove(pdu.event_id().as_bytes())?; + self.replace_pdu_leaves(&pdu.room_id, &pdu.event_id)?; // Mark as read first so the sending client doesn't get a notification even if appending diff --git a/src/error.rs b/src/error.rs index c57843c..fed545c 100644 --- a/src/error.rs +++ b/src/error.rs @@ -122,10 +122,9 @@ impl log::Log for ConduitLogger { let output = format!("{} - {}", record.level(), record.args()); if self.enabled(record.metadata()) - && (record - .module_path() - .map_or(false, |path| path.starts_with("conduit::")) - || record + && (record.module_path().map_or(false, |path| { + path.starts_with("conduit::") || path.starts_with("state") + }) || record .module_path() .map_or(true, |path| !path.starts_with("rocket::")) // Rockets logs are annoying && record.metadata().level() <= log::Level::Warn) diff --git a/src/main.rs b/src/main.rs index 4cab764..e5c0399 100644 --- a/src/main.rs +++ b/src/main.rs @@ -167,6 +167,7 @@ fn setup_rocket() -> rocket::Rocket { .figment() .extract() .expect("It looks like your config is invalid. Please take a look at the error"); + let data = Database::load_or_create(config) .await .expect("config is valid"); diff --git a/src/server_server.rs b/src/server_server.rs index 64e0a05..6907e34 100644 --- a/src/server_server.rs +++ b/src/server_server.rs @@ -1,5 +1,4 @@ use crate::{client_server, utils, ConduitResult, Database, Error, PduEvent, Result, Ruma}; -use get_devices::v1::UserDevice; use get_profile_information::v1::ProfileField; use http::header::{HeaderValue, AUTHORIZATION, HOST}; use log::{error, info, warn}; @@ -7,7 +6,6 @@ use rocket::{get, post, put, response::content::Json, State}; use ruma::{ api::{ federation::{ - device::get_devices, directory::{get_public_rooms, get_public_rooms_filtered}, discovery::{ get_server_keys, get_server_version::v1 as get_server_version, ServerSigningKeys, @@ -20,7 +18,6 @@ use ruma::{ OutgoingRequest, }, directory::{IncomingFilter, IncomingRoomNetwork}, - events::pdu::Pdu, serde::to_canonical_value, signatures::{CanonicalJsonObject, CanonicalJsonValue, PublicKeyMap}, EventId, RoomId, RoomVersionId, ServerName, ServerSigningKeyId, UserId, @@ -28,9 +25,12 @@ use ruma::{ use state_res::{Event, EventMap, StateMap}; use std::{ collections::{BTreeMap, BTreeSet}, - convert::{TryFrom, TryInto}, + convert::TryFrom, fmt::Debug, + future::Future, net::{IpAddr, SocketAddr}, + pin::Pin, + result::Result as StdResult, sync::Arc, time::{Duration, SystemTime}, }; @@ -575,6 +575,26 @@ pub async fn send_transaction_message_route<'a>( // We do not add the event_id field to the pdu here because of signature and hashes checks let (event_id, value) = crate::pdu::gen_event_id_canonical_json(pdu); + // If we have no idea about this room skip the PDU + let room_id = match value + .get("room_id") + .map(|id| match id { + CanonicalJsonValue::String(id) => RoomId::try_from(id.as_str()).ok(), + _ => None, + }) + .flatten() + { + Some(id) => id, + None => { + resolved_map.insert(event_id, Err("Event needs a valid RoomId".to_string())); + continue; + } + }; + if !db.rooms.exists(&room_id)? { + resolved_map.insert(event_id, Err("Room is unknown to this server".to_string())); + continue; + } + let server_name = &body.body.origin; let mut pub_key_map = BTreeMap::new(); if let Some(sig) = value.get("signatures") { @@ -583,20 +603,12 @@ pub async fn send_transaction_message_route<'a>( for key in entity.keys() { // TODO: save this in a DB maybe... // fetch the public signing key - let res = db - .sending - .send_federation_request( - &db.globals, - <&ServerName>::try_from(key.as_str()).unwrap(), - get_server_keys::v2::Request::new(), - ) - .await?; + let origin = <&ServerName>::try_from(key.as_str()).unwrap(); + let keys = fetch_signing_keys(&db, origin).await?; pub_key_map.insert( - res.server_key.server_name.to_string(), - res.server_key - .verify_keys - .into_iter() + origin.to_string(), + keys.into_iter() .map(|(k, v)| (k.to_string(), v.key)) .collect(), ); @@ -615,10 +627,31 @@ pub async fn send_transaction_message_route<'a>( continue; } - // Ruma/PduEvent satisfies - 1. Is a valid event, otherwise it is dropped. - // 2. Passes signature checks, otherwise event is dropped. - // 3. Passes hash checks, otherwise it is redacted before being processed further. - let mut val = match signature_and_hash_check(&pub_key_map, value) { + // TODO: make this persist but not a DB Tree... + // This is all the auth_events that have been recursively fetched so they don't have to be + // deserialized over and over again. This could potentially also be some sort of trie (suffix tree) + // like structure so that once an auth event is known it would know (using indexes maybe) all of + // the auth events that it references. + let mut auth_cache = EventMap::new(); + + // 1. check the server is in the room (optional) + // 2. check content hash, redact if doesn't match + // 3. fetch any missing auth events doing all checks listed here starting at 1. These are not timeline events + // 4. reject "due to auth events" if can't get all the auth events or some of the auth events are also rejected "due to auth events" + // 5. reject "due to auth events" if the event doesn't pass auth based on the auth events + // 6. persist this event as an outlier + // 7. if not timeline event: stop + let pdu = match validate_event( + &db, + value, + event_id.clone(), + &pub_key_map, + server_name, + // All the auth events gathered will be here + &mut auth_cache, + ) + .await + { Ok(pdu) => pdu, Err(e) => { resolved_map.insert(event_id, Err(e)); @@ -626,59 +659,31 @@ pub async fn send_transaction_message_route<'a>( } }; - // Now that we have checked the signature and hashes we can add the eventID and convert - // to our PduEvent type also finally verifying the first step listed above - val.insert( - "event_id".to_owned(), - to_canonical_value(&event_id).expect("EventId is a valid CanonicalJsonValue"), - ); - let pdu = match serde_json::from_value::( - serde_json::to_value(val).expect("CanonicalJsonObj is a valid JsonValue"), - ) { - Ok(pdu) => pdu, - Err(_) => { - resolved_map.insert(event_id, Err("Event is not a valid PDU".into())); - continue; - } - }; + let pdu = Arc::new(pdu.clone()); - // If we have no idea about this room skip the PDU - if !db.rooms.exists(&pdu.room_id)? { - resolved_map.insert(event_id, Err("Room is unknown to this server".into())); - continue; - } - - let event = Arc::new(pdu.clone()); - dbg!(&*event); // Fetch any unknown prev_events or retrieve them from the DB - let previous = match fetch_events(&db, server_name, &pub_key_map, &pdu.prev_events).await { - Ok(mut evs) if evs.len() == 1 => Some(Arc::new(evs.remove(0))), + let previous = match fetch_events( + &db, + server_name, + &pub_key_map, + &pdu.prev_events, + &mut auth_cache, + ) + .await + { + Ok(mut evs) if evs.len() == 1 => Some(evs.remove(0)), _ => None, }; - // 4. Passes authorization rules based on the event's auth events, otherwise it is rejected. - // Recursively gather all auth events checking that the previous auth events are valid. - let auth_events: Vec = - match fetch_check_auth_events(&db, server_name, &pub_key_map, &pdu.prev_events).await { - Ok(events) => events, - Err(_) => { - resolved_map.insert( - pdu.event_id, - Err("Failed to recursively gather auth events".into()), - ); - continue; - } - }; - - let mut event_map: state_res::EventMap> = auth_events + let mut event_map: state_res::EventMap> = auth_cache .iter() - .map(|v| (v.event_id().clone(), Arc::new(v.clone()))) + .map(|(k, v)| (k.clone(), v.clone())) .collect(); // Check that the event passes auth based on the auth_events let is_authed = state_res::event_auth::auth_check( &RoomVersionId::Version6, - &event, + &pdu, previous.clone(), &pdu.auth_events .iter() @@ -696,9 +701,10 @@ pub async fn send_transaction_message_route<'a>( None, // TODO: third party invite ) .map_err(|_e| Error::Conflict("Auth check failed"))?; + if !is_authed { resolved_map.insert( - pdu.event_id, + pdu.event_id().clone(), Err("Event has failed auth check with auth events".into()), ); continue; @@ -720,7 +726,14 @@ pub async fn send_transaction_message_route<'a>( .await { Ok(res) => { - let state = fetch_events(&db, server_name, &pub_key_map, &res.pdu_ids).await?; + let state = fetch_events( + &db, + server_name, + &pub_key_map, + &res.pdu_ids, + &mut auth_cache, + ) + .await?; // Sanity check: there are no conflicting events in the state we received let mut seen = BTreeSet::new(); for ev in &state { @@ -732,21 +745,26 @@ pub async fn send_transaction_message_route<'a>( let state = state .into_iter() - .map(|pdu| ((pdu.kind.clone(), pdu.state_key.clone()), Arc::new(pdu))) + .map(|pdu| ((pdu.kind.clone(), pdu.state_key.clone()), pdu)) .collect(); ( state, - fetch_events(&db, server_name, &pub_key_map, &res.auth_chain_ids) - .await? - .into_iter() - .map(Arc::new) - .collect(), + fetch_events( + &db, + server_name, + &pub_key_map, + &res.auth_chain_ids, + &mut auth_cache, + ) + .await? + .into_iter() + .collect(), ) } Err(_) => { resolved_map.insert( - event.event_id().clone(), + pdu.event_id().clone(), Err("Fetching state for event failed".into()), ); continue; @@ -755,7 +773,7 @@ pub async fn send_transaction_message_route<'a>( if !state_res::event_auth::auth_check( &RoomVersionId::Version6, - &event, + &pdu, previous.clone(), &state_at_event, None, // TODO: third party invite @@ -764,37 +782,21 @@ pub async fn send_transaction_message_route<'a>( { // Event failed auth with state_at resolved_map.insert( - pdu.event_id, + event_id, Err("Event has failed auth check with state at the event".into()), ); continue; } // End of step 5. - // The event could still be soft failed - append_state_soft(&db, &pdu)?; - // Gather the forward extremities and resolve - let forward_extrems = forward_extremity_ids(&db, &pdu.room_id)?; - let mut fork_states: Vec>> = vec![]; - for id in &forward_extrems { - if let Some(id) = db.rooms.get_pdu_id(id)? { - let state_hash = db - .rooms - .pdu_state_hash(&id)? - .expect("found pdu with no statehash"); - let state = db - .rooms - .state_full(&pdu.room_id, &state_hash)? - .into_iter() - .map(|(k, v)| ((k.0, Some(k.1)), Arc::new(v))) - .collect(); - - fork_states.push(state); - } else { - todo!("we don't know of a pdu that is part of our known forks OOPS") + let fork_states = match forward_extremity_ids(&db, &pdu) { + Ok(states) => states, + Err(_) => { + resolved_map.insert(event_id, Err("Failed to gather forward extremities".into())); + continue; } - } + }; // Step 6. event passes auth based on state of all forks and current room state let state_at_forks = if fork_states.is_empty() { @@ -803,19 +805,47 @@ pub async fn send_transaction_message_route<'a>( } else if fork_states.len() == 1 { fork_states[0].clone() } else { - let auth_events = fork_states - .iter() - .map(|map| { - db.rooms - .auth_events_full( - pdu.room_id(), - &map.values() - .map(|pdu| pdu.event_id().clone()) - .collect::>(), + let mut auth_events = vec![]; + // this keeps track if we error so we can break out of these inner loops + // to continue on with the incoming PDU's + let mut failed = false; + for map in &fork_states { + let mut state_auth = vec![]; + for pdu in map.values() { + let event = match auth_cache.get(pdu.event_id()) { + Some(aev) => aev.clone(), + // We should know about every event at this point but just incase... + None => match fetch_events( + &db, + server_name, + &pub_key_map, + &[pdu.event_id().clone()], + &mut auth_cache, ) - .map(|pdus| pdus.into_iter().map(Arc::new).collect::>()) - }) - .collect::>>()?; + .await + .map(|mut vec| vec.remove(0)) + { + Ok(aev) => aev.clone(), + Err(_) => { + resolved_map.insert( + event_id.clone(), + Err("Event has been soft failed".into()), + ); + failed = true; + break; + } + }, + }; + state_auth.push(event); + } + if failed { + break; + } + auth_events.push(state_auth); + } + if failed { + continue; + } // Add everything we will need to event_map event_map.extend( @@ -862,74 +892,163 @@ pub async fn send_transaction_message_route<'a>( if !state_res::event_auth::auth_check( &RoomVersionId::Version6, - &event, + &pdu, previous, &state_at_forks, None, ) .map_err(|_e| Error::Conflict("Auth check failed"))? { - // Soft fail + // Soft fail, we add the event as an outlier. resolved_map.insert( - event.event_id().clone(), + pdu.event_id().clone(), Err("Event has been soft failed".into()), ); } else { append_state(&db, &pdu)?; // Event has passed all auth/stateres checks - resolved_map.insert(event.event_id().clone(), Ok(())); + resolved_map.insert(pdu.event_id().clone(), Ok(())); } } Ok(dbg!(send_transaction_message::v1::Response { pdus: resolved_map }).into()) } -async fn auth_each_event( - db: &Database, +/// Validate any event that is given to us by another server. +/// +/// 1. Is a valid event, otherwise it is dropped (PduEvent deserialization satisfies this). +/// 2. Passes signature checks, otherwise event is dropped. +/// 3. Passes hash checks, otherwise it is redacted before being processed further. +/// 4. Passes auth_chain collection (we can gather the events that auth this event recursively). +/// 5. Once the event has passed all checks it can be added as an outlier to the DB. +fn validate_event<'a>( + db: &'a Database, value: CanonicalJsonObject, event_id: EventId, - pub_key_map: &PublicKeyMap, - server_name: &ServerName, - auth_cache: EventMap>, -) -> std::result::Result { - // Ruma/PduEvent satisfies - 1. Is a valid event, otherwise it is dropped. - // 2. Passes signature checks, otherwise event is dropped. - // 3. Passes hash checks, otherwise it is redacted before being processed further. - let mut val = signature_and_hash_check(&pub_key_map, value)?; + pub_key_map: &'a PublicKeyMap, + origin: &'a ServerName, + auth_cache: &'a mut EventMap>, +) -> Pin> + 'a + Send>> { + Box::pin(async move { + let mut val = signature_and_hash_check(&pub_key_map, value)?; - // Now that we have checked the signature and hashes we can add the eventID and convert - // to our PduEvent type also finally verifying the first step listed above - val.insert( - "event_id".to_owned(), - to_canonical_value(&event_id).expect("EventId is a valid CanonicalJsonValue"), - ); - let pdu = serde_json::from_value::( - serde_json::to_value(val).expect("CanonicalJsonObj is a valid JsonValue"), - ) - .map_err(|_| "Event is not a valid PDU".to_string())?; + // Now that we have checked the signature and hashes we can add the eventID and convert + // to our PduEvent type also finally verifying the first step listed above + val.insert( + "event_id".to_owned(), + to_canonical_value(&event_id).expect("EventId is a valid CanonicalJsonValue"), + ); + let pdu = serde_json::from_value::( + serde_json::to_value(val).expect("CanonicalJsonObj is a valid JsonValue"), + ) + .map_err(|_| "Event is not a valid PDU".to_string())?; - // If we have no idea about this room skip the PDU - if !db.rooms.exists(&pdu.room_id).map_err(|e| e.to_string())? { - return Err("Room is unknown to this server".into()); - } + fetch_check_auth_events(db, origin, pub_key_map, &pdu.auth_events, auth_cache) + .await + .map_err(|_| "Event failed auth chain check".to_string())?; - // Fetch any unknown prev_events or retrieve them from the DB - let previous = match fetch_events(&db, server_name, &pub_key_map, &pdu.prev_events).await { - Ok(mut evs) if evs.len() == 1 => Some(Arc::new(evs.remove(0))), - _ => None, - }; + db.rooms + .append_pdu_outlier(pdu.event_id(), &pdu) + .map_err(|e| e.to_string())?; - // 4. Passes authorization rules based on the event's auth events, otherwise it is rejected. - // Recursively gather all auth events checking that the previous auth events are valid. - let auth_events: Vec = - match fetch_check_auth_events(&db, server_name, &pub_key_map, &pdu.prev_events).await { - Ok(events) => events, - Err(_) => return Err("Failed to recursively gather auth events".into()), - }; - - Ok(pdu) + Ok(pdu) + }) } +/// Find the event and auth it. +/// +/// 1. Look in the main timeline (pduid_pdu tree) +/// 2. Look at outlier pdu tree +/// 3. Ask origin server over federation +/// 4. TODO: Ask other servers over federation? +async fn fetch_events( + db: &Database, + origin: &ServerName, + key_map: &PublicKeyMap, + events: &[EventId], + auth_cache: &mut EventMap>, +) -> Result>> { + let mut pdus = vec![]; + for id in events { + let pdu = match db.rooms.get_pdu(&id)? { + Some(pdu) => Arc::new(pdu), + None => match db.rooms.get_pdu_outlier(&id)? { + Some(pdu) => Arc::new(pdu), + None => match db + .sending + .send_federation_request( + &db.globals, + origin, + get_event::v1::Request { event_id: &id }, + ) + .await + { + Ok(res) => { + let (event_id, value) = crate::pdu::gen_event_id_canonical_json(&res.pdu); + let pdu = validate_event(db, value, event_id, key_map, origin, auth_cache) + .await + .map_err(|_| Error::Conflict("Authentication of event failed"))?; + + Arc::new(pdu) + } + Err(_) => return Err(Error::BadServerResponse("Failed to fetch event")), + }, + }, + }; + pdus.push(pdu); + } + Ok(pdus) +} + +/// The check in `fetch_check_auth_events` is that a complete chain is found for the +/// events `auth_events`. If the chain is found to have any missing events it fails. +/// +/// The `auth_cache` is filled instead of returning a `Vec`. +async fn fetch_check_auth_events( + db: &Database, + origin: &ServerName, + key_map: &PublicKeyMap, + event_ids: &[EventId], + auth_cache: &mut EventMap>, +) -> Result<()> { + let mut stack = event_ids.to_vec(); + + // DFS for auth event chain + while !stack.is_empty() { + let ev_id = stack.pop().unwrap(); + if auth_cache.contains_key(&ev_id) { + continue; + } + + let ev = fetch_events(db, origin, key_map, &[ev_id.clone()], auth_cache) + .await + .map(|mut vec| vec.remove(0))?; + + stack.extend(ev.auth_events()); + auth_cache.insert(ev.event_id().clone(), ev); + } + Ok(()) +} + +/// Search the DB for the signing keys of the given server, if we don't have them +/// fetch them from the server and save to our DB. +async fn fetch_signing_keys( + db: &Database, + origin: &ServerName, +) -> Result> { + match db.globals.signing_keys_for(origin)? { + keys if !keys.is_empty() => Ok(keys), + _ => { + let keys = db + .sending + .send_federation_request(&db.globals, origin, get_server_keys::v2::Request::new()) + .await + .map_err(|_| Error::BadServerResponse("Failed to request server keys"))?; + db.globals.add_signing_key(origin, &keys.server_key)?; + Ok(keys.server_key.verify_keys) + } + } +} fn signature_and_hash_check( pub_key_map: &ruma::signatures::PublicKeyMap, value: CanonicalJsonObject, @@ -954,122 +1073,29 @@ fn signature_and_hash_check( ) } -/// The check in `fetch_check_auth_events` is that a complete chain is found for the -/// events `auth_events`. If the chain is found to have missing events it fails. -async fn fetch_check_auth_events( - db: &Database, - origin: &ServerName, - key_map: &PublicKeyMap, - event_ids: &[EventId], -) -> Result> { - let mut result = BTreeMap::new(); - let mut stack = event_ids.to_vec(); +fn forward_extremity_ids(db: &Database, pdu: &PduEvent) -> Result>>> { + let mut fork_states = vec![]; + for id in &db.rooms.get_pdu_leaves(pdu.room_id())? { + if let Some(id) = db.rooms.get_pdu_id(id)? { + let state_hash = db + .rooms + .pdu_state_hash(&id)? + .expect("found pdu with no statehash"); + let state = db + .rooms + .state_full(&pdu.room_id, &state_hash)? + .into_iter() + .map(|(k, v)| ((k.0, Some(k.1)), Arc::new(v))) + .collect(); - // DFS for auth event chain - while !stack.is_empty() { - let ev_id = stack.pop().unwrap(); - if result.contains_key(&ev_id) { - continue; - } - - let ev = match db.rooms.get_pdu(&ev_id)? { - Some(pdu) => pdu, - None => match db - .sending - .send_federation_request( - &db.globals, - origin, - get_event::v1::Request { event_id: &ev_id }, - ) - .await - { - Ok(res) => { - let (event_id, value) = crate::pdu::gen_event_id_canonical_json(&res.pdu); - match signature_and_hash_check(key_map, value) { - Ok(mut val) => { - val.insert( - "event_id".to_owned(), - to_canonical_value(&event_id) - .expect("EventId is a valid CanonicalJsonValue"), - ); - serde_json::from_value::( - serde_json::to_value(val) - .expect("CanonicalJsonObj is a valid JsonValue"), - ) - .expect("Pdu is valid Canonical JSON Map") - } - Err(e) => { - // TODO: I would assume we just keep going - error!("{:?}", e); - continue; - } - } - } - Err(_) => return Err(Error::BadServerResponse("Failed to fetch event")), - }, - }; - stack.extend(ev.auth_events()); - result.insert(ev.event_id().clone(), ev); - } - - Ok(result.into_iter().map(|(_, v)| v).collect()) -} - -/// TODO: this needs to add events to the DB in a way that does not -/// effect the state of the room -async fn fetch_events( - db: &Database, - origin: &ServerName, - key_map: &PublicKeyMap, - events: &[EventId], -) -> Result> { - let mut pdus = vec![]; - for id in events { - match db.rooms.get_pdu(id)? { - Some(pdu) => pdus.push(pdu), - None => match db - .sending - .send_federation_request( - &db.globals, - origin, - get_event::v1::Request { event_id: id }, - ) - .await - { - Ok(res) => { - let (event_id, value) = crate::pdu::gen_event_id_canonical_json(&res.pdu); - match signature_and_hash_check(key_map, value) { - Ok(mut val) => { - // TODO: add to our DB somehow? - val.insert( - "event_id".to_owned(), - to_canonical_value(&event_id) - .expect("EventId is a valid CanonicalJsonValue"), - ); - let pdu = serde_json::from_value::( - serde_json::to_value(val) - .expect("CanonicalJsonObj is a valid JsonValue"), - ) - .expect("Pdu is valid Canonical JSON Map"); - - pdus.push(pdu); - } - Err(e) => { - // TODO: I would assume we just keep going - error!("{:?}", e); - continue; - } - } - } - Err(_) => return Err(Error::BadServerResponse("Failed to fetch event")), - }, + fork_states.push(state); + } else { + return Err(Error::Conflict( + "we don't know of a pdu that is part of our known forks OOPS", + )); } } - Ok(pdus) -} - -fn forward_extremity_ids(db: &Database, room_id: &RoomId) -> Result> { - db.rooms.get_pdu_leaves(room_id) + Ok(fork_states) } fn append_state(db: &Database, pdu: &PduEvent) -> Result<()> { @@ -1078,9 +1104,12 @@ fn append_state(db: &Database, pdu: &PduEvent) -> Result<()> { pdu_id.push(0xff); pdu_id.extend_from_slice(&count.to_be_bytes()); - db.rooms.append_to_state(&pdu_id, pdu, &db.globals)?; + // We append to state before appending the pdu, so we don't have a moment in time with the + // pdu without it's state. This is okay because append_pdu can't fail. + let statehashid = db.rooms.append_to_state(&pdu_id, &pdu, &db.globals)?; + db.rooms.append_pdu( - pdu, + &pdu, utils::to_canonical_object(pdu).expect("Pdu is valid canonical object"), count, pdu_id.clone().into(), @@ -1089,78 +1118,9 @@ fn append_state(db: &Database, pdu: &PduEvent) -> Result<()> { &db.admin, )?; - for appservice in db.appservice.iter_all().filter_map(|r| r.ok()) { - db.sending.send_pdu_appservice(&appservice.0, &pdu_id)?; - } - - Ok(()) -} - -/// TODO: This should not write to the current room state (roomid_statehash) -fn append_state_soft(db: &Database, pdu: &PduEvent) -> Result<()> { - let count = db.globals.next_count()?; - let mut pdu_id = pdu.room_id.as_bytes().to_vec(); - pdu_id.push(0xff); - pdu_id.extend_from_slice(&count.to_be_bytes()); - - // db.rooms.append_pdu( - // pdu, - // &utils::to_canonical_object(pdu).expect("Pdu is valid canonical object"), - // count, - // pdu_id.clone().into(), - // &db.globals, - // &db.account_data, - // &db.admin, - // )?; - - Ok(()) -} - -fn forward_extremity_ids(db: &Database, room_id: &RoomId) -> Result> { - todo!() -} - -fn append_state(db: &Database, pdu: &PduEvent) -> Result<()> { - let count = db.globals.next_count()?; - let mut pdu_id = pdu.room_id.as_bytes().to_vec(); - pdu_id.push(0xff); - pdu_id.extend_from_slice(&count.to_be_bytes()); - - db.rooms.append_to_state(&pdu_id, pdu, &db.globals)?; - db.rooms.append_pdu( - pdu, - &utils::to_canonical_object(pdu).expect("Pdu is valid canonical object"), - count, - pdu_id.clone().into(), - &db.globals, - &db.account_data, - &db.admin, - )?; - - for appservice in db.appservice.iter_all().filter_map(|r| r.ok()) { - db.sending.send_pdu_appservice(&appservice.0, &pdu_id)?; - } - - Ok(()) -} - -/// TODO: This should not write to the current room state (roomid_statehash) -fn append_state_soft(db: &Database, pdu: &PduEvent) -> Result<()> { - let count = db.globals.next_count()?; - let mut pdu_id = pdu.room_id.as_bytes().to_vec(); - pdu_id.push(0xff); - pdu_id.extend_from_slice(&count.to_be_bytes()); - - db.rooms.append_to_state(&pdu_id, pdu, &db.globals)?; - db.rooms.append_pdu( - pdu, - &utils::to_canonical_object(pdu).expect("Pdu is valid canonical object"), - count, - pdu_id.clone().into(), - &db.globals, - &db.account_data, - &db.admin, - )?; + // We set the room state after inserting the pdu, so that we never have a moment in time + // where events in the current room state do not exist + db.rooms.set_room_state(&pdu.room_id, &statehashid)?; for appservice in db.appservice.iter_all().filter_map(|r| r.ok()) { db.sending.send_pdu_appservice(&appservice.0, &pdu_id)?;