finish implementing better state store

next
Timo Kösters 2021-08-12 23:04:00 +02:00
parent 31f60ad6fd
commit 3eabaa2a95
No known key found for this signature in database
GPG Key ID: 356E705610F626D5
10 changed files with 645 additions and 526 deletions

View File

@ -249,6 +249,8 @@ pub async fn register_route(
let room_id = RoomId::new(db.globals.server_name()); let room_id = RoomId::new(db.globals.server_name());
db.rooms.get_or_create_shortroomid(&room_id, &db.globals)?;
let mutex_state = Arc::clone( let mutex_state = Arc::clone(
db.globals db.globals
.roomid_mutex_state .roomid_mutex_state

View File

@ -44,7 +44,7 @@ pub async fn get_context_route(
let events_before = db let events_before = db
.rooms .rooms
.pdus_until(&sender_user, &body.room_id, base_token) .pdus_until(&sender_user, &body.room_id, base_token)?
.take( .take(
u32::try_from(body.limit).map_err(|_| { u32::try_from(body.limit).map_err(|_| {
Error::BadRequest(ErrorKind::InvalidParam, "Limit value is invalid.") Error::BadRequest(ErrorKind::InvalidParam, "Limit value is invalid.")
@ -66,7 +66,7 @@ pub async fn get_context_route(
let events_after = db let events_after = db
.rooms .rooms
.pdus_after(&sender_user, &body.room_id, base_token) .pdus_after(&sender_user, &body.room_id, base_token)?
.take( .take(
u32::try_from(body.limit).map_err(|_| { u32::try_from(body.limit).map_err(|_| {
Error::BadRequest(ErrorKind::InvalidParam, "Limit value is invalid.") Error::BadRequest(ErrorKind::InvalidParam, "Limit value is invalid.")

View File

@ -609,6 +609,8 @@ async fn join_room_by_id_helper(
) )
.await?; .await?;
db.rooms.get_or_create_shortroomid(&room_id, &db.globals)?;
let pdu = PduEvent::from_id_val(&event_id, join_event.clone()) let pdu = PduEvent::from_id_val(&event_id, join_event.clone())
.map_err(|_| Error::BadServerResponse("Invalid join event PDU."))?; .map_err(|_| Error::BadServerResponse("Invalid join event PDU."))?;

View File

@ -128,7 +128,7 @@ pub async fn get_message_events_route(
get_message_events::Direction::Forward => { get_message_events::Direction::Forward => {
let events_after = db let events_after = db
.rooms .rooms
.pdus_after(&sender_user, &body.room_id, from) .pdus_after(&sender_user, &body.room_id, from)?
.take(limit) .take(limit)
.filter_map(|r| r.ok()) // Filter out buggy events .filter_map(|r| r.ok()) // Filter out buggy events
.filter_map(|(pdu_id, pdu)| { .filter_map(|(pdu_id, pdu)| {
@ -158,7 +158,7 @@ pub async fn get_message_events_route(
get_message_events::Direction::Backward => { get_message_events::Direction::Backward => {
let events_before = db let events_before = db
.rooms .rooms
.pdus_until(&sender_user, &body.room_id, from) .pdus_until(&sender_user, &body.room_id, from)?
.take(limit) .take(limit)
.filter_map(|r| r.ok()) // Filter out buggy events .filter_map(|r| r.ok()) // Filter out buggy events
.filter_map(|(pdu_id, pdu)| { .filter_map(|(pdu_id, pdu)| {

View File

@ -33,6 +33,8 @@ pub async fn create_room_route(
let room_id = RoomId::new(db.globals.server_name()); let room_id = RoomId::new(db.globals.server_name());
db.rooms.get_or_create_shortroomid(&room_id, &db.globals)?;
let mutex_state = Arc::clone( let mutex_state = Arc::clone(
db.globals db.globals
.roomid_mutex_state .roomid_mutex_state
@ -173,7 +175,6 @@ pub async fn create_room_route(
)?; )?;
// 4. Canonical room alias // 4. Canonical room alias
if let Some(room_alias_id) = &alias { if let Some(room_alias_id) = &alias {
db.rooms.build_and_append_pdu( db.rooms.build_and_append_pdu(
PduBuilder { PduBuilder {
@ -193,7 +194,7 @@ pub async fn create_room_route(
&room_id, &room_id,
&db, &db,
&state_lock, &state_lock,
); )?;
} }
// 5. Events set by preset // 5. Events set by preset

View File

@ -205,7 +205,7 @@ async fn sync_helper(
let mut non_timeline_pdus = db let mut non_timeline_pdus = db
.rooms .rooms
.pdus_until(&sender_user, &room_id, u64::MAX) .pdus_until(&sender_user, &room_id, u64::MAX)?
.filter_map(|r| { .filter_map(|r| {
// Filter out buggy events // Filter out buggy events
if r.is_err() { if r.is_err() {
@ -248,13 +248,13 @@ async fn sync_helper(
let first_pdu_before_since = db let first_pdu_before_since = db
.rooms .rooms
.pdus_until(&sender_user, &room_id, since) .pdus_until(&sender_user, &room_id, since)?
.next() .next()
.transpose()?; .transpose()?;
let pdus_after_since = db let pdus_after_since = db
.rooms .rooms
.pdus_after(&sender_user, &room_id, since) .pdus_after(&sender_user, &room_id, since)?
.next() .next()
.is_some(); .is_some();
@ -286,7 +286,7 @@ async fn sync_helper(
for hero in db for hero in db
.rooms .rooms
.all_pdus(&sender_user, &room_id) .all_pdus(&sender_user, &room_id)?
.filter_map(|pdu| pdu.ok()) // Ignore all broken pdus .filter_map(|pdu| pdu.ok()) // Ignore all broken pdus
.filter(|(_, pdu)| pdu.kind == EventType::RoomMember) .filter(|(_, pdu)| pdu.kind == EventType::RoomMember)
.map(|(_, pdu)| { .map(|(_, pdu)| {
@ -328,11 +328,11 @@ async fn sync_helper(
} }
} }
( Ok::<_, Error>((
Some(joined_member_count), Some(joined_member_count),
Some(invited_member_count), Some(invited_member_count),
heroes, heroes,
) ))
}; };
let ( let (
@ -343,7 +343,7 @@ async fn sync_helper(
state_events, state_events,
) = if since_shortstatehash.is_none() { ) = if since_shortstatehash.is_none() {
// Probably since = 0, we will do an initial sync // Probably since = 0, we will do an initial sync
let (joined_member_count, invited_member_count, heroes) = calculate_counts(); let (joined_member_count, invited_member_count, heroes) = calculate_counts()?;
let current_state_ids = db.rooms.state_full_ids(current_shortstatehash)?; let current_state_ids = db.rooms.state_full_ids(current_shortstatehash)?;
let state_events = current_state_ids let state_events = current_state_ids
@ -510,7 +510,7 @@ async fn sync_helper(
} }
let (joined_member_count, invited_member_count, heroes) = if send_member_count { let (joined_member_count, invited_member_count, heroes) = if send_member_count {
calculate_counts() calculate_counts()?
} else { } else {
(None, None, Vec::new()) (None, None, Vec::new())
}; };

View File

@ -28,7 +28,7 @@ use ruma::{DeviceId, EventId, RoomId, ServerName, UserId};
use serde::{de::IgnoredAny, Deserialize}; use serde::{de::IgnoredAny, Deserialize};
use std::{ use std::{
collections::{BTreeMap, HashMap, HashSet}, collections::{BTreeMap, HashMap, HashSet},
convert::TryFrom, convert::{TryFrom, TryInto},
fs::{self, remove_dir_all}, fs::{self, remove_dir_all},
io::Write, io::Write,
mem::size_of, mem::size_of,
@ -266,7 +266,6 @@ impl Database {
shortroomid_roomid: builder.open_tree("shortroomid_roomid")?, shortroomid_roomid: builder.open_tree("shortroomid_roomid")?,
roomid_shortroomid: builder.open_tree("roomid_shortroomid")?, roomid_shortroomid: builder.open_tree("roomid_shortroomid")?,
stateid_shorteventid: builder.open_tree("stateid_shorteventid")?,
shortstatehash_statediff: builder.open_tree("shortstatehash_statediff")?, shortstatehash_statediff: builder.open_tree("shortstatehash_statediff")?,
eventid_shorteventid: builder.open_tree("eventid_shorteventid")?, eventid_shorteventid: builder.open_tree("eventid_shorteventid")?,
shorteventid_eventid: builder.open_tree("shorteventid_eventid")?, shorteventid_eventid: builder.open_tree("shorteventid_eventid")?,
@ -431,7 +430,6 @@ impl Database {
} }
if db.globals.database_version()? < 6 { if db.globals.database_version()? < 6 {
// TODO update to 6
// Set room member count // Set room member count
for (roomid, _) in db.rooms.roomid_shortstatehash.iter() { for (roomid, _) in db.rooms.roomid_shortstatehash.iter() {
let room_id = let room_id =
@ -445,263 +443,98 @@ impl Database {
println!("Migration: 5 -> 6 finished"); println!("Migration: 5 -> 6 finished");
} }
fn load_shortstatehash_info(
shortstatehash: &[u8],
db: &Database,
lru: &mut LruCache<
Vec<u8>,
Vec<(
Vec<u8>,
HashSet<Vec<u8>>,
HashSet<Vec<u8>>,
HashSet<Vec<u8>>,
)>,
>,
) -> Result<
Vec<(
Vec<u8>, // sstatehash
HashSet<Vec<u8>>, // full state
HashSet<Vec<u8>>, // added
HashSet<Vec<u8>>, // removed
)>,
> {
if let Some(result) = lru.get_mut(shortstatehash) {
return Ok(result.clone());
}
let value = db
.rooms
.shortstatehash_statediff
.get(shortstatehash)?
.ok_or_else(|| Error::bad_database("State hash does not exist"))?;
let parent = value[0..size_of::<u64>()].to_vec();
let mut add_mode = true;
let mut added = HashSet::new();
let mut removed = HashSet::new();
let mut i = size_of::<u64>();
while let Some(v) = value.get(i..i + 2 * size_of::<u64>()) {
if add_mode && v.starts_with(&0_u64.to_be_bytes()) {
add_mode = false;
i += size_of::<u64>();
continue;
}
if add_mode {
added.insert(v.to_vec());
} else {
removed.insert(v.to_vec());
}
i += 2 * size_of::<u64>();
}
if parent != 0_u64.to_be_bytes() {
let mut response = load_shortstatehash_info(&parent, db, lru)?;
let mut state = response.last().unwrap().1.clone();
state.extend(added.iter().cloned());
for r in &removed {
state.remove(r);
}
response.push((shortstatehash.to_vec(), state, added, removed));
lru.insert(shortstatehash.to_vec(), response.clone());
Ok(response)
} else {
let mut response = Vec::new();
response.push((shortstatehash.to_vec(), added.clone(), added, removed));
lru.insert(shortstatehash.to_vec(), response.clone());
Ok(response)
}
}
fn update_shortstatehash_level(
current_shortstatehash: &[u8],
statediffnew: HashSet<Vec<u8>>,
statediffremoved: HashSet<Vec<u8>>,
diff_to_sibling: usize,
mut parent_states: Vec<(
Vec<u8>, // sstatehash
HashSet<Vec<u8>>, // full state
HashSet<Vec<u8>>, // added
HashSet<Vec<u8>>, // removed
)>,
db: &Database,
) -> Result<()> {
let diffsum = statediffnew.len() + statediffremoved.len();
if parent_states.len() > 3 {
// Number of layers
// To many layers, we have to go deeper
let parent = parent_states.pop().unwrap();
let mut parent_new = parent.2;
let mut parent_removed = parent.3;
for removed in statediffremoved {
if !parent_new.remove(&removed) {
parent_removed.insert(removed);
}
}
parent_new.extend(statediffnew);
update_shortstatehash_level(
current_shortstatehash,
parent_new,
parent_removed,
diffsum,
parent_states,
db,
)?;
return Ok(());
}
if parent_states.len() == 0 {
// There is no parent layer, create a new state
let mut value = 0_u64.to_be_bytes().to_vec(); // 0 means no parent
for new in &statediffnew {
value.extend_from_slice(&new);
}
if !statediffremoved.is_empty() {
warn!("Tried to create new state with removals");
}
db.rooms
.shortstatehash_statediff
.insert(&current_shortstatehash, &value)?;
return Ok(());
};
// Else we have two options.
// 1. We add the current diff on top of the parent layer.
// 2. We replace a layer above
let parent = parent_states.pop().unwrap();
let parent_diff = parent.2.len() + parent.3.len();
if diffsum * diffsum >= 2 * diff_to_sibling * parent_diff {
// Diff too big, we replace above layer(s)
let mut parent_new = parent.2;
let mut parent_removed = parent.3;
for removed in statediffremoved {
if !parent_new.remove(&removed) {
parent_removed.insert(removed);
}
}
parent_new.extend(statediffnew);
update_shortstatehash_level(
current_shortstatehash,
parent_new,
parent_removed,
diffsum,
parent_states,
db,
)?;
} else {
// Diff small enough, we add diff as layer on top of parent
let mut value = parent.0.clone();
for new in &statediffnew {
value.extend_from_slice(&new);
}
if !statediffremoved.is_empty() {
value.extend_from_slice(&0_u64.to_be_bytes());
for removed in &statediffremoved {
value.extend_from_slice(&removed);
}
}
db.rooms
.shortstatehash_statediff
.insert(&current_shortstatehash, &value)?;
}
Ok(())
}
if db.globals.database_version()? < 7 { if db.globals.database_version()? < 7 {
// Upgrade state store // Upgrade state store
let mut lru = LruCache::new(1000); let mut last_roomstates: HashMap<RoomId, u64> = HashMap::new();
let mut last_roomstates: HashMap<RoomId, Vec<u8>> = HashMap::new(); let mut current_sstatehash: Option<u64> = None;
let mut current_sstatehash: Vec<u8> = Vec::new();
let mut current_room = None; let mut current_room = None;
let mut current_state = HashSet::new(); let mut current_state = HashSet::new();
let mut counter = 0; let mut counter = 0;
let mut handle_state =
|current_sstatehash: u64,
current_room: &RoomId,
current_state: HashSet<_>,
last_roomstates: &mut HashMap<_, _>| {
counter += 1;
println!("counter: {}", counter);
let last_roomsstatehash = last_roomstates.get(current_room);
let states_parents = last_roomsstatehash.map_or_else(
|| Ok(Vec::new()),
|&last_roomsstatehash| {
db.rooms.load_shortstatehash_info(dbg!(last_roomsstatehash))
},
)?;
let (statediffnew, statediffremoved) =
if let Some(parent_stateinfo) = states_parents.last() {
let statediffnew = current_state
.difference(&parent_stateinfo.1)
.cloned()
.collect::<HashSet<_>>();
let statediffremoved = parent_stateinfo
.1
.difference(&current_state)
.cloned()
.collect::<HashSet<_>>();
(statediffnew, statediffremoved)
} else {
(current_state, HashSet::new())
};
db.rooms.save_state_from_diff(
dbg!(current_sstatehash),
statediffnew,
statediffremoved,
2, // every state change is 2 event changes on average
states_parents,
)?;
/*
let mut tmp = db.rooms.load_shortstatehash_info(&current_sstatehash, &db)?;
let state = tmp.pop().unwrap();
println!(
"{}\t{}{:?}: {:?} + {:?} - {:?}",
current_room,
" ".repeat(tmp.len()),
utils::u64_from_bytes(&current_sstatehash).unwrap(),
tmp.last().map(|b| utils::u64_from_bytes(&b.0).unwrap()),
state
.2
.iter()
.map(|b| utils::u64_from_bytes(&b[size_of::<u64>()..]).unwrap())
.collect::<Vec<_>>(),
state
.3
.iter()
.map(|b| utils::u64_from_bytes(&b[size_of::<u64>()..]).unwrap())
.collect::<Vec<_>>()
);
*/
Ok::<_, Error>(())
};
for (k, seventid) in db._db.open_tree("stateid_shorteventid")?.iter() { for (k, seventid) in db._db.open_tree("stateid_shorteventid")?.iter() {
let sstatehash = k[0..size_of::<u64>()].to_vec(); let sstatehash = utils::u64_from_bytes(&k[0..size_of::<u64>()])
.expect("number of bytes is correct");
let sstatekey = k[size_of::<u64>()..].to_vec(); let sstatekey = k[size_of::<u64>()..].to_vec();
if sstatehash != current_sstatehash { if Some(sstatehash) != current_sstatehash {
if !current_sstatehash.is_empty() { if let Some(current_sstatehash) = current_sstatehash {
counter += 1; handle_state(
println!("counter: {}", counter); current_sstatehash,
let current_room = current_room.as_ref().unwrap(); current_room.as_ref().unwrap(),
let last_roomsstatehash = last_roomstates.get(&current_room); current_state,
&mut last_roomstates,
let states_parents = last_roomsstatehash.map_or_else(
|| Ok(Vec::new()),
|last_roomsstatehash| {
load_shortstatehash_info(&last_roomsstatehash, &db, &mut lru)
},
)?; )?;
last_roomstates
let (statediffnew, statediffremoved) = .insert(current_room.clone().unwrap(), current_sstatehash);
if let Some(parent_stateinfo) = states_parents.last() {
let statediffnew = current_state
.difference(&parent_stateinfo.1)
.cloned()
.collect::<HashSet<_>>();
let statediffremoved = parent_stateinfo
.1
.difference(&current_state)
.cloned()
.collect::<HashSet<_>>();
(statediffnew, statediffremoved)
} else {
(current_state, HashSet::new())
};
update_shortstatehash_level(
&current_sstatehash,
statediffnew,
statediffremoved,
2, // every state change is 2 event changes on average
states_parents,
&db,
)?;
/*
let mut tmp = load_shortstatehash_info(&current_sstatehash, &db)?;
let state = tmp.pop().unwrap();
println!(
"{}\t{}{:?}: {:?} + {:?} - {:?}",
current_room,
" ".repeat(tmp.len()),
utils::u64_from_bytes(&current_sstatehash).unwrap(),
tmp.last().map(|b| utils::u64_from_bytes(&b.0).unwrap()),
state
.2
.iter()
.map(|b| utils::u64_from_bytes(&b[size_of::<u64>()..]).unwrap())
.collect::<Vec<_>>(),
state
.3
.iter()
.map(|b| utils::u64_from_bytes(&b[size_of::<u64>()..]).unwrap())
.collect::<Vec<_>>()
);
*/
last_roomstates.insert(current_room.clone(), current_sstatehash);
} }
current_state = HashSet::new(); current_state = HashSet::new();
current_sstatehash = sstatehash; current_sstatehash = Some(sstatehash);
let event_id = db let event_id = db
.rooms .rooms
@ -721,7 +554,16 @@ impl Database {
let mut val = sstatekey; let mut val = sstatekey;
val.extend_from_slice(&seventid); val.extend_from_slice(&seventid);
current_state.insert(val); current_state.insert(val.try_into().expect("size is correct"));
}
if let Some(current_sstatehash) = current_sstatehash {
handle_state(
current_sstatehash,
current_room.as_ref().unwrap(),
current_state,
&mut last_roomstates,
)?;
} }
db.globals.bump_database_version(7)?; db.globals.bump_database_version(7)?;
@ -761,11 +603,28 @@ impl Database {
db.rooms.pduid_pdu.insert_batch(&mut batch)?; db.rooms.pduid_pdu.insert_batch(&mut batch)?;
for (key, _) in db.rooms.pduid_pdu.iter() { let mut batch2 = db.rooms.eventid_pduid.iter().filter_map(|(k, value)| {
if key.starts_with(b"!") { if !value.starts_with(b"!") {
db.rooms.pduid_pdu.remove(&key); return None;
} }
} let mut parts = value.splitn(2, |&b| b == 0xff);
let room_id = parts.next().unwrap();
let count = parts.next().unwrap();
let short_room_id = db
.rooms
.roomid_shortroomid
.get(&room_id)
.unwrap()
.expect("shortroomid should exist");
let mut new_value = short_room_id;
new_value.extend_from_slice(count);
Some((k, new_value))
});
db.rooms.eventid_pduid.insert_batch(&mut batch2)?;
db.globals.bump_database_version(8)?; db.globals.bump_database_version(8)?;
@ -803,7 +662,7 @@ impl Database {
for (key, _) in db.rooms.tokenids.iter() { for (key, _) in db.rooms.tokenids.iter() {
if key.starts_with(b"!") { if key.starts_with(b"!") {
db.rooms.pduid_pdu.remove(&key)?; db.rooms.tokenids.remove(&key)?;
} }
} }
@ -811,8 +670,6 @@ impl Database {
println!("Migration: 8 -> 9 finished"); println!("Migration: 8 -> 9 finished");
} }
panic!();
} }
let guard = db.read().await; let guard = db.read().await;

View File

@ -9,13 +9,13 @@ use std::{
path::{Path, PathBuf}, path::{Path, PathBuf},
pin::Pin, pin::Pin,
sync::Arc, sync::Arc,
time::{Duration, Instant},
}; };
use tokio::sync::oneshot::Sender; use tokio::sync::oneshot::Sender;
use tracing::debug; use tracing::debug;
thread_local! { thread_local! {
static READ_CONNECTION: RefCell<Option<&'static Connection>> = RefCell::new(None); static READ_CONNECTION: RefCell<Option<&'static Connection>> = RefCell::new(None);
static READ_CONNECTION_ITERATOR: RefCell<Option<&'static Connection>> = RefCell::new(None);
} }
struct PreparedStatementIterator<'a> { struct PreparedStatementIterator<'a> {
@ -77,6 +77,21 @@ impl Engine {
}) })
} }
fn read_lock_iterator(&self) -> &'static Connection {
READ_CONNECTION_ITERATOR.with(|cell| {
let connection = &mut cell.borrow_mut();
if (*connection).is_none() {
let c = Box::leak(Box::new(
Self::prepare_conn(&self.path, self.cache_size_per_thread).unwrap(),
));
**connection = Some(c);
}
connection.unwrap()
})
}
pub fn flush_wal(self: &Arc<Self>) -> Result<()> { pub fn flush_wal(self: &Arc<Self>) -> Result<()> {
self.write_lock() self.write_lock()
.pragma_update(Some(Main), "wal_checkpoint", &"TRUNCATE")?; .pragma_update(Some(Main), "wal_checkpoint", &"TRUNCATE")?;
@ -151,6 +166,34 @@ impl SqliteTable {
)?; )?;
Ok(()) Ok(())
} }
pub fn iter_with_guard<'a>(
&'a self,
guard: &'a Connection,
) -> Box<dyn Iterator<Item = TupleOfBytes> + 'a> {
let statement = Box::leak(Box::new(
guard
.prepare(&format!(
"SELECT key, value FROM {} ORDER BY key ASC",
&self.name
))
.unwrap(),
));
let statement_ref = NonAliasingBox(statement);
let iterator = Box::new(
statement
.query_map([], |row| Ok((row.get_unwrap(0), row.get_unwrap(1))))
.unwrap()
.map(|r| r.unwrap()),
);
Box::new(PreparedStatementIterator {
iterator,
statement_ref,
})
}
} }
impl Tree for SqliteTable { impl Tree for SqliteTable {
@ -219,30 +262,9 @@ impl Tree for SqliteTable {
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = TupleOfBytes> + 'a> { fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = TupleOfBytes> + 'a> {
let guard = self.engine.read_lock(); let guard = self.engine.read_lock_iterator();
let statement = Box::leak(Box::new( self.iter_with_guard(&guard)
guard
.prepare(&format!(
"SELECT key, value FROM {} ORDER BY key ASC",
&self.name
))
.unwrap(),
));
let statement_ref = NonAliasingBox(statement);
let iterator = Box::new(
statement
.query_map([], |row| Ok((row.get_unwrap(0), row.get_unwrap(1))))
.unwrap()
.map(|r| r.unwrap()),
);
Box::new(PreparedStatementIterator {
iterator,
statement_ref,
})
} }
#[tracing::instrument(skip(self, from, backwards))] #[tracing::instrument(skip(self, from, backwards))]
@ -251,7 +273,7 @@ impl Tree for SqliteTable {
from: &[u8], from: &[u8],
backwards: bool, backwards: bool,
) -> Box<dyn Iterator<Item = TupleOfBytes> + 'a> { ) -> Box<dyn Iterator<Item = TupleOfBytes> + 'a> {
let guard = self.engine.read_lock(); let guard = self.engine.read_lock_iterator();
let from = from.to_vec(); // TODO change interface? let from = from.to_vec(); // TODO change interface?
if backwards { if backwards {

View File

@ -24,7 +24,7 @@ use ruma::{
use std::{ use std::{
collections::{BTreeMap, BTreeSet, HashMap, HashSet}, collections::{BTreeMap, BTreeSet, HashMap, HashSet},
convert::{TryFrom, TryInto}, convert::{TryFrom, TryInto},
mem, mem::size_of,
sync::{Arc, Mutex}, sync::{Arc, Mutex},
}; };
use tokio::sync::MutexGuard; use tokio::sync::MutexGuard;
@ -37,10 +37,11 @@ use super::{abstraction::Tree, admin::AdminCommand, pusher};
/// This is created when a state group is added to the database by /// This is created when a state group is added to the database by
/// hashing the entire state. /// hashing the entire state.
pub type StateHashId = Vec<u8>; pub type StateHashId = Vec<u8>;
pub type CompressedStateEvent = [u8; 2 * size_of::<u64>()];
pub struct Rooms { pub struct Rooms {
pub edus: edus::RoomEdus, pub edus: edus::RoomEdus,
pub(super) pduid_pdu: Arc<dyn Tree>, // PduId = RoomId + Count pub(super) pduid_pdu: Arc<dyn Tree>, // PduId = ShortRoomId + Count
pub(super) eventid_pduid: Arc<dyn Tree>, pub(super) eventid_pduid: Arc<dyn Tree>,
pub(super) roomid_pduleaves: Arc<dyn Tree>, pub(super) roomid_pduleaves: Arc<dyn Tree>,
pub(super) alias_roomid: Arc<dyn Tree>, pub(super) alias_roomid: Arc<dyn Tree>,
@ -79,9 +80,6 @@ pub struct Rooms {
pub(super) eventid_shorteventid: Arc<dyn Tree>, pub(super) eventid_shorteventid: Arc<dyn Tree>,
pub(super) statehash_shortstatehash: Arc<dyn Tree>, pub(super) statehash_shortstatehash: Arc<dyn Tree>,
/// ShortStateHash = Count
/// StateId = ShortStateHash
pub(super) stateid_shorteventid: Arc<dyn Tree>,
pub(super) shortstatehash_statediff: Arc<dyn Tree>, // StateDiff = parent (or 0) + (shortstatekey+shorteventid++) + 0_u64 + (shortstatekey+shorteventid--) pub(super) shortstatehash_statediff: Arc<dyn Tree>, // StateDiff = parent (or 0) + (shortstatekey+shorteventid++) + 0_u64 + (shortstatekey+shorteventid--)
/// RoomId + EventId -> outlier PDU. /// RoomId + EventId -> outlier PDU.
@ -100,29 +98,30 @@ impl Rooms {
/// Builds a StateMap by iterating over all keys that start /// Builds a StateMap by iterating over all keys that start
/// with state_hash, this gives the full state for the given state_hash. /// with state_hash, this gives the full state for the given state_hash.
pub fn state_full_ids(&self, shortstatehash: u64) -> Result<BTreeSet<EventId>> { pub fn state_full_ids(&self, shortstatehash: u64) -> Result<BTreeSet<EventId>> {
Ok(self let full_state = self
.stateid_shorteventid .load_shortstatehash_info(shortstatehash)?
.scan_prefix(shortstatehash.to_be_bytes().to_vec()) .pop()
.map(|(_, bytes)| { .expect("there is always one layer")
self.get_eventid_from_short(utils::u64_from_bytes(&bytes).unwrap()) .1;
.ok() full_state
}) .into_iter()
.flatten() .map(|compressed| self.parse_compressed_state_event(compressed))
.collect()) .collect()
} }
pub fn state_full( pub fn state_full(
&self, &self,
shortstatehash: u64, shortstatehash: u64,
) -> Result<HashMap<(EventType, String), Arc<PduEvent>>> { ) -> Result<HashMap<(EventType, String), Arc<PduEvent>>> {
let state = self let full_state = self
.stateid_shorteventid .load_shortstatehash_info(shortstatehash)?
.scan_prefix(shortstatehash.to_be_bytes().to_vec()) .pop()
.map(|(_, bytes)| { .expect("there is always one layer")
self.get_eventid_from_short(utils::u64_from_bytes(&bytes).unwrap()) .1;
.ok() Ok(full_state
}) .into_iter()
.flatten() .map(|compressed| self.parse_compressed_state_event(compressed))
.filter_map(|r| r.ok())
.map(|eventid| self.get_pdu(&eventid)) .map(|eventid| self.get_pdu(&eventid))
.filter_map(|r| r.ok().flatten()) .filter_map(|r| r.ok().flatten())
.map(|pdu| { .map(|pdu| {
@ -138,9 +137,7 @@ impl Rooms {
)) ))
}) })
.filter_map(|r| r.ok()) .filter_map(|r| r.ok())
.collect(); .collect())
Ok(state)
} }
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`).
@ -151,27 +148,19 @@ impl Rooms {
event_type: &EventType, event_type: &EventType,
state_key: &str, state_key: &str,
) -> Result<Option<EventId>> { ) -> Result<Option<EventId>> {
let mut key = event_type.as_ref().as_bytes().to_vec(); let shortstatekey = match self.get_shortstatekey(event_type, state_key)? {
key.push(0xff); Some(s) => s,
key.extend_from_slice(&state_key.as_bytes()); None => return Ok(None),
};
let shortstatekey = self.statekey_shortstatekey.get(&key)?; let full_state = self
.load_shortstatehash_info(shortstatehash)?
if let Some(shortstatekey) = shortstatekey { .pop()
let mut stateid = shortstatehash.to_be_bytes().to_vec(); .expect("there is always one layer")
stateid.extend_from_slice(&shortstatekey); .1;
Ok(full_state
Ok(self .into_iter()
.stateid_shorteventid .find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes()))
.get(&stateid)? .and_then(|compressed| self.parse_compressed_state_event(compressed).ok()))
.map(|bytes| {
self.get_eventid_from_short(utils::u64_from_bytes(&bytes).unwrap())
.ok()
})
.flatten())
} else {
Ok(None)
}
} }
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`).
@ -260,8 +249,7 @@ impl Rooms {
/// Checks if a room exists. /// Checks if a room exists.
pub fn exists(&self, room_id: &RoomId) -> Result<bool> { pub fn exists(&self, room_id: &RoomId) -> Result<bool> {
let mut prefix = room_id.as_bytes().to_vec(); let prefix = self.get_shortroomid(room_id)?.to_be_bytes().to_vec();
prefix.push(0xff);
// Look for PDUs in that room. // Look for PDUs in that room.
Ok(self Ok(self
@ -274,8 +262,7 @@ impl Rooms {
/// Checks if a room exists. /// Checks if a room exists.
pub fn first_pdu_in_room(&self, room_id: &RoomId) -> Result<Option<Arc<PduEvent>>> { pub fn first_pdu_in_room(&self, room_id: &RoomId) -> Result<Option<Arc<PduEvent>>> {
let mut prefix = room_id.as_bytes().to_vec(); let prefix = self.get_shortroomid(room_id)?.to_be_bytes().to_vec();
prefix.push(0xff);
// Look for PDUs in that room. // Look for PDUs in that room.
self.pduid_pdu self.pduid_pdu
@ -292,74 +279,78 @@ impl Rooms {
/// Force the creation of a new StateHash and insert it into the db. /// Force the creation of a new StateHash and insert it into the db.
/// ///
/// Whatever `state` is supplied to `force_state` __is__ the current room state snapshot. /// Whatever `state` is supplied to `force_state` becomes the new current room state snapshot.
pub fn force_state( pub fn force_state(
&self, &self,
room_id: &RoomId, room_id: &RoomId,
state: HashMap<(EventType, String), EventId>, new_state: HashMap<(EventType, String), EventId>,
db: &Database, db: &Database,
) -> Result<()> { ) -> Result<()> {
let previous_shortstatehash = self.current_shortstatehash(&room_id)?;
let new_state_ids_compressed = new_state
.iter()
.filter_map(|((event_type, state_key), event_id)| {
let shortstatekey = self
.get_or_create_shortstatekey(event_type, state_key, &db.globals)
.ok()?;
Some(
self.compress_state_event(shortstatekey, event_id, &db.globals)
.ok()?,
)
})
.collect::<HashSet<_>>();
let state_hash = self.calculate_hash( let state_hash = self.calculate_hash(
&state &new_state
.values() .values()
.map(|event_id| event_id.as_bytes()) .map(|event_id| event_id.as_bytes())
.collect::<Vec<_>>(), .collect::<Vec<_>>(),
); );
let (shortstatehash, already_existed) = let (new_shortstatehash, already_existed) =
self.get_or_create_shortstatehash(&state_hash, &db.globals)?; self.get_or_create_shortstatehash(&state_hash, &db.globals)?;
let new_state = if !already_existed { if Some(new_shortstatehash) == previous_shortstatehash {
let mut new_state = HashSet::new(); return Ok(());
}
let batch = state let states_parents = previous_shortstatehash
.iter() .map_or_else(|| Ok(Vec::new()), |p| self.load_shortstatehash_info(p))?;
.filter_map(|((event_type, state_key), eventid)| {
new_state.insert(eventid.clone());
let mut statekey = event_type.as_ref().as_bytes().to_vec(); let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last()
statekey.push(0xff); {
statekey.extend_from_slice(&state_key.as_bytes()); let statediffnew = new_state_ids_compressed
.difference(&parent_stateinfo.1)
.cloned()
.collect::<HashSet<_>>();
let shortstatekey = match self.statekey_shortstatekey.get(&statekey).ok()? { let statediffremoved = parent_stateinfo
Some(shortstatekey) => shortstatekey.to_vec(), .1
None => { .difference(&new_state_ids_compressed)
let shortstatekey = db.globals.next_count().ok()?; .cloned()
self.statekey_shortstatekey .collect::<HashSet<_>>();
.insert(&statekey, &shortstatekey.to_be_bytes())
.ok()?;
shortstatekey.to_be_bytes().to_vec()
}
};
let shorteventid = self (statediffnew, statediffremoved)
.get_or_create_shorteventid(&eventid, &db.globals)
.ok()?;
let mut state_id = shortstatehash.to_be_bytes().to_vec();
state_id.extend_from_slice(&shortstatekey);
Some((state_id, shorteventid.to_be_bytes().to_vec()))
})
.collect::<Vec<_>>();
self.stateid_shorteventid
.insert_batch(&mut batch.into_iter())?;
new_state
} else { } else {
self.state_full_ids(shortstatehash)?.into_iter().collect() (new_state_ids_compressed, HashSet::new())
}; };
let old_state = self if !already_existed {
.current_shortstatehash(&room_id)? self.save_state_from_diff(
.map(|s| self.state_full_ids(s)) new_shortstatehash,
.transpose()? statediffnew.clone(),
.map(|vec| vec.into_iter().collect::<HashSet<_>>()) statediffremoved.clone(),
.unwrap_or_default(); 2, // every state change is 2 event changes on average
states_parents,
)?;
};
for event_id in new_state.difference(&old_state) { for event_id in statediffnew
if let Some(pdu) = self.get_pdu_json(event_id)? { .into_iter()
.filter_map(|new| self.parse_compressed_state_event(new).ok())
{
if let Some(pdu) = self.get_pdu_json(&event_id)? {
if pdu.get("type").and_then(|val| val.as_str()) == Some("m.room.member") { if pdu.get("type").and_then(|val| val.as_str()) == Some("m.room.member") {
if let Ok(pdu) = serde_json::from_value::<PduEvent>( if let Ok(pdu) = serde_json::from_value::<PduEvent>(
serde_json::to_value(&pdu).expect("CanonicalJsonObj is a valid JsonValue"), serde_json::to_value(&pdu).expect("CanonicalJsonObj is a valid JsonValue"),
@ -392,7 +383,206 @@ impl Rooms {
} }
self.roomid_shortstatehash self.roomid_shortstatehash
.insert(room_id.as_bytes(), &shortstatehash.to_be_bytes())?; .insert(room_id.as_bytes(), &new_shortstatehash.to_be_bytes())?;
Ok(())
}
/// Returns a stack with info on shortstatehash, full state, added diff and removed diff for the selected shortstatehash and each parent layer.
pub fn load_shortstatehash_info(
&self,
shortstatehash: u64,
) -> Result<
Vec<(
u64, // sstatehash
HashSet<CompressedStateEvent>, // full state
HashSet<CompressedStateEvent>, // added
HashSet<CompressedStateEvent>, // removed
)>,
> {
let value = self
.shortstatehash_statediff
.get(&shortstatehash.to_be_bytes())?
.ok_or_else(|| Error::bad_database("State hash does not exist"))?;
let parent =
utils::u64_from_bytes(&value[0..size_of::<u64>()]).expect("bytes have right length");
let mut add_mode = true;
let mut added = HashSet::new();
let mut removed = HashSet::new();
let mut i = size_of::<u64>();
while let Some(v) = value.get(i..i + 2 * size_of::<u64>()) {
if add_mode && v.starts_with(&0_u64.to_be_bytes()) {
add_mode = false;
i += size_of::<u64>();
continue;
}
if add_mode {
added.insert(v.try_into().expect("we checked the size above"));
} else {
removed.insert(v.try_into().expect("we checked the size above"));
}
i += 2 * size_of::<u64>();
}
if parent != 0_u64 {
let mut response = self.load_shortstatehash_info(parent)?;
let mut state = response.last().unwrap().1.clone();
state.extend(added.iter().cloned());
for r in &removed {
state.remove(r);
}
response.push((shortstatehash, state, added, removed));
Ok(response)
} else {
let mut response = Vec::new();
response.push((shortstatehash, added.clone(), added, removed));
Ok(response)
}
}
pub fn compress_state_event(
&self,
shortstatekey: u64,
event_id: &EventId,
globals: &super::globals::Globals,
) -> Result<CompressedStateEvent> {
let mut v = shortstatekey.to_be_bytes().to_vec();
v.extend_from_slice(
&self
.get_or_create_shorteventid(event_id, globals)?
.to_be_bytes(),
);
Ok(v.try_into().expect("we checked the size above"))
}
pub fn parse_compressed_state_event(
&self,
compressed_event: CompressedStateEvent,
) -> Result<EventId> {
self.get_eventid_from_short(
utils::u64_from_bytes(&compressed_event[size_of::<u64>()..])
.expect("bytes have right length"),
)
}
/// Creates a new shortstatehash that often is just a diff to an already existing
/// shortstatehash and therefore very efficient.
///
/// There are multiple layers of diffs. The bottom layer 0 always contains the full state. Layer
/// 1 contains diffs to states of layer 0, layer 2 diffs to layer 1 and so on. If layer n > 0
/// grows too big, it will be combined with layer n-1 to create a new diff on layer n-1 that's
/// based on layer n-2. If that layer is also too big, it will recursively fix above layers too.
///
/// * `shortstatehash` - Shortstatehash of this state
/// * `statediffnew` - Added to base. Each vec is shortstatekey+shorteventid
/// * `statediffremoved` - Removed from base. Each vec is shortstatekey+shorteventid
/// * `diff_to_sibling` - Approximately how much the diff grows each time for this layer
/// * `parent_states` - A stack with info on shortstatehash, full state, added diff and removed diff for each parent layer
pub fn save_state_from_diff(
&self,
shortstatehash: u64,
statediffnew: HashSet<CompressedStateEvent>,
statediffremoved: HashSet<CompressedStateEvent>,
diff_to_sibling: usize,
mut parent_states: Vec<(
u64, // sstatehash
HashSet<CompressedStateEvent>, // full state
HashSet<CompressedStateEvent>, // added
HashSet<CompressedStateEvent>, // removed
)>,
) -> Result<()> {
let diffsum = statediffnew.len() + statediffremoved.len();
if parent_states.len() > 3 {
// Number of layers
// To many layers, we have to go deeper
let parent = parent_states.pop().unwrap();
let mut parent_new = parent.2;
let mut parent_removed = parent.3;
for removed in statediffremoved {
if !parent_new.remove(&removed) {
parent_removed.insert(removed);
}
}
parent_new.extend(statediffnew);
self.save_state_from_diff(
shortstatehash,
parent_new,
parent_removed,
diffsum,
parent_states,
)?;
return Ok(());
}
if parent_states.len() == 0 {
// There is no parent layer, create a new state
let mut value = 0_u64.to_be_bytes().to_vec(); // 0 means no parent
for new in &statediffnew {
value.extend_from_slice(&new[..]);
}
if !statediffremoved.is_empty() {
warn!("Tried to create new state with removals");
}
self.shortstatehash_statediff
.insert(&shortstatehash.to_be_bytes(), &value)?;
return Ok(());
};
// Else we have two options.
// 1. We add the current diff on top of the parent layer.
// 2. We replace a layer above
let parent = parent_states.pop().unwrap();
let parent_diff = parent.2.len() + parent.3.len();
if diffsum * diffsum >= 2 * diff_to_sibling * parent_diff {
// Diff too big, we replace above layer(s)
let mut parent_new = parent.2;
let mut parent_removed = parent.3;
for removed in statediffremoved {
if !parent_new.remove(&removed) {
parent_removed.insert(removed);
}
}
parent_new.extend(statediffnew);
self.save_state_from_diff(
shortstatehash,
parent_new,
parent_removed,
diffsum,
parent_states,
)?;
} else {
// Diff small enough, we add diff as layer on top of parent
let mut value = parent.0.to_be_bytes().to_vec();
for new in &statediffnew {
value.extend_from_slice(&new[..]);
}
if !statediffremoved.is_empty() {
value.extend_from_slice(&0_u64.to_be_bytes());
for removed in &statediffremoved {
value.extend_from_slice(&removed[..]);
}
}
self.shortstatehash_statediff
.insert(&shortstatehash.to_be_bytes(), &value)?;
}
Ok(()) Ok(())
} }
@ -418,7 +608,6 @@ impl Rooms {
}) })
} }
/// Returns (shortstatehash, already_existed)
pub fn get_or_create_shorteventid( pub fn get_or_create_shorteventid(
&self, &self,
event_id: &EventId, event_id: &EventId,
@ -438,6 +627,71 @@ impl Rooms {
}) })
} }
pub fn get_shortroomid(&self, room_id: &RoomId) -> Result<u64> {
let bytes = self
.roomid_shortroomid
.get(&room_id.as_bytes())?
.expect("every room has a shortroomid");
utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid shortroomid in db."))
}
pub fn get_shortstatekey(
&self,
event_type: &EventType,
state_key: &str,
) -> Result<Option<u64>> {
let mut statekey = event_type.as_ref().as_bytes().to_vec();
statekey.push(0xff);
statekey.extend_from_slice(&state_key.as_bytes());
self.statekey_shortstatekey
.get(&statekey)?
.map(|shortstatekey| {
utils::u64_from_bytes(&shortstatekey)
.map_err(|_| Error::bad_database("Invalid shortstatekey in db."))
})
.transpose()
}
pub fn get_or_create_shortroomid(
&self,
room_id: &RoomId,
globals: &super::globals::Globals,
) -> Result<u64> {
Ok(match self.roomid_shortroomid.get(&room_id.as_bytes())? {
Some(short) => utils::u64_from_bytes(&short)
.map_err(|_| Error::bad_database("Invalid shortroomid in db."))?,
None => {
let short = globals.next_count()?;
self.roomid_shortroomid
.insert(&room_id.as_bytes(), &short.to_be_bytes())?;
short
}
})
}
pub fn get_or_create_shortstatekey(
&self,
event_type: &EventType,
state_key: &str,
globals: &super::globals::Globals,
) -> Result<u64> {
let mut statekey = event_type.as_ref().as_bytes().to_vec();
statekey.push(0xff);
statekey.extend_from_slice(&state_key.as_bytes());
Ok(match self.statekey_shortstatekey.get(&statekey)? {
Some(shortstatekey) => utils::u64_from_bytes(&shortstatekey)
.map_err(|_| Error::bad_database("Invalid shortstatekey in db."))?,
None => {
let shortstatekey = globals.next_count()?;
self.statekey_shortstatekey
.insert(&statekey, &shortstatekey.to_be_bytes())?;
shortstatekey
}
})
}
pub fn get_eventid_from_short(&self, shorteventid: u64) -> Result<EventId> { pub fn get_eventid_from_short(&self, shorteventid: u64) -> Result<EventId> {
if let Some(id) = self if let Some(id) = self
.shorteventid_cache .shorteventid_cache
@ -514,7 +768,7 @@ impl Rooms {
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
pub fn pdu_count(&self, pdu_id: &[u8]) -> Result<u64> { pub fn pdu_count(&self, pdu_id: &[u8]) -> Result<u64> {
Ok( Ok(
utils::u64_from_bytes(&pdu_id[pdu_id.len() - mem::size_of::<u64>()..pdu_id.len()]) utils::u64_from_bytes(&pdu_id[pdu_id.len() - size_of::<u64>()..])
.map_err(|_| Error::bad_database("PDU has invalid count bytes."))?, .map_err(|_| Error::bad_database("PDU has invalid count bytes."))?,
) )
} }
@ -527,8 +781,7 @@ impl Rooms {
} }
pub fn latest_pdu_count(&self, room_id: &RoomId) -> Result<u64> { pub fn latest_pdu_count(&self, room_id: &RoomId) -> Result<u64> {
let mut prefix = room_id.as_bytes().to_vec(); let prefix = self.get_shortroomid(room_id)?.to_be_bytes().to_vec();
prefix.push(0xff);
let mut last_possible_key = prefix.clone(); let mut last_possible_key = prefix.clone();
last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes());
@ -758,6 +1011,8 @@ impl Rooms {
/// ///
/// By this point the incoming event should be fully authenticated, no auth happens /// By this point the incoming event should be fully authenticated, no auth happens
/// in `append_pdu`. /// in `append_pdu`.
///
/// Returns pdu id
#[tracing::instrument(skip(self, pdu, pdu_json, leaves, db))] #[tracing::instrument(skip(self, pdu, pdu_json, leaves, db))]
pub fn append_pdu( pub fn append_pdu(
&self, &self,
@ -766,7 +1021,8 @@ impl Rooms {
leaves: &[EventId], leaves: &[EventId],
db: &Database, db: &Database,
) -> Result<Vec<u8>> { ) -> Result<Vec<u8>> {
// returns pdu id let shortroomid = self.get_shortroomid(&pdu.room_id)?;
// Make unsigned fields correct. This is not properly documented in the spec, but state // Make unsigned fields correct. This is not properly documented in the spec, but state
// events need to have previous content in the unsigned field, so clients can easily // events need to have previous content in the unsigned field, so clients can easily
// interpret things like membership changes // interpret things like membership changes
@ -821,8 +1077,7 @@ impl Rooms {
self.reset_notification_counts(&pdu.sender, &pdu.room_id)?; self.reset_notification_counts(&pdu.sender, &pdu.room_id)?;
let count2 = db.globals.next_count()?; let count2 = db.globals.next_count()?;
let mut pdu_id = pdu.room_id.as_bytes().to_vec(); let mut pdu_id = shortroomid.to_be_bytes().to_vec();
pdu_id.push(0xff);
pdu_id.extend_from_slice(&count2.to_be_bytes()); pdu_id.extend_from_slice(&count2.to_be_bytes());
// There's a brief moment of time here where the count is updated but the pdu does not // There's a brief moment of time here where the count is updated but the pdu does not
@ -968,8 +1223,7 @@ impl Rooms {
.filter(|word| word.len() <= 50) .filter(|word| word.len() <= 50)
.map(str::to_lowercase) .map(str::to_lowercase)
.map(|word| { .map(|word| {
let mut key = pdu.room_id.as_bytes().to_vec(); let mut key = shortroomid.to_be_bytes().to_vec();
key.push(0xff);
key.extend_from_slice(word.as_bytes()); key.extend_from_slice(word.as_bytes());
key.push(0xff); key.push(0xff);
key.extend_from_slice(&pdu_id); key.extend_from_slice(&pdu_id);
@ -1152,11 +1406,27 @@ impl Rooms {
pub fn set_event_state( pub fn set_event_state(
&self, &self,
event_id: &EventId, event_id: &EventId,
room_id: &RoomId,
state: &StateMap<Arc<PduEvent>>, state: &StateMap<Arc<PduEvent>>,
globals: &super::globals::Globals, globals: &super::globals::Globals,
) -> Result<()> { ) -> Result<()> {
let shorteventid = self.get_or_create_shorteventid(&event_id, globals)?; let shorteventid = self.get_or_create_shorteventid(&event_id, globals)?;
let previous_shortstatehash = self.current_shortstatehash(&room_id)?;
let state_ids_compressed = state
.iter()
.filter_map(|((event_type, state_key), pdu)| {
let shortstatekey = self
.get_or_create_shortstatekey(event_type, state_key, globals)
.ok()?;
Some(
self.compress_state_event(shortstatekey, &pdu.event_id, globals)
.ok()?,
)
})
.collect::<HashSet<_>>();
let state_hash = self.calculate_hash( let state_hash = self.calculate_hash(
&state &state
.values() .values()
@ -1168,37 +1438,33 @@ impl Rooms {
self.get_or_create_shortstatehash(&state_hash, globals)?; self.get_or_create_shortstatehash(&state_hash, globals)?;
if !already_existed { if !already_existed {
let batch = state let states_parents = previous_shortstatehash
.iter() .map_or_else(|| Ok(Vec::new()), |p| self.load_shortstatehash_info(p))?;
.filter_map(|((event_type, state_key), pdu)| {
let mut statekey = event_type.as_ref().as_bytes().to_vec();
statekey.push(0xff);
statekey.extend_from_slice(&state_key.as_bytes());
let shortstatekey = match self.statekey_shortstatekey.get(&statekey).ok()? { let (statediffnew, statediffremoved) =
Some(shortstatekey) => shortstatekey.to_vec(), if let Some(parent_stateinfo) = states_parents.last() {
None => { let statediffnew = state_ids_compressed
let shortstatekey = globals.next_count().ok()?; .difference(&parent_stateinfo.1)
self.statekey_shortstatekey .cloned()
.insert(&statekey, &shortstatekey.to_be_bytes()) .collect::<HashSet<_>>();
.ok()?;
shortstatekey.to_be_bytes().to_vec()
}
};
let shorteventid = self let statediffremoved = parent_stateinfo
.get_or_create_shorteventid(&pdu.event_id, globals) .1
.ok()?; .difference(&state_ids_compressed)
.cloned()
.collect::<HashSet<_>>();
let mut state_id = shortstatehash.to_be_bytes().to_vec(); (statediffnew, statediffremoved)
state_id.extend_from_slice(&shortstatekey); } else {
(state_ids_compressed, HashSet::new())
Some((state_id, shorteventid.to_be_bytes().to_vec())) };
}) self.save_state_from_diff(
.collect::<Vec<_>>(); shortstatehash,
statediffnew.clone(),
self.stateid_shorteventid statediffremoved.clone(),
.insert_batch(&mut batch.into_iter())?; 1_000_000, // high number because no state will be based on this one
states_parents,
)?;
} }
self.shorteventid_shortstatehash self.shorteventid_shortstatehash
@ -1219,82 +1485,52 @@ impl Rooms {
) -> Result<u64> { ) -> Result<u64> {
let shorteventid = self.get_or_create_shorteventid(&new_pdu.event_id, globals)?; let shorteventid = self.get_or_create_shorteventid(&new_pdu.event_id, globals)?;
let old_state = if let Some(old_shortstatehash) = let previous_shortstatehash = self.current_shortstatehash(&new_pdu.room_id)?;
self.roomid_shortstatehash.get(new_pdu.room_id.as_bytes())?
{ if let Some(p) = previous_shortstatehash {
// Store state for event. The state does not include the event itself.
// Instead it's the state before the pdu, so the room's old state.
self.shorteventid_shortstatehash self.shorteventid_shortstatehash
.insert(&shorteventid.to_be_bytes(), &old_shortstatehash)?; .insert(&shorteventid.to_be_bytes(), &p.to_be_bytes())?;
}
if new_pdu.state_key.is_none() {
return utils::u64_from_bytes(&old_shortstatehash).map_err(|_| {
Error::bad_database("Invalid shortstatehash in roomid_shortstatehash.")
});
}
self.stateid_shorteventid
.scan_prefix(old_shortstatehash.clone())
// Chop the old_shortstatehash out leaving behind the short state key
.map(|(k, v)| (k[old_shortstatehash.len()..].to_vec(), v))
.collect::<HashMap<Vec<u8>, Vec<u8>>>()
} else {
HashMap::new()
};
if let Some(state_key) = &new_pdu.state_key { if let Some(state_key) = &new_pdu.state_key {
let mut new_state: HashMap<Vec<u8>, Vec<u8>> = old_state; let states_parents = previous_shortstatehash
.map_or_else(|| Ok(Vec::new()), |p| self.load_shortstatehash_info(p))?;
let mut new_state_key = new_pdu.kind.as_ref().as_bytes().to_vec(); let shortstatekey =
new_state_key.push(0xff); self.get_or_create_shortstatekey(&new_pdu.kind, &state_key, globals)?;
new_state_key.extend_from_slice(state_key.as_bytes());
let shortstatekey = match self.statekey_shortstatekey.get(&new_state_key)? { let replaces = states_parents
Some(shortstatekey) => shortstatekey.to_vec(), .last()
None => { .map(|info| {
let shortstatekey = globals.next_count()?; info.1
self.statekey_shortstatekey .iter()
.insert(&new_state_key, &shortstatekey.to_be_bytes())?; .find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes()))
shortstatekey.to_be_bytes().to_vec() })
} .unwrap_or_default();
};
new_state.insert(shortstatekey, shorteventid.to_be_bytes().to_vec()); // TODO: statehash with deterministic inputs
let shortstatehash = globals.next_count()?;
let new_state_hash = self.calculate_hash( let mut statediffnew = HashSet::new();
&new_state let new = self.compress_state_event(shortstatekey, &new_pdu.event_id, globals)?;
.values() statediffnew.insert(new);
.map(|event_id| &**event_id)
.collect::<Vec<_>>(),
);
let shortstatehash = match self.statehash_shortstatehash.get(&new_state_hash)? { let mut statediffremoved = HashSet::new();
Some(shortstatehash) => { if let Some(replaces) = replaces {
warn!("state hash already existed?!"); statediffremoved.insert(replaces.clone());
utils::u64_from_bytes(&shortstatehash) }
.map_err(|_| Error::bad_database("PDU has invalid count bytes."))?
}
None => {
let shortstatehash = globals.next_count()?;
self.statehash_shortstatehash
.insert(&new_state_hash, &shortstatehash.to_be_bytes())?;
shortstatehash
}
};
let mut batch = new_state.into_iter().map(|(shortstatekey, shorteventid)| { self.save_state_from_diff(
let mut state_id = shortstatehash.to_be_bytes().to_vec(); shortstatehash,
state_id.extend_from_slice(&shortstatekey); statediffnew,
(state_id, shorteventid) statediffremoved,
}); 2,
states_parents,
self.stateid_shorteventid.insert_batch(&mut batch)?; )?;
Ok(shortstatehash) Ok(shortstatehash)
} else { } else {
Err(Error::bad_database( Ok(previous_shortstatehash.expect("first event in room must be a state event"))
"Tried to insert non-state event into room without a state.",
))
} }
} }
@ -1597,7 +1833,7 @@ impl Rooms {
&'a self, &'a self,
user_id: &UserId, user_id: &UserId,
room_id: &RoomId, room_id: &RoomId,
) -> impl Iterator<Item = Result<(Vec<u8>, PduEvent)>> + 'a { ) -> Result<impl Iterator<Item = Result<(Vec<u8>, PduEvent)>> + 'a> {
self.pdus_since(user_id, room_id, 0) self.pdus_since(user_id, room_id, 0)
} }
@ -1609,16 +1845,17 @@ impl Rooms {
user_id: &UserId, user_id: &UserId,
room_id: &RoomId, room_id: &RoomId,
since: u64, since: u64,
) -> impl Iterator<Item = Result<(Vec<u8>, PduEvent)>> + 'a { ) -> Result<impl Iterator<Item = Result<(Vec<u8>, PduEvent)>> + 'a> {
let mut prefix = room_id.as_bytes().to_vec(); let prefix = self.get_shortroomid(room_id)?.to_be_bytes().to_vec();
prefix.push(0xff);
// Skip the first pdu if it's exactly at since, because we sent that last time // Skip the first pdu if it's exactly at since, because we sent that last time
let mut first_pdu_id = prefix.clone(); let mut first_pdu_id = prefix.clone();
first_pdu_id.extend_from_slice(&(since + 1).to_be_bytes()); first_pdu_id.extend_from_slice(&(since + 1).to_be_bytes());
let user_id = user_id.clone(); let user_id = user_id.clone();
self.pduid_pdu
Ok(self
.pduid_pdu
.iter_from(&first_pdu_id, false) .iter_from(&first_pdu_id, false)
.take_while(move |(k, _)| k.starts_with(&prefix)) .take_while(move |(k, _)| k.starts_with(&prefix))
.map(move |(pdu_id, v)| { .map(move |(pdu_id, v)| {
@ -1628,7 +1865,7 @@ impl Rooms {
pdu.unsigned.remove("transaction_id"); pdu.unsigned.remove("transaction_id");
} }
Ok((pdu_id, pdu)) Ok((pdu_id, pdu))
}) }))
} }
/// Returns an iterator over all events and their tokens in a room that happened before the /// Returns an iterator over all events and their tokens in a room that happened before the
@ -1639,10 +1876,9 @@ impl Rooms {
user_id: &UserId, user_id: &UserId,
room_id: &RoomId, room_id: &RoomId,
until: u64, until: u64,
) -> impl Iterator<Item = Result<(Vec<u8>, PduEvent)>> + 'a { ) -> Result<impl Iterator<Item = Result<(Vec<u8>, PduEvent)>> + 'a> {
// Create the first part of the full pdu id // Create the first part of the full pdu id
let mut prefix = room_id.as_bytes().to_vec(); let prefix = self.get_shortroomid(room_id)?.to_be_bytes().to_vec();
prefix.push(0xff);
let mut current = prefix.clone(); let mut current = prefix.clone();
current.extend_from_slice(&(until.saturating_sub(1)).to_be_bytes()); // -1 because we don't want event at `until` current.extend_from_slice(&(until.saturating_sub(1)).to_be_bytes()); // -1 because we don't want event at `until`
@ -1650,7 +1886,9 @@ impl Rooms {
let current: &[u8] = &current; let current: &[u8] = &current;
let user_id = user_id.clone(); let user_id = user_id.clone();
self.pduid_pdu
Ok(self
.pduid_pdu
.iter_from(current, true) .iter_from(current, true)
.take_while(move |(k, _)| k.starts_with(&prefix)) .take_while(move |(k, _)| k.starts_with(&prefix))
.map(move |(pdu_id, v)| { .map(move |(pdu_id, v)| {
@ -1660,7 +1898,7 @@ impl Rooms {
pdu.unsigned.remove("transaction_id"); pdu.unsigned.remove("transaction_id");
} }
Ok((pdu_id, pdu)) Ok((pdu_id, pdu))
}) }))
} }
/// Returns an iterator over all events and their token in a room that happened after the event /// Returns an iterator over all events and their token in a room that happened after the event
@ -1671,10 +1909,9 @@ impl Rooms {
user_id: &UserId, user_id: &UserId,
room_id: &RoomId, room_id: &RoomId,
from: u64, from: u64,
) -> impl Iterator<Item = Result<(Vec<u8>, PduEvent)>> + 'a { ) -> Result<impl Iterator<Item = Result<(Vec<u8>, PduEvent)>> + 'a> {
// Create the first part of the full pdu id // Create the first part of the full pdu id
let mut prefix = room_id.as_bytes().to_vec(); let prefix = self.get_shortroomid(room_id)?.to_be_bytes().to_vec();
prefix.push(0xff);
let mut current = prefix.clone(); let mut current = prefix.clone();
current.extend_from_slice(&(from + 1).to_be_bytes()); // +1 so we don't send the base event current.extend_from_slice(&(from + 1).to_be_bytes()); // +1 so we don't send the base event
@ -1682,7 +1919,9 @@ impl Rooms {
let current: &[u8] = &current; let current: &[u8] = &current;
let user_id = user_id.clone(); let user_id = user_id.clone();
self.pduid_pdu
Ok(self
.pduid_pdu
.iter_from(current, false) .iter_from(current, false)
.take_while(move |(k, _)| k.starts_with(&prefix)) .take_while(move |(k, _)| k.starts_with(&prefix))
.map(move |(pdu_id, v)| { .map(move |(pdu_id, v)| {
@ -1692,7 +1931,7 @@ impl Rooms {
pdu.unsigned.remove("transaction_id"); pdu.unsigned.remove("transaction_id");
} }
Ok((pdu_id, pdu)) Ok((pdu_id, pdu))
}) }))
} }
/// Replace a PDU with the redacted form. /// Replace a PDU with the redacted form.
@ -2223,8 +2462,8 @@ impl Rooms {
room_id: &RoomId, room_id: &RoomId,
search_string: &str, search_string: &str,
) -> Result<(impl Iterator<Item = Vec<u8>> + 'a, Vec<String>)> { ) -> Result<(impl Iterator<Item = Vec<u8>> + 'a, Vec<String>)> {
let mut prefix = room_id.as_bytes().to_vec(); let prefix = self.get_shortroomid(room_id)?.to_be_bytes().to_vec();
prefix.push(0xff); let prefix_clone = prefix.clone();
let words = search_string let words = search_string
.split_terminator(|c: char| !c.is_alphanumeric()) .split_terminator(|c: char| !c.is_alphanumeric())
@ -2243,16 +2482,7 @@ impl Rooms {
.iter_from(&last_possible_id, true) // Newest pdus first .iter_from(&last_possible_id, true) // Newest pdus first
.take_while(move |(k, _)| k.starts_with(&prefix2)) .take_while(move |(k, _)| k.starts_with(&prefix2))
.map(|(key, _)| { .map(|(key, _)| {
let pduid_index = key let pdu_id = key[key.len() - size_of::<u64>()..].to_vec();
.iter()
.enumerate()
.filter(|(_, &b)| b == 0xff)
.nth(1)
.ok_or_else(|| Error::bad_database("Invalid tokenid in db."))?
.0
+ 1; // +1 because the pdu id starts AFTER the separator
let pdu_id = key[pduid_index..].to_vec();
Ok::<_, Error>(pdu_id) Ok::<_, Error>(pdu_id)
}) })
@ -2264,7 +2494,12 @@ impl Rooms {
// We compare b with a because we reversed the iterator earlier // We compare b with a because we reversed the iterator earlier
b.cmp(a) b.cmp(a)
}) })
.unwrap(), .unwrap()
.map(move |id| {
let mut pduid = prefix_clone.clone();
pduid.extend_from_slice(&id);
pduid
}),
words, words,
)) ))
} }

View File

@ -1704,7 +1704,7 @@ fn append_incoming_pdu(
// We append to state before appending the pdu, so we don't have a moment in time with the // We append to state before appending the pdu, so we don't have a moment in time with the
// pdu without it's state. This is okay because append_pdu can't fail. // pdu without it's state. This is okay because append_pdu can't fail.
db.rooms db.rooms
.set_event_state(&pdu.event_id, state, &db.globals)?; .set_event_state(&pdu.event_id, &pdu.room_id, state, &db.globals)?;
let pdu_id = db.rooms.append_pdu( let pdu_id = db.rooms.append_pdu(
pdu, pdu,