From cfaa900e8306b35c2d9b719b96bfce29ac37f58f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timo=20K=C3=B6sters?= Date: Sun, 18 Jul 2021 20:43:39 +0200 Subject: [PATCH] improvement: auth chain cache --- Cargo.lock | 36 ++-- Cargo.toml | 7 +- src/client_server/membership.rs | 4 +- src/client_server/sync.rs | 6 +- src/database.rs | 5 +- src/database/abstraction/rocksdb.rs | 6 +- src/database/abstraction/sqlite.rs | 6 +- src/database/globals.rs | 24 +-- src/database/rooms.rs | 23 ++- src/server_server.rs | 260 +++++++++++++++------------- 10 files changed, 201 insertions(+), 176 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 828d7cc..8bfba0e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2015,7 +2015,7 @@ dependencies = [ [[package]] name = "ruma" version = "0.2.0" -source = "git+https://github.com/ruma/ruma?rev=c29c2b16ec114fa655e2b70bdd53c82e35859005#c29c2b16ec114fa655e2b70bdd53c82e35859005" +source = "git+https://github.com/timokoesters/ruma?rev=a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386#a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386" dependencies = [ "assign", "js_int", @@ -2036,7 +2036,7 @@ dependencies = [ [[package]] name = "ruma-api" version = "0.17.1" -source = "git+https://github.com/ruma/ruma?rev=c29c2b16ec114fa655e2b70bdd53c82e35859005#c29c2b16ec114fa655e2b70bdd53c82e35859005" +source = "git+https://github.com/timokoesters/ruma?rev=a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386#a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386" dependencies = [ "bytes", "http", @@ -2052,7 +2052,7 @@ dependencies = [ [[package]] name = "ruma-api-macros" version = "0.17.1" -source = "git+https://github.com/ruma/ruma?rev=c29c2b16ec114fa655e2b70bdd53c82e35859005#c29c2b16ec114fa655e2b70bdd53c82e35859005" +source = "git+https://github.com/timokoesters/ruma?rev=a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386#a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386" dependencies = [ "proc-macro-crate", "proc-macro2", @@ -2063,7 +2063,7 @@ dependencies = [ [[package]] name = "ruma-appservice-api" version = "0.3.0" -source = "git+https://github.com/ruma/ruma?rev=c29c2b16ec114fa655e2b70bdd53c82e35859005#c29c2b16ec114fa655e2b70bdd53c82e35859005" +source = "git+https://github.com/timokoesters/ruma?rev=a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386#a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386" dependencies = [ "ruma-api", "ruma-common", @@ -2077,7 +2077,7 @@ dependencies = [ [[package]] name = "ruma-client-api" version = "0.11.0" -source = "git+https://github.com/ruma/ruma?rev=c29c2b16ec114fa655e2b70bdd53c82e35859005#c29c2b16ec114fa655e2b70bdd53c82e35859005" +source = "git+https://github.com/timokoesters/ruma?rev=a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386#a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386" dependencies = [ "assign", "bytes", @@ -2097,7 +2097,7 @@ dependencies = [ [[package]] name = "ruma-common" version = "0.5.4" -source = "git+https://github.com/ruma/ruma?rev=c29c2b16ec114fa655e2b70bdd53c82e35859005#c29c2b16ec114fa655e2b70bdd53c82e35859005" +source = "git+https://github.com/timokoesters/ruma?rev=a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386#a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386" dependencies = [ "indexmap", "js_int", @@ -2112,7 +2112,7 @@ dependencies = [ [[package]] name = "ruma-events" version = "0.23.2" -source = "git+https://github.com/ruma/ruma?rev=c29c2b16ec114fa655e2b70bdd53c82e35859005#c29c2b16ec114fa655e2b70bdd53c82e35859005" +source = "git+https://github.com/timokoesters/ruma?rev=a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386#a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386" dependencies = [ "indoc", "js_int", @@ -2128,7 +2128,7 @@ dependencies = [ [[package]] name = "ruma-events-macros" version = "0.23.2" -source = "git+https://github.com/ruma/ruma?rev=c29c2b16ec114fa655e2b70bdd53c82e35859005#c29c2b16ec114fa655e2b70bdd53c82e35859005" +source = "git+https://github.com/timokoesters/ruma?rev=a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386#a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386" dependencies = [ "proc-macro-crate", "proc-macro2", @@ -2139,7 +2139,7 @@ dependencies = [ [[package]] name = "ruma-federation-api" version = "0.2.0" -source = "git+https://github.com/ruma/ruma?rev=c29c2b16ec114fa655e2b70bdd53c82e35859005#c29c2b16ec114fa655e2b70bdd53c82e35859005" +source = "git+https://github.com/timokoesters/ruma?rev=a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386#a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386" dependencies = [ "js_int", "ruma-api", @@ -2154,7 +2154,7 @@ dependencies = [ [[package]] name = "ruma-identifiers" version = "0.19.4" -source = "git+https://github.com/ruma/ruma?rev=c29c2b16ec114fa655e2b70bdd53c82e35859005#c29c2b16ec114fa655e2b70bdd53c82e35859005" +source = "git+https://github.com/timokoesters/ruma?rev=a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386#a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386" dependencies = [ "paste", "rand 0.8.4", @@ -2168,7 +2168,7 @@ dependencies = [ [[package]] name = "ruma-identifiers-macros" version = "0.19.4" -source = "git+https://github.com/ruma/ruma?rev=c29c2b16ec114fa655e2b70bdd53c82e35859005#c29c2b16ec114fa655e2b70bdd53c82e35859005" +source = "git+https://github.com/timokoesters/ruma?rev=a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386#a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386" dependencies = [ "quote", "ruma-identifiers-validation", @@ -2178,12 +2178,12 @@ dependencies = [ [[package]] name = "ruma-identifiers-validation" version = "0.4.0" -source = "git+https://github.com/ruma/ruma?rev=c29c2b16ec114fa655e2b70bdd53c82e35859005#c29c2b16ec114fa655e2b70bdd53c82e35859005" +source = "git+https://github.com/timokoesters/ruma?rev=a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386#a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386" [[package]] name = "ruma-identity-service-api" version = "0.2.0" -source = "git+https://github.com/ruma/ruma?rev=c29c2b16ec114fa655e2b70bdd53c82e35859005#c29c2b16ec114fa655e2b70bdd53c82e35859005" +source = "git+https://github.com/timokoesters/ruma?rev=a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386#a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386" dependencies = [ "js_int", "ruma-api", @@ -2196,7 +2196,7 @@ dependencies = [ [[package]] name = "ruma-push-gateway-api" version = "0.2.0" -source = "git+https://github.com/ruma/ruma?rev=c29c2b16ec114fa655e2b70bdd53c82e35859005#c29c2b16ec114fa655e2b70bdd53c82e35859005" +source = "git+https://github.com/timokoesters/ruma?rev=a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386#a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386" dependencies = [ "js_int", "ruma-api", @@ -2211,7 +2211,7 @@ dependencies = [ [[package]] name = "ruma-serde" version = "0.4.1" -source = "git+https://github.com/ruma/ruma?rev=c29c2b16ec114fa655e2b70bdd53c82e35859005#c29c2b16ec114fa655e2b70bdd53c82e35859005" +source = "git+https://github.com/timokoesters/ruma?rev=a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386#a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386" dependencies = [ "bytes", "form_urlencoded", @@ -2225,7 +2225,7 @@ dependencies = [ [[package]] name = "ruma-serde-macros" version = "0.4.1" -source = "git+https://github.com/ruma/ruma?rev=c29c2b16ec114fa655e2b70bdd53c82e35859005#c29c2b16ec114fa655e2b70bdd53c82e35859005" +source = "git+https://github.com/timokoesters/ruma?rev=a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386#a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386" dependencies = [ "proc-macro-crate", "proc-macro2", @@ -2236,7 +2236,7 @@ dependencies = [ [[package]] name = "ruma-signatures" version = "0.8.0" -source = "git+https://github.com/ruma/ruma?rev=c29c2b16ec114fa655e2b70bdd53c82e35859005#c29c2b16ec114fa655e2b70bdd53c82e35859005" +source = "git+https://github.com/timokoesters/ruma?rev=a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386#a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386" dependencies = [ "base64 0.13.0", "ed25519-dalek", @@ -2253,7 +2253,7 @@ dependencies = [ [[package]] name = "ruma-state-res" version = "0.2.0" -source = "git+https://github.com/ruma/ruma?rev=c29c2b16ec114fa655e2b70bdd53c82e35859005#c29c2b16ec114fa655e2b70bdd53c82e35859005" +source = "git+https://github.com/timokoesters/ruma?rev=a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386#a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386" dependencies = [ "itertools 0.10.1", "js_int", diff --git a/Cargo.toml b/Cargo.toml index fd72d0e..64d67a1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,7 +18,8 @@ edition = "2018" rocket = { version = "0.5.0-rc.1", features = ["tls"] } # Used to handle requests # Used for matrix spec type definitions and helpers -ruma = { git = "https://github.com/ruma/ruma", rev = "c29c2b16ec114fa655e2b70bdd53c82e35859005", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-pre-spec", "unstable-exhaustive-types"] } +#ruma = { git = "https://github.com/ruma/ruma", rev = "c29c2b16ec114fa655e2b70bdd53c82e35859005", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-pre-spec", "unstable-exhaustive-types"] } +ruma = { git = "https://github.com/timokoesters/ruma", rev = "a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-pre-spec", "unstable-exhaustive-types"] } #ruma = { path = "../ruma/crates/ruma", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-pre-spec", "unstable-exhaustive-types"] } # Used for long polling and federation sender, should be the same as rocket::tokio @@ -119,5 +120,5 @@ maintainer-scripts = "debian/" systemd-units = { unit-name = "matrix-conduit" } # For flamegraphs: -[profile.release] -debug = true +#[profile.release] +#debug = true diff --git a/src/client_server/membership.rs b/src/client_server/membership.rs index ef141f0..d8c2781 100644 --- a/src/client_server/membership.rs +++ b/src/client_server/membership.rs @@ -29,7 +29,7 @@ use ruma::{ uint, EventId, RoomId, RoomVersionId, ServerName, UserId, }; use std::{ - collections::{btree_map::Entry, BTreeMap, HashSet}, + collections::{hash_map::Entry, BTreeMap, HashMap, HashSet}, convert::{TryFrom, TryInto}, sync::{Arc, RwLock}, time::{Duration, Instant}, @@ -607,7 +607,7 @@ async fn join_room_by_id_helper( let pdu = PduEvent::from_id_val(&event_id, join_event.clone()) .map_err(|_| Error::BadServerResponse("Invalid join event PDU."))?; - let mut state = BTreeMap::new(); + let mut state = HashMap::new(); let pub_key_map = RwLock::new(BTreeMap::new()); for result in futures::future::join_all( diff --git a/src/client_server/sync.rs b/src/client_server/sync.rs index 3beddad..65922be 100644 --- a/src/client_server/sync.rs +++ b/src/client_server/sync.rs @@ -7,7 +7,7 @@ use ruma::{ DeviceId, RoomId, UserId, }; use std::{ - collections::{btree_map::Entry, hash_map, BTreeMap, HashMap, HashSet}, + collections::{hash_map::Entry, BTreeMap, HashMap, HashSet}, convert::{TryFrom, TryInto}, sync::Arc, time::Duration, @@ -622,10 +622,10 @@ async fn sync_helper( .presence_since(&room_id, since, &db.rooms, &db.globals)? { match presence_updates.entry(user_id) { - hash_map::Entry::Vacant(v) => { + Entry::Vacant(v) => { v.insert(presence); } - hash_map::Entry::Occupied(mut o) => { + Entry::Occupied(mut o) => { let p = o.get_mut(); // Update existing presence event with more info diff --git a/src/database.rs b/src/database.rs index 27b9eb6..e359a5f 100644 --- a/src/database.rs +++ b/src/database.rs @@ -33,7 +33,7 @@ use std::{ io::Write, ops::Deref, path::Path, - sync::{Arc, RwLock}, + sync::{Arc, Mutex, RwLock}, }; use tokio::sync::{OwnedRwLockReadGuard, RwLock as TokioRwLock, Semaphore}; @@ -292,7 +292,8 @@ impl Database { eventid_outlierpdu: builder.open_tree("eventid_outlierpdu")?, prevevent_parent: builder.open_tree("prevevent_parent")?, - pdu_cache: RwLock::new(LruCache::new(10_000)), + pdu_cache: Mutex::new(LruCache::new(100_000)), + auth_chain_cache: Mutex::new(LruCache::new(100_000)), }, account_data: account_data::AccountData { roomuserdataid_accountdata: builder.open_tree("roomuserdataid_accountdata")?, diff --git a/src/database/abstraction/rocksdb.rs b/src/database/abstraction/rocksdb.rs index b996130..4699b2d 100644 --- a/src/database/abstraction/rocksdb.rs +++ b/src/database/abstraction/rocksdb.rs @@ -5,14 +5,14 @@ use std::{future::Future, pin::Pin, sync::Arc}; use super::{DatabaseEngine, Tree}; -use std::{collections::BTreeMap, sync::RwLock}; +use std::{collections::HashMap, sync::RwLock}; pub struct Engine(rocksdb::DBWithThreadMode); pub struct RocksDbEngineTree<'a> { db: Arc, name: &'a str, - watchers: RwLock, Vec>>>, + watchers: RwLock, Vec>>>, } impl DatabaseEngine for Engine { @@ -58,7 +58,7 @@ impl DatabaseEngine for Engine { Ok(Arc::new(RocksDbEngineTree { name, db: Arc::clone(self), - watchers: RwLock::new(BTreeMap::new()), + watchers: RwLock::new(HashMap::new()), })) } } diff --git a/src/database/abstraction/sqlite.rs b/src/database/abstraction/sqlite.rs index 8100ed9..8cc6a8d 100644 --- a/src/database/abstraction/sqlite.rs +++ b/src/database/abstraction/sqlite.rs @@ -7,7 +7,7 @@ use log::debug; use parking_lot::{Mutex, MutexGuard, RwLock}; use rusqlite::{params, Connection, DatabaseName::Main, OptionalExtension}; use std::{ - collections::BTreeMap, + collections::HashMap, future::Future, ops::Deref, path::{Path, PathBuf}, @@ -206,7 +206,7 @@ impl DatabaseEngine for Engine { Ok(Arc::new(SqliteTable { engine: Arc::clone(self), name: name.to_owned(), - watchers: RwLock::new(BTreeMap::new()), + watchers: RwLock::new(HashMap::new()), })) } @@ -266,7 +266,7 @@ impl Engine { pub struct SqliteTable { engine: Arc, name: String, - watchers: RwLock, Vec>>>, + watchers: RwLock, Vec>>>, } type TupleOfBytes = (Vec, Vec); diff --git a/src/database/globals.rs b/src/database/globals.rs index 0e72297..fbd41a3 100644 --- a/src/database/globals.rs +++ b/src/database/globals.rs @@ -41,12 +41,12 @@ pub struct Globals { dns_resolver: TokioAsyncResolver, jwt_decoding_key: Option>, pub(super) server_signingkeys: Arc, - pub bad_event_ratelimiter: Arc>>, - pub bad_signature_ratelimiter: Arc, RateLimitState>>>, - pub servername_ratelimiter: Arc, Arc>>>, - pub sync_receivers: RwLock), SyncHandle>>, - pub roomid_mutex: RwLock>>>, - pub roomid_mutex_federation: RwLock>>>, // this lock will be held longer + pub bad_event_ratelimiter: Arc>>, + pub bad_signature_ratelimiter: Arc, RateLimitState>>>, + pub servername_ratelimiter: Arc, Arc>>>, + pub sync_receivers: RwLock), SyncHandle>>, + pub roomid_mutex: RwLock>>>, + pub roomid_mutex_federation: RwLock>>>, // this lock will be held longer pub rotate: RotationHandler, } @@ -196,12 +196,12 @@ impl Globals { tls_name_override, server_signingkeys, jwt_decoding_key, - bad_event_ratelimiter: Arc::new(RwLock::new(BTreeMap::new())), - bad_signature_ratelimiter: Arc::new(RwLock::new(BTreeMap::new())), - servername_ratelimiter: Arc::new(RwLock::new(BTreeMap::new())), - roomid_mutex: RwLock::new(BTreeMap::new()), - roomid_mutex_federation: RwLock::new(BTreeMap::new()), - sync_receivers: RwLock::new(BTreeMap::new()), + bad_event_ratelimiter: Arc::new(RwLock::new(HashMap::new())), + bad_signature_ratelimiter: Arc::new(RwLock::new(HashMap::new())), + servername_ratelimiter: Arc::new(RwLock::new(HashMap::new())), + roomid_mutex: RwLock::new(HashMap::new()), + roomid_mutex_federation: RwLock::new(HashMap::new()), + sync_receivers: RwLock::new(HashMap::new()), rotate: RotationHandler::new(), }; diff --git a/src/database/rooms.rs b/src/database/rooms.rs index f6f5021..fa121bd 100644 --- a/src/database/rooms.rs +++ b/src/database/rooms.rs @@ -25,7 +25,7 @@ use std::{ collections::{BTreeMap, BTreeSet, HashMap, HashSet}, convert::{TryFrom, TryInto}, mem, - sync::{Arc, RwLock}, + sync::{Arc, Mutex}, }; use super::{abstraction::Tree, admin::AdminCommand, pusher}; @@ -84,7 +84,8 @@ pub struct Rooms { /// RoomId + EventId -> Parent PDU EventId. pub(super) prevevent_parent: Arc, - pub(super) pdu_cache: RwLock>>, + pub(super) pdu_cache: Mutex>>, + pub(super) auth_chain_cache: Mutex>>, } impl Rooms { @@ -109,7 +110,7 @@ impl Rooms { pub fn state_full( &self, shortstatehash: u64, - ) -> Result>> { + ) -> Result>> { let state = self .stateid_shorteventid .scan_prefix(shortstatehash.to_be_bytes().to_vec()) @@ -282,7 +283,7 @@ impl Rooms { pub fn force_state( &self, room_id: &RoomId, - state: BTreeMap<(EventType, String), EventId>, + state: HashMap<(EventType, String), EventId>, db: &Database, ) -> Result<()> { let state_hash = self.calculate_hash( @@ -402,11 +403,11 @@ impl Rooms { pub fn room_state_full( &self, room_id: &RoomId, - ) -> Result>> { + ) -> Result>> { if let Some(current_shortstatehash) = self.current_shortstatehash(room_id)? { self.state_full(current_shortstatehash) } else { - Ok(BTreeMap::new()) + Ok(HashMap::new()) } } @@ -542,7 +543,7 @@ impl Rooms { /// /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. pub fn get_pdu(&self, event_id: &EventId) -> Result>> { - if let Some(p) = self.pdu_cache.write().unwrap().get_mut(&event_id) { + if let Some(p) = self.pdu_cache.lock().unwrap().get_mut(&event_id) { return Ok(Some(Arc::clone(p))); } @@ -568,7 +569,7 @@ impl Rooms { .transpose()? { self.pdu_cache - .write() + .lock() .unwrap() .insert(event_id.clone(), Arc::clone(&pdu)); Ok(Some(pdu)) @@ -2520,4 +2521,10 @@ impl Rooms { Ok(self.userroomid_leftstate.get(&userroom_id)?.is_some()) } + + pub fn auth_chain_cache( + &self, + ) -> std::sync::MutexGuard<'_, LruCache>> { + self.auth_chain_cache.lock().unwrap() + } } diff --git a/src/server_server.rs b/src/server_server.rs index bfb3e72..39a1847 100644 --- a/src/server_server.rs +++ b/src/server_server.rs @@ -6,6 +6,7 @@ use crate::{ use get_profile_information::v1::ProfileField; use http::header::{HeaderValue, AUTHORIZATION, HOST}; use log::{debug, error, info, trace, warn}; +use lru_cache::LruCache; use regex::Regex; use rocket::response::content::Json; use ruma::{ @@ -52,7 +53,7 @@ use ruma::{ ServerSigningKeyId, UserId, }; use std::{ - collections::{btree_map::Entry, BTreeMap, BTreeSet, HashSet}, + collections::{hash_map::Entry, BTreeMap, HashMap, HashSet}, convert::{TryFrom, TryInto}, fmt::Debug, future::Future, @@ -931,7 +932,7 @@ pub fn handle_incoming_pdu<'a>( ); // Build map of auth events - let mut auth_events = BTreeMap::new(); + let mut auth_events = HashMap::new(); for id in &incoming_pdu.auth_events { let auth_event = db .rooms @@ -1097,7 +1098,7 @@ pub fn handle_incoming_pdu<'a>( Err(_) => return Err("Failed to fetch state events.".to_owned()), }; - let mut state = BTreeMap::new(); + let mut state = HashMap::new(); for pdu in state_vec { match state.entry((pdu.kind.clone(), pdu.state_key.clone().ok_or_else(|| "Found non-state pdu in state events.".to_owned())?)) { Entry::Vacant(v) => { @@ -1173,7 +1174,8 @@ pub fn handle_incoming_pdu<'a>( } } - let mut fork_states = BTreeSet::new(); + let mut extremity_statehashes = Vec::new(); + for id in &extremities { match db .rooms @@ -1181,30 +1183,19 @@ pub fn handle_incoming_pdu<'a>( .map_err(|_| "Failed to ask db for pdu.".to_owned())? { Some(leaf_pdu) => { - let pdu_shortstatehash = db - .rooms - .pdu_shortstatehash(&leaf_pdu.event_id) - .map_err(|_| "Failed to ask db for pdu state hash.".to_owned())? - .ok_or_else(|| { - error!( - "Found extremity pdu with no statehash in db: {:?}", - leaf_pdu - ); - "Found pdu with no statehash in db.".to_owned() - })?; - - let mut leaf_state = db - .rooms - .state_full(pdu_shortstatehash) - .map_err(|_| "Failed to ask db for room state.".to_owned())?; - - if let Some(state_key) = &leaf_pdu.state_key { - // Now it's the state after - let key = (leaf_pdu.kind.clone(), state_key.clone()); - leaf_state.insert(key, leaf_pdu); - } - - fork_states.insert(leaf_state); + extremity_statehashes.push(( + db.rooms + .pdu_shortstatehash(&leaf_pdu.event_id) + .map_err(|_| "Failed to ask db for pdu state hash.".to_owned())? + .ok_or_else(|| { + error!( + "Found extremity pdu with no statehash in db: {:?}", + leaf_pdu + ); + "Found pdu with no statehash in db.".to_owned() + })?, + Some(leaf_pdu), + )); } _ => { error!("Missing state snapshot for {:?}", id); @@ -1218,12 +1209,36 @@ pub fn handle_incoming_pdu<'a>( // don't just trust a set of state we got from a remote). // We do this by adding the current state to the list of fork states + let current_statehash = db + .rooms + .current_shortstatehash(&room_id) + .map_err(|_| "Failed to load current state hash.".to_owned())? + .expect("every room has state"); + let current_state = db .rooms - .room_state_full(&room_id) - .map_err(|_| "Failed to load room state.".to_owned())?; + .state_full(current_statehash) + .map_err(|_| "Failed to load room state.")?; - fork_states.insert(current_state.clone()); + extremity_statehashes.push((current_statehash.clone(), None)); + + let mut fork_states = Vec::new(); + for (statehash, leaf_pdu) in extremity_statehashes { + let mut leaf_state = db + .rooms + .state_full(statehash) + .map_err(|_| "Failed to ask db for room state.".to_owned())?; + + if let Some(leaf_pdu) = leaf_pdu { + if let Some(state_key) = &leaf_pdu.state_key { + // Now it's the state after + let key = (leaf_pdu.kind.clone(), state_key.clone()); + leaf_state.insert(key, leaf_pdu); + } + } + + fork_states.push(leaf_state); + } // We also add state after incoming event to the fork states extremities.insert(incoming_pdu.event_id.clone()); @@ -1234,9 +1249,7 @@ pub fn handle_incoming_pdu<'a>( incoming_pdu.clone(), ); } - fork_states.insert(state_after.clone()); - - let fork_states = fork_states.into_iter().collect::>(); + fork_states.push(state_after.clone()); let mut update_state = false; // 14. Use state resolution to find new room state @@ -1254,17 +1267,31 @@ pub fn handle_incoming_pdu<'a>( // We do need to force an update to this room's state update_state = true; - match state_res::StateResolution::resolve( + let fork_states = &fork_states + .into_iter() + .map(|map| { + map.into_iter() + .map(|(k, v)| (k, v.event_id.clone())) + .collect::>() + }) + .collect::>(); + + let auth_chain_t = Instant::now(); + let mut auth_chain_sets = Vec::new(); + for state in fork_states { + auth_chain_sets.push( + get_auth_chain(state.iter().map(|(_, id)| id.clone()).collect(), db) + .map_err(|_| "Failed to load auth chain.".to_owned())?, + ); + } + dbg!(auth_chain_t.elapsed()); + + let state_res_t = Instant::now(); + let state = match state_res::StateResolution::resolve( &room_id, room_version_id, - &fork_states - .into_iter() - .map(|map| { - map.into_iter() - .map(|(k, v)| (k, v.event_id.clone())) - .collect::>() - }) - .collect::>(), + fork_states, + auth_chain_sets, |id| { let res = db.rooms.get_pdu(id); if let Err(e) = &res { @@ -1277,7 +1304,9 @@ pub fn handle_incoming_pdu<'a>( Err(_) => { return Err("State resolution failed, either an event could not be found or deserialization".into()); } - } + }; + dbg!(state_res_t.elapsed()); + state }; // 13. Check if the event passes auth based on the "current state" of the room, if not "soft fail" it @@ -1696,6 +1725,42 @@ async fn append_incoming_pdu( Ok(pdu_id) } +fn get_auth_chain(starting_events: Vec, db: &Database) -> Result> { + let mut auth_chain_cache = db.rooms.auth_chain_cache(); + + let mut auth_chain = HashSet::new(); + + for event in starting_events { + auth_chain.extend(get_auth_chain_recursive(&event, &mut auth_chain_cache, db)?); + } + + Ok(auth_chain) +} + +fn get_auth_chain_recursive( + event_id: &EventId, + auth_chain_cache: &mut std::sync::MutexGuard<'_, LruCache>>, + db: &Database, +) -> Result> { + if let Some(cached) = auth_chain_cache.get_mut(event_id) { + return Ok(cached.clone()); + } + + let mut auth_chain = HashSet::new(); + + if let Some(pdu) = db.rooms.get_pdu(&event_id)? { + for auth_event in &pdu.auth_events { + auth_chain.extend(get_auth_chain_recursive(&auth_event, auth_chain_cache, db)?); + } + } else { + warn!("Could not find pdu mentioned in auth events."); + } + + auth_chain_cache.insert(event_id.clone(), auth_chain.clone()); + + Ok(auth_chain) +} + #[cfg_attr( feature = "conduit_bin", get("/_matrix/federation/v1/event/<_>", data = "") @@ -1783,35 +1848,20 @@ pub fn get_event_authorization_route( return Err(Error::bad_config("Federation is disabled.")); } - let mut auth_chain = Vec::new(); - let mut auth_chain_ids = BTreeSet::::new(); - let mut todo = BTreeSet::new(); - todo.insert(body.event_id.clone()); + let auth_chain_ids = get_auth_chain(vec![body.event_id.clone()], &db)?; - while let Some(event_id) = todo.iter().next().cloned() { - if let Some(pdu) = db.rooms.get_pdu(&event_id)? { - todo.extend( - pdu.auth_events - .clone() - .into_iter() - .collect::>() - .difference(&auth_chain_ids) - .cloned(), - ); - auth_chain_ids.extend(pdu.auth_events.clone().into_iter()); - - let pdu_json = PduEvent::convert_to_outgoing_federation_event( - db.rooms.get_pdu_json(&event_id)?.unwrap(), - ); - auth_chain.push(pdu_json); - } else { - warn!("Could not find pdu mentioned in auth events."); - } - - todo.remove(&event_id); + Ok(get_event_authorization::v1::Response { + auth_chain: auth_chain_ids + .into_iter() + .map(|id| { + Ok::<_, Error>(PduEvent::convert_to_outgoing_federation_event( + db.rooms.get_pdu_json(&id)?.unwrap(), + )) + }) + .filter_map(|r| r.ok()) + .collect(), } - - Ok(get_event_authorization::v1::Response { auth_chain }.into()) + .into()) } #[cfg_attr( @@ -1846,35 +1896,21 @@ pub fn get_room_state_route( }) .collect(); - let mut auth_chain = Vec::new(); - let mut auth_chain_ids = BTreeSet::::new(); - let mut todo = BTreeSet::new(); - todo.insert(body.event_id.clone()); + let auth_chain_ids = get_auth_chain(vec![body.event_id.clone()], &db)?; - while let Some(event_id) = todo.iter().next().cloned() { - if let Some(pdu) = db.rooms.get_pdu(&event_id)? { - todo.extend( - pdu.auth_events - .clone() - .into_iter() - .collect::>() - .difference(&auth_chain_ids) - .cloned(), - ); - auth_chain_ids.extend(pdu.auth_events.clone().into_iter()); - - let pdu_json = PduEvent::convert_to_outgoing_federation_event( - db.rooms.get_pdu_json(&event_id)?.unwrap(), - ); - auth_chain.push(pdu_json); - } else { - warn!("Could not find pdu mentioned in auth events."); - } - - todo.remove(&event_id); + Ok(get_room_state::v1::Response { + auth_chain: auth_chain_ids + .into_iter() + .map(|id| { + Ok::<_, Error>(PduEvent::convert_to_outgoing_federation_event( + db.rooms.get_pdu_json(&id)?.unwrap(), + )) + }) + .filter_map(|r| r.ok()) + .collect(), + pdus, } - - Ok(get_room_state::v1::Response { auth_chain, pdus }.into()) + .into()) } #[cfg_attr( @@ -1904,27 +1940,7 @@ pub fn get_room_state_ids_route( .into_iter() .collect(); - let mut auth_chain_ids = BTreeSet::::new(); - let mut todo = BTreeSet::new(); - todo.insert(body.event_id.clone()); - - while let Some(event_id) = todo.iter().next().cloned() { - if let Some(pdu) = db.rooms.get_pdu(&event_id)? { - todo.extend( - pdu.auth_events - .clone() - .into_iter() - .collect::>() - .difference(&auth_chain_ids) - .cloned(), - ); - auth_chain_ids.extend(pdu.auth_events.clone().into_iter()); - } else { - warn!("Could not find pdu mentioned in auth events."); - } - - todo.remove(&event_id); - } + let auth_chain_ids = get_auth_chain(vec![body.event_id.clone()], &db)?; Ok(get_room_state_ids::v1::Response { auth_chain_ids: auth_chain_ids.into_iter().collect(), @@ -2182,8 +2198,8 @@ pub async fn create_join_event_route( let state_ids = db.rooms.state_full_ids(shortstatehash)?; - let mut auth_chain_ids = BTreeSet::::new(); - let mut todo = state_ids.iter().cloned().collect::>(); + let mut auth_chain_ids = HashSet::::new(); + let mut todo = state_ids.iter().cloned().collect::>(); while let Some(event_id) = todo.iter().next().cloned() { if let Some(pdu) = db.rooms.get_pdu(&event_id)? { @@ -2191,7 +2207,7 @@ pub async fn create_join_event_route( pdu.auth_events .clone() .into_iter() - .collect::>() + .collect::>() .difference(&auth_chain_ids) .cloned(), );