From 3eabaa2a95645378b130134220f264a23e0fd7ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timo=20K=C3=B6sters?= Date: Thu, 12 Aug 2021 23:04:00 +0200 Subject: [PATCH] finish implementing better state store --- src/client_server/account.rs | 2 + src/client_server/context.rs | 4 +- src/client_server/membership.rs | 2 + src/client_server/message.rs | 4 +- src/client_server/room.rs | 5 +- src/client_server/sync.rs | 16 +- src/database.rs | 373 +++++----------- src/database/abstraction/sqlite.rs | 72 +-- src/database/rooms.rs | 691 +++++++++++++++++++---------- src/server_server.rs | 2 +- 10 files changed, 645 insertions(+), 526 deletions(-) diff --git a/src/client_server/account.rs b/src/client_server/account.rs index 48159c9..d4f103c 100644 --- a/src/client_server/account.rs +++ b/src/client_server/account.rs @@ -249,6 +249,8 @@ pub async fn register_route( let room_id = RoomId::new(db.globals.server_name()); + db.rooms.get_or_create_shortroomid(&room_id, &db.globals)?; + let mutex_state = Arc::clone( db.globals .roomid_mutex_state diff --git a/src/client_server/context.rs b/src/client_server/context.rs index dbc121e..701e584 100644 --- a/src/client_server/context.rs +++ b/src/client_server/context.rs @@ -44,7 +44,7 @@ pub async fn get_context_route( let events_before = db .rooms - .pdus_until(&sender_user, &body.room_id, base_token) + .pdus_until(&sender_user, &body.room_id, base_token)? .take( u32::try_from(body.limit).map_err(|_| { Error::BadRequest(ErrorKind::InvalidParam, "Limit value is invalid.") @@ -66,7 +66,7 @@ pub async fn get_context_route( let events_after = db .rooms - .pdus_after(&sender_user, &body.room_id, base_token) + .pdus_after(&sender_user, &body.room_id, base_token)? .take( u32::try_from(body.limit).map_err(|_| { Error::BadRequest(ErrorKind::InvalidParam, "Limit value is invalid.") diff --git a/src/client_server/membership.rs b/src/client_server/membership.rs index 716a615..de6fa5a 100644 --- a/src/client_server/membership.rs +++ b/src/client_server/membership.rs @@ -609,6 +609,8 @@ async fn join_room_by_id_helper( ) .await?; + db.rooms.get_or_create_shortroomid(&room_id, &db.globals)?; + let pdu = PduEvent::from_id_val(&event_id, join_event.clone()) .map_err(|_| Error::BadServerResponse("Invalid join event PDU."))?; diff --git a/src/client_server/message.rs b/src/client_server/message.rs index 9cb6faa..70cc00f 100644 --- a/src/client_server/message.rs +++ b/src/client_server/message.rs @@ -128,7 +128,7 @@ pub async fn get_message_events_route( get_message_events::Direction::Forward => { let events_after = db .rooms - .pdus_after(&sender_user, &body.room_id, from) + .pdus_after(&sender_user, &body.room_id, from)? .take(limit) .filter_map(|r| r.ok()) // Filter out buggy events .filter_map(|(pdu_id, pdu)| { @@ -158,7 +158,7 @@ pub async fn get_message_events_route( get_message_events::Direction::Backward => { let events_before = db .rooms - .pdus_until(&sender_user, &body.room_id, from) + .pdus_until(&sender_user, &body.room_id, from)? .take(limit) .filter_map(|r| r.ok()) // Filter out buggy events .filter_map(|(pdu_id, pdu)| { diff --git a/src/client_server/room.rs b/src/client_server/room.rs index 89241f5..c323be4 100644 --- a/src/client_server/room.rs +++ b/src/client_server/room.rs @@ -33,6 +33,8 @@ pub async fn create_room_route( let room_id = RoomId::new(db.globals.server_name()); + db.rooms.get_or_create_shortroomid(&room_id, &db.globals)?; + let mutex_state = Arc::clone( db.globals .roomid_mutex_state @@ -173,7 +175,6 @@ pub async fn create_room_route( )?; // 4. Canonical room alias - if let Some(room_alias_id) = &alias { db.rooms.build_and_append_pdu( PduBuilder { @@ -193,7 +194,7 @@ pub async fn create_room_route( &room_id, &db, &state_lock, - ); + )?; } // 5. Events set by preset diff --git a/src/client_server/sync.rs b/src/client_server/sync.rs index 937a252..c196b2a 100644 --- a/src/client_server/sync.rs +++ b/src/client_server/sync.rs @@ -205,7 +205,7 @@ async fn sync_helper( let mut non_timeline_pdus = db .rooms - .pdus_until(&sender_user, &room_id, u64::MAX) + .pdus_until(&sender_user, &room_id, u64::MAX)? .filter_map(|r| { // Filter out buggy events if r.is_err() { @@ -248,13 +248,13 @@ async fn sync_helper( let first_pdu_before_since = db .rooms - .pdus_until(&sender_user, &room_id, since) + .pdus_until(&sender_user, &room_id, since)? .next() .transpose()?; let pdus_after_since = db .rooms - .pdus_after(&sender_user, &room_id, since) + .pdus_after(&sender_user, &room_id, since)? .next() .is_some(); @@ -286,7 +286,7 @@ async fn sync_helper( for hero in db .rooms - .all_pdus(&sender_user, &room_id) + .all_pdus(&sender_user, &room_id)? .filter_map(|pdu| pdu.ok()) // Ignore all broken pdus .filter(|(_, pdu)| pdu.kind == EventType::RoomMember) .map(|(_, pdu)| { @@ -328,11 +328,11 @@ async fn sync_helper( } } - ( + Ok::<_, Error>(( Some(joined_member_count), Some(invited_member_count), heroes, - ) + )) }; let ( @@ -343,7 +343,7 @@ async fn sync_helper( state_events, ) = if since_shortstatehash.is_none() { // Probably since = 0, we will do an initial sync - let (joined_member_count, invited_member_count, heroes) = calculate_counts(); + let (joined_member_count, invited_member_count, heroes) = calculate_counts()?; let current_state_ids = db.rooms.state_full_ids(current_shortstatehash)?; let state_events = current_state_ids @@ -510,7 +510,7 @@ async fn sync_helper( } let (joined_member_count, invited_member_count, heroes) = if send_member_count { - calculate_counts() + calculate_counts()? } else { (None, None, Vec::new()) }; diff --git a/src/database.rs b/src/database.rs index 2e471d8..4e34019 100644 --- a/src/database.rs +++ b/src/database.rs @@ -28,7 +28,7 @@ use ruma::{DeviceId, EventId, RoomId, ServerName, UserId}; use serde::{de::IgnoredAny, Deserialize}; use std::{ collections::{BTreeMap, HashMap, HashSet}, - convert::TryFrom, + convert::{TryFrom, TryInto}, fs::{self, remove_dir_all}, io::Write, mem::size_of, @@ -266,7 +266,6 @@ impl Database { shortroomid_roomid: builder.open_tree("shortroomid_roomid")?, roomid_shortroomid: builder.open_tree("roomid_shortroomid")?, - stateid_shorteventid: builder.open_tree("stateid_shorteventid")?, shortstatehash_statediff: builder.open_tree("shortstatehash_statediff")?, eventid_shorteventid: builder.open_tree("eventid_shorteventid")?, shorteventid_eventid: builder.open_tree("shorteventid_eventid")?, @@ -431,7 +430,6 @@ impl Database { } if db.globals.database_version()? < 6 { - // TODO update to 6 // Set room member count for (roomid, _) in db.rooms.roomid_shortstatehash.iter() { let room_id = @@ -445,263 +443,98 @@ impl Database { println!("Migration: 5 -> 6 finished"); } - fn load_shortstatehash_info( - shortstatehash: &[u8], - db: &Database, - lru: &mut LruCache< - Vec, - Vec<( - Vec, - HashSet>, - HashSet>, - HashSet>, - )>, - >, - ) -> Result< - Vec<( - Vec, // sstatehash - HashSet>, // full state - HashSet>, // added - HashSet>, // removed - )>, - > { - if let Some(result) = lru.get_mut(shortstatehash) { - return Ok(result.clone()); - } - - let value = db - .rooms - .shortstatehash_statediff - .get(shortstatehash)? - .ok_or_else(|| Error::bad_database("State hash does not exist"))?; - let parent = value[0..size_of::()].to_vec(); - - let mut add_mode = true; - let mut added = HashSet::new(); - let mut removed = HashSet::new(); - - let mut i = size_of::(); - while let Some(v) = value.get(i..i + 2 * size_of::()) { - if add_mode && v.starts_with(&0_u64.to_be_bytes()) { - add_mode = false; - i += size_of::(); - continue; - } - if add_mode { - added.insert(v.to_vec()); - } else { - removed.insert(v.to_vec()); - } - i += 2 * size_of::(); - } - - if parent != 0_u64.to_be_bytes() { - let mut response = load_shortstatehash_info(&parent, db, lru)?; - let mut state = response.last().unwrap().1.clone(); - state.extend(added.iter().cloned()); - for r in &removed { - state.remove(r); - } - - response.push((shortstatehash.to_vec(), state, added, removed)); - - lru.insert(shortstatehash.to_vec(), response.clone()); - Ok(response) - } else { - let mut response = Vec::new(); - response.push((shortstatehash.to_vec(), added.clone(), added, removed)); - lru.insert(shortstatehash.to_vec(), response.clone()); - Ok(response) - } - } - - fn update_shortstatehash_level( - current_shortstatehash: &[u8], - statediffnew: HashSet>, - statediffremoved: HashSet>, - diff_to_sibling: usize, - mut parent_states: Vec<( - Vec, // sstatehash - HashSet>, // full state - HashSet>, // added - HashSet>, // removed - )>, - db: &Database, - ) -> Result<()> { - let diffsum = statediffnew.len() + statediffremoved.len(); - - if parent_states.len() > 3 { - // Number of layers - // To many layers, we have to go deeper - let parent = parent_states.pop().unwrap(); - - let mut parent_new = parent.2; - let mut parent_removed = parent.3; - - for removed in statediffremoved { - if !parent_new.remove(&removed) { - parent_removed.insert(removed); - } - } - parent_new.extend(statediffnew); - - update_shortstatehash_level( - current_shortstatehash, - parent_new, - parent_removed, - diffsum, - parent_states, - db, - )?; - - return Ok(()); - } - - if parent_states.len() == 0 { - // There is no parent layer, create a new state - let mut value = 0_u64.to_be_bytes().to_vec(); // 0 means no parent - for new in &statediffnew { - value.extend_from_slice(&new); - } - - if !statediffremoved.is_empty() { - warn!("Tried to create new state with removals"); - } - - db.rooms - .shortstatehash_statediff - .insert(¤t_shortstatehash, &value)?; - - return Ok(()); - }; - - // Else we have two options. - // 1. We add the current diff on top of the parent layer. - // 2. We replace a layer above - - let parent = parent_states.pop().unwrap(); - let parent_diff = parent.2.len() + parent.3.len(); - - if diffsum * diffsum >= 2 * diff_to_sibling * parent_diff { - // Diff too big, we replace above layer(s) - let mut parent_new = parent.2; - let mut parent_removed = parent.3; - - for removed in statediffremoved { - if !parent_new.remove(&removed) { - parent_removed.insert(removed); - } - } - - parent_new.extend(statediffnew); - update_shortstatehash_level( - current_shortstatehash, - parent_new, - parent_removed, - diffsum, - parent_states, - db, - )?; - } else { - // Diff small enough, we add diff as layer on top of parent - let mut value = parent.0.clone(); - for new in &statediffnew { - value.extend_from_slice(&new); - } - - if !statediffremoved.is_empty() { - value.extend_from_slice(&0_u64.to_be_bytes()); - for removed in &statediffremoved { - value.extend_from_slice(&removed); - } - } - - db.rooms - .shortstatehash_statediff - .insert(¤t_shortstatehash, &value)?; - } - - Ok(()) - } - if db.globals.database_version()? < 7 { // Upgrade state store - let mut lru = LruCache::new(1000); - let mut last_roomstates: HashMap> = HashMap::new(); - let mut current_sstatehash: Vec = Vec::new(); + let mut last_roomstates: HashMap = HashMap::new(); + let mut current_sstatehash: Option = None; let mut current_room = None; let mut current_state = HashSet::new(); let mut counter = 0; + + let mut handle_state = + |current_sstatehash: u64, + current_room: &RoomId, + current_state: HashSet<_>, + last_roomstates: &mut HashMap<_, _>| { + counter += 1; + println!("counter: {}", counter); + let last_roomsstatehash = last_roomstates.get(current_room); + + let states_parents = last_roomsstatehash.map_or_else( + || Ok(Vec::new()), + |&last_roomsstatehash| { + db.rooms.load_shortstatehash_info(dbg!(last_roomsstatehash)) + }, + )?; + + let (statediffnew, statediffremoved) = + if let Some(parent_stateinfo) = states_parents.last() { + let statediffnew = current_state + .difference(&parent_stateinfo.1) + .cloned() + .collect::>(); + + let statediffremoved = parent_stateinfo + .1 + .difference(¤t_state) + .cloned() + .collect::>(); + + (statediffnew, statediffremoved) + } else { + (current_state, HashSet::new()) + }; + + db.rooms.save_state_from_diff( + dbg!(current_sstatehash), + statediffnew, + statediffremoved, + 2, // every state change is 2 event changes on average + states_parents, + )?; + + /* + let mut tmp = db.rooms.load_shortstatehash_info(¤t_sstatehash, &db)?; + let state = tmp.pop().unwrap(); + println!( + "{}\t{}{:?}: {:?} + {:?} - {:?}", + current_room, + " ".repeat(tmp.len()), + utils::u64_from_bytes(¤t_sstatehash).unwrap(), + tmp.last().map(|b| utils::u64_from_bytes(&b.0).unwrap()), + state + .2 + .iter() + .map(|b| utils::u64_from_bytes(&b[size_of::()..]).unwrap()) + .collect::>(), + state + .3 + .iter() + .map(|b| utils::u64_from_bytes(&b[size_of::()..]).unwrap()) + .collect::>() + ); + */ + + Ok::<_, Error>(()) + }; + for (k, seventid) in db._db.open_tree("stateid_shorteventid")?.iter() { - let sstatehash = k[0..size_of::()].to_vec(); + let sstatehash = utils::u64_from_bytes(&k[0..size_of::()]) + .expect("number of bytes is correct"); let sstatekey = k[size_of::()..].to_vec(); - if sstatehash != current_sstatehash { - if !current_sstatehash.is_empty() { - counter += 1; - println!("counter: {}", counter); - let current_room = current_room.as_ref().unwrap(); - let last_roomsstatehash = last_roomstates.get(¤t_room); - - let states_parents = last_roomsstatehash.map_or_else( - || Ok(Vec::new()), - |last_roomsstatehash| { - load_shortstatehash_info(&last_roomsstatehash, &db, &mut lru) - }, + if Some(sstatehash) != current_sstatehash { + if let Some(current_sstatehash) = current_sstatehash { + handle_state( + current_sstatehash, + current_room.as_ref().unwrap(), + current_state, + &mut last_roomstates, )?; - - let (statediffnew, statediffremoved) = - if let Some(parent_stateinfo) = states_parents.last() { - let statediffnew = current_state - .difference(&parent_stateinfo.1) - .cloned() - .collect::>(); - - let statediffremoved = parent_stateinfo - .1 - .difference(¤t_state) - .cloned() - .collect::>(); - - (statediffnew, statediffremoved) - } else { - (current_state, HashSet::new()) - }; - - update_shortstatehash_level( - ¤t_sstatehash, - statediffnew, - statediffremoved, - 2, // every state change is 2 event changes on average - states_parents, - &db, - )?; - - /* - let mut tmp = load_shortstatehash_info(¤t_sstatehash, &db)?; - let state = tmp.pop().unwrap(); - println!( - "{}\t{}{:?}: {:?} + {:?} - {:?}", - current_room, - " ".repeat(tmp.len()), - utils::u64_from_bytes(¤t_sstatehash).unwrap(), - tmp.last().map(|b| utils::u64_from_bytes(&b.0).unwrap()), - state - .2 - .iter() - .map(|b| utils::u64_from_bytes(&b[size_of::()..]).unwrap()) - .collect::>(), - state - .3 - .iter() - .map(|b| utils::u64_from_bytes(&b[size_of::()..]).unwrap()) - .collect::>() - ); - */ - - last_roomstates.insert(current_room.clone(), current_sstatehash); + last_roomstates + .insert(current_room.clone().unwrap(), current_sstatehash); } current_state = HashSet::new(); - current_sstatehash = sstatehash; + current_sstatehash = Some(sstatehash); let event_id = db .rooms @@ -721,7 +554,16 @@ impl Database { let mut val = sstatekey; val.extend_from_slice(&seventid); - current_state.insert(val); + current_state.insert(val.try_into().expect("size is correct")); + } + + if let Some(current_sstatehash) = current_sstatehash { + handle_state( + current_sstatehash, + current_room.as_ref().unwrap(), + current_state, + &mut last_roomstates, + )?; } db.globals.bump_database_version(7)?; @@ -761,11 +603,28 @@ impl Database { db.rooms.pduid_pdu.insert_batch(&mut batch)?; - for (key, _) in db.rooms.pduid_pdu.iter() { - if key.starts_with(b"!") { - db.rooms.pduid_pdu.remove(&key); + let mut batch2 = db.rooms.eventid_pduid.iter().filter_map(|(k, value)| { + if !value.starts_with(b"!") { + return None; } - } + let mut parts = value.splitn(2, |&b| b == 0xff); + let room_id = parts.next().unwrap(); + let count = parts.next().unwrap(); + + let short_room_id = db + .rooms + .roomid_shortroomid + .get(&room_id) + .unwrap() + .expect("shortroomid should exist"); + + let mut new_value = short_room_id; + new_value.extend_from_slice(count); + + Some((k, new_value)) + }); + + db.rooms.eventid_pduid.insert_batch(&mut batch2)?; db.globals.bump_database_version(8)?; @@ -803,7 +662,7 @@ impl Database { for (key, _) in db.rooms.tokenids.iter() { if key.starts_with(b"!") { - db.rooms.pduid_pdu.remove(&key)?; + db.rooms.tokenids.remove(&key)?; } } @@ -811,8 +670,6 @@ impl Database { println!("Migration: 8 -> 9 finished"); } - - panic!(); } let guard = db.read().await; diff --git a/src/database/abstraction/sqlite.rs b/src/database/abstraction/sqlite.rs index f420021..3c4ae9c 100644 --- a/src/database/abstraction/sqlite.rs +++ b/src/database/abstraction/sqlite.rs @@ -9,13 +9,13 @@ use std::{ path::{Path, PathBuf}, pin::Pin, sync::Arc, - time::{Duration, Instant}, }; use tokio::sync::oneshot::Sender; use tracing::debug; thread_local! { static READ_CONNECTION: RefCell> = RefCell::new(None); + static READ_CONNECTION_ITERATOR: RefCell> = RefCell::new(None); } struct PreparedStatementIterator<'a> { @@ -77,6 +77,21 @@ impl Engine { }) } + fn read_lock_iterator(&self) -> &'static Connection { + READ_CONNECTION_ITERATOR.with(|cell| { + let connection = &mut cell.borrow_mut(); + + if (*connection).is_none() { + let c = Box::leak(Box::new( + Self::prepare_conn(&self.path, self.cache_size_per_thread).unwrap(), + )); + **connection = Some(c); + } + + connection.unwrap() + }) + } + pub fn flush_wal(self: &Arc) -> Result<()> { self.write_lock() .pragma_update(Some(Main), "wal_checkpoint", &"TRUNCATE")?; @@ -151,6 +166,34 @@ impl SqliteTable { )?; Ok(()) } + + pub fn iter_with_guard<'a>( + &'a self, + guard: &'a Connection, + ) -> Box + 'a> { + let statement = Box::leak(Box::new( + guard + .prepare(&format!( + "SELECT key, value FROM {} ORDER BY key ASC", + &self.name + )) + .unwrap(), + )); + + let statement_ref = NonAliasingBox(statement); + + let iterator = Box::new( + statement + .query_map([], |row| Ok((row.get_unwrap(0), row.get_unwrap(1)))) + .unwrap() + .map(|r| r.unwrap()), + ); + + Box::new(PreparedStatementIterator { + iterator, + statement_ref, + }) + } } impl Tree for SqliteTable { @@ -219,30 +262,9 @@ impl Tree for SqliteTable { #[tracing::instrument(skip(self))] fn iter<'a>(&'a self) -> Box + 'a> { - let guard = self.engine.read_lock(); + let guard = self.engine.read_lock_iterator(); - let statement = Box::leak(Box::new( - guard - .prepare(&format!( - "SELECT key, value FROM {} ORDER BY key ASC", - &self.name - )) - .unwrap(), - )); - - let statement_ref = NonAliasingBox(statement); - - let iterator = Box::new( - statement - .query_map([], |row| Ok((row.get_unwrap(0), row.get_unwrap(1)))) - .unwrap() - .map(|r| r.unwrap()), - ); - - Box::new(PreparedStatementIterator { - iterator, - statement_ref, - }) + self.iter_with_guard(&guard) } #[tracing::instrument(skip(self, from, backwards))] @@ -251,7 +273,7 @@ impl Tree for SqliteTable { from: &[u8], backwards: bool, ) -> Box + 'a> { - let guard = self.engine.read_lock(); + let guard = self.engine.read_lock_iterator(); let from = from.to_vec(); // TODO change interface? if backwards { diff --git a/src/database/rooms.rs b/src/database/rooms.rs index a3a1c41..fc01e8a 100644 --- a/src/database/rooms.rs +++ b/src/database/rooms.rs @@ -24,7 +24,7 @@ use ruma::{ use std::{ collections::{BTreeMap, BTreeSet, HashMap, HashSet}, convert::{TryFrom, TryInto}, - mem, + mem::size_of, sync::{Arc, Mutex}, }; use tokio::sync::MutexGuard; @@ -37,10 +37,11 @@ use super::{abstraction::Tree, admin::AdminCommand, pusher}; /// This is created when a state group is added to the database by /// hashing the entire state. pub type StateHashId = Vec; +pub type CompressedStateEvent = [u8; 2 * size_of::()]; pub struct Rooms { pub edus: edus::RoomEdus, - pub(super) pduid_pdu: Arc, // PduId = RoomId + Count + pub(super) pduid_pdu: Arc, // PduId = ShortRoomId + Count pub(super) eventid_pduid: Arc, pub(super) roomid_pduleaves: Arc, pub(super) alias_roomid: Arc, @@ -79,9 +80,6 @@ pub struct Rooms { pub(super) eventid_shorteventid: Arc, pub(super) statehash_shortstatehash: Arc, - /// ShortStateHash = Count - /// StateId = ShortStateHash - pub(super) stateid_shorteventid: Arc, pub(super) shortstatehash_statediff: Arc, // StateDiff = parent (or 0) + (shortstatekey+shorteventid++) + 0_u64 + (shortstatekey+shorteventid--) /// RoomId + EventId -> outlier PDU. @@ -100,29 +98,30 @@ impl Rooms { /// Builds a StateMap by iterating over all keys that start /// with state_hash, this gives the full state for the given state_hash. pub fn state_full_ids(&self, shortstatehash: u64) -> Result> { - Ok(self - .stateid_shorteventid - .scan_prefix(shortstatehash.to_be_bytes().to_vec()) - .map(|(_, bytes)| { - self.get_eventid_from_short(utils::u64_from_bytes(&bytes).unwrap()) - .ok() - }) - .flatten() - .collect()) + let full_state = self + .load_shortstatehash_info(shortstatehash)? + .pop() + .expect("there is always one layer") + .1; + full_state + .into_iter() + .map(|compressed| self.parse_compressed_state_event(compressed)) + .collect() } pub fn state_full( &self, shortstatehash: u64, ) -> Result>> { - let state = self - .stateid_shorteventid - .scan_prefix(shortstatehash.to_be_bytes().to_vec()) - .map(|(_, bytes)| { - self.get_eventid_from_short(utils::u64_from_bytes(&bytes).unwrap()) - .ok() - }) - .flatten() + let full_state = self + .load_shortstatehash_info(shortstatehash)? + .pop() + .expect("there is always one layer") + .1; + Ok(full_state + .into_iter() + .map(|compressed| self.parse_compressed_state_event(compressed)) + .filter_map(|r| r.ok()) .map(|eventid| self.get_pdu(&eventid)) .filter_map(|r| r.ok().flatten()) .map(|pdu| { @@ -138,9 +137,7 @@ impl Rooms { )) }) .filter_map(|r| r.ok()) - .collect(); - - Ok(state) + .collect()) } /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). @@ -151,27 +148,19 @@ impl Rooms { event_type: &EventType, state_key: &str, ) -> Result> { - let mut key = event_type.as_ref().as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(&state_key.as_bytes()); - - let shortstatekey = self.statekey_shortstatekey.get(&key)?; - - if let Some(shortstatekey) = shortstatekey { - let mut stateid = shortstatehash.to_be_bytes().to_vec(); - stateid.extend_from_slice(&shortstatekey); - - Ok(self - .stateid_shorteventid - .get(&stateid)? - .map(|bytes| { - self.get_eventid_from_short(utils::u64_from_bytes(&bytes).unwrap()) - .ok() - }) - .flatten()) - } else { - Ok(None) - } + let shortstatekey = match self.get_shortstatekey(event_type, state_key)? { + Some(s) => s, + None => return Ok(None), + }; + let full_state = self + .load_shortstatehash_info(shortstatehash)? + .pop() + .expect("there is always one layer") + .1; + Ok(full_state + .into_iter() + .find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes())) + .and_then(|compressed| self.parse_compressed_state_event(compressed).ok())) } /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). @@ -260,8 +249,7 @@ impl Rooms { /// Checks if a room exists. pub fn exists(&self, room_id: &RoomId) -> Result { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); + let prefix = self.get_shortroomid(room_id)?.to_be_bytes().to_vec(); // Look for PDUs in that room. Ok(self @@ -274,8 +262,7 @@ impl Rooms { /// Checks if a room exists. pub fn first_pdu_in_room(&self, room_id: &RoomId) -> Result>> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); + let prefix = self.get_shortroomid(room_id)?.to_be_bytes().to_vec(); // Look for PDUs in that room. self.pduid_pdu @@ -292,74 +279,78 @@ impl Rooms { /// Force the creation of a new StateHash and insert it into the db. /// - /// Whatever `state` is supplied to `force_state` __is__ the current room state snapshot. + /// Whatever `state` is supplied to `force_state` becomes the new current room state snapshot. pub fn force_state( &self, room_id: &RoomId, - state: HashMap<(EventType, String), EventId>, + new_state: HashMap<(EventType, String), EventId>, db: &Database, ) -> Result<()> { + let previous_shortstatehash = self.current_shortstatehash(&room_id)?; + + let new_state_ids_compressed = new_state + .iter() + .filter_map(|((event_type, state_key), event_id)| { + let shortstatekey = self + .get_or_create_shortstatekey(event_type, state_key, &db.globals) + .ok()?; + Some( + self.compress_state_event(shortstatekey, event_id, &db.globals) + .ok()?, + ) + }) + .collect::>(); + let state_hash = self.calculate_hash( - &state + &new_state .values() .map(|event_id| event_id.as_bytes()) .collect::>(), ); - let (shortstatehash, already_existed) = + let (new_shortstatehash, already_existed) = self.get_or_create_shortstatehash(&state_hash, &db.globals)?; - let new_state = if !already_existed { - let mut new_state = HashSet::new(); + if Some(new_shortstatehash) == previous_shortstatehash { + return Ok(()); + } - let batch = state - .iter() - .filter_map(|((event_type, state_key), eventid)| { - new_state.insert(eventid.clone()); + let states_parents = previous_shortstatehash + .map_or_else(|| Ok(Vec::new()), |p| self.load_shortstatehash_info(p))?; - let mut statekey = event_type.as_ref().as_bytes().to_vec(); - statekey.push(0xff); - statekey.extend_from_slice(&state_key.as_bytes()); + let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() + { + let statediffnew = new_state_ids_compressed + .difference(&parent_stateinfo.1) + .cloned() + .collect::>(); - let shortstatekey = match self.statekey_shortstatekey.get(&statekey).ok()? { - Some(shortstatekey) => shortstatekey.to_vec(), - None => { - let shortstatekey = db.globals.next_count().ok()?; - self.statekey_shortstatekey - .insert(&statekey, &shortstatekey.to_be_bytes()) - .ok()?; - shortstatekey.to_be_bytes().to_vec() - } - }; + let statediffremoved = parent_stateinfo + .1 + .difference(&new_state_ids_compressed) + .cloned() + .collect::>(); - let shorteventid = self - .get_or_create_shorteventid(&eventid, &db.globals) - .ok()?; - - let mut state_id = shortstatehash.to_be_bytes().to_vec(); - state_id.extend_from_slice(&shortstatekey); - - Some((state_id, shorteventid.to_be_bytes().to_vec())) - }) - .collect::>(); - - self.stateid_shorteventid - .insert_batch(&mut batch.into_iter())?; - - new_state + (statediffnew, statediffremoved) } else { - self.state_full_ids(shortstatehash)?.into_iter().collect() + (new_state_ids_compressed, HashSet::new()) }; - let old_state = self - .current_shortstatehash(&room_id)? - .map(|s| self.state_full_ids(s)) - .transpose()? - .map(|vec| vec.into_iter().collect::>()) - .unwrap_or_default(); + if !already_existed { + self.save_state_from_diff( + new_shortstatehash, + statediffnew.clone(), + statediffremoved.clone(), + 2, // every state change is 2 event changes on average + states_parents, + )?; + }; - for event_id in new_state.difference(&old_state) { - if let Some(pdu) = self.get_pdu_json(event_id)? { + for event_id in statediffnew + .into_iter() + .filter_map(|new| self.parse_compressed_state_event(new).ok()) + { + if let Some(pdu) = self.get_pdu_json(&event_id)? { if pdu.get("type").and_then(|val| val.as_str()) == Some("m.room.member") { if let Ok(pdu) = serde_json::from_value::( serde_json::to_value(&pdu).expect("CanonicalJsonObj is a valid JsonValue"), @@ -392,7 +383,206 @@ impl Rooms { } self.roomid_shortstatehash - .insert(room_id.as_bytes(), &shortstatehash.to_be_bytes())?; + .insert(room_id.as_bytes(), &new_shortstatehash.to_be_bytes())?; + + Ok(()) + } + + /// Returns a stack with info on shortstatehash, full state, added diff and removed diff for the selected shortstatehash and each parent layer. + pub fn load_shortstatehash_info( + &self, + shortstatehash: u64, + ) -> Result< + Vec<( + u64, // sstatehash + HashSet, // full state + HashSet, // added + HashSet, // removed + )>, + > { + let value = self + .shortstatehash_statediff + .get(&shortstatehash.to_be_bytes())? + .ok_or_else(|| Error::bad_database("State hash does not exist"))?; + let parent = + utils::u64_from_bytes(&value[0..size_of::()]).expect("bytes have right length"); + + let mut add_mode = true; + let mut added = HashSet::new(); + let mut removed = HashSet::new(); + + let mut i = size_of::(); + while let Some(v) = value.get(i..i + 2 * size_of::()) { + if add_mode && v.starts_with(&0_u64.to_be_bytes()) { + add_mode = false; + i += size_of::(); + continue; + } + if add_mode { + added.insert(v.try_into().expect("we checked the size above")); + } else { + removed.insert(v.try_into().expect("we checked the size above")); + } + i += 2 * size_of::(); + } + + if parent != 0_u64 { + let mut response = self.load_shortstatehash_info(parent)?; + let mut state = response.last().unwrap().1.clone(); + state.extend(added.iter().cloned()); + for r in &removed { + state.remove(r); + } + + response.push((shortstatehash, state, added, removed)); + + Ok(response) + } else { + let mut response = Vec::new(); + response.push((shortstatehash, added.clone(), added, removed)); + Ok(response) + } + } + + pub fn compress_state_event( + &self, + shortstatekey: u64, + event_id: &EventId, + globals: &super::globals::Globals, + ) -> Result { + let mut v = shortstatekey.to_be_bytes().to_vec(); + v.extend_from_slice( + &self + .get_or_create_shorteventid(event_id, globals)? + .to_be_bytes(), + ); + Ok(v.try_into().expect("we checked the size above")) + } + + pub fn parse_compressed_state_event( + &self, + compressed_event: CompressedStateEvent, + ) -> Result { + self.get_eventid_from_short( + utils::u64_from_bytes(&compressed_event[size_of::()..]) + .expect("bytes have right length"), + ) + } + + /// Creates a new shortstatehash that often is just a diff to an already existing + /// shortstatehash and therefore very efficient. + /// + /// There are multiple layers of diffs. The bottom layer 0 always contains the full state. Layer + /// 1 contains diffs to states of layer 0, layer 2 diffs to layer 1 and so on. If layer n > 0 + /// grows too big, it will be combined with layer n-1 to create a new diff on layer n-1 that's + /// based on layer n-2. If that layer is also too big, it will recursively fix above layers too. + /// + /// * `shortstatehash` - Shortstatehash of this state + /// * `statediffnew` - Added to base. Each vec is shortstatekey+shorteventid + /// * `statediffremoved` - Removed from base. Each vec is shortstatekey+shorteventid + /// * `diff_to_sibling` - Approximately how much the diff grows each time for this layer + /// * `parent_states` - A stack with info on shortstatehash, full state, added diff and removed diff for each parent layer + pub fn save_state_from_diff( + &self, + shortstatehash: u64, + statediffnew: HashSet, + statediffremoved: HashSet, + diff_to_sibling: usize, + mut parent_states: Vec<( + u64, // sstatehash + HashSet, // full state + HashSet, // added + HashSet, // removed + )>, + ) -> Result<()> { + let diffsum = statediffnew.len() + statediffremoved.len(); + + if parent_states.len() > 3 { + // Number of layers + // To many layers, we have to go deeper + let parent = parent_states.pop().unwrap(); + + let mut parent_new = parent.2; + let mut parent_removed = parent.3; + + for removed in statediffremoved { + if !parent_new.remove(&removed) { + parent_removed.insert(removed); + } + } + parent_new.extend(statediffnew); + + self.save_state_from_diff( + shortstatehash, + parent_new, + parent_removed, + diffsum, + parent_states, + )?; + + return Ok(()); + } + + if parent_states.len() == 0 { + // There is no parent layer, create a new state + let mut value = 0_u64.to_be_bytes().to_vec(); // 0 means no parent + for new in &statediffnew { + value.extend_from_slice(&new[..]); + } + + if !statediffremoved.is_empty() { + warn!("Tried to create new state with removals"); + } + + self.shortstatehash_statediff + .insert(&shortstatehash.to_be_bytes(), &value)?; + + return Ok(()); + }; + + // Else we have two options. + // 1. We add the current diff on top of the parent layer. + // 2. We replace a layer above + + let parent = parent_states.pop().unwrap(); + let parent_diff = parent.2.len() + parent.3.len(); + + if diffsum * diffsum >= 2 * diff_to_sibling * parent_diff { + // Diff too big, we replace above layer(s) + let mut parent_new = parent.2; + let mut parent_removed = parent.3; + + for removed in statediffremoved { + if !parent_new.remove(&removed) { + parent_removed.insert(removed); + } + } + + parent_new.extend(statediffnew); + self.save_state_from_diff( + shortstatehash, + parent_new, + parent_removed, + diffsum, + parent_states, + )?; + } else { + // Diff small enough, we add diff as layer on top of parent + let mut value = parent.0.to_be_bytes().to_vec(); + for new in &statediffnew { + value.extend_from_slice(&new[..]); + } + + if !statediffremoved.is_empty() { + value.extend_from_slice(&0_u64.to_be_bytes()); + for removed in &statediffremoved { + value.extend_from_slice(&removed[..]); + } + } + + self.shortstatehash_statediff + .insert(&shortstatehash.to_be_bytes(), &value)?; + } Ok(()) } @@ -418,7 +608,6 @@ impl Rooms { }) } - /// Returns (shortstatehash, already_existed) pub fn get_or_create_shorteventid( &self, event_id: &EventId, @@ -438,6 +627,71 @@ impl Rooms { }) } + pub fn get_shortroomid(&self, room_id: &RoomId) -> Result { + let bytes = self + .roomid_shortroomid + .get(&room_id.as_bytes())? + .expect("every room has a shortroomid"); + utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid shortroomid in db.")) + } + + pub fn get_shortstatekey( + &self, + event_type: &EventType, + state_key: &str, + ) -> Result> { + let mut statekey = event_type.as_ref().as_bytes().to_vec(); + statekey.push(0xff); + statekey.extend_from_slice(&state_key.as_bytes()); + + self.statekey_shortstatekey + .get(&statekey)? + .map(|shortstatekey| { + utils::u64_from_bytes(&shortstatekey) + .map_err(|_| Error::bad_database("Invalid shortstatekey in db.")) + }) + .transpose() + } + + pub fn get_or_create_shortroomid( + &self, + room_id: &RoomId, + globals: &super::globals::Globals, + ) -> Result { + Ok(match self.roomid_shortroomid.get(&room_id.as_bytes())? { + Some(short) => utils::u64_from_bytes(&short) + .map_err(|_| Error::bad_database("Invalid shortroomid in db."))?, + None => { + let short = globals.next_count()?; + self.roomid_shortroomid + .insert(&room_id.as_bytes(), &short.to_be_bytes())?; + short + } + }) + } + + pub fn get_or_create_shortstatekey( + &self, + event_type: &EventType, + state_key: &str, + globals: &super::globals::Globals, + ) -> Result { + let mut statekey = event_type.as_ref().as_bytes().to_vec(); + statekey.push(0xff); + statekey.extend_from_slice(&state_key.as_bytes()); + + Ok(match self.statekey_shortstatekey.get(&statekey)? { + Some(shortstatekey) => utils::u64_from_bytes(&shortstatekey) + .map_err(|_| Error::bad_database("Invalid shortstatekey in db."))?, + None => { + let shortstatekey = globals.next_count()?; + self.statekey_shortstatekey + .insert(&statekey, &shortstatekey.to_be_bytes())?; + shortstatekey + } + }) + } + pub fn get_eventid_from_short(&self, shorteventid: u64) -> Result { if let Some(id) = self .shorteventid_cache @@ -514,7 +768,7 @@ impl Rooms { #[tracing::instrument(skip(self))] pub fn pdu_count(&self, pdu_id: &[u8]) -> Result { Ok( - utils::u64_from_bytes(&pdu_id[pdu_id.len() - mem::size_of::()..pdu_id.len()]) + utils::u64_from_bytes(&pdu_id[pdu_id.len() - size_of::()..]) .map_err(|_| Error::bad_database("PDU has invalid count bytes."))?, ) } @@ -527,8 +781,7 @@ impl Rooms { } pub fn latest_pdu_count(&self, room_id: &RoomId) -> Result { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); + let prefix = self.get_shortroomid(room_id)?.to_be_bytes().to_vec(); let mut last_possible_key = prefix.clone(); last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); @@ -758,6 +1011,8 @@ impl Rooms { /// /// By this point the incoming event should be fully authenticated, no auth happens /// in `append_pdu`. + /// + /// Returns pdu id #[tracing::instrument(skip(self, pdu, pdu_json, leaves, db))] pub fn append_pdu( &self, @@ -766,7 +1021,8 @@ impl Rooms { leaves: &[EventId], db: &Database, ) -> Result> { - // returns pdu id + let shortroomid = self.get_shortroomid(&pdu.room_id)?; + // Make unsigned fields correct. This is not properly documented in the spec, but state // events need to have previous content in the unsigned field, so clients can easily // interpret things like membership changes @@ -821,8 +1077,7 @@ impl Rooms { self.reset_notification_counts(&pdu.sender, &pdu.room_id)?; let count2 = db.globals.next_count()?; - let mut pdu_id = pdu.room_id.as_bytes().to_vec(); - pdu_id.push(0xff); + let mut pdu_id = shortroomid.to_be_bytes().to_vec(); pdu_id.extend_from_slice(&count2.to_be_bytes()); // There's a brief moment of time here where the count is updated but the pdu does not @@ -968,8 +1223,7 @@ impl Rooms { .filter(|word| word.len() <= 50) .map(str::to_lowercase) .map(|word| { - let mut key = pdu.room_id.as_bytes().to_vec(); - key.push(0xff); + let mut key = shortroomid.to_be_bytes().to_vec(); key.extend_from_slice(word.as_bytes()); key.push(0xff); key.extend_from_slice(&pdu_id); @@ -1152,11 +1406,27 @@ impl Rooms { pub fn set_event_state( &self, event_id: &EventId, + room_id: &RoomId, state: &StateMap>, globals: &super::globals::Globals, ) -> Result<()> { let shorteventid = self.get_or_create_shorteventid(&event_id, globals)?; + let previous_shortstatehash = self.current_shortstatehash(&room_id)?; + + let state_ids_compressed = state + .iter() + .filter_map(|((event_type, state_key), pdu)| { + let shortstatekey = self + .get_or_create_shortstatekey(event_type, state_key, globals) + .ok()?; + Some( + self.compress_state_event(shortstatekey, &pdu.event_id, globals) + .ok()?, + ) + }) + .collect::>(); + let state_hash = self.calculate_hash( &state .values() @@ -1168,37 +1438,33 @@ impl Rooms { self.get_or_create_shortstatehash(&state_hash, globals)?; if !already_existed { - let batch = state - .iter() - .filter_map(|((event_type, state_key), pdu)| { - let mut statekey = event_type.as_ref().as_bytes().to_vec(); - statekey.push(0xff); - statekey.extend_from_slice(&state_key.as_bytes()); + let states_parents = previous_shortstatehash + .map_or_else(|| Ok(Vec::new()), |p| self.load_shortstatehash_info(p))?; - let shortstatekey = match self.statekey_shortstatekey.get(&statekey).ok()? { - Some(shortstatekey) => shortstatekey.to_vec(), - None => { - let shortstatekey = globals.next_count().ok()?; - self.statekey_shortstatekey - .insert(&statekey, &shortstatekey.to_be_bytes()) - .ok()?; - shortstatekey.to_be_bytes().to_vec() - } - }; + let (statediffnew, statediffremoved) = + if let Some(parent_stateinfo) = states_parents.last() { + let statediffnew = state_ids_compressed + .difference(&parent_stateinfo.1) + .cloned() + .collect::>(); - let shorteventid = self - .get_or_create_shorteventid(&pdu.event_id, globals) - .ok()?; + let statediffremoved = parent_stateinfo + .1 + .difference(&state_ids_compressed) + .cloned() + .collect::>(); - let mut state_id = shortstatehash.to_be_bytes().to_vec(); - state_id.extend_from_slice(&shortstatekey); - - Some((state_id, shorteventid.to_be_bytes().to_vec())) - }) - .collect::>(); - - self.stateid_shorteventid - .insert_batch(&mut batch.into_iter())?; + (statediffnew, statediffremoved) + } else { + (state_ids_compressed, HashSet::new()) + }; + self.save_state_from_diff( + shortstatehash, + statediffnew.clone(), + statediffremoved.clone(), + 1_000_000, // high number because no state will be based on this one + states_parents, + )?; } self.shorteventid_shortstatehash @@ -1219,82 +1485,52 @@ impl Rooms { ) -> Result { let shorteventid = self.get_or_create_shorteventid(&new_pdu.event_id, globals)?; - let old_state = if let Some(old_shortstatehash) = - self.roomid_shortstatehash.get(new_pdu.room_id.as_bytes())? - { - // Store state for event. The state does not include the event itself. - // Instead it's the state before the pdu, so the room's old state. + let previous_shortstatehash = self.current_shortstatehash(&new_pdu.room_id)?; + + if let Some(p) = previous_shortstatehash { self.shorteventid_shortstatehash - .insert(&shorteventid.to_be_bytes(), &old_shortstatehash)?; - - if new_pdu.state_key.is_none() { - return utils::u64_from_bytes(&old_shortstatehash).map_err(|_| { - Error::bad_database("Invalid shortstatehash in roomid_shortstatehash.") - }); - } - - self.stateid_shorteventid - .scan_prefix(old_shortstatehash.clone()) - // Chop the old_shortstatehash out leaving behind the short state key - .map(|(k, v)| (k[old_shortstatehash.len()..].to_vec(), v)) - .collect::, Vec>>() - } else { - HashMap::new() - }; + .insert(&shorteventid.to_be_bytes(), &p.to_be_bytes())?; + } if let Some(state_key) = &new_pdu.state_key { - let mut new_state: HashMap, Vec> = old_state; + let states_parents = previous_shortstatehash + .map_or_else(|| Ok(Vec::new()), |p| self.load_shortstatehash_info(p))?; - let mut new_state_key = new_pdu.kind.as_ref().as_bytes().to_vec(); - new_state_key.push(0xff); - new_state_key.extend_from_slice(state_key.as_bytes()); + let shortstatekey = + self.get_or_create_shortstatekey(&new_pdu.kind, &state_key, globals)?; - let shortstatekey = match self.statekey_shortstatekey.get(&new_state_key)? { - Some(shortstatekey) => shortstatekey.to_vec(), - None => { - let shortstatekey = globals.next_count()?; - self.statekey_shortstatekey - .insert(&new_state_key, &shortstatekey.to_be_bytes())?; - shortstatekey.to_be_bytes().to_vec() - } - }; + let replaces = states_parents + .last() + .map(|info| { + info.1 + .iter() + .find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes())) + }) + .unwrap_or_default(); - new_state.insert(shortstatekey, shorteventid.to_be_bytes().to_vec()); + // TODO: statehash with deterministic inputs + let shortstatehash = globals.next_count()?; - let new_state_hash = self.calculate_hash( - &new_state - .values() - .map(|event_id| &**event_id) - .collect::>(), - ); + let mut statediffnew = HashSet::new(); + let new = self.compress_state_event(shortstatekey, &new_pdu.event_id, globals)?; + statediffnew.insert(new); - let shortstatehash = match self.statehash_shortstatehash.get(&new_state_hash)? { - Some(shortstatehash) => { - warn!("state hash already existed?!"); - utils::u64_from_bytes(&shortstatehash) - .map_err(|_| Error::bad_database("PDU has invalid count bytes."))? - } - None => { - let shortstatehash = globals.next_count()?; - self.statehash_shortstatehash - .insert(&new_state_hash, &shortstatehash.to_be_bytes())?; - shortstatehash - } - }; + let mut statediffremoved = HashSet::new(); + if let Some(replaces) = replaces { + statediffremoved.insert(replaces.clone()); + } - let mut batch = new_state.into_iter().map(|(shortstatekey, shorteventid)| { - let mut state_id = shortstatehash.to_be_bytes().to_vec(); - state_id.extend_from_slice(&shortstatekey); - (state_id, shorteventid) - }); - - self.stateid_shorteventid.insert_batch(&mut batch)?; + self.save_state_from_diff( + shortstatehash, + statediffnew, + statediffremoved, + 2, + states_parents, + )?; Ok(shortstatehash) } else { - Err(Error::bad_database( - "Tried to insert non-state event into room without a state.", - )) + Ok(previous_shortstatehash.expect("first event in room must be a state event")) } } @@ -1597,7 +1833,7 @@ impl Rooms { &'a self, user_id: &UserId, room_id: &RoomId, - ) -> impl Iterator, PduEvent)>> + 'a { + ) -> Result, PduEvent)>> + 'a> { self.pdus_since(user_id, room_id, 0) } @@ -1609,16 +1845,17 @@ impl Rooms { user_id: &UserId, room_id: &RoomId, since: u64, - ) -> impl Iterator, PduEvent)>> + 'a { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); + ) -> Result, PduEvent)>> + 'a> { + let prefix = self.get_shortroomid(room_id)?.to_be_bytes().to_vec(); // Skip the first pdu if it's exactly at since, because we sent that last time let mut first_pdu_id = prefix.clone(); first_pdu_id.extend_from_slice(&(since + 1).to_be_bytes()); let user_id = user_id.clone(); - self.pduid_pdu + + Ok(self + .pduid_pdu .iter_from(&first_pdu_id, false) .take_while(move |(k, _)| k.starts_with(&prefix)) .map(move |(pdu_id, v)| { @@ -1628,7 +1865,7 @@ impl Rooms { pdu.unsigned.remove("transaction_id"); } Ok((pdu_id, pdu)) - }) + })) } /// Returns an iterator over all events and their tokens in a room that happened before the @@ -1639,10 +1876,9 @@ impl Rooms { user_id: &UserId, room_id: &RoomId, until: u64, - ) -> impl Iterator, PduEvent)>> + 'a { + ) -> Result, PduEvent)>> + 'a> { // Create the first part of the full pdu id - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); + let prefix = self.get_shortroomid(room_id)?.to_be_bytes().to_vec(); let mut current = prefix.clone(); current.extend_from_slice(&(until.saturating_sub(1)).to_be_bytes()); // -1 because we don't want event at `until` @@ -1650,7 +1886,9 @@ impl Rooms { let current: &[u8] = ¤t; let user_id = user_id.clone(); - self.pduid_pdu + + Ok(self + .pduid_pdu .iter_from(current, true) .take_while(move |(k, _)| k.starts_with(&prefix)) .map(move |(pdu_id, v)| { @@ -1660,7 +1898,7 @@ impl Rooms { pdu.unsigned.remove("transaction_id"); } Ok((pdu_id, pdu)) - }) + })) } /// Returns an iterator over all events and their token in a room that happened after the event @@ -1671,10 +1909,9 @@ impl Rooms { user_id: &UserId, room_id: &RoomId, from: u64, - ) -> impl Iterator, PduEvent)>> + 'a { + ) -> Result, PduEvent)>> + 'a> { // Create the first part of the full pdu id - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); + let prefix = self.get_shortroomid(room_id)?.to_be_bytes().to_vec(); let mut current = prefix.clone(); current.extend_from_slice(&(from + 1).to_be_bytes()); // +1 so we don't send the base event @@ -1682,7 +1919,9 @@ impl Rooms { let current: &[u8] = ¤t; let user_id = user_id.clone(); - self.pduid_pdu + + Ok(self + .pduid_pdu .iter_from(current, false) .take_while(move |(k, _)| k.starts_with(&prefix)) .map(move |(pdu_id, v)| { @@ -1692,7 +1931,7 @@ impl Rooms { pdu.unsigned.remove("transaction_id"); } Ok((pdu_id, pdu)) - }) + })) } /// Replace a PDU with the redacted form. @@ -2223,8 +2462,8 @@ impl Rooms { room_id: &RoomId, search_string: &str, ) -> Result<(impl Iterator> + 'a, Vec)> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); + let prefix = self.get_shortroomid(room_id)?.to_be_bytes().to_vec(); + let prefix_clone = prefix.clone(); let words = search_string .split_terminator(|c: char| !c.is_alphanumeric()) @@ -2243,16 +2482,7 @@ impl Rooms { .iter_from(&last_possible_id, true) // Newest pdus first .take_while(move |(k, _)| k.starts_with(&prefix2)) .map(|(key, _)| { - let pduid_index = key - .iter() - .enumerate() - .filter(|(_, &b)| b == 0xff) - .nth(1) - .ok_or_else(|| Error::bad_database("Invalid tokenid in db."))? - .0 - + 1; // +1 because the pdu id starts AFTER the separator - - let pdu_id = key[pduid_index..].to_vec(); + let pdu_id = key[key.len() - size_of::()..].to_vec(); Ok::<_, Error>(pdu_id) }) @@ -2264,7 +2494,12 @@ impl Rooms { // We compare b with a because we reversed the iterator earlier b.cmp(a) }) - .unwrap(), + .unwrap() + .map(move |id| { + let mut pduid = prefix_clone.clone(); + pduid.extend_from_slice(&id); + pduid + }), words, )) } diff --git a/src/server_server.rs b/src/server_server.rs index b3f0353..0c226ac 100644 --- a/src/server_server.rs +++ b/src/server_server.rs @@ -1704,7 +1704,7 @@ fn append_incoming_pdu( // We append to state before appending the pdu, so we don't have a moment in time with the // pdu without it's state. This is okay because append_pdu can't fail. db.rooms - .set_event_state(&pdu.event_id, state, &db.globals)?; + .set_event_state(&pdu.event_id, &pdu.room_id, state, &db.globals)?; let pdu_id = db.rooms.append_pdu( pdu,