From c4f5a0a6316200f797bf81d77c3cd2815a73706d Mon Sep 17 00:00:00 2001 From: Devin Ragotzy Date: Thu, 6 Aug 2020 08:29:59 -0400 Subject: [PATCH] Keep track of State at event for state resolution feat: first steps towards joining rooms over federation Add state-res as a dependency of conduit Add reverse_topological_power_sort before append_pdu Implement statehashstatid_pduid tree for keeping track of state Clean up implementation of state_hash as key for tracking state --- Cargo.lock | 206 +++++++++++- Cargo.toml | 7 +- src/client_server/account.rs | 6 +- src/client_server/alias.rs | 16 +- src/client_server/context.rs | 30 +- src/client_server/device.rs | 8 +- src/client_server/directory.rs | 13 +- src/client_server/filter.rs | 21 +- src/client_server/keys.rs | 2 +- src/client_server/media.rs | 4 +- src/client_server/membership.rs | 125 ++++++-- src/client_server/message.rs | 43 +-- src/client_server/room.rs | 2 +- src/client_server/session.rs | 20 +- src/client_server/state.rs | 48 ++- src/client_server/sync.rs | 2 +- src/client_server/unversioned.rs | 13 +- src/database.rs | 5 +- src/database/rooms.rs | 525 +++++++++++++++++++++---------- src/database/uiaa.rs | 6 +- src/pdu.rs | 29 ++ src/ruma_wrapper.rs | 10 +- src/server_server.rs | 27 +- src/utils.rs | 6 + 24 files changed, 818 insertions(+), 356 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 0a7334c..ffee8ea 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -75,6 +75,15 @@ dependencies = [ "opaque-debug 0.2.3", ] +[[package]] +name = "ansi_term" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d52a9bb7ec0cf484c551830a7ce27bd20d67eac647e1befb56b0be4ee39a55d2" +dependencies = [ + "winapi 0.3.9", +] + [[package]] name = "arc-swap" version = "0.4.7" @@ -248,6 +257,17 @@ version = "0.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4785bdd1c96b2a846b2bd7cc02e86b6b3dbf14e7e53446c4f54c92a361040822" +[[package]] +name = "chrono" +version = "0.4.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "942f72db697d8767c22d46a598e01f2d3b475501ea43d0db4f16d90259182d0b" +dependencies = [ + "num-integer", + "num-traits", + "time 0.1.43", +] + [[package]] name = "cloudabi" version = "0.1.0" @@ -281,6 +301,7 @@ dependencies = [ "serde", "serde_json", "sled", + "state-res", "thiserror", "tokio", ] @@ -456,6 +477,12 @@ version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "134951f4028bdadb9b84baf4232681efbf277da25144b9b0ad65df75946c422b" +[[package]] +name = "either" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd56b59865bce947ac5958779cfa508f6c3b9497cc762b7e24a12d11ccde2c4f" + [[package]] name = "encoding_rs" version = "0.8.23" @@ -872,6 +899,15 @@ version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "47be2f14c678be2fdcab04ab1171db51b2762ce6f0a8ee87c8dd4a04ed216135" +[[package]] +name = "itertools" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "284f18f85651fe11e8a991b2adb42cb078325c996ed026d994719efcfca1d54b" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "0.4.6" @@ -951,6 +987,21 @@ version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7d947cbb889ed21c2a84be6ffbaebf5b4e0f4340638cba0444907e38b56be084" +[[package]] +name = "maplit" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3e2e65a1a2e43cfcb47a895c4c8b10d1f4a61097f9f254f183aee60cad9c651d" + +[[package]] +name = "matchers" +version = "0.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f099785f7595cc4b4553a174ce30dd7589ef93391ff414dbb67f62392b9e0ce1" +dependencies = [ + "regex-automata", +] + [[package]] name = "matches" version = "0.1.8" @@ -1439,6 +1490,31 @@ dependencies = [ "syn", ] +[[package]] +name = "regex" +version = "1.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c3780fcf44b193bc4d09f36d2a3c87b251da4a046c87795a0d35f4f927ad8e6" +dependencies = [ + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae1ded71d66a4a97f5e961fd0cb25a5f366a42a41570d16a763a69c092c26ae4" +dependencies = [ + "byteorder", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.6.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26412eb97c6b088a6997e05f69403a802a92d520de2f8e63c2b65f9e0f47c4e8" + [[package]] name = "remove_dir_all" version = "0.5.3" @@ -1560,21 +1636,23 @@ dependencies = [ [[package]] name = "ruma" version = "0.0.1" -source = "git+https://github.com/timokoesters/ruma?branch=timo-fixes#c2adc9ecb85538505ff351dbd883c9106f651744" +source = "git+https://github.com/ruma/ruma?rev=aff914050eb297bd82b8aafb12158c88a9e480e1#aff914050eb297bd82b8aafb12158c88a9e480e1" dependencies = [ "ruma-api", + "ruma-appservice-api", "ruma-client-api", "ruma-common", "ruma-events", "ruma-federation-api", "ruma-identifiers", + "ruma-serde", "ruma-signatures", ] [[package]] name = "ruma-api" version = "0.17.0-alpha.1" -source = "git+https://github.com/timokoesters/ruma?branch=timo-fixes#c2adc9ecb85538505ff351dbd883c9106f651744" +source = "git+https://github.com/ruma/ruma?rev=aff914050eb297bd82b8aafb12158c88a9e480e1#aff914050eb297bd82b8aafb12158c88a9e480e1" dependencies = [ "http", "percent-encoding", @@ -1589,7 +1667,7 @@ dependencies = [ [[package]] name = "ruma-api-macros" version = "0.17.0-alpha.1" -source = "git+https://github.com/timokoesters/ruma?branch=timo-fixes#c2adc9ecb85538505ff351dbd883c9106f651744" +source = "git+https://github.com/ruma/ruma?rev=aff914050eb297bd82b8aafb12158c88a9e480e1#aff914050eb297bd82b8aafb12158c88a9e480e1" dependencies = [ "proc-macro-crate", "proc-macro2", @@ -1597,14 +1675,28 @@ dependencies = [ "syn", ] +[[package]] +name = "ruma-appservice-api" +version = "0.2.0-alpha.1" +source = "git+https://github.com/ruma/ruma?rev=aff914050eb297bd82b8aafb12158c88a9e480e1#aff914050eb297bd82b8aafb12158c88a9e480e1" +dependencies = [ + "ruma-api", + "ruma-common", + "ruma-events", + "ruma-identifiers", + "serde", + "serde_json", +] + [[package]] name = "ruma-client-api" version = "0.10.0-alpha.1" -source = "git+https://github.com/timokoesters/ruma?branch=timo-fixes#c2adc9ecb85538505ff351dbd883c9106f651744" +source = "git+https://github.com/ruma/ruma?rev=aff914050eb297bd82b8aafb12158c88a9e480e1#aff914050eb297bd82b8aafb12158c88a9e480e1" dependencies = [ "assign", "http", "js_int", + "percent-encoding", "ruma-api", "ruma-common", "ruma-events", @@ -1618,7 +1710,7 @@ dependencies = [ [[package]] name = "ruma-common" version = "0.2.0" -source = "git+https://github.com/timokoesters/ruma?branch=timo-fixes#c2adc9ecb85538505ff351dbd883c9106f651744" +source = "git+https://github.com/ruma/ruma?rev=aff914050eb297bd82b8aafb12158c88a9e480e1#aff914050eb297bd82b8aafb12158c88a9e480e1" dependencies = [ "js_int", "ruma-identifiers", @@ -1631,7 +1723,7 @@ dependencies = [ [[package]] name = "ruma-events" version = "0.22.0-alpha.1" -source = "git+https://github.com/timokoesters/ruma?branch=timo-fixes#c2adc9ecb85538505ff351dbd883c9106f651744" +source = "git+https://github.com/ruma/ruma?rev=aff914050eb297bd82b8aafb12158c88a9e480e1#aff914050eb297bd82b8aafb12158c88a9e480e1" dependencies = [ "js_int", "ruma-common", @@ -1646,7 +1738,7 @@ dependencies = [ [[package]] name = "ruma-events-macros" version = "0.22.0-alpha.1" -source = "git+https://github.com/timokoesters/ruma?branch=timo-fixes#c2adc9ecb85538505ff351dbd883c9106f651744" +source = "git+https://github.com/ruma/ruma?rev=aff914050eb297bd82b8aafb12158c88a9e480e1#aff914050eb297bd82b8aafb12158c88a9e480e1" dependencies = [ "proc-macro-crate", "proc-macro2", @@ -1657,7 +1749,7 @@ dependencies = [ [[package]] name = "ruma-federation-api" version = "0.0.3" -source = "git+https://github.com/timokoesters/ruma?branch=timo-fixes#c2adc9ecb85538505ff351dbd883c9106f651744" +source = "git+https://github.com/ruma/ruma?rev=aff914050eb297bd82b8aafb12158c88a9e480e1#aff914050eb297bd82b8aafb12158c88a9e480e1" dependencies = [ "js_int", "ruma-api", @@ -1672,7 +1764,7 @@ dependencies = [ [[package]] name = "ruma-identifiers" version = "0.17.4" -source = "git+https://github.com/timokoesters/ruma?branch=timo-fixes#c2adc9ecb85538505ff351dbd883c9106f651744" +source = "git+https://github.com/ruma/ruma?rev=aff914050eb297bd82b8aafb12158c88a9e480e1#aff914050eb297bd82b8aafb12158c88a9e480e1" dependencies = [ "rand", "ruma-identifiers-macros", @@ -1684,7 +1776,7 @@ dependencies = [ [[package]] name = "ruma-identifiers-macros" version = "0.17.4" -source = "git+https://github.com/timokoesters/ruma?branch=timo-fixes#c2adc9ecb85538505ff351dbd883c9106f651744" +source = "git+https://github.com/ruma/ruma?rev=aff914050eb297bd82b8aafb12158c88a9e480e1#aff914050eb297bd82b8aafb12158c88a9e480e1" dependencies = [ "proc-macro2", "quote", @@ -1695,7 +1787,7 @@ dependencies = [ [[package]] name = "ruma-identifiers-validation" version = "0.1.1" -source = "git+https://github.com/timokoesters/ruma?branch=timo-fixes#c2adc9ecb85538505ff351dbd883c9106f651744" +source = "git+https://github.com/ruma/ruma?rev=aff914050eb297bd82b8aafb12158c88a9e480e1#aff914050eb297bd82b8aafb12158c88a9e480e1" dependencies = [ "ruma-serde", "serde", @@ -1706,7 +1798,7 @@ dependencies = [ [[package]] name = "ruma-serde" version = "0.2.3" -source = "git+https://github.com/timokoesters/ruma?branch=timo-fixes#c2adc9ecb85538505ff351dbd883c9106f651744" +source = "git+https://github.com/ruma/ruma?rev=aff914050eb297bd82b8aafb12158c88a9e480e1#aff914050eb297bd82b8aafb12158c88a9e480e1" dependencies = [ "form_urlencoded", "itoa", @@ -1718,7 +1810,7 @@ dependencies = [ [[package]] name = "ruma-signatures" version = "0.6.0-dev.1" -source = "git+https://github.com/timokoesters/ruma?branch=timo-fixes#c2adc9ecb85538505ff351dbd883c9106f651744" +source = "git+https://github.com/ruma/ruma?rev=aff914050eb297bd82b8aafb12158c88a9e480e1#aff914050eb297bd82b8aafb12158c88a9e480e1" dependencies = [ "base64 0.12.3", "ring", @@ -1910,6 +2002,15 @@ dependencies = [ "opaque-debug 0.3.0", ] +[[package]] +name = "sharded-slab" +version = "0.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06d5a3f5166fb5b42a5439f2eee8b9de149e235961e3eb21c5808fc3ea17ff3e" +dependencies = [ + "lazy_static", +] + [[package]] name = "signal-hook-registry" version = "1.2.1" @@ -1983,6 +2084,22 @@ version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7345c971d1ef21ffdbd103a75990a15eb03604fc8b8852ca8cb418ee1a099028" +[[package]] +name = "state-res" +version = "0.1.0" +source = "git+https://github.com/ruma/state-res#789c8140890e076d38b23fa1147c4ff0500c0d38" +dependencies = [ + "itertools", + "js_int", + "maplit", + "ruma", + "serde", + "serde_json", + "thiserror", + "tracing", + "tracing-subscriber", +] + [[package]] name = "stdweb" version = "0.4.20" @@ -2104,6 +2221,15 @@ dependencies = [ "syn", ] +[[package]] +name = "thread_local" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d40c6d1b69745a6ec6fb1ca717914848da4b44ae29d9b3080cbee91d72a69b14" +dependencies = [ + "lazy_static", +] + [[package]] name = "time" version = "0.1.43" @@ -2251,9 +2377,21 @@ checksum = "6d79ca061b032d6ce30c660fded31189ca0b9922bf483cd70759f13a2d86786c" dependencies = [ "cfg-if", "log", + "tracing-attributes", "tracing-core", ] +[[package]] +name = "tracing-attributes" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fe233f4227389ab7df5b32649239da7ebe0b281824b4e84b342d04d3fd8c25e" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "tracing-core" version = "0.1.14" @@ -2263,6 +2401,48 @@ dependencies = [ "lazy_static", ] +[[package]] +name = "tracing-log" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e0f8c7178e13481ff6765bd169b33e8d554c5d2bbede5e32c356194be02b9b9" +dependencies = [ + "lazy_static", + "log", + "tracing-core", +] + +[[package]] +name = "tracing-serde" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6ccba2f8f16e0ed268fc765d9b7ff22e965e7185d32f8f1ec8294fe17d86e79" +dependencies = [ + "serde", + "tracing-core", +] + +[[package]] +name = "tracing-subscriber" +version = "0.2.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "abd165311cc4d7a555ad11cc77a37756df836182db0d81aac908c8184c584f40" +dependencies = [ + "ansi_term", + "chrono", + "lazy_static", + "matchers", + "regex", + "serde", + "serde_json", + "sharded-slab", + "smallvec", + "thread_local", + "tracing-core", + "tracing-log", + "tracing-serde", +] + [[package]] name = "try-lock" version = "0.2.3" diff --git a/Cargo.toml b/Cargo.toml index 4945e3c..4c14d71 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,9 +16,7 @@ edition = "2018" #rocket = { git = "https://github.com/SergioBenitez/Rocket.git", rev = "8d779caa22c63b15a6c3ceb75d8f6d4971b2eb67", features = ["tls"] } # Used to handle requests rocket = { git = "https://github.com/timokoesters/Rocket.git", branch = "empty_parameters", features = ["tls"] } -#ruma = { git = "https://github.com/ruma/ruma", features = ["rand", "client-api", "federation-api", "unstable-pre-spec", "unstable-synapse-quirks"], rev = "987d48666cf166cf12100b5dbc61b5e3385c4014" } # Used for matrix spec type definitions and helpers -ruma = { git = "https://github.com/timokoesters/ruma", features = ["rand", "client-api", "federation-api", "unstable-pre-spec", "unstable-synapse-quirks"], branch = "timo-fixes" } # Used for matrix spec type definitions and helpers -#ruma = { path = "../ruma/ruma", features = ["rand", "client-api", "federation-api", "unstable-pre-spec", "unstable-synapse-quirks"] } +ruma = { git = "https://github.com/ruma/ruma", features = ["rand", "client-api", "federation-api", "unstable-pre-spec", "unstable-synapse-quirks"], rev = "aff914050eb297bd82b8aafb12158c88a9e480e1" } # Used for matrix spec type definitions and helpers tokio = "0.2.22" # Used for long polling sled = "0.32.0" # Used for storing data permanently log = "0.4.8" # Used for emitting log entries @@ -33,6 +31,9 @@ reqwest = "0.10.6" # Used to send requests thiserror = "1.0.19" # Used for conduit::Error type image = { version = "0.23.4", default-features = false, features = ["jpeg", "png", "gif"] } # Used to generate thumbnails for images base64 = "0.12.3" # Used to encode server public key +# state-res = { path = "../../state-res" } +state-res = { git = "https://github.com/ruma/state-res", version = "0.1.0" } + [features] default = ["conduit_bin"] diff --git a/src/client_server/account.rs b/src/client_server/account.rs index 9837d1b..9fa1a9c 100644 --- a/src/client_server/account.rs +++ b/src/client_server/account.rs @@ -75,7 +75,7 @@ pub fn get_register_available_route( )] pub fn register_route( db: State<'_, Database>, - body: Ruma, + body: Ruma, ) -> ConduitResult { if db.globals.registration_disabled() { return Err(Error::BadRequest( @@ -223,7 +223,7 @@ pub fn register_route( )] pub fn change_password_route( db: State<'_, Database>, - body: Ruma, + body: Ruma, ) -> ConduitResult { let sender_id = body.sender_id.as_ref().expect("user is authenticated"); let device_id = body.device_id.as_ref().expect("user is authenticated"); @@ -305,7 +305,7 @@ pub fn whoami_route(body: Ruma) -> ConduitResult, - body: Ruma, + body: Ruma, ) -> ConduitResult { let sender_id = body.sender_id.as_ref().expect("user is authenticated"); let device_id = body.device_id.as_ref().expect("user is authenticated"); diff --git a/src/client_server/alias.rs b/src/client_server/alias.rs index a029388..7dc9078 100644 --- a/src/client_server/alias.rs +++ b/src/client_server/alias.rs @@ -26,7 +26,7 @@ pub fn create_alias_route( db.rooms .set_alias(&body.room_alias, Some(&body.room_id), &db.globals)?; - Ok(create_alias::Response.into()) + Ok(create_alias::Response::new().into()) } #[cfg_attr( @@ -39,7 +39,7 @@ pub fn delete_alias_route( ) -> ConduitResult { db.rooms.set_alias(&body.room_alias, None, &db.globals)?; - Ok(delete_alias::Response.into()) + Ok(delete_alias::Response::new().into()) } #[cfg_attr( @@ -60,11 +60,7 @@ pub async fn get_alias_route( ) .await?; - return Ok(get_alias::Response { - room_id: response.room_id, - servers: response.servers, - } - .into()); + return Ok(get_alias::Response::new(response.room_id, response.servers).into()); } let room_id = db @@ -75,9 +71,5 @@ pub async fn get_alias_route( "Room with alias not found.", ))?; - Ok(get_alias::Response { - room_id, - servers: vec![db.globals.server_name().to_string()], - } - .into()) + Ok(get_alias::Response::new(room_id, vec![db.globals.server_name().to_string()]).into()) } diff --git a/src/client_server/context.rs b/src/client_server/context.rs index 7a6cbce..7b1fad9 100644 --- a/src/client_server/context.rs +++ b/src/client_server/context.rs @@ -12,7 +12,7 @@ use rocket::get; )] pub fn get_context_route( db: State<'_, Database>, - body: Ruma, + body: Ruma, ) -> ConduitResult { let sender_id = body.sender_id.as_ref().expect("user is authenticated"); @@ -75,18 +75,18 @@ pub fn get_context_route( .map(|(_, pdu)| pdu.to_room_event()) .collect::>(); - Ok(get_context::Response { - start: start_token, - end: end_token, - events_before, - event: Some(base_event), - events_after, - state: db // TODO: State at event - .rooms - .room_state_full(&body.room_id)? - .values() - .map(|pdu| pdu.to_state_event()) - .collect(), - } - .into()) + let mut resp = get_context::Response::new(); + resp.start = start_token; + resp.end = end_token; + resp.events_before = events_before; + resp.event = Some(base_event); + resp.events_after = events_after; + resp.state = db // TODO: State at event + .rooms + .room_state_full(&body.room_id)? + .values() + .map(|pdu| pdu.to_state_event()) + .collect(); + + Ok(resp.into()) } diff --git a/src/client_server/device.rs b/src/client_server/device.rs index 379f827..89033f0 100644 --- a/src/client_server/device.rs +++ b/src/client_server/device.rs @@ -37,7 +37,7 @@ pub fn get_devices_route( )] pub fn get_device_route( db: State<'_, Database>, - body: Ruma, + body: Ruma, _device_id: String, ) -> ConduitResult { let sender_id = body.sender_id.as_ref().expect("user is authenticated"); @@ -56,7 +56,7 @@ pub fn get_device_route( )] pub fn update_device_route( db: State<'_, Database>, - body: Ruma, + body: Ruma, _device_id: String, ) -> ConduitResult { let sender_id = body.sender_id.as_ref().expect("user is authenticated"); @@ -80,7 +80,7 @@ pub fn update_device_route( )] pub fn delete_device_route( db: State<'_, Database>, - body: Ruma, + body: Ruma, _device_id: String, ) -> ConduitResult { let sender_id = body.sender_id.as_ref().expect("user is authenticated"); @@ -127,7 +127,7 @@ pub fn delete_device_route( )] pub fn delete_devices_route( db: State<'_, Database>, - body: Ruma, + body: Ruma, ) -> ConduitResult { let sender_id = body.sender_id.as_ref().expect("user is authenticated"); let device_id = body.device_id.as_ref().expect("user is authenticated"); diff --git a/src/client_server/directory.rs b/src/client_server/directory.rs index 26188f7..0aace15 100644 --- a/src/client_server/directory.rs +++ b/src/client_server/directory.rs @@ -6,7 +6,7 @@ use ruma::{ error::ErrorKind, r0::{ directory::{ - self, get_public_rooms, get_public_rooms_filtered, get_room_visibility, + get_public_rooms, get_public_rooms_filtered, get_room_visibility, set_room_visibility, }, room, @@ -14,6 +14,7 @@ use ruma::{ }, federation, }, + directory::PublicRoomsChunk, events::{ room::{avatar, canonical_alias, guest_access, history_visibility, name, topic}, EventType, @@ -35,15 +36,15 @@ pub async fn get_public_rooms_filtered_route( if let Some(other_server) = body .server .clone() - .filter(|server| server != &db.globals.server_name().as_str()) + .filter(|server| server != db.globals.server_name().as_str()) { let response = server_server::send_request( &db, other_server, federation::directory::get_public_rooms::v1::Request { limit: body.limit, - since: body.since.clone(), - room_network: federation::directory::get_public_rooms::v1::RoomNetwork::Matrix, + since: body.since.as_deref(), + room_network: ruma::directory::RoomNetwork::Matrix, }, ) .await?; @@ -107,7 +108,7 @@ pub async fn get_public_rooms_filtered_route( // TODO: Do not load full state? let state = db.rooms.room_state_full(&room_id)?; - let chunk = directory::PublicRoomsChunk { + let chunk = PublicRoomsChunk { aliases: Vec::new(), canonical_alias: state .get(&(EventType::RoomCanonicalAlias, "".to_owned())) @@ -272,7 +273,7 @@ pub async fn get_public_rooms_route( body: get_public_rooms_filtered::IncomingRequest { filter: None, limit, - room_network: get_public_rooms_filtered::RoomNetwork::Matrix, + room_network: ruma::directory::RoomNetwork::Matrix, server, since, }, diff --git a/src/client_server/filter.rs b/src/client_server/filter.rs index 165419a..4322de3 100644 --- a/src/client_server/filter.rs +++ b/src/client_server/filter.rs @@ -7,23 +7,18 @@ use rocket::{get, post}; #[cfg_attr(feature = "conduit_bin", get("/_matrix/client/r0/user/<_>/filter/<_>"))] pub fn get_filter_route() -> ConduitResult { // TODO - Ok(get_filter::Response { - filter: filter::FilterDefinition { - event_fields: None, - event_format: None, - account_data: None, - room: None, - presence: None, - }, - } + Ok(get_filter::Response::new(filter::FilterDefinition { + event_fields: None, + event_format: None, + account_data: None, + room: None, + presence: None, + }) .into()) } #[cfg_attr(feature = "conduit_bin", post("/_matrix/client/r0/user/<_>/filter"))] pub fn create_filter_route() -> ConduitResult { // TODO - Ok(create_filter::Response { - filter_id: utils::random_string(10), - } - .into()) + Ok(create_filter::Response::new(utils::random_string(10)).into()) } diff --git a/src/client_server/keys.rs b/src/client_server/keys.rs index f88878c..3311529 100644 --- a/src/client_server/keys.rs +++ b/src/client_server/keys.rs @@ -167,7 +167,7 @@ pub fn claim_keys_route( )] pub fn upload_signing_keys_route( db: State<'_, Database>, - body: Ruma, + body: Ruma, ) -> ConduitResult { let sender_id = body.sender_id.as_ref().expect("user is authenticated"); let device_id = body.device_id.as_ref().expect("user is authenticated"); diff --git a/src/client_server/media.rs b/src/client_server/media.rs index efcb3a6..79c1f08 100644 --- a/src/client_server/media.rs +++ b/src/client_server/media.rs @@ -53,7 +53,7 @@ pub fn create_content_route( )] pub fn get_content_route( db: State<'_, Database>, - body: Ruma, + body: Ruma, _server_name: String, _media_id: String, ) -> ConduitResult { @@ -85,7 +85,7 @@ pub fn get_content_route( )] pub fn get_content_thumbnail_route( db: State<'_, Database>, - body: Ruma, + body: Ruma, _server_name: String, _media_id: String, ) -> ConduitResult { diff --git a/src/client_server/membership.rs b/src/client_server/membership.rs index 84c0ebd..c04cf7f 100644 --- a/src/client_server/membership.rs +++ b/src/client_server/membership.rs @@ -20,6 +20,8 @@ use ruma::{ events::{room::member, EventType}, EventId, Raw, RoomId, RoomVersionId, }; +use state_res::StateEvent; + use std::{collections::BTreeMap, convert::TryFrom}; #[cfg(feature = "conduit_bin")] @@ -92,17 +94,73 @@ pub async fn join_room_by_id_route( let send_join_response = server_server::send_request( &db, body.room_id.server_name().to_string(), - federation::membership::create_join_event::v2::Request { + federation::membership::create_join_event::v1::Request { room_id: body.room_id.clone(), event_id, - pdu_stub: serde_json::from_value::>(join_event_stub_value) + pdu_stub: serde_json::from_value(join_event_stub_value) .expect("Raw::from_value always works"), }, ) .await?; - dbg!(send_join_response); - todo!("Take send_join_response and 'create' the room using that data"); + dbg!(&send_join_response); + // todo!("Take send_join_response and 'create' the room using that data"); + + let mut event_map = send_join_response + .room_state + .state + .iter() + .map(|pdu| pdu.deserialize().map(StateEvent::Full)) + .map(|ev| { + let ev = ev?; + Ok::<_, serde_json::Error>((ev.event_id(), ev)) + }) + .collect::, _>>() + .map_err(|_| Error::bad_database("Invalid PDU found in db."))?; + + let _auth_chain = send_join_response + .room_state + .auth_chain + .iter() + .flat_map(|pdu| pdu.deserialize().ok()) + .map(StateEvent::Full) + .collect::>(); + + // TODO make StateResolution's methods free functions ? or no self param ? + let sorted_events_ids = state_res::StateResolution::default() + .reverse_topological_power_sort( + &body.room_id, + &event_map.keys().cloned().collect::>(), + &mut event_map, + &db.rooms, + &[], // TODO auth_diff: is this none since we have a set of resolved events we only want to sort + ); + + for ev_id in &sorted_events_ids { + // this is a `state_res::StateEvent` that holds a `ruma::Pdu` + let pdu = event_map.get(ev_id).ok_or_else(|| { + Error::Conflict("Found event_id in sorted events that is not in resolved state") + })?; + + db.rooms.append_pdu( + PduBuilder { + room_id: pdu.room_id().unwrap_or(&body.room_id).clone(), + sender: pdu.sender().clone(), + event_type: pdu.kind(), + content: pdu.content().clone(), + unsigned: Some( + pdu.unsigned() + .iter() + .map(|(k, v)| (k.clone(), v.clone())) + .collect(), + ), + state_key: pdu.state_key(), + redacts: pdu.redacts().cloned(), + }, + &db.globals, + &db.account_data, + )?; + } } let event = member::MemberEventContent { @@ -127,10 +185,7 @@ pub async fn join_room_by_id_route( &db.account_data, )?; - Ok(join_room_by_id::Response { - room_id: body.room_id.clone(), - } - .into()) + Ok(join_room_by_id::Response::new(body.room_id.clone()).into()) } #[cfg_attr( @@ -140,7 +195,7 @@ pub async fn join_room_by_id_route( pub async fn join_room_by_id_or_alias_route( db: State<'_, Database>, db2: State<'_, Database>, - body: Ruma, + body: Ruma, ) -> ConduitResult { let room_id = match RoomId::try_from(body.room_id_or_alias.clone()) { Ok(room_id) => room_id, @@ -148,7 +203,13 @@ pub async fn join_room_by_id_or_alias_route( client_server::get_alias_route( db, Ruma { - body: alias::get_alias::IncomingRequest { room_alias }, + body: alias::get_alias::IncomingRequest::try_from(http::Request::new( + serde_json::json!({ "room_alias": room_alias }) + .to_string() + .as_bytes() + .to_vec(), + )) + .unwrap(), sender_id: body.sender_id.clone(), device_id: body.device_id.clone(), json_body: None, @@ -160,14 +221,32 @@ pub async fn join_room_by_id_or_alias_route( } }; + // TODO ruma needs to implement the same constructors for the Incoming variants + let tps = if let Some(in_tps) = &body.third_party_signed { + Some(ruma::api::client::r0::membership::ThirdPartySigned { + token: &in_tps.token, + sender: &in_tps.sender, + signatures: in_tps.signatures.clone(), + mxid: &in_tps.mxid, + }) + } else { + None + }; + let body = Ruma { sender_id: body.sender_id.clone(), device_id: body.device_id.clone(), json_body: None, - body: join_room_by_id::IncomingRequest { - room_id, - third_party_signed: body.third_party_signed.clone(), - }, + body: join_room_by_id::IncomingRequest::try_from(http::Request::new( + serde_json::json!({ + "room_id": room_id, + "third_party_signed": tps, + }) + .to_string() + .as_bytes() + .to_vec(), + )) + .unwrap(), }; Ok(join_room_by_id_or_alias::Response { @@ -219,7 +298,7 @@ pub fn leave_room_route( &db.account_data, )?; - Ok(leave_room::Response.into()) + Ok(leave_room::Response::new().into()) } #[cfg_attr( @@ -266,7 +345,7 @@ pub fn invite_user_route( )] pub fn kick_user_route( db: State<'_, Database>, - body: Ruma, + body: Ruma, ) -> ConduitResult { let sender_id = body.sender_id.as_ref().expect("user is authenticated"); @@ -304,7 +383,7 @@ pub fn kick_user_route( &db.account_data, )?; - Ok(kick_user::Response.into()) + Ok(kick_user::Response::new().into()) } #[cfg_attr( @@ -313,7 +392,7 @@ pub fn kick_user_route( )] pub fn ban_user_route( db: State<'_, Database>, - body: Ruma, + body: Ruma, ) -> ConduitResult { let sender_id = body.sender_id.as_ref().expect("user is authenticated"); @@ -359,7 +438,7 @@ pub fn ban_user_route( &db.account_data, )?; - Ok(ban_user::Response.into()) + Ok(ban_user::Response::new().into()) } #[cfg_attr( @@ -368,7 +447,7 @@ pub fn ban_user_route( )] pub fn unban_user_route( db: State<'_, Database>, - body: Ruma, + body: Ruma, ) -> ConduitResult { let sender_id = body.sender_id.as_ref().expect("user is authenticated"); @@ -405,7 +484,7 @@ pub fn unban_user_route( &db.account_data, )?; - Ok(unban_user::Response.into()) + Ok(unban_user::Response::new().into()) } #[cfg_attr( @@ -414,13 +493,13 @@ pub fn unban_user_route( )] pub fn forget_room_route( db: State<'_, Database>, - body: Ruma, + body: Ruma, ) -> ConduitResult { let sender_id = body.sender_id.as_ref().expect("user is authenticated"); db.rooms.forget(&body.room_id, &sender_id)?; - Ok(forget_room::Response.into()) + Ok(forget_room::Response::new().into()) } #[cfg_attr( diff --git a/src/client_server/message.rs b/src/client_server/message.rs index d851214..1b461d2 100644 --- a/src/client_server/message.rs +++ b/src/client_server/message.rs @@ -1,8 +1,11 @@ use super::State; use crate::{pdu::PduBuilder, ConduitResult, Database, Error, Ruma}; -use ruma::api::client::{ - error::ErrorKind, - r0::message::{get_message_events, send_message_event}, +use ruma::{ + api::client::{ + error::ErrorKind, + r0::message::{get_message_events, send_message_event}, + }, + events::EventContent, }; use std::convert::TryInto; @@ -26,7 +29,7 @@ pub fn send_message_event_route( PduBuilder { room_id: body.room_id.clone(), sender: sender_id.clone(), - event_type: body.event_type.clone(), + event_type: body.content.event_type().into(), content: serde_json::from_str( body.json_body .ok_or(Error::BadRequest(ErrorKind::BadJson, "Invalid JSON body."))? @@ -41,7 +44,7 @@ pub fn send_message_event_route( &db.account_data, )?; - Ok(send_message_event::Response { event_id }.into()) + Ok(send_message_event::Response::new(event_id).into()) } #[cfg_attr( @@ -50,7 +53,7 @@ pub fn send_message_event_route( )] pub fn get_message_events_route( db: State<'_, Database>, - body: Ruma, + body: Ruma, ) -> ConduitResult { let sender_id = body.sender_id.as_ref().expect("user is authenticated"); @@ -92,13 +95,13 @@ pub fn get_message_events_route( .map(|(_, pdu)| pdu.to_room_event()) .collect::>(); - Ok(get_message_events::Response { - start: Some(body.from.clone()), - end: end_token, - chunk: events_after, - state: Vec::new(), - } - .into()) + let mut resp = get_message_events::Response::new(); + resp.start = Some(body.from.clone()); + resp.end = end_token; + resp.chunk = events_after; + resp.state = Vec::new(); + + Ok(resp.into()) } get_message_events::Direction::Backward => { let events_before = db @@ -116,13 +119,13 @@ pub fn get_message_events_route( .map(|(_, pdu)| pdu.to_room_event()) .collect::>(); - Ok(get_message_events::Response { - start: Some(body.from.clone()), - end: start_token, - chunk: events_before, - state: Vec::new(), - } - .into()) + let mut resp = get_message_events::Response::new(); + resp.start = Some(body.from.clone()); + resp.end = start_token; + resp.chunk = events_before; + resp.state = Vec::new(); + + Ok(resp.into()) } } } diff --git a/src/client_server/room.rs b/src/client_server/room.rs index b5f1529..589a2dc 100644 --- a/src/client_server/room.rs +++ b/src/client_server/room.rs @@ -315,7 +315,7 @@ pub fn create_room_route( db.rooms.set_public(&room_id, true)?; } - Ok(create_room::Response { room_id }.into()) + Ok(create_room::Response::new(room_id).into()) } #[cfg_attr( diff --git a/src/client_server/session.rs b/src/client_server/session.rs index 4011058..948b455 100644 --- a/src/client_server/session.rs +++ b/src/client_server/session.rs @@ -1,5 +1,4 @@ -use super::State; -use super::{DEVICE_ID_LENGTH, TOKEN_LENGTH}; +use super::{State, DEVICE_ID_LENGTH, TOKEN_LENGTH}; use crate::{utils, ConduitResult, Database, Error, Ruma}; use ruma::{ api::client::{ @@ -18,10 +17,7 @@ use rocket::{get, post}; /// when logging in. #[cfg_attr(feature = "conduit_bin", get("/_matrix/client/r0/login"))] pub fn get_login_types_route() -> ConduitResult { - Ok(get_login_types::Response { - flows: vec![get_login_types::LoginType::Password], - } - .into()) + Ok(get_login_types::Response::new(vec![get_login_types::LoginType::Password]).into()) } /// # `POST /_matrix/client/r0/login` @@ -40,15 +36,15 @@ pub fn get_login_types_route() -> ConduitResult { )] pub fn login_route( db: State<'_, Database>, - body: Ruma, + body: Ruma, ) -> ConduitResult { // Validate login method let user_id = // TODO: Other login methods - if let (login::UserInfo::MatrixId(username), login::LoginInfo::Password { password }) = - (body.user.clone(), body.login_info.clone()) + if let (login::IncomingUserInfo::MatrixId(username), login::IncomingLoginInfo::Password { password }) = + (&body.user, &body.login_info) { - let user_id = UserId::parse_with_server_name(username, db.globals.server_name()) + let user_id = UserId::parse_with_server_name(username.to_string(), db.globals.server_name()) .map_err(|_| Error::BadRequest( ErrorKind::InvalidUsername, "Username is invalid." @@ -126,7 +122,7 @@ pub fn logout_route( db.users.remove_device(&sender_id, device_id)?; - Ok(logout::Response.into()) + Ok(logout::Response::new().into()) } /// # `POST /_matrix/client/r0/logout/all` @@ -154,5 +150,5 @@ pub fn logout_all_route( } } - Ok(logout_all::Response.into()) + Ok(logout_all::Response::new().into()) } diff --git a/src/client_server/state.rs b/src/client_server/state.rs index 60b3e9f..14cc497 100644 --- a/src/client_server/state.rs +++ b/src/client_server/state.rs @@ -8,9 +8,9 @@ use ruma::{ send_state_event_for_empty_key, send_state_event_for_key, }, }, - events::{room::canonical_alias, EventType}, - Raw, + events::{AnyStateEventContent, EventContent}, }; +use std::convert::TryFrom; #[cfg(feature = "conduit_bin")] use rocket::{get, put}; @@ -33,17 +33,10 @@ pub fn send_state_event_for_key_route( ) .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Invalid JSON body."))?; - if body.event_type == EventType::RoomCanonicalAlias { - let canonical_alias = serde_json::from_value::< - Raw, - >(content.clone()) - .expect("from_value::> can never fail") - .deserialize() - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid canonical alias."))?; + if let AnyStateEventContent::RoomCanonicalAlias(canonical_alias) = &body.content { + let mut aliases = canonical_alias.alt_aliases.clone(); - let mut aliases = canonical_alias.alt_aliases; - - if let Some(alias) = canonical_alias.alias { + if let Some(alias) = canonical_alias.alias.clone() { aliases.push(alias); } @@ -68,7 +61,7 @@ pub fn send_state_event_for_key_route( PduBuilder { room_id: body.room_id.clone(), sender: sender_id.clone(), - event_type: body.event_type.clone(), + event_type: body.content.event_type().into(), content, unsigned: None, state_key: Some(body.state_key.clone()), @@ -78,7 +71,7 @@ pub fn send_state_event_for_key_route( &db.account_data, )?; - Ok(send_state_event_for_key::Response { event_id }.into()) + Ok(send_state_event_for_key::Response::new(event_id).into()) } #[cfg_attr( @@ -93,25 +86,28 @@ pub fn send_state_event_for_empty_key_route( let Ruma { body: send_state_event_for_empty_key::IncomingRequest { - room_id, - event_type, - data, + room_id, content, .. }, sender_id, device_id, json_body, } = body; - Ok(send_state_event_for_empty_key::Response { - event_id: send_state_event_for_key_route( + Ok(send_state_event_for_empty_key::Response::new( + send_state_event_for_key_route( db, Ruma { - body: send_state_event_for_key::IncomingRequest { - room_id, - event_type, - data, - state_key: "".to_owned(), - }, + body: send_state_event_for_key::IncomingRequest::try_from(http::Request::new( + serde_json::json!({ + "room_id": room_id, + "state_key": "", + "content": content, + }) + .to_string() + .as_bytes() + .to_vec(), + )) + .unwrap(), sender_id, device_id, json_body, @@ -119,7 +115,7 @@ pub fn send_state_event_for_empty_key_route( )? .0 .event_id, - } + ) .into()) } diff --git a/src/client_server/sync.rs b/src/client_server/sync.rs index 2307f02..ae4c224 100644 --- a/src/client_server/sync.rs +++ b/src/client_server/sync.rs @@ -31,7 +31,7 @@ use std::{ )] pub async fn sync_events_route( db: State<'_, Database>, - body: Ruma, + body: Ruma, ) -> ConduitResult { let sender_id = body.sender_id.as_ref().expect("user is authenticated"); let device_id = body.device_id.as_ref().expect("user is authenticated"); diff --git a/src/client_server/unversioned.rs b/src/client_server/unversioned.rs index 3ff8bec..ea7f633 100644 --- a/src/client_server/unversioned.rs +++ b/src/client_server/unversioned.rs @@ -1,6 +1,5 @@ use crate::ConduitResult; use ruma::api::client::unversioned::get_supported_versions; -use std::collections::BTreeMap; #[cfg(feature = "conduit_bin")] use rocket::get; @@ -17,13 +16,11 @@ use rocket::get; /// unstable features in their stable releases #[cfg_attr(feature = "conduit_bin", get("/_matrix/client/versions"))] pub fn get_supported_versions_route() -> ConduitResult { - let mut unstable_features = BTreeMap::new(); + let mut resp = + get_supported_versions::Response::new(vec!["r0.5.0".to_owned(), "r0.6.0".to_owned()]); - unstable_features.insert("org.matrix.e2e_cross_signing".to_owned(), true); + resp.unstable_features + .insert("org.matrix.e2e_cross_signing".to_owned(), true); - Ok(get_supported_versions::Response { - versions: vec!["r0.5.0".to_owned(), "r0.6.0".to_owned()], - unstable_features, - } - .into()) + Ok(resp.into()) } diff --git a/src/database.rs b/src/database.rs index 7bbb6dd..6cd65c3 100644 --- a/src/database.rs +++ b/src/database.rs @@ -97,8 +97,8 @@ impl Database { }, pduid_pdu: db.open_tree("pduid_pdu")?, eventid_pduid: db.open_tree("eventid_pduid")?, + roomstateid_pduid: db.open_tree("roomstateid_pduid")?, roomid_pduleaves: db.open_tree("roomid_pduleaves")?, - roomstateid_pdu: db.open_tree("roomstateid_pdu")?, alias_roomid: db.open_tree("alias_roomid")?, aliasid_alias: db.open_tree("alias_roomid")?, @@ -111,6 +111,9 @@ impl Database { userroomid_invited: db.open_tree("userroomid_invited")?, roomuserid_invited: db.open_tree("roomuserid_invited")?, userroomid_left: db.open_tree("userroomid_left")?, + + stateid_pduid: db.open_tree("stateid_pduid")?, + pduid_statehash: db.open_tree("pduid_statehash")?, }, account_data: account_data::AccountData { roomuserdataid_accountdata: db.open_tree("roomuserdataid_accountdata")?, diff --git a/src/database/rooms.rs b/src/database/rooms.rs index d2cd5e9..0d36326 100644 --- a/src/database/rooms.rs +++ b/src/database/rooms.rs @@ -9,7 +9,7 @@ use ruma::{ events::{ ignored_user_list, room::{ - join_rules, member, + member, power_levels::{self, PowerLevelsEventContent}, redaction, }, @@ -18,19 +18,31 @@ use ruma::{ EventId, Raw, RoomAliasId, RoomId, UserId, }; use sled::IVec; +use state_res::{event_auth, Requester, StateEvent, StateMap, StateStore}; + use std::{ - collections::{BTreeMap, HashMap}, + collections::{hash_map::DefaultHasher, BTreeMap, HashMap}, convert::{TryFrom, TryInto}, + hash::{Hash, Hasher}, mem, + result::Result as StdResult, }; +/// The unique identifier of each state group. +/// +/// This is created when a state group is added to the database by +/// hashing the entire state. +pub type StateHashId = String; + +/// This identifier consists of roomId + count. It represents a +/// unique event, it will never be overwritten or removed. +pub type PduId = IVec; + pub struct Rooms { pub edus: edus::RoomEdus, pub(super) pduid_pdu: sled::Tree, // PduId = RoomId + Count pub(super) eventid_pduid: sled::Tree, pub(super) roomid_pduleaves: sled::Tree, - pub(super) roomstateid_pdu: sled::Tree, // RoomStateId = Room + StateType + StateKey - pub(super) alias_roomid: sled::Tree, pub(super) aliasid_alias: sled::Tree, // AliasId = RoomId + Count pub(super) publicroomids: sled::Tree, @@ -42,9 +54,263 @@ pub struct Rooms { pub(super) userroomid_invited: sled::Tree, pub(super) roomuserid_invited: sled::Tree, pub(super) userroomid_left: sled::Tree, + + // STATE TREES + /// This holds the full current state, including the latest event. + pub(super) roomstateid_pduid: sled::Tree, // RoomStateId = Room + StateType + StateKey + /// This holds the full room state minus the latest event. + pub(super) pduid_statehash: sled::Tree, // PDU id -> StateHash + /// Also holds the full room state minus the latest event. + pub(super) stateid_pduid: sled::Tree, // StateId = StateHash + (EventType, StateKey) } +impl StateStore for Rooms { + fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> StdResult { + let pid = self + .eventid_pduid + .get(event_id.as_bytes()) + .map_err(|e| e.to_string())? + .ok_or_else(|| "PDU via room_id and event_id not found in the db.".to_owned())?; + + utils::deserialize( + &self + .pduid_pdu + .get(pid) + .map_err(|e| e.to_string())? + .ok_or_else(|| "PDU via pduid not found in db.".to_owned())?, + ) + .and_then(|pdu: StateEvent| { + // conduit's PDU's always contain a room_id but some + // of ruma's do not so this must be an Option + if pdu.room_id() == Some(room_id) { + Ok(pdu) + } else { + Err(Error::bad_database("Found PDU for incorrect room in db.")) + } + }) + .map_err(|e| e.to_string()) + } +} + +// These are the methods related to STATE resolution. impl Rooms { + /// Generates a new StateHash and associates it with the incoming event. + /// + /// This adds all current state events (not including the incoming event) + /// to `stateid_pduid` and adds the incoming event to `pduid_statehash`. + /// The incoming event is the `pdu_id` passed to this method. + pub fn append_state_pdu(&self, room_id: &RoomId, pdu_id: &[u8]) -> Result { + let state_hash = self.new_state_hash_id(room_id)?; + let state = self.current_state_pduids(room_id)?; + + let mut key = state_hash.as_bytes().to_vec(); + key.push(0xff); + + // TODO eventually we could avoid writing to the DB so much on every event + // by keeping track of the delta and write that every so often + for ((ev_ty, state_key), pid) in state { + let mut state_id = key.to_vec(); + state_id.extend_from_slice(ev_ty.to_string().as_bytes()); + key.push(0xff); + state_id.extend_from_slice(state_key.expect("state event").as_bytes()); + key.push(0xff); + + self.stateid_pduid.insert(&state_id, &pid)?; + } + + // This event's state does not include the event itself. `current_state_pduids` + // uses `roomstateid_pduid` before the current event is inserted to the tree so the state + // will be everything up to but not including the incoming event. + self.pduid_statehash.insert(pdu_id, state_hash.as_bytes())?; + + Ok(state_hash) + } + + /// Builds a `StateMap` by iterating over all keys that start + /// with `state_hash`, this gives the full state at event "x". + pub fn get_statemap_by_hash(&self, state_hash: StateHashId) -> Result> { + self.stateid_pduid + .scan_prefix(state_hash.as_bytes()) + .values() + .map(|pduid| { + self.pduid_pdu.get(&pduid?)?.map_or_else( + || Err(Error::bad_database("Failed to find StateMap.")), + |b| { + serde_json::from_slice::(&b) + .map_err(|_| Error::bad_database("Invalid PDU in db.")) + }, + ) + }) + .map(|pdu| { + let pdu = pdu?; + Ok(((pdu.kind, pdu.state_key), pdu.event_id)) + }) + .collect::>>() + } + + // TODO make this return Result + /// Fetches the previous StateHash ID to `current`. + pub fn prev_state_hash(&self, current: StateHashId) -> Option { + let mut found = false; + for pair in self.pduid_statehash.iter().rev() { + let prev = utils::string_from_bytes(&pair.ok()?.1).ok()?; + if current == prev { + found = true; + } + if current != prev && found { + return Some(prev); + } + } + None + } + + /// Fetch the current State using the `roomstateid_pduid` tree. + pub fn current_state_pduids(&self, room_id: &RoomId) -> Result> { + // TODO this could also scan roomstateid_pduid if we passed in room_id ? + self.roomstateid_pduid + .scan_prefix(room_id.as_bytes()) + .values() + .map(|pduid| { + let pduid = &pduid?; + self.pduid_pdu.get(pduid)?.map_or_else( + || { + Err(Error::bad_database( + "Failed to find current state of pduid's.", + )) + }, + |b| { + Ok(( + serde_json::from_slice::(&b) + .map_err(|_| Error::bad_database("Invalid PDU in db."))?, + pduid.clone(), + )) + }, + ) + }) + .map(|pair| { + let (pdu, id) = pair?; + Ok(((pdu.kind, pdu.state_key), id)) + }) + .collect::>>() + } + + /// Returns the last state hash key added to the db. + pub fn current_state_hash(&self, room_id: &RoomId) -> Result { + let mut prefix = room_id.as_bytes().to_vec(); + prefix.push(0xff); + + // We must check here because this method is called outside and before + // `append_state_pdu` so the DB can be empty + if self.pduid_statehash.scan_prefix(prefix).next().is_none() { + // TODO use ring crate to hash + return Ok(room_id.as_str().to_owned()); + } + + self.pduid_statehash + .iter() + .next_back() + .map(|pair| { + utils::string_from_bytes(&pair?.1) + .map_err(|_| Error::bad_database("Invalid state hash string in db.")) + }) + .ok_or_else(|| Error::bad_database("No PDU's found for this room."))? + } + + /// This fetches auth event_ids from the current state using the + /// full `roomstateid_pdu` tree. + pub fn get_auth_event_ids( + &self, + room_id: &RoomId, + kind: &EventType, + sender: &UserId, + state_key: Option<&str>, + content: serde_json::Value, + ) -> Result> { + let auth_events = state_res::auth_types_for_event( + kind.clone(), + sender, + state_key.map(|s| s.to_string()), + content, + ); + + let mut events = vec![]; + for (event_type, state_key) in auth_events { + if let Some(state_key) = state_key.as_ref() { + if let Some(id) = self.room_state_get(room_id, &event_type, state_key)? { + events.push(id.event_id); + } + } + } + Ok(events) + } + + // This fetches auth events from the current state using the + /// full `roomstateid_pdu` tree. + pub fn get_auth_events( + &self, + room_id: &RoomId, + kind: &EventType, + sender: &UserId, + state_key: Option<&str>, + content: serde_json::Value, + ) -> Result> { + let auth_events = state_res::auth_types_for_event( + kind.clone(), + sender, + state_key.map(|s| s.to_string()), + content, + ); + + let mut events = StateMap::new(); + for (event_type, state_key) in auth_events { + if let Some(s_key) = state_key.as_ref() { + if let Some(pdu) = self.room_state_get(room_id, &event_type, s_key)? { + events.insert((event_type, state_key), pdu); + } + } + } + Ok(events) + } + + /// Generate a new StateHash. + /// + /// A unique hash made from hashing the current states pduid's. + /// Because `append_state_pdu` handles the empty state db case it does not + /// have to be here. + fn new_state_hash_id(&self, room_id: &RoomId) -> Result { + // Use hashed roomId as the first StateHash key for first state event in room + if self + .pduid_statehash + .scan_prefix(room_id.as_bytes()) + .next() + .is_none() + { + // TODO use ring crate to hash + return Ok(room_id.as_str().to_owned()); + } + + let pdu_ids_to_hash = self + .pduid_statehash + .scan_prefix(room_id.as_bytes()) + .values() + .next_back() + .unwrap() // We just checked if the tree was empty + .map(|hash| { + self.stateid_pduid + .scan_prefix(hash) + .values() + // pduid is roomId + count so just hash the whole thing + .map(|pid| Ok(pid?.to_vec())) + .collect::>>>() + })??; + + let mut hasher = DefaultHasher::new(); + pdu_ids_to_hash.hash(&mut hasher); + let hash = hasher.finish().to_string(); + // TODO not sure how you want to hash this + Ok(hash) + } + /// Checks if a room exists. pub fn exists(&self, room_id: &RoomId) -> Result { let mut prefix = room_id.to_string().as_bytes().to_vec(); @@ -64,16 +330,20 @@ impl Rooms { room_id: &RoomId, ) -> Result> { let mut hashmap = HashMap::new(); - for pdu in self - .roomstateid_pdu - .scan_prefix(&room_id.to_string().as_bytes()) - .values() - .map(|value| { - Ok::<_, Error>( - serde_json::from_slice::(&value?) + for pdu in + self.roomstateid_pduid + .scan_prefix(&room_id.to_string().as_bytes()) + .values() + .map(|value| { + Ok::<_, Error>( + serde_json::from_slice::( + &self.pduid_pdu.get(value?)?.ok_or_else(|| { + Error::bad_database("PDU not found for ID in db.") + })?, + ) .map_err(|_| Error::bad_database("Invalid PDU in db."))?, - ) - }) + ) + }) { let pdu = pdu?; let state_key = pdu.state_key.clone().ok_or_else(|| { @@ -95,16 +365,20 @@ impl Rooms { prefix.extend_from_slice(&event_type.to_string().as_bytes()); let mut hashmap = HashMap::new(); - for pdu in self - .roomstateid_pdu - .scan_prefix(&prefix) - .values() - .map(|value| { - Ok::<_, Error>( - serde_json::from_slice::(&value?) + for pdu in + self.roomstateid_pduid + .scan_prefix(&prefix) + .values() + .map(|value| { + Ok::<_, Error>( + serde_json::from_slice::( + &self.pduid_pdu.get(value?)?.ok_or_else(|| { + Error::bad_database("PDU not found for ID in db.") + })?, + ) .map_err(|_| Error::bad_database("Invalid PDU in db."))?, - ) - }) + ) + }) { let pdu = pdu?; let state_key = pdu.state_key.clone().ok_or_else(|| { @@ -115,23 +389,28 @@ impl Rooms { Ok(hashmap) } - /// Returns the full room state. + /// Returns a single PDU in `room_id` with key (`event_type`, `state_key`). pub fn room_state_get( &self, room_id: &RoomId, event_type: &EventType, state_key: &str, ) -> Result> { - let mut key = room_id.to_string().as_bytes().to_vec(); + let mut key = room_id.as_bytes().to_vec(); key.push(0xff); key.extend_from_slice(&event_type.to_string().as_bytes()); key.push(0xff); key.extend_from_slice(&state_key.as_bytes()); - self.roomstateid_pdu.get(&key)?.map_or(Ok(None), |value| { + self.roomstateid_pduid.get(&key)?.map_or(Ok(None), |value| { Ok::<_, Error>(Some( - serde_json::from_slice::(&value) - .map_err(|_| Error::bad_database("Invalid PDU in db."))?, + serde_json::from_slice::( + &self + .pduid_pdu + .get(value)? + .ok_or_else(|| Error::bad_database("PDU not found for ID in db."))?, + ) + .map_err(|_| Error::bad_database("Invalid PDU in db."))?, )) }) } @@ -139,7 +418,7 @@ impl Rooms { /// Returns the `count` of this pdu's id. pub fn get_pdu_count(&self, event_id: &EventId) -> Result> { self.eventid_pduid - .get(event_id.to_string().as_bytes())? + .get(event_id.as_bytes())? .map_or(Ok(None), |pdu_id| { Ok(Some( utils::u64_from_bytes( @@ -153,7 +432,7 @@ impl Rooms { /// Returns the json of a pdu. pub fn get_pdu_json(&self, event_id: &EventId) -> Result> { self.eventid_pduid - .get(event_id.to_string().as_bytes())? + .get(event_id.as_bytes())? .map_or(Ok(None), |pdu_id| { Ok(Some( serde_json::from_slice(&self.pduid_pdu.get(pdu_id)?.ok_or_else(|| { @@ -174,7 +453,7 @@ impl Rooms { /// Returns the pdu. pub fn get_pdu(&self, event_id: &EventId) -> Result> { self.eventid_pduid - .get(event_id.to_string().as_bytes())? + .get(event_id.as_bytes())? .map_or(Ok(None), |pdu_id| { Ok(Some( serde_json::from_slice(&self.pduid_pdu.get(pdu_id)?.ok_or_else(|| { @@ -238,16 +517,15 @@ impl Rooms { /// Replace the leaves of a room with a new event. pub fn replace_pdu_leaves(&self, room_id: &RoomId, event_id: &EventId) -> Result<()> { - let mut prefix = room_id.to_string().as_bytes().to_vec(); + let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xff); for key in self.roomid_pduleaves.scan_prefix(&prefix).keys() { self.roomid_pduleaves.remove(key?)?; } - prefix.extend_from_slice(event_id.to_string().as_bytes()); - self.roomid_pduleaves - .insert(&prefix, &*event_id.to_string())?; + prefix.extend_from_slice(event_id.as_bytes()); + self.roomid_pduleaves.insert(&prefix, event_id.as_bytes())?; Ok(()) } @@ -272,6 +550,14 @@ impl Rooms { // TODO: Make sure this isn't called twice in parallel let prev_events = self.get_pdu_leaves(&room_id)?; + let auth_events = self.get_auth_events( + &room_id, + &event_type, + &sender, + state_key.as_deref(), + content.clone(), + )?; + // Is the event authorized? if let Some(state_key) = &state_key { let power_levels = self @@ -333,138 +619,24 @@ impl Rooms { // Don't allow encryption events when it's disabled !globals.encryption_disabled() } - EventType::RoomMember => { - let target_user_id = UserId::try_from(&**state_key).map_err(|_| { - Error::BadRequest( - ErrorKind::InvalidParam, - "State key of member event does not contain user id.", - ) - })?; - - let current_membership = self - .room_state_get( - &room_id, - &EventType::RoomMember, - &target_user_id.to_string(), - )? - .map_or(Ok::<_, Error>(member::MembershipState::Leave), |pdu| { - Ok(serde_json::from_value::>( - pdu.content, - ) - .expect("Raw::from_value always works.") - .deserialize() - .map_err(|_| Error::bad_database("Invalid Member event in db."))? - .membership) - })?; - - let target_membership = - serde_json::from_value::>(content.clone()) - .expect("Raw::from_value always works.") - .deserialize() - .map_err(|_| Error::bad_database("Invalid Member event in db."))? - .membership; - - let target_power = power_levels.users.get(&target_user_id).map_or_else( - || { - if target_membership != member::MembershipState::Join { - None - } else { - Some(&power_levels.users_default) - } - }, - // If it's okay, wrap with Some(_) - Some, - ); - - let join_rules = - self.room_state_get(&room_id, &EventType::RoomJoinRules, "")? - .map_or(Ok::<_, Error>(join_rules::JoinRule::Public), |pdu| { - Ok(serde_json::from_value::< - Raw, - >(pdu.content) - .expect("Raw::from_value always works.") - .deserialize() - .map_err(|_| { - Error::bad_database("Database contains invalid JoinRules event") - })? - .join_rule) - })?; - - if target_membership == member::MembershipState::Join { - let mut prev_events = prev_events.iter(); - let prev_event = self - .get_pdu(prev_events.next().ok_or(Error::BadRequest( - ErrorKind::Unknown, - "Membership can't be the first event", - ))?)? - .ok_or_else(|| { - Error::bad_database("PDU leaf points to invalid event!") - })?; - if prev_event.kind == EventType::RoomCreate - && prev_event.prev_events.is_empty() - { - true - } else if sender != target_user_id { - false - } else if let member::MembershipState::Ban = current_membership { - false - } else { - join_rules == join_rules::JoinRule::Invite - && (current_membership == member::MembershipState::Join - || current_membership == member::MembershipState::Invite) - || join_rules == join_rules::JoinRule::Public - } - } else if target_membership == member::MembershipState::Invite { - if let Some(third_party_invite_json) = content.get("third_party_invite") { - if current_membership == member::MembershipState::Ban { - false - } else { - let _third_party_invite = - serde_json::from_value::( - third_party_invite_json.clone(), - ) - .map_err(|_| { - Error::BadRequest( - ErrorKind::InvalidParam, - "ThirdPartyInvite is invalid", - ) - })?; - todo!("handle third party invites"); - } - } else if sender_membership != member::MembershipState::Join - || current_membership == member::MembershipState::Join - || current_membership == member::MembershipState::Ban - { - false - } else { - sender_power - .filter(|&p| p >= &power_levels.invite) - .is_some() - } - } else if target_membership == member::MembershipState::Leave { - if sender == target_user_id { - current_membership == member::MembershipState::Join - || current_membership == member::MembershipState::Invite - } else if sender_membership != member::MembershipState::Join - || current_membership == member::MembershipState::Ban - && sender_power.filter(|&p| p < &power_levels.ban).is_some() - { - false - } else { - sender_power.filter(|&p| p >= &power_levels.kick).is_some() - && target_power < sender_power - } - } else if target_membership == member::MembershipState::Ban { - if sender_membership != member::MembershipState::Join { - false - } else { - sender_power.filter(|&p| p >= &power_levels.ban).is_some() - && target_power < sender_power - } - } else { - false - } - } + EventType::RoomMember => event_auth::is_membership_change_allowed( + // TODO this is a bit of a hack but not sure how to have a type + // declared in `state_res` crate be + Requester { + prev_event_ids: prev_events.to_owned(), + room_id: &room_id, + content: &content, + state_key: Some(state_key.to_owned()), + sender: &sender, + }, + &auth_events + .iter() + .map(|((ty, key), pdu)| { + Ok(((ty.clone(), key.clone()), pdu.convert_for_state_res()?)) + }) + .collect::>>()?, + ) + .ok_or(Error::Conflict("Found incoming PDU with invalid data."))?, EventType::RoomCreate => prev_events.is_empty(), // Not allow any of the following events if the sender is not joined. _ if sender_membership != member::MembershipState::Join => false, @@ -474,7 +646,7 @@ impl Rooms { >= &power_levels.state_default } } { - error!("Unauthorized"); + error!("Unauthorized {}", event_type); // Not authorized return Err(Error::BadRequest( ErrorKind::Forbidden, @@ -483,7 +655,7 @@ impl Rooms { } } else if !self.is_joined(&sender, &room_id)? { // TODO: auth rules apply to all events, not only those with a state key - error!("Unauthorized"); + error!("Unauthorized {}", event_type); return Err(Error::BadRequest( ErrorKind::Forbidden, "Event is not authorized", @@ -524,7 +696,10 @@ impl Rooms { depth: depth .try_into() .map_err(|_| Error::bad_database("Depth is invalid"))?, - auth_events: Vec::new(), + auth_events: auth_events + .into_iter() + .map(|(_, pdu)| pdu.event_id) + .collect(), redacts: redacts.clone(), unsigned, hashes: ruma::events::pdu::EventHash { @@ -564,15 +739,19 @@ impl Rooms { self.pduid_pdu.insert(&pdu_id, &*pdu_json.to_string())?; self.eventid_pduid - .insert(pdu.event_id.to_string(), pdu_id.clone())?; + .insert(pdu.event_id.to_string(), &*pdu_id)?; - if let Some(state_key) = pdu.state_key { - let mut key = room_id.to_string().as_bytes().to_vec(); + if let Some(state_key) = &pdu.state_key { + // We call this first because our StateHash relies on the + // state before the new event + self.append_state_pdu(&room_id, &pdu_id)?; + + let mut key = room_id.as_bytes().to_vec(); key.push(0xff); key.extend_from_slice(pdu.kind.to_string().as_bytes()); key.push(0xff); key.extend_from_slice(state_key.as_bytes()); - self.roomstateid_pdu.insert(key, &*pdu_json.to_string())?; + self.roomstateid_pduid.insert(key, pdu_id.as_slice())?; } match event_type { diff --git a/src/database/uiaa.rs b/src/database/uiaa.rs index cece8db..e318f43 100644 --- a/src/database/uiaa.rs +++ b/src/database/uiaa.rs @@ -2,7 +2,7 @@ use crate::{Error, Result}; use ruma::{ api::client::{ error::ErrorKind, - r0::uiaa::{AuthData, UiaaInfo}, + r0::uiaa::{IncomingAuthData, UiaaInfo}, }, DeviceId, UserId, }; @@ -26,12 +26,12 @@ impl Uiaa { &self, user_id: &UserId, device_id: &DeviceId, - auth: &AuthData, + auth: &IncomingAuthData, uiaainfo: &UiaaInfo, users: &super::users::Users, globals: &super::globals::Globals, ) -> Result<(bool, UiaaInfo)> { - if let AuthData::DirectRequest { + if let IncomingAuthData::DirectRequest { kind, session, auth_parameters, diff --git a/src/pdu.rs b/src/pdu.rs index 9936802..5485f23 100644 --- a/src/pdu.rs +++ b/src/pdu.rs @@ -177,6 +177,35 @@ impl PduEvent { } } +impl PduEvent { + pub fn convert_for_state_res(&self) -> Result { + serde_json::from_value(json!({ + "event_id": self.event_id, + "room_id": self.room_id, + "sender": self.sender, + "origin": self.origin, + "origin_server_ts": self.origin_server_ts, + "type": self.kind, + "content": self.content, + "state_key": self.state_key, + "prev_events": self.prev_events + .iter() + .map(|id| (id, EventHash { sha256: "hello".into() })) + .collect::>(), + "depth": self.depth, + "auth_events": self.auth_events + .iter() + .map(|id| (id, EventHash { sha256: "hello".into() })) + .collect::>(), + "redacts": self.redacts, + "unsigned": self.unsigned, + "hashes": self.hashes, + "signatures": self.signatures, + })) + .map_err(|_| Error::bad_database("Failed to convert PDU to ruma::Pdu type.")) + } +} + /// Build the start of a PDU in order to add it to the `Database`. #[derive(Debug)] pub struct PduBuilder { diff --git a/src/ruma_wrapper.rs b/src/ruma_wrapper.rs index 8d86204..80e6e58 100644 --- a/src/ruma_wrapper.rs +++ b/src/ruma_wrapper.rs @@ -1,5 +1,8 @@ use crate::Error; -use ruma::identifiers::{DeviceId, UserId}; +use ruma::{ + api::IncomingRequest, + identifiers::{DeviceId, UserId}, +}; use std::{convert::TryInto, ops::Deref}; #[cfg(feature = "conduit_bin")] @@ -16,13 +19,12 @@ use { tokio::io::AsyncReadExt, Request, State, }, - ruma::api::IncomingRequest, std::io::Cursor, }; /// This struct converts rocket requests into ruma structs by converting them into http requests /// first. -pub struct Ruma { +pub struct Ruma { pub body: T, pub sender_id: Option, pub device_id: Option>, @@ -110,7 +112,7 @@ impl<'a, T: IncomingRequest> FromTransformedData<'a> for Ruma { } } -impl Deref for Ruma { +impl Deref for Ruma { type Target = T; fn deref(&self) -> &Self::Target { diff --git a/src/server_server.rs b/src/server_server.rs index f48f502..e47b50a 100644 --- a/src/server_server.rs +++ b/src/server_server.rs @@ -1,14 +1,17 @@ use crate::{client_server, ConduitResult, Database, Error, Result, Ruma}; use http::header::{HeaderValue, AUTHORIZATION}; use rocket::{get, post, put, response::content::Json, State}; -use ruma::api::federation::{ - directory::get_public_rooms, - discovery::{ - get_server_keys, get_server_version::v1 as get_server_version, ServerKey, VerifyKey, +use ruma::api::{ + client, + federation::{ + directory::get_public_rooms, + discovery::{ + get_server_keys, get_server_version::v1 as get_server_version, ServerKey, VerifyKey, + }, + transactions::send_transaction_message, }, - transactions::send_transaction_message, + OutgoingRequest, }; -use ruma::api::{client, OutgoingRequest}; use serde_json::json; use std::{ collections::BTreeMap, @@ -204,11 +207,11 @@ pub fn get_server_keys_deprecated(db: State<'_, Database>) -> Json { )] pub async fn get_public_rooms_route( db: State<'_, Database>, - body: Ruma, + body: Ruma, ) -> ConduitResult { let Ruma { body: - get_public_rooms::v1::Request { + get_public_rooms::v1::IncomingRequest { room_network: _room_network, // TODO limit, since, @@ -229,7 +232,7 @@ pub async fn get_public_rooms_route( body: client::r0::directory::get_public_rooms_filtered::IncomingRequest { filter: None, limit, - room_network: client::r0::directory::get_public_rooms_filtered::RoomNetwork::Matrix, + room_network: ruma::directory::RoomNetwork::Matrix, server: None, since, }, @@ -268,9 +271,9 @@ pub async fn get_public_rooms_route( feature = "conduit_bin", put("/_matrix/federation/v1/send/<_>", data = "") )] -pub fn send_transaction_message_route( - db: State<'_, Database>, - body: Ruma, +pub fn send_transaction_message_route<'a>( + _db: State<'a, Database>, + body: Ruma, ) -> ConduitResult { dbg!(&*body); Ok(send_transaction_message::v1::Response { diff --git a/src/utils.rs b/src/utils.rs index 8cf1b2c..77a7d1f 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,3 +1,4 @@ +use crate::Error; use argon2::{Config, Variant}; use cmp::Ordering; use rand::prelude::*; @@ -90,3 +91,8 @@ pub fn common_elements( .all(|b| b) })) } + +pub fn deserialize<'de, T: serde::Deserialize<'de>>(val: &'de sled::IVec) -> Result { + serde_json::from_slice::(val.as_ref()) + .map_err(|_| Error::bad_database("PDU in db is invalid.")) +}