From ea3aaa6b5c06e01bef52a66b64fe45d74d5f60c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timo=20K=C3=B6sters?= Date: Thu, 17 Sep 2020 14:44:47 +0200 Subject: [PATCH] improvement: more efficient /sync with gaps --- Cargo.lock | 48 +++---- src/client_server/context.rs | 10 +- src/client_server/membership.rs | 11 +- src/client_server/message.rs | 12 ++ src/client_server/sync.rs | 152 +++++++++++++------- src/database/rooms.rs | 247 ++++++++++++-------------------- src/main.rs | 1 + src/pdu.rs | 6 +- src/server_server.rs | 4 +- src/stateres.rs | 59 -------- 10 files changed, 251 insertions(+), 299 deletions(-) delete mode 100644 src/stateres.rs diff --git a/Cargo.lock b/Cargo.lock index 30144ca..e142d72 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -134,9 +134,9 @@ checksum = "0e4cec68f03f32e44924783795810fa50a7035d8c8ebe78580ad7e6c703fba38" [[package]] name = "cc" -version = "1.0.59" +version = "1.0.60" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "66120af515773fb005778dc07c261bd201ec8ce50bd6e7144c927753fe013381" +checksum = "ef611cc68ff783f18535d77ddd080185275713d852c4f5cbb6122c462a7a825c" [[package]] name = "cfg-if" @@ -213,7 +213,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1373a16a4937bc34efec7b391f9c1500c30b8478a701a4f44c9165cc0475a6e0" dependencies = [ "percent-encoding", - "time 0.2.19", + "time 0.2.20", "version_check", ] @@ -342,9 +342,9 @@ checksum = "134951f4028bdadb9b84baf4232681efbf277da25144b9b0ad65df75946c422b" [[package]] name = "either" -version = "1.6.0" +version = "1.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd56b59865bce947ac5958779cfa508f6c3b9497cc762b7e24a12d11ccde2c4f" +checksum = "e78d4f1cc4ae33bbfc157ed5d5a5ef3bc29227303d595861deb238fcec4e9457" [[package]] name = "encoding_rs" @@ -1370,7 +1370,7 @@ dependencies = [ "rocket_codegen", "rocket_http", "state", - "time 0.2.19", + "time 0.2.20", "tokio", "toml", "version_check", @@ -1405,7 +1405,7 @@ dependencies = [ "ref-cast", "smallvec", "state", - "time 0.2.19", + "time 0.2.20", "tokio", "tokio-rustls", "unicode-xid", @@ -1414,7 +1414,7 @@ dependencies = [ [[package]] name = "ruma" version = "0.0.1" -source = "git+https://github.com/timokoesters/ruma?branch=timo-fed-fixes#6ccb3ecaf69167ba405379826a9d87a98f168df8" +source = "git+https://github.com/timokoesters/ruma?branch=timo-fed-fixes#425d34d4cfb5aefe5bab6957d71bc9389384c1e5" dependencies = [ "ruma-api", "ruma-appservice-api", @@ -1430,7 +1430,7 @@ dependencies = [ [[package]] name = "ruma-api" version = "0.17.0-alpha.1" -source = "git+https://github.com/timokoesters/ruma?branch=timo-fed-fixes#6ccb3ecaf69167ba405379826a9d87a98f168df8" +source = "git+https://github.com/timokoesters/ruma?branch=timo-fed-fixes#425d34d4cfb5aefe5bab6957d71bc9389384c1e5" dependencies = [ "http", "percent-encoding", @@ -1445,7 +1445,7 @@ dependencies = [ [[package]] name = "ruma-api-macros" version = "0.17.0-alpha.1" -source = "git+https://github.com/timokoesters/ruma?branch=timo-fed-fixes#6ccb3ecaf69167ba405379826a9d87a98f168df8" +source = "git+https://github.com/timokoesters/ruma?branch=timo-fed-fixes#425d34d4cfb5aefe5bab6957d71bc9389384c1e5" dependencies = [ "proc-macro-crate", "proc-macro2", @@ -1456,7 +1456,7 @@ dependencies = [ [[package]] name = "ruma-appservice-api" version = "0.2.0-alpha.1" -source = "git+https://github.com/timokoesters/ruma?branch=timo-fed-fixes#6ccb3ecaf69167ba405379826a9d87a98f168df8" +source = "git+https://github.com/timokoesters/ruma?branch=timo-fed-fixes#425d34d4cfb5aefe5bab6957d71bc9389384c1e5" dependencies = [ "ruma-api", "ruma-common", @@ -1469,7 +1469,7 @@ dependencies = [ [[package]] name = "ruma-client-api" version = "0.10.0-alpha.1" -source = "git+https://github.com/timokoesters/ruma?branch=timo-fed-fixes#6ccb3ecaf69167ba405379826a9d87a98f168df8" +source = "git+https://github.com/timokoesters/ruma?branch=timo-fed-fixes#425d34d4cfb5aefe5bab6957d71bc9389384c1e5" dependencies = [ "assign", "http", @@ -1488,7 +1488,7 @@ dependencies = [ [[package]] name = "ruma-common" version = "0.2.0" -source = "git+https://github.com/timokoesters/ruma?branch=timo-fed-fixes#6ccb3ecaf69167ba405379826a9d87a98f168df8" +source = "git+https://github.com/timokoesters/ruma?branch=timo-fed-fixes#425d34d4cfb5aefe5bab6957d71bc9389384c1e5" dependencies = [ "js_int", "ruma-api", @@ -1502,7 +1502,7 @@ dependencies = [ [[package]] name = "ruma-events" version = "0.22.0-alpha.1" -source = "git+https://github.com/timokoesters/ruma?branch=timo-fed-fixes#6ccb3ecaf69167ba405379826a9d87a98f168df8" +source = "git+https://github.com/timokoesters/ruma?branch=timo-fed-fixes#425d34d4cfb5aefe5bab6957d71bc9389384c1e5" dependencies = [ "js_int", "ruma-common", @@ -1517,7 +1517,7 @@ dependencies = [ [[package]] name = "ruma-events-macros" version = "0.22.0-alpha.1" -source = "git+https://github.com/timokoesters/ruma?branch=timo-fed-fixes#6ccb3ecaf69167ba405379826a9d87a98f168df8" +source = "git+https://github.com/timokoesters/ruma?branch=timo-fed-fixes#425d34d4cfb5aefe5bab6957d71bc9389384c1e5" dependencies = [ "proc-macro-crate", "proc-macro2", @@ -1528,7 +1528,7 @@ dependencies = [ [[package]] name = "ruma-federation-api" version = "0.0.3" -source = "git+https://github.com/timokoesters/ruma?branch=timo-fed-fixes#6ccb3ecaf69167ba405379826a9d87a98f168df8" +source = "git+https://github.com/timokoesters/ruma?branch=timo-fed-fixes#425d34d4cfb5aefe5bab6957d71bc9389384c1e5" dependencies = [ "js_int", "ruma-api", @@ -1543,7 +1543,7 @@ dependencies = [ [[package]] name = "ruma-identifiers" version = "0.17.4" -source = "git+https://github.com/timokoesters/ruma?branch=timo-fed-fixes#6ccb3ecaf69167ba405379826a9d87a98f168df8" +source = "git+https://github.com/timokoesters/ruma?branch=timo-fed-fixes#425d34d4cfb5aefe5bab6957d71bc9389384c1e5" dependencies = [ "rand", "ruma-identifiers-macros", @@ -1555,7 +1555,7 @@ dependencies = [ [[package]] name = "ruma-identifiers-macros" version = "0.17.4" -source = "git+https://github.com/timokoesters/ruma?branch=timo-fed-fixes#6ccb3ecaf69167ba405379826a9d87a98f168df8" +source = "git+https://github.com/timokoesters/ruma?branch=timo-fed-fixes#425d34d4cfb5aefe5bab6957d71bc9389384c1e5" dependencies = [ "proc-macro2", "quote", @@ -1566,7 +1566,7 @@ dependencies = [ [[package]] name = "ruma-identifiers-validation" version = "0.1.1" -source = "git+https://github.com/timokoesters/ruma?branch=timo-fed-fixes#6ccb3ecaf69167ba405379826a9d87a98f168df8" +source = "git+https://github.com/timokoesters/ruma?branch=timo-fed-fixes#425d34d4cfb5aefe5bab6957d71bc9389384c1e5" dependencies = [ "serde", "strum", @@ -1575,7 +1575,7 @@ dependencies = [ [[package]] name = "ruma-serde" version = "0.2.3" -source = "git+https://github.com/timokoesters/ruma?branch=timo-fed-fixes#6ccb3ecaf69167ba405379826a9d87a98f168df8" +source = "git+https://github.com/timokoesters/ruma?branch=timo-fed-fixes#425d34d4cfb5aefe5bab6957d71bc9389384c1e5" dependencies = [ "form_urlencoded", "itoa", @@ -1587,7 +1587,7 @@ dependencies = [ [[package]] name = "ruma-signatures" version = "0.6.0-dev.1" -source = "git+https://github.com/timokoesters/ruma?branch=timo-fed-fixes#6ccb3ecaf69167ba405379826a9d87a98f168df8" +source = "git+https://github.com/timokoesters/ruma?branch=timo-fed-fixes#425d34d4cfb5aefe5bab6957d71bc9389384c1e5" dependencies = [ "base64", "ring", @@ -1831,7 +1831,7 @@ checksum = "7345c971d1ef21ffdbd103a75990a15eb03604fc8b8852ca8cb418ee1a099028" [[package]] name = "state-res" version = "0.1.0" -source = "git+https://github.com/timokoesters/state-res?branch=spec-comp#1d01b6e65b6afd50e65085fb40f1e7d2782f519e" +source = "git+https://github.com/timokoesters/state-res?branch=spec-comp#d11a3feb5307715ab5d86af8f25d4bccfee6264b" dependencies = [ "itertools", "js_int", @@ -1981,9 +1981,9 @@ dependencies = [ [[package]] name = "time" -version = "0.2.19" +version = "0.2.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "80c1a1fd93112fc50b11c43a1def21f926be3c18884fad676ea879572da070a1" +checksum = "0d4953c513c9bf1b97e9cdd83f11d60c4b0a83462880a360d80d96953a953fee" dependencies = [ "const_fn", "libc", diff --git a/src/client_server/context.rs b/src/client_server/context.rs index 9593726..4c9be20 100644 --- a/src/client_server/context.rs +++ b/src/client_server/context.rs @@ -49,7 +49,10 @@ pub fn get_context_route( .filter_map(|r| r.ok()) // Remove buggy events .collect::>(); - let start_token = events_before.last().map(|(count, _)| count.to_string()); + let start_token = events_before + .last() + .and_then(|(pdu_id, _)| db.rooms.pdu_count(pdu_id).ok()) + .map(|count| count.to_string()); let events_before = events_before .into_iter() @@ -68,7 +71,10 @@ pub fn get_context_route( .filter_map(|r| r.ok()) // Remove buggy events .collect::>(); - let end_token = events_after.last().map(|(count, _)| count.to_string()); + let end_token = events_after + .last() + .and_then(|(pdu_id, _)| db.rooms.pdu_count(pdu_id).ok()) + .map(|count| count.to_string()); let events_after = events_after .into_iter() diff --git a/src/client_server/membership.rs b/src/client_server/membership.rs index c4eed95..628045d 100644 --- a/src/client_server/membership.rs +++ b/src/client_server/membership.rs @@ -601,8 +601,7 @@ async fn join_room_by_id_helper( .cloned() .collect::>(); - let power_level = - resolved_control_events.get(&(EventType::RoomPowerLevels, Some("".into()))); + let power_level = resolved_control_events.get(&(EventType::RoomPowerLevels, "".into())); // Sort the remaining non control events let sorted_event_ids = state_res::StateResolution::mainline_sort( room_id, @@ -644,13 +643,7 @@ async fn join_room_by_id_helper( )?; if state_events.contains(ev_id) { - state.insert( - ( - pdu.kind(), - pdu.state_key().expect("State events have a state key"), - ), - pdu_id, - ); + state.insert((pdu.kind(), pdu.state_key()), pdu_id); } } diff --git a/src/client_server/message.rs b/src/client_server/message.rs index 3944d5b..5a4488f 100644 --- a/src/client_server/message.rs +++ b/src/client_server/message.rs @@ -117,6 +117,12 @@ pub fn get_message_events_route( .pdus_after(&sender_id, &body.room_id, from) .take(limit) .filter_map(|r| r.ok()) // Filter out buggy events + .filter_map(|(pdu_id, pdu)| { + db.rooms + .pdu_count(&pdu_id) + .map(|pdu_count| (pdu_count, pdu)) + .ok() + }) .take_while(|&(k, _)| Some(Ok(k)) != to) // Stop at `to` .collect::>(); @@ -141,6 +147,12 @@ pub fn get_message_events_route( .pdus_until(&sender_id, &body.room_id, from) .take(limit) .filter_map(|r| r.ok()) // Filter out buggy events + .filter_map(|(pdu_id, pdu)| { + db.rooms + .pdu_count(&pdu_id) + .map(|pdu_count| (pdu_count, pdu)) + .ok() + }) .take_while(|&(k, _)| Some(Ok(k)) != to) // Stop at `to` .collect::>(); diff --git a/src/client_server/sync.rs b/src/client_server/sync.rs index 6ece180..0e40bfb 100644 --- a/src/client_server/sync.rs +++ b/src/client_server/sync.rs @@ -105,50 +105,92 @@ pub async fn sync_events_route( .room_state_get(&room_id, &EventType::RoomEncryption, "")? .is_some(); - // TODO: optimize this? - let mut send_member_count = false; - let mut joined_since_last_sync = false; - let mut new_encrypted_room = false; - for (state_key, pdu) in db + // Database queries: + let since_state_hash = db .rooms - .pdus_since(&sender_id, &room_id, since)? - .filter_map(|r| r.ok()) - .filter_map(|(_, pdu)| Some((pdu.state_key.clone()?, pdu))) - { - if pdu.kind == EventType::RoomMember { - send_member_count = true; + .pdus_until(sender_id, &room_id, since) + .next() + .and_then(|pdu| pdu.ok()) + .and_then(|pdu| db.rooms.pdu_state_hash(&pdu.0).ok()?); - let content = serde_json::from_value::< - Raw, - >(pdu.content.clone()) + let since_members = since_state_hash + .as_ref() + .and_then(|state_hash| db.rooms.state_type(state_hash, &EventType::RoomMember).ok()); + + let since_encryption = since_state_hash.as_ref().and_then(|state_hash| { + db.rooms + .state_get(&state_hash, &EventType::RoomEncryption, "") + .ok() + }); + + let current_members = db.rooms.room_state_type(&room_id, &EventType::RoomMember)?; + + // Calculations: + let new_encrypted_room = encrypted_room && since_encryption.is_none(); + + let send_member_count = since_members.as_ref().map_or(true, |since_members| { + current_members.len() != since_members.len() + }); + + let since_sender_member = since_members.as_ref().and_then(|members| { + members.get(sender_id.as_str()).and_then(|pdu| { + serde_json::from_value::>( + pdu.content.clone(), + ) .expect("Raw::from_value always works") .deserialize() - .map_err(|_| Error::bad_database("Invalid PDU in database."))?; + .map_err(|_| Error::bad_database("Invalid PDU in database.")) + .ok() + }) + }); - if pdu.state_key == Some(sender_id.to_string()) - && content.membership == MembershipState::Join - { - joined_since_last_sync = true; - } else if encrypted_room && content.membership == MembershipState::Join { - // A new user joined an encrypted room - let user_id = UserId::try_from(state_key) - .map_err(|_| Error::bad_database("Invalid UserId in member PDU."))?; - // Add encryption update if we didn't share an encrypted room already - if !share_encrypted_room(&db, &sender_id, &user_id, &room_id) { - device_list_updates.insert(user_id); + if encrypted_room { + for (user_id, current_member) in current_members { + let current_membership = serde_json::from_value::< + Raw, + >(current_member.content.clone()) + .expect("Raw::from_value always works") + .deserialize() + .map_err(|_| Error::bad_database("Invalid PDU in database."))? + .membership; + + let since_membership = since_members + .as_ref() + .and_then(|members| { + members.get(&user_id).and_then(|since_member| { + serde_json::from_value::< + Raw, + >(since_member.content.clone()) + .expect("Raw::from_value always works") + .deserialize() + .map_err(|_| Error::bad_database("Invalid PDU in database.")) + .ok() + }) + }) + .map_or(MembershipState::Leave, |member| member.membership); + + let user_id = UserId::try_from(user_id) + .map_err(|_| Error::bad_database("Invalid UserId in member PDU."))?; + + match (since_membership, current_membership) { + (MembershipState::Leave, MembershipState::Join) => { + // A new user joined an encrypted room + if !share_encrypted_room(&db, &sender_id, &user_id, &room_id) { + device_list_updates.insert(user_id); + } } - } else if encrypted_room && content.membership == MembershipState::Leave { - // Write down users that have left encrypted rooms we are in - left_encrypted_users.insert( - UserId::try_from(state_key) - .map_err(|_| Error::bad_database("Invalid UserId in member PDU."))?, - ); + (MembershipState::Join, MembershipState::Leave) => { + // Write down users that have left encrypted rooms we are in + left_encrypted_users.insert(user_id); + } + _ => {} } - } else if pdu.kind == EventType::RoomEncryption { - new_encrypted_room = true; } } + let joined_since_last_sync = + since_sender_member.map_or(true, |member| member.membership != MembershipState::Join); + if joined_since_last_sync && encrypted_room || new_encrypted_room { // If the user is in a new encrypted room, give them all joined users device_list_updates.extend( @@ -390,23 +432,37 @@ pub async fn sync_events_route( state: sync_events::State { events: Vec::new() }, }; - let mut left_since_last_sync = false; - for pdu in db.rooms.pdus_since(&sender_id, &room_id, since)? { - let (_, pdu) = pdu?; - if pdu.kind == EventType::RoomMember && pdu.state_key == Some(sender_id.to_string()) { - let content = serde_json::from_value::< - Raw, - >(pdu.content.clone()) + let since_member = db + .rooms + .pdus_until(sender_id, &room_id, since) + .next() + .and_then(|pdu| pdu.ok()) + .and_then(|pdu| { + db.rooms + .pdu_state_hash(&pdu.0) + .ok()? + .ok_or_else(|| Error::bad_database("Pdu in db doesn't have a state hash.")) + .ok() + }) + .and_then(|state_hash| { + db.rooms + .state_get(&state_hash, &EventType::RoomMember, sender_id.as_str()) + .ok()? + .ok_or_else(|| Error::bad_database("State hash in db doesn't have a state.")) + .ok() + }) + .and_then(|pdu| { + serde_json::from_value::>( + pdu.content.clone(), + ) .expect("Raw::from_value always works") .deserialize() - .map_err(|_| Error::bad_database("Invalid PDU in database."))?; + .map_err(|_| Error::bad_database("Invalid PDU in database.")) + .ok() + }); - if content.membership == MembershipState::Leave { - left_since_last_sync = true; - break; - } - } - } + let left_since_last_sync = + since_member.map_or(false, |member| member.membership == MembershipState::Join); if left_since_last_sync { device_list_left.extend( diff --git a/src/database/rooms.rs b/src/database/rooms.rs index 263f51b..5958626 100644 --- a/src/database/rooms.rs +++ b/src/database/rooms.rs @@ -31,7 +31,7 @@ use std::{ /// /// This is created when a state group is added to the database by /// hashing the entire state. -pub type StateHashId = Vec; +pub type StateHashId = IVec; #[derive(Clone)] pub struct Rooms { @@ -100,7 +100,7 @@ impl StateStore for Rooms { impl Rooms { /// Builds a StateMap by iterating over all keys that start /// with state_hash, this gives the full state for the given state_hash. - pub fn state_full(&self, state_hash: StateHashId) -> Result> { + pub fn state_full(&self, state_hash: &StateHashId) -> Result> { self.stateid_pduid .scan_prefix(&state_hash) .values() @@ -115,61 +115,87 @@ impl Rooms { }) .map(|pdu| { let pdu = pdu?; - Ok(((pdu.kind, pdu.state_key), pdu.event_id)) + Ok(( + ( + pdu.kind.clone(), + pdu.state_key + .as_ref() + .ok_or_else(|| Error::bad_database("State event has no state key."))? + .clone(), + ), + pdu, + )) }) .collect::>>() } - // TODO make this return Result - /// Fetches the previous StateHash ID to `current`. - pub fn prev_state_hash(&self, current: StateHashId) -> Option { - let mut found = false; - for pair in self.pduid_statehash.iter().rev() { - let prev = pair.ok()?.1; - if current == prev.as_ref() { - found = true; - } - if current != prev.as_ref() && found { - return Some(prev.to_vec()); - } + /// Returns all state entries for this type. + pub fn state_type( + &self, + state_hash: &StateHashId, + event_type: &EventType, + ) -> Result> { + let mut prefix = state_hash.to_vec(); + prefix.push(0xff); + prefix.extend_from_slice(&event_type.to_string().as_bytes()); + prefix.push(0xff); + + let mut hashmap = HashMap::new(); + for pdu in self + .stateid_pduid + .scan_prefix(&prefix) + .values() + .map(|pdu_id| { + Ok::<_, Error>( + serde_json::from_slice::(&self.pduid_pdu.get(pdu_id?)?.ok_or_else( + || Error::bad_database("PDU in state not found in database."), + )?) + .map_err(|_| Error::bad_database("Invalid PDU bytes in room state."))?, + ) + }) + { + let pdu = pdu?; + let state_key = pdu.state_key.clone().ok_or_else(|| { + Error::bad_database("Room state contains event without state_key.") + })?; + hashmap.insert(state_key, pdu); } - None + Ok(hashmap) + } + + /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). + pub fn state_get( + &self, + state_hash: &StateHashId, + event_type: &EventType, + state_key: &str, + ) -> Result> { + let mut key = state_hash.to_vec(); + key.push(0xff); + key.extend_from_slice(&event_type.to_string().as_bytes()); + key.push(0xff); + key.extend_from_slice(&state_key.as_bytes()); + + self.stateid_pduid.get(&key)?.map_or(Ok(None), |pdu_id| { + Ok::<_, Error>(Some( + serde_json::from_slice::( + &self.pduid_pdu.get(pdu_id)?.ok_or_else(|| { + Error::bad_database("PDU in state not found in database.") + })?, + ) + .map_err(|_| Error::bad_database("Invalid PDU bytes in room state."))?, + )) + }) + } + + /// Returns the last state hash key added to the db. + pub fn pdu_state_hash(&self, pdu_id: &[u8]) -> Result> { + Ok(self.pduid_statehash.get(pdu_id)?) } /// Returns the last state hash key added to the db. pub fn current_state_hash(&self, room_id: &RoomId) -> Result> { - Ok(self - .roomid_statehash - .get(room_id.as_bytes())? - .map(|bytes| bytes.to_vec())) - } - - /// This fetches auth event_ids from the current state using the - /// full `roomstateid_pdu` tree. - pub fn get_auth_event_ids( - &self, - room_id: &RoomId, - kind: &EventType, - sender: &UserId, - state_key: Option<&str>, - content: serde_json::Value, - ) -> Result> { - let auth_events = state_res::auth_types_for_event( - kind.clone(), - sender, - state_key.map(|s| s.to_string()), - content, - ); - - let mut events = vec![]; - for (event_type, state_key) in auth_events { - if let Some(state_key) = state_key.as_ref() { - if let Some(id) = self.room_state_get(room_id, &event_type, state_key)? { - events.push(id.event_id); - } - } - } - Ok(events) + Ok(self.roomid_statehash.get(room_id.as_bytes())?) } /// This fetches auth events from the current state. @@ -190,10 +216,8 @@ impl Rooms { let mut events = StateMap::new(); for (event_type, state_key) in auth_events { - if let Some(s_key) = state_key.as_ref() { - if let Some(pdu) = self.room_state_get(room_id, &event_type, s_key)? { - events.insert((event_type, state_key), pdu); - } + if let Some(pdu) = self.room_state_get(room_id, &event_type, &state_key)? { + events.insert((event_type, state_key), pdu); } } Ok(events) @@ -206,7 +230,7 @@ impl Rooms { // We only hash the pdu's event ids, not the whole pdu let bytes = pdu_id_bytes.join(&0xff); let hash = digest::digest(&digest::SHA256, &bytes); - Ok(hash.as_ref().to_vec()) + Ok(hash.as_ref().into()) } /// Checks if a room exists. @@ -230,7 +254,7 @@ impl Rooms { ) -> Result<()> { let state_hash = self.calculate_hash(&state.values().map(|pdu_id| &**pdu_id).collect::>())?; - let mut prefix = state_hash.clone(); + let mut prefix = state_hash.to_vec(); prefix.push(0xff); for ((event_type, state_key), pdu_id) in state { @@ -248,41 +272,11 @@ impl Rooms { } /// Returns the full room state. - pub fn room_state_full( - &self, - room_id: &RoomId, - ) -> Result> { + pub fn room_state_full(&self, room_id: &RoomId) -> Result> { if let Some(current_state_hash) = self.current_state_hash(room_id)? { - let mut prefix = current_state_hash; - prefix.push(0xff); - - let mut hashmap = HashMap::new(); - for pdu in self - .stateid_pduid - .scan_prefix(prefix) - .values() - .map(|pdu_id| { - Ok::<_, Error>( - serde_json::from_slice::( - &self.pduid_pdu.get(pdu_id?)?.ok_or_else(|| { - Error::bad_database("PDU in state not found in database.") - })?, - ) - .map_err(|_| { - Error::bad_database("Invalid PDU bytes in current room state.") - })?, - ) - }) - { - let pdu = pdu?; - let state_key = pdu.state_key.clone().ok_or_else(|| { - Error::bad_database("Room state contains event without state_key.") - })?; - hashmap.insert((pdu.kind.clone(), state_key), pdu); - } - Ok(hashmap) + self.state_full(¤t_state_hash) } else { - Ok(HashMap::new()) + Ok(BTreeMap::new()) } } @@ -293,36 +287,7 @@ impl Rooms { event_type: &EventType, ) -> Result> { if let Some(current_state_hash) = self.current_state_hash(room_id)? { - let mut prefix = current_state_hash; - prefix.push(0xff); - prefix.extend_from_slice(&event_type.to_string().as_bytes()); - prefix.push(0xff); - - let mut hashmap = HashMap::new(); - for pdu in self - .stateid_pduid - .scan_prefix(&prefix) - .values() - .map(|pdu_id| { - Ok::<_, Error>( - serde_json::from_slice::( - &self.pduid_pdu.get(pdu_id?)?.ok_or_else(|| { - Error::bad_database("PDU in state not found in database.") - })?, - ) - .map_err(|_| { - Error::bad_database("Invalid PDU bytes in current room state.") - })?, - ) - }) - { - let pdu = pdu?; - let state_key = pdu.state_key.clone().ok_or_else(|| { - Error::bad_database("Room state contains event without state_key.") - })?; - hashmap.insert(state_key, pdu); - } - Ok(hashmap) + self.state_type(¤t_state_hash, event_type) } else { Ok(HashMap::new()) } @@ -336,20 +301,7 @@ impl Rooms { state_key: &str, ) -> Result> { if let Some(current_state_hash) = self.current_state_hash(room_id)? { - let mut key = current_state_hash; - key.push(0xff); - key.extend_from_slice(&event_type.to_string().as_bytes()); - key.push(0xff); - key.extend_from_slice(&state_key.as_bytes()); - - self.stateid_pduid.get(&key)?.map_or(Ok(None), |pdu_id| { - Ok::<_, Error>(Some( - serde_json::from_slice::(&self.pduid_pdu.get(pdu_id)?.ok_or_else( - || Error::bad_database("PDU in state not found in database."), - )?) - .map_err(|_| Error::bad_database("Invalid PDU bytes in current room state."))?, - )) - }) + self.state_get(¤t_state_hash, event_type, state_key) } else { Ok(None) } @@ -562,14 +514,15 @@ impl Rooms { /// This adds all current state events (not including the incoming event) /// to `stateid_pduid` and adds the incoming event to `pduid_statehash`. /// The incoming event is the `pdu_id` passed to this method. - fn append_to_state(&self, new_pdu_id: &[u8], new_pdu: &PduEvent) -> Result { + pub fn append_to_state(&self, new_pdu_id: &[u8], new_pdu: &PduEvent) -> Result { let old_state = if let Some(old_state_hash) = self.roomid_statehash.get(new_pdu.room_id.as_bytes())? { // Store state for event. The state does not include the event itself. // Instead it's the state before the pdu, so the room's old state. - self.pduid_statehash.insert(new_pdu_id, &old_state_hash)?; + self.pduid_statehash + .insert(dbg!(new_pdu_id), &old_state_hash)?; if new_pdu.state_key.is_none() { - return Ok(old_state_hash.to_vec()); + return Ok(old_state_hash); } let mut prefix = old_state_hash.to_vec(); @@ -841,9 +794,7 @@ impl Rooms { let pdu_id = self.append_pdu(&pdu, &pdu_json, globals, account_data)?; - if pdu.state_key.is_some() { - self.append_to_state(&pdu_id, &pdu)?; - } + self.append_to_state(&pdu_id, &pdu)?; for server in self .room_servers(room_id) @@ -905,7 +856,7 @@ impl Rooms { user_id: &UserId, room_id: &RoomId, until: u64, - ) -> impl Iterator> { + ) -> impl Iterator> { // Create the first part of the full pdu id let mut prefix = room_id.to_string().as_bytes().to_vec(); prefix.push(0xff); @@ -916,23 +867,18 @@ impl Rooms { let current: &[u8] = ¤t; let user_id = user_id.clone(); - let prefixlen = prefix.len(); self.pduid_pdu .range(..current) .rev() .filter_map(|r| r.ok()) .take_while(move |(k, _)| k.starts_with(&prefix)) - .map(move |(k, v)| { + .map(move |(pdu_id, v)| { let mut pdu = serde_json::from_slice::(&v) .map_err(|_| Error::bad_database("PDU in db is invalid."))?; if pdu.sender != user_id { pdu.unsigned.remove("transaction_id"); } - Ok(( - utils::u64_from_bytes(&k[prefixlen..]) - .map_err(|_| Error::bad_database("Invalid pdu id in db."))?, - pdu, - )) + Ok((pdu_id, pdu)) }) } @@ -943,7 +889,7 @@ impl Rooms { user_id: &UserId, room_id: &RoomId, from: u64, - ) -> impl Iterator> { + ) -> impl Iterator> { // Create the first part of the full pdu id let mut prefix = room_id.to_string().as_bytes().to_vec(); prefix.push(0xff); @@ -954,22 +900,17 @@ impl Rooms { let current: &[u8] = ¤t; let user_id = user_id.clone(); - let prefixlen = prefix.len(); self.pduid_pdu .range(current..) .filter_map(|r| r.ok()) .take_while(move |(k, _)| k.starts_with(&prefix)) - .map(move |(k, v)| { + .map(move |(pdu_id, v)| { let mut pdu = serde_json::from_slice::(&v) .map_err(|_| Error::bad_database("PDU in db is invalid."))?; if pdu.sender != user_id { pdu.unsigned.remove("transaction_id"); } - Ok(( - utils::u64_from_bytes(&k[prefixlen..]) - .map_err(|_| Error::bad_database("Invalid pdu id in db."))?, - pdu, - )) + Ok((pdu_id, pdu)) }) } diff --git a/src/main.rs b/src/main.rs index f81c7f4..06fda59 100644 --- a/src/main.rs +++ b/src/main.rs @@ -123,6 +123,7 @@ fn setup_rocket() -> rocket::Rocket { server_server::get_server_keys, server_server::get_server_keys_deprecated, server_server::get_public_rooms_route, + server_server::get_public_rooms_filtered_route, server_server::send_transaction_message_route, ], ) diff --git a/src/pdu.rs b/src/pdu.rs index 957d9e0..d5b5415 100644 --- a/src/pdu.rs +++ b/src/pdu.rs @@ -1,4 +1,4 @@ -use crate::{Error, Result}; +use crate::Error; use js_int::UInt; use ruma::{ events::pdu::PduStub, @@ -35,7 +35,7 @@ pub struct PduEvent { } impl PduEvent { - pub fn redact(&mut self, reason: &PduEvent) -> Result<()> { + pub fn redact(&mut self, reason: &PduEvent) -> crate::Result<()> { self.unsigned.clear(); let allowed: &[&str] = match self.kind { @@ -244,7 +244,7 @@ impl From<&state_res::StateEvent> for PduEvent { .expect("time is valid"), kind: pdu.kind(), content: pdu.content().clone(), - state_key: pdu.state_key(), + state_key: Some(pdu.state_key()), prev_events: pdu.prev_event_ids(), depth: pdu.depth().clone(), auth_events: pdu.auth_events(), diff --git a/src/server_server.rs b/src/server_server.rs index aef3991..6f2b179 100644 --- a/src/server_server.rs +++ b/src/server_server.rs @@ -329,8 +329,10 @@ pub fn send_transaction_message_route<'a>( let pdu = serde_json::from_value::(value.clone()) .expect("all ruma pdus are conduit pdus"); if db.rooms.exists(&pdu.room_id)? { - db.rooms + let pdu_id = db + .rooms .append_pdu(&pdu, &value, &db.globals, &db.account_data)?; + db.rooms.append_to_state(&pdu_id, &pdu)?; } } Ok(send_transaction_message::v1::Response { diff --git a/src/stateres.rs b/src/stateres.rs deleted file mode 100644 index ee47099..0000000 --- a/src/stateres.rs +++ /dev/null @@ -1,59 +0,0 @@ -use std::collections::HashMap; - -fn stateres(state_a: HashMap, state_b: HashMap) { - let mut unconflicted = todo!("state at fork event"); - - let mut conflicted: HashMap = state_a - .iter() - .filter(|(key_a, value_a)| match state_b.remove(key_a) { - Some(value_b) if value_a == value_b => unconflicted.insert(key_a, value_a), - _ => false, - }) - .collect(); - - // We removed unconflicted from state_b, now we can easily insert all events that are only in fork b - conflicted.extend(state_b); - - let partial_state = unconflicted.clone(); - - let full_conflicted = conflicted.clone(); // TODO: auth events - - let output_rev = Vec::new(); - let event_map = HashMap::new(); - let incoming_edges = HashMap::new(); - - for event in full_conflicted { - event_map.insert(event.event_id, event); - incoming_edges.insert(event.event_id, 0); - } - - for e in conflicted_control_events { - for a in e.auth_events { - incoming_edges[a.event_id] += 1; - } - } - - while incoming_edges.len() > 0 { - let mut count_0 = incoming_edges - .iter() - .filter(|(_, c)| c == 0) - .collect::>(); - - count_0.sort_by(|(x, _), (y, _)| { - x.power_level - .cmp(&a.power_level) - .then_with(|| x.origin_server.ts.cmp(&y.origin_server_ts)) - .then_with(|| x.event_id.cmp(&y.event_id)) - }); - - for (id, count) in count_0 { - output_rev.push(event_map[id]); - - for auth_event in event_map[id].auth_events { - incoming_edges[auth_event.event_id] -= 1; - } - - incoming_edges.remove(id); - } - } -}