diff --git a/src/database/rooms.rs b/src/database/rooms.rs index 0f42235..c53fa9e 100644 --- a/src/database/rooms.rs +++ b/src/database/rooms.rs @@ -88,7 +88,7 @@ pub struct Rooms { pub(super) referencedevents: Arc, pub(super) pdu_cache: Mutex>>, - pub(super) auth_chain_cache: Mutex, HashSet>>, + pub(super) auth_chain_cache: Mutex>>, } impl Rooms { @@ -315,19 +315,7 @@ impl Rooms { ); let (shortstatehash, already_existed) = - match self.statehash_shortstatehash.get(&state_hash)? { - Some(shortstatehash) => ( - utils::u64_from_bytes(&shortstatehash) - .map_err(|_| Error::bad_database("Invalid shortstatehash in db."))?, - true, - ), - None => { - let shortstatehash = db.globals.next_count()?; - self.statehash_shortstatehash - .insert(&state_hash, &shortstatehash.to_be_bytes())?; - (shortstatehash, false) - } - }; + self.get_or_create_shortstatehash(&state_hash, &db.globals)?; let new_state = if !already_existed { let mut new_state = HashSet::new(); @@ -352,25 +340,14 @@ impl Rooms { } }; - let shorteventid = - match self.eventid_shorteventid.get(eventid.as_bytes()).ok()? { - Some(shorteventid) => shorteventid.to_vec(), - None => { - let shorteventid = db.globals.next_count().ok()?; - self.eventid_shorteventid - .insert(eventid.as_bytes(), &shorteventid.to_be_bytes()) - .ok()?; - self.shorteventid_eventid - .insert(&shorteventid.to_be_bytes(), eventid.as_bytes()) - .ok()?; - shorteventid.to_be_bytes().to_vec() - } - }; + let shorteventid = self + .get_or_create_shorteventid(&eventid, &db.globals) + .ok()?; let mut state_id = shortstatehash.to_be_bytes().to_vec(); state_id.extend_from_slice(&shortstatekey); - Some((state_id, shorteventid)) + Some((state_id, shorteventid.to_be_bytes().to_vec())) }) .collect::>(); @@ -428,6 +405,61 @@ impl Rooms { Ok(()) } + /// Returns (shortstatehash, already_existed) + fn get_or_create_shortstatehash( + &self, + state_hash: &StateHashId, + globals: &super::globals::Globals, + ) -> Result<(u64, bool)> { + Ok(match self.statehash_shortstatehash.get(&state_hash)? { + Some(shortstatehash) => ( + utils::u64_from_bytes(&shortstatehash) + .map_err(|_| Error::bad_database("Invalid shortstatehash in db."))?, + true, + ), + None => { + let shortstatehash = globals.next_count()?; + self.statehash_shortstatehash + .insert(&state_hash, &shortstatehash.to_be_bytes())?; + (shortstatehash, false) + } + }) + } + + /// Returns (shortstatehash, already_existed) + pub fn get_or_create_shorteventid( + &self, + event_id: &EventId, + globals: &super::globals::Globals, + ) -> Result { + Ok(match self.eventid_shorteventid.get(event_id.as_bytes())? { + Some(shorteventid) => utils::u64_from_bytes(&shorteventid) + .map_err(|_| Error::bad_database("Invalid shorteventid in db."))?, + None => { + let shorteventid = globals.next_count()?; + self.eventid_shorteventid + .insert(event_id.as_bytes(), &shorteventid.to_be_bytes())?; + self.shorteventid_eventid + .insert(&shorteventid.to_be_bytes(), event_id.as_bytes())?; + shorteventid + } + }) + } + + pub fn get_eventid_from_short(&self, shorteventid: u64) -> Result { + let bytes = self + .shorteventid_eventid + .get(&shorteventid.to_be_bytes())? + .ok_or_else(|| Error::bad_database("Shorteventid does not exist"))?; + + EventId::try_from( + utils::string_from_bytes(&bytes).map_err(|_| { + Error::bad_database("EventID in roomid_pduleaves is invalid unicode.") + })?, + ) + .map_err(|_| Error::bad_database("EventId in roomid_pduleaves is invalid.")) + } + /// Returns the full room state. #[tracing::instrument(skip(self))] pub fn room_state_full( @@ -1116,17 +1148,7 @@ impl Rooms { state: &StateMap>, globals: &super::globals::Globals, ) -> Result<()> { - let shorteventid = match self.eventid_shorteventid.get(event_id.as_bytes())? { - Some(shorteventid) => shorteventid.to_vec(), - None => { - let shorteventid = globals.next_count()?; - self.eventid_shorteventid - .insert(event_id.as_bytes(), &shorteventid.to_be_bytes())?; - self.shorteventid_eventid - .insert(&shorteventid.to_be_bytes(), event_id.as_bytes())?; - shorteventid.to_be_bytes().to_vec() - } - }; + let shorteventid = self.get_or_create_shorteventid(&event_id, globals)?; let state_hash = self.calculate_hash( &state @@ -1135,69 +1157,45 @@ impl Rooms { .collect::>(), ); - let shortstatehash = match self.statehash_shortstatehash.get(&state_hash)? { - Some(shortstatehash) => { - // State already existed in db - self.shorteventid_shortstatehash - .insert(&shorteventid, &*shortstatehash)?; - return Ok(()); - } - None => { - let shortstatehash = globals.next_count()?; - self.statehash_shortstatehash - .insert(&state_hash, &shortstatehash.to_be_bytes())?; - shortstatehash.to_be_bytes().to_vec() - } - }; + let (shortstatehash, already_existed) = + self.get_or_create_shortstatehash(&state_hash, globals)?; - let batch = state - .iter() - .filter_map(|((event_type, state_key), pdu)| { - let mut statekey = event_type.as_ref().as_bytes().to_vec(); - statekey.push(0xff); - statekey.extend_from_slice(&state_key.as_bytes()); + if !already_existed { + let batch = state + .iter() + .filter_map(|((event_type, state_key), pdu)| { + let mut statekey = event_type.as_ref().as_bytes().to_vec(); + statekey.push(0xff); + statekey.extend_from_slice(&state_key.as_bytes()); - let shortstatekey = match self.statekey_shortstatekey.get(&statekey).ok()? { - Some(shortstatekey) => shortstatekey.to_vec(), - None => { - let shortstatekey = globals.next_count().ok()?; - self.statekey_shortstatekey - .insert(&statekey, &shortstatekey.to_be_bytes()) - .ok()?; - shortstatekey.to_be_bytes().to_vec() - } - }; + let shortstatekey = match self.statekey_shortstatekey.get(&statekey).ok()? { + Some(shortstatekey) => shortstatekey.to_vec(), + None => { + let shortstatekey = globals.next_count().ok()?; + self.statekey_shortstatekey + .insert(&statekey, &shortstatekey.to_be_bytes()) + .ok()?; + shortstatekey.to_be_bytes().to_vec() + } + }; - let shorteventid = match self - .eventid_shorteventid - .get(pdu.event_id.as_bytes()) - .ok()? - { - Some(shorteventid) => shorteventid.to_vec(), - None => { - let shorteventid = globals.next_count().ok()?; - self.eventid_shorteventid - .insert(pdu.event_id.as_bytes(), &shorteventid.to_be_bytes()) - .ok()?; - self.shorteventid_eventid - .insert(&shorteventid.to_be_bytes(), pdu.event_id.as_bytes()) - .ok()?; - shorteventid.to_be_bytes().to_vec() - } - }; + let shorteventid = self + .get_or_create_shorteventid(&pdu.event_id, globals) + .ok()?; - let mut state_id = shortstatehash.clone(); - state_id.extend_from_slice(&shortstatekey); + let mut state_id = shortstatehash.to_be_bytes().to_vec(); + state_id.extend_from_slice(&shortstatekey); - Some((state_id, shorteventid)) - }) - .collect::>(); + Some((state_id, shorteventid.to_be_bytes().to_vec())) + }) + .collect::>(); - self.stateid_shorteventid - .insert_batch(&mut batch.into_iter())?; + self.stateid_shorteventid + .insert_batch(&mut batch.into_iter())?; + } self.shorteventid_shortstatehash - .insert(&shorteventid, &*shortstatehash)?; + .insert(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes())?; Ok(()) } @@ -1212,26 +1210,16 @@ impl Rooms { new_pdu: &PduEvent, globals: &super::globals::Globals, ) -> Result { + let shorteventid = self.get_or_create_shorteventid(&new_pdu.event_id, globals)?; + let old_state = if let Some(old_shortstatehash) = self.roomid_shortstatehash.get(new_pdu.room_id.as_bytes())? { // Store state for event. The state does not include the event itself. // Instead it's the state before the pdu, so the room's old state. - - let shorteventid = match self.eventid_shorteventid.get(new_pdu.event_id.as_bytes())? { - Some(shorteventid) => shorteventid.to_vec(), - None => { - let shorteventid = globals.next_count()?; - self.eventid_shorteventid - .insert(new_pdu.event_id.as_bytes(), &shorteventid.to_be_bytes())?; - self.shorteventid_eventid - .insert(&shorteventid.to_be_bytes(), new_pdu.event_id.as_bytes())?; - shorteventid.to_be_bytes().to_vec() - } - }; - self.shorteventid_shortstatehash - .insert(&shorteventid, &old_shortstatehash)?; + .insert(&shorteventid.to_be_bytes(), &old_shortstatehash)?; + if new_pdu.state_key.is_none() { return utils::u64_from_bytes(&old_shortstatehash).map_err(|_| { Error::bad_database("Invalid shortstatehash in roomid_shortstatehash.") @@ -1264,19 +1252,7 @@ impl Rooms { } }; - let shorteventid = match self.eventid_shorteventid.get(new_pdu.event_id.as_bytes())? { - Some(shorteventid) => shorteventid.to_vec(), - None => { - let shorteventid = globals.next_count()?; - self.eventid_shorteventid - .insert(new_pdu.event_id.as_bytes(), &shorteventid.to_be_bytes())?; - self.shorteventid_eventid - .insert(&shorteventid.to_be_bytes(), new_pdu.event_id.as_bytes())?; - shorteventid.to_be_bytes().to_vec() - } - }; - - new_state.insert(shortstatekey, shorteventid); + new_state.insert(shortstatekey, shorteventid.to_be_bytes().to_vec()); let new_state_hash = self.calculate_hash( &new_state @@ -1516,11 +1492,7 @@ impl Rooms { ); // Generate short event id - let shorteventid = db.globals.next_count()?; - self.eventid_shorteventid - .insert(pdu.event_id.as_bytes(), &shorteventid.to_be_bytes())?; - self.shorteventid_eventid - .insert(&shorteventid.to_be_bytes(), pdu.event_id.as_bytes())?; + let _shorteventid = self.get_or_create_shorteventid(&pdu.event_id, &db.globals)?; // We append to state before appending the pdu, so we don't have a moment in time with the // pdu without it's state. This is okay because append_pdu can't fail. @@ -2655,9 +2627,7 @@ impl Rooms { } #[tracing::instrument(skip(self))] - pub fn auth_chain_cache( - &self, - ) -> std::sync::MutexGuard<'_, LruCache, HashSet>> { + pub fn auth_chain_cache(&self) -> std::sync::MutexGuard<'_, LruCache>> { self.auth_chain_cache.lock().unwrap() } } diff --git a/src/server_server.rs b/src/server_server.rs index 68adcd0..23c80ee 100644 --- a/src/server_server.rs +++ b/src/server_server.rs @@ -1044,13 +1044,16 @@ pub fn handle_incoming_pdu<'a>( if incoming_pdu.prev_events.len() == 1 { let prev_event = &incoming_pdu.prev_events[0]; - let state = db + let prev_event_sstatehash = db .rooms .pdu_shortstatehash(prev_event) - .map_err(|_| "Failed talking to db".to_owned())? - .map(|shortstatehash| db.rooms.state_full_ids(shortstatehash).ok()) - .flatten(); - if let Some(state) = state { + .map_err(|_| "Failed talking to db".to_owned())?; + + let state = + prev_event_sstatehash.map(|shortstatehash| db.rooms.state_full_ids(shortstatehash)); + + if let Some(Ok(state)) = state { + warn!("Using cached state"); let mut state = fetch_and_handle_events( db, origin, @@ -1088,6 +1091,7 @@ pub fn handle_incoming_pdu<'a>( } if state_at_incoming_event.is_none() { + warn!("Calling /state_ids"); // Call /state_ids to find out what the state at this pdu is. We trust the server's // response to some extend, but we still do a lot of checks on the events match db @@ -1755,35 +1759,50 @@ fn append_incoming_pdu( fn get_auth_chain(starting_events: Vec, db: &Database) -> Result> { let mut full_auth_chain = HashSet::new(); + let starting_events = starting_events + .iter() + .map(|id| { + (db.rooms + .get_or_create_shorteventid(id, &db.globals) + .map(|s| (s, id))) + }) + .collect::>>()?; + let mut cache = db.rooms.auth_chain_cache(); - for event_id in &starting_events { - if let Some(cached) = cache.get_mut(&[event_id.clone()][..]) { + for (sevent_id, event_id) in starting_events { + if let Some(cached) = cache.get_mut(&sevent_id) { full_auth_chain.extend(cached.iter().cloned()); } else { drop(cache); let mut auth_chain = HashSet::new(); get_auth_chain_recursive(&event_id, &mut auth_chain, db)?; cache = db.rooms.auth_chain_cache(); - cache.insert(vec![event_id.clone()], auth_chain.clone()); + cache.insert(sevent_id, auth_chain.clone()); full_auth_chain.extend(auth_chain); }; } - Ok(full_auth_chain) + full_auth_chain + .into_iter() + .map(|sid| db.rooms.get_eventid_from_short(sid)) + .collect() } fn get_auth_chain_recursive( event_id: &EventId, - found: &mut HashSet, + found: &mut HashSet, db: &Database, ) -> Result<()> { let r = db.rooms.get_pdu(&event_id); match r { Ok(Some(pdu)) => { for auth_event in &pdu.auth_events { - if !found.contains(auth_event) { - found.insert(auth_event.clone()); + let sauthevent = db + .rooms + .get_or_create_shorteventid(auth_event, &db.globals)?; + if !found.contains(&sauthevent) { + found.insert(sauthevent); get_auth_chain_recursive(&auth_event, found, db)?; } }