From 6606e41dde413af64278e52ee2a376377c8c035e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timo=20K=C3=B6sters?= Date: Sat, 19 Dec 2020 16:00:11 +0100 Subject: [PATCH] feat: improved state store --- src/client_server/alias.rs | 32 +++--- src/client_server/directory.rs | 28 ++--- src/client_server/media.rs | 60 ++++++----- src/client_server/membership.rs | 55 +++++----- src/client_server/read_marker.rs | 19 +++- src/client_server/sync.rs | 111 ++++++++++--------- src/database.rs | 13 ++- src/database/globals.rs | 10 +- src/database/rooms.rs | 178 +++++++++++++++++-------------- src/database/sending.rs | 137 ++++++++++++++++++------ src/error.rs | 2 +- src/main.rs | 9 +- src/server_server.rs | 2 +- 13 files changed, 405 insertions(+), 251 deletions(-) diff --git a/src/client_server/alias.rs b/src/client_server/alias.rs index ec73ffc..498e882 100644 --- a/src/client_server/alias.rs +++ b/src/client_server/alias.rs @@ -1,5 +1,5 @@ use super::State; -use crate::{appservice_server, server_server, ConduitResult, Database, Error, Ruma}; +use crate::{ConduitResult, Database, Error, Ruma}; use ruma::{ api::{ appservice, @@ -66,12 +66,14 @@ pub async fn get_alias_helper( room_alias: &RoomAliasId, ) -> ConduitResult { if room_alias.server_name() != db.globals.server_name() { - let response = server_server::send_request( - &db.globals, - room_alias.server_name().to_owned(), - federation::query::get_room_information::v1::Request { room_alias }, - ) - .await?; + let response = db + .sending + .send_federation_request( + &db.globals, + room_alias.server_name().to_owned(), + federation::query::get_room_information::v1::Request { room_alias }, + ) + .await?; return Ok(get_alias::Response::new(response.room_id, response.servers).into()); } @@ -81,13 +83,15 @@ pub async fn get_alias_helper( Some(r) => room_id = Some(r), None => { for (_id, registration) in db.appservice.iter_all().filter_map(|r| r.ok()) { - if appservice_server::send_request( - &db.globals, - registration, - appservice::query::query_room_alias::v1::Request { room_alias }, - ) - .await - .is_ok() + if db + .sending + .send_appservice_request( + &db.globals, + registration, + appservice::query::query_room_alias::v1::Request { room_alias }, + ) + .await + .is_ok() { room_id = Some(db.rooms.id_from_alias(&room_alias)?.ok_or_else(|| { Error::bad_config("Appservice lied to us. Room does not exist.") diff --git a/src/client_server/directory.rs b/src/client_server/directory.rs index 559071a..fa5db3a 100644 --- a/src/client_server/directory.rs +++ b/src/client_server/directory.rs @@ -1,5 +1,5 @@ use super::State; -use crate::{server_server, ConduitResult, Database, Error, Result, Ruma}; +use crate::{ConduitResult, Database, Error, Result, Ruma}; use log::info; use ruma::{ api::{ @@ -133,19 +133,21 @@ pub async fn get_public_rooms_filtered_helper( .clone() .filter(|server| *server != db.globals.server_name().as_str()) { - let response = server_server::send_request( - &db.globals, - other_server.to_owned(), - federation::directory::get_public_rooms_filtered::v1::Request { - limit, - since: since.as_deref(), - filter: Filter { - generic_search_term: filter.generic_search_term.as_deref(), + let response = db + .sending + .send_federation_request( + &db.globals, + other_server.to_owned(), + federation::directory::get_public_rooms_filtered::v1::Request { + limit, + since: since.as_deref(), + filter: Filter { + generic_search_term: filter.generic_search_term.as_deref(), + }, + room_network: RoomNetwork::Matrix, }, - room_network: RoomNetwork::Matrix, - }, - ) - .await?; + ) + .await?; return Ok(get_public_rooms_filtered::Response { chunk: response diff --git a/src/client_server/media.rs b/src/client_server/media.rs index 0776c9e..156040b 100644 --- a/src/client_server/media.rs +++ b/src/client_server/media.rs @@ -1,7 +1,5 @@ use super::State; -use crate::{ - database::media::FileMeta, server_server, utils, ConduitResult, Database, Error, Ruma, -}; +use crate::{database::media::FileMeta, utils, ConduitResult, Database, Error, Ruma}; use ruma::api::client::{ error::ErrorKind, r0::media::{create_content, get_content, get_content_thumbnail, get_media_config}, @@ -45,7 +43,11 @@ pub async fn create_content_route( db.flush().await?; - Ok(create_content::Response { content_uri: mxc, blurhash: None }.into()) + Ok(create_content::Response { + content_uri: mxc, + blurhash: None, + } + .into()) } #[cfg_attr( @@ -71,16 +73,18 @@ pub async fn get_content_route( } .into()) } else if &*body.server_name != db.globals.server_name() && body.allow_remote { - let get_content_response = server_server::send_request( - &db.globals, - body.server_name.clone(), - get_content::Request { - allow_remote: false, - server_name: &body.server_name, - media_id: &body.media_id, - }, - ) - .await?; + let get_content_response = db + .sending + .send_federation_request( + &db.globals, + body.server_name.clone(), + get_content::Request { + allow_remote: false, + server_name: &body.server_name, + media_id: &body.media_id, + }, + ) + .await?; db.media.create( mxc, @@ -118,19 +122,21 @@ pub async fn get_content_thumbnail_route( )? { Ok(get_content_thumbnail::Response { file, content_type }.into()) } else if &*body.server_name != db.globals.server_name() && body.allow_remote { - let get_thumbnail_response = server_server::send_request( - &db.globals, - body.server_name.clone(), - get_content_thumbnail::Request { - allow_remote: false, - height: body.height, - width: body.width, - method: body.method, - server_name: &body.server_name, - media_id: &body.media_id, - }, - ) - .await?; + let get_thumbnail_response = db + .sending + .send_federation_request( + &db.globals, + body.server_name.clone(), + get_content_thumbnail::Request { + allow_remote: false, + height: body.height, + width: body.width, + method: body.method, + server_name: &body.server_name, + media_id: &body.media_id, + }, + ) + .await?; db.media.upload_thumbnail( mxc, diff --git a/src/client_server/membership.rs b/src/client_server/membership.rs index 46548d5..e8d57bc 100644 --- a/src/client_server/membership.rs +++ b/src/client_server/membership.rs @@ -2,7 +2,7 @@ use super::State; use crate::{ client_server, pdu::{PduBuilder, PduEvent}, - server_server, utils, ConduitResult, Database, Error, Result, Ruma, + utils, ConduitResult, Database, Error, Result, Ruma, }; use log::warn; use ruma::{ @@ -401,9 +401,10 @@ pub async fn get_member_events_route( Ok(get_member_events::Response { chunk: db .rooms - .room_state_type(&body.room_id, &EventType::RoomMember)? - .values() - .map(|pdu| pdu.to_member_event()) + .room_state_full(&body.room_id)? + .iter() + .filter(|(key, _)| key.0 == EventType::RoomMember) + .map(|(_, pdu)| pdu.to_member_event()) .collect(), } .into()) @@ -463,16 +464,18 @@ async fn join_room_by_id_helper( )); for remote_server in servers { - let make_join_response = server_server::send_request( - &db.globals, - remote_server.clone(), - federation::membership::create_join_event_template::v1::Request { - room_id, - user_id: sender_user, - ver: &[RoomVersionId::Version5, RoomVersionId::Version6], - }, - ) - .await; + let make_join_response = db + .sending + .send_federation_request( + &db.globals, + remote_server.clone(), + federation::membership::create_join_event_template::v1::Request { + room_id, + user_id: sender_user, + ver: &[RoomVersionId::Version5, RoomVersionId::Version6], + }, + ) + .await; make_join_response_and_server = make_join_response.map(|r| (r, remote_server)); @@ -540,16 +543,18 @@ async fn join_room_by_id_helper( // It has enough fields to be called a proper event now let join_event = join_event_stub; - let send_join_response = server_server::send_request( - &db.globals, - remote_server.clone(), - federation::membership::create_join_event::v2::Request { - room_id, - event_id: &event_id, - pdu: PduEvent::convert_to_outgoing_federation_event(join_event.clone()), - }, - ) - .await?; + let send_join_response = db + .sending + .send_federation_request( + &db.globals, + remote_server.clone(), + federation::membership::create_join_event::v2::Request { + room_id, + event_id: &event_id, + pdu: PduEvent::convert_to_outgoing_federation_event(join_event.clone()), + }, + ) + .await?; let add_event_id = |pdu: &Raw| -> Result<(EventId, CanonicalJsonObject)> { let mut value = serde_json::from_str(pdu.json().get()) @@ -694,7 +699,7 @@ async fn join_room_by_id_helper( } } - db.rooms.force_state(room_id, state)?; + db.rooms.force_state(room_id, state, &db.globals)?; } else { let event = member::MemberEventContent { membership: member::MembershipState::Join, diff --git a/src/client_server/read_marker.rs b/src/client_server/read_marker.rs index f3e7211..0c4ec1a 100644 --- a/src/client_server/read_marker.rs +++ b/src/client_server/read_marker.rs @@ -1,7 +1,9 @@ use super::State; use crate::{ConduitResult, Database, Error, Ruma}; use ruma::{ - api::client::{error::ErrorKind, r0::read_marker::set_read_marker}, + api::client::{ + error::ErrorKind, r0::capabilities::get_capabilities, r0::read_marker::set_read_marker, + }, events::{AnyEphemeralRoomEvent, AnyEvent, EventType}, }; @@ -76,3 +78,18 @@ pub async fn set_read_marker_route( Ok(set_read_marker::Response.into()) } + +#[cfg_attr( + feature = "conduit_bin", + post("/_matrix/client/r0/rooms/<_>/receipt/<_>/<_>", data = "") +)] +pub async fn set_receipt_route( + db: State<'_, Database>, + body: Ruma, +) -> ConduitResult { + let _sender_user = body.sender_user.as_ref().expect("user is authenticated"); + + db.flush().await?; + + Ok(set_read_marker::Response.into()) +} diff --git a/src/client_server/sync.rs b/src/client_server/sync.rs index d7c24dc..8213651 100644 --- a/src/client_server/sync.rs +++ b/src/client_server/sync.rs @@ -102,9 +102,15 @@ pub async fn sync_events_route( } // Database queries: - let encrypted_room = db - .rooms - .room_state_get(&room_id, &EventType::RoomEncryption, "")? + + let current_state = db.rooms.room_state_full(&room_id)?; + let current_members = current_state + .iter() + .filter(|(key, _)| key.0 == EventType::RoomMember) + .map(|(key, value)| (&key.1, value)) // Only keep state key + .collect::>(); + let encrypted_room = current_state + .get(&(EventType::RoomEncryption, "".to_owned())) .is_some(); // These type is Option>. The outer Option is None when there is no event between @@ -117,45 +123,45 @@ pub async fn sync_events_route( .as_ref() .map(|pdu| db.rooms.pdu_state_hash(&pdu.as_ref().ok()?.0).ok()?); - let since_members = since_state_hash.as_ref().map(|state_hash| { - state_hash.as_ref().and_then(|state_hash| { - db.rooms - .state_type(&state_hash, &EventType::RoomMember) - .ok() - }) + let since_state = since_state_hash.as_ref().map(|state_hash| { + state_hash + .as_ref() + .and_then(|state_hash| db.rooms.state_full(&room_id, &state_hash).ok()) }); - let since_encryption = since_state_hash.as_ref().map(|state_hash| { - state_hash.as_ref().and_then(|state_hash| { - db.rooms - .state_get(&state_hash, &EventType::RoomEncryption, "") - .ok() - }) + let since_encryption = since_state.as_ref().map(|state| { + state + .as_ref() + .map(|state| state.get(&(EventType::RoomEncryption, "".to_owned()))) }); - let current_members = db.rooms.room_state_type(&room_id, &EventType::RoomMember)?; - // Calculations: let new_encrypted_room = encrypted_room && since_encryption.map_or(false, |encryption| encryption.is_none()); - let send_member_count = since_members.as_ref().map_or(false, |since_members| { - since_members.as_ref().map_or(true, |since_members| { - current_members.len() != since_members.len() + let send_member_count = since_state.as_ref().map_or(false, |since_state| { + since_state.as_ref().map_or(true, |since_state| { + current_members.len() + != since_state + .iter() + .filter(|(key, _)| key.0 == EventType::RoomMember) + .count() }) }); - let since_sender_member = since_members.as_ref().map(|since_members| { - since_members.as_ref().and_then(|members| { - members.get(sender_user.as_str()).and_then(|pdu| { - serde_json::from_value::>( - pdu.content.clone(), - ) - .expect("Raw::from_value always works") - .deserialize() - .map_err(|_| Error::bad_database("Invalid PDU in database.")) - .ok() - }) + let since_sender_member = since_state.as_ref().map(|since_state| { + since_state.as_ref().and_then(|state| { + state + .get(&(EventType::RoomMember, sender_user.as_str().to_owned())) + .and_then(|pdu| { + serde_json::from_value::< + Raw, + >(pdu.content.clone()) + .expect("Raw::from_value always works") + .deserialize() + .map_err(|_| Error::bad_database("Invalid PDU in database.")) + .ok() + }) }) }); @@ -170,30 +176,32 @@ pub async fn sync_events_route( .membership; let since_membership = - since_members + since_state .as_ref() - .map_or(MembershipState::Join, |members| { - members + .map_or(MembershipState::Join, |since_state| { + since_state .as_ref() - .and_then(|members| { - members.get(&user_id).and_then(|since_member| { - serde_json::from_value::< - Raw, - >( - since_member.content.clone() - ) - .expect("Raw::from_value always works") - .deserialize() - .map_err(|_| { - Error::bad_database("Invalid PDU in database.") + .and_then(|since_state| { + since_state + .get(&(EventType::RoomMember, user_id.clone())) + .and_then(|since_member| { + serde_json::from_value::< + Raw, + >( + since_member.content.clone() + ) + .expect("Raw::from_value always works") + .deserialize() + .map_err(|_| { + Error::bad_database("Invalid PDU in database.") + }) + .ok() }) - .ok() - }) }) .map_or(MembershipState::Leave, |member| member.membership) }); - let user_id = UserId::try_from(user_id) + let user_id = UserId::try_from(user_id.clone()) .map_err(|_| Error::bad_database("Invalid UserId in member PDU."))?; match (since_membership, current_membership) { @@ -456,7 +464,12 @@ pub async fn sync_events_route( }) .and_then(|state_hash| { db.rooms - .state_get(&state_hash, &EventType::RoomMember, sender_user.as_str()) + .state_get( + &room_id, + &state_hash, + &EventType::RoomMember, + sender_user.as_str(), + ) .ok()? .ok_or_else(|| Error::bad_database("State hash in db doesn't have a state.")) .ok() diff --git a/src/database.rs b/src/database.rs index 5150517..99bba83 100644 --- a/src/database.rs +++ b/src/database.rs @@ -20,6 +20,7 @@ use serde::Deserialize; use std::collections::HashMap; use std::sync::{Arc, RwLock}; use std::{convert::TryInto, fs::remove_dir_all}; +use tokio::sync::Semaphore; #[derive(Clone, Deserialize)] pub struct Config { @@ -30,6 +31,8 @@ pub struct Config { cache_capacity: u64, #[serde(default = "default_max_request_size")] max_request_size: u32, + #[serde(default = "default_max_concurrent_requests")] + max_concurrent_requests: u16, #[serde(default)] registration_disabled: bool, #[serde(default)] @@ -39,7 +42,9 @@ pub struct Config { } fn default_server_name() -> Box { - "localhost".try_into().expect("") + "localhost" + .try_into() + .expect("localhost is valid servername") } fn default_cache_capacity() -> u64 { @@ -50,6 +55,10 @@ fn default_max_request_size() -> u32 { 20 * 1024 * 1024 // Default to 20 MB } +fn default_max_concurrent_requests() -> u16 { + 4 +} + #[derive(Clone)] pub struct Database { pub globals: globals::Globals, @@ -159,6 +168,7 @@ impl Database { roomuserid_invited: db.open_tree("roomuserid_invited")?, userroomid_left: db.open_tree("userroomid_left")?, + statekey_short: db.open_tree("statekey_short")?, stateid_pduid: db.open_tree("stateid_pduid")?, pduid_statehash: db.open_tree("pduid_statehash")?, roomid_statehash: db.open_tree("roomid_statehash")?, @@ -180,6 +190,7 @@ impl Database { sending: sending::Sending { servernamepduids: db.open_tree("servernamepduids")?, servercurrentpdus: db.open_tree("servercurrentpdus")?, + maximum_requests: Arc::new(Semaphore::new(10)), }, admin: admin::Admin { sender: admin_sender, diff --git a/src/database/globals.rs b/src/database/globals.rs index e913c0f..485650f 100644 --- a/src/database/globals.rs +++ b/src/database/globals.rs @@ -4,6 +4,7 @@ use ruma::ServerName; use std::collections::HashMap; use std::sync::Arc; use std::sync::RwLock; +use std::time::Duration; use trust_dns_resolver::TokioAsyncResolver; pub const COUNTER: &str = "c"; @@ -54,11 +55,18 @@ impl Globals { } }; + let reqwest_client = reqwest::Client::builder() + .connect_timeout(Duration::from_secs(30)) + .timeout(Duration::from_secs(60 * 3)) + .pool_max_idle_per_host(1) + .build() + .unwrap(); + Ok(Self { globals, config, keypair: Arc::new(keypair), - reqwest_client: reqwest::Client::new(), + reqwest_client, dns_resolver: TokioAsyncResolver::tokio_from_system_conf() .await .map_err(|_| { diff --git a/src/database/rooms.rs b/src/database/rooms.rs index 3e2a17f..3f096a9 100644 --- a/src/database/rooms.rs +++ b/src/database/rooms.rs @@ -62,7 +62,8 @@ pub struct Rooms { /// Remember the state hash at events in the past. pub(super) pduid_statehash: sled::Tree, /// The state for a given state hash. - pub(super) stateid_pduid: sled::Tree, // StateId = StateHash + EventType + StateKey + pub(super) statekey_short: sled::Tree, // StateKey = EventType + StateKey, Short = Count + pub(super) stateid_pduid: sled::Tree, // StateId = StateHash + Short, PduId = Count (without roomid) } impl StateStore for Rooms { @@ -106,21 +107,28 @@ impl StateStore for Rooms { impl Rooms { /// Builds a StateMap by iterating over all keys that start /// with state_hash, this gives the full state for the given state_hash. - pub fn state_full(&self, state_hash: &StateHashId) -> Result> { + pub fn state_full( + &self, + room_id: &RoomId, + state_hash: &StateHashId, + ) -> Result> { self.stateid_pduid .scan_prefix(&state_hash) .values() - .map(|pduid| { - self.pduid_pdu.get(&pduid?)?.map_or_else( - || Err(Error::bad_database("Failed to find StateMap.")), + .map(|pduid_short| { + let mut pduid = room_id.as_bytes().to_vec(); + pduid.push(0xff); + pduid.extend_from_slice(&pduid_short?); + self.pduid_pdu.get(&pduid)?.map_or_else( + || Err(Error::bad_database("Failed to find PDU in state snapshot.")), |b| { serde_json::from_slice::(&b) .map_err(|_| Error::bad_database("Invalid PDU in db.")) }, ) }) + .filter_map(|r| r.ok()) .map(|pdu| { - let pdu = pdu?; Ok(( ( pdu.kind.clone(), @@ -135,64 +143,45 @@ impl Rooms { .collect::>>() } - /// Returns all state entries for this type. - pub fn state_type( - &self, - state_hash: &StateHashId, - event_type: &EventType, - ) -> Result> { - let mut prefix = state_hash.to_vec(); - prefix.push(0xff); - prefix.extend_from_slice(&event_type.to_string().as_bytes()); - prefix.push(0xff); - - let mut hashmap = HashMap::new(); - for pdu in self - .stateid_pduid - .scan_prefix(&prefix) - .values() - .map(|pdu_id| { - Ok::<_, Error>( - serde_json::from_slice::(&self.pduid_pdu.get(pdu_id?)?.ok_or_else( - || Error::bad_database("PDU in state not found in database."), - )?) - .map_err(|_| Error::bad_database("Invalid PDU bytes in room state."))?, - ) - }) - { - let pdu = pdu?; - let state_key = pdu.state_key.clone().ok_or_else(|| { - Error::bad_database("Room state contains event without state_key.") - })?; - hashmap.insert(state_key, pdu); - } - Ok(hashmap) - } - /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). pub fn state_get( &self, + room_id: &RoomId, state_hash: &StateHashId, event_type: &EventType, state_key: &str, ) -> Result> { - let mut key = state_hash.to_vec(); - key.push(0xff); - key.extend_from_slice(&event_type.to_string().as_bytes()); + let mut key = event_type.to_string().as_bytes().to_vec(); key.push(0xff); key.extend_from_slice(&state_key.as_bytes()); - self.stateid_pduid.get(&key)?.map_or(Ok(None), |pdu_id| { - Ok::<_, Error>(Some(( - pdu_id.clone(), - serde_json::from_slice::( - &self.pduid_pdu.get(&pdu_id)?.ok_or_else(|| { - Error::bad_database("PDU in state not found in database.") - })?, - ) - .map_err(|_| Error::bad_database("Invalid PDU bytes in room state."))?, - ))) - }) + let short = self.statekey_short.get(&key)?; + + if let Some(short) = short { + let mut stateid = state_hash.to_vec(); + stateid.push(0xff); + stateid.extend_from_slice(&short); + + self.stateid_pduid + .get(&stateid)? + .map_or(Ok(None), |pdu_id_short| { + let mut pdu_id = room_id.as_bytes().to_vec(); + pdu_id.push(0xff); + pdu_id.extend_from_slice(&pdu_id_short); + + Ok::<_, Error>(Some(( + pdu_id.clone().into(), + serde_json::from_slice::( + &self.pduid_pdu.get(&pdu_id)?.ok_or_else(|| { + Error::bad_database("PDU in state not found in database.") + })?, + ) + .map_err(|_| Error::bad_database("Invalid PDU bytes in room state."))?, + ))) + }) + } else { + return Ok(None); + } } /// Returns the last state hash key added to the db. @@ -260,6 +249,7 @@ impl Rooms { &self, room_id: &RoomId, state: HashMap<(EventType, String), Vec>, + globals: &super::globals::Globals, ) -> Result<()> { let state_hash = self.calculate_hash(&state.values().map(|pdu_id| &**pdu_id).collect::>())?; @@ -267,11 +257,29 @@ impl Rooms { prefix.push(0xff); for ((event_type, state_key), pdu_id) in state { + let mut statekey = event_type.as_ref().as_bytes().to_vec(); + statekey.push(0xff); + statekey.extend_from_slice(&state_key.as_bytes()); + + let short = match self.statekey_short.get(&statekey)? { + Some(short) => utils::u64_from_bytes(&short) + .map_err(|_| Error::bad_database("Invalid short bytes in statekey_short."))?, + None => { + let short = globals.next_count()?; + self.statekey_short + .insert(&statekey, &short.to_be_bytes())?; + short + } + }; + + let pdu_id_short = pdu_id + .splitn(2, |&b| b == 0xff) + .nth(1) + .ok_or_else(|| Error::bad_database("Invalid pduid in state."))?; + let mut state_id = prefix.clone(); - state_id.extend_from_slice(&event_type.as_ref().as_bytes()); - state_id.push(0xff); - state_id.extend_from_slice(&state_key.as_bytes()); - self.stateid_pduid.insert(state_id, pdu_id)?; + state_id.extend_from_slice(&short.to_be_bytes()); + self.stateid_pduid.insert(state_id, pdu_id_short)?; } self.roomid_statehash @@ -283,25 +291,12 @@ impl Rooms { /// Returns the full room state. pub fn room_state_full(&self, room_id: &RoomId) -> Result> { if let Some(current_state_hash) = self.current_state_hash(room_id)? { - self.state_full(¤t_state_hash) + self.state_full(&room_id, ¤t_state_hash) } else { Ok(BTreeMap::new()) } } - /// Returns all state entries for this type. - pub fn room_state_type( - &self, - room_id: &RoomId, - event_type: &EventType, - ) -> Result> { - if let Some(current_state_hash) = self.current_state_hash(room_id)? { - self.state_type(¤t_state_hash, event_type) - } else { - Ok(HashMap::new()) - } - } - /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). pub fn room_state_get( &self, @@ -310,7 +305,7 @@ impl Rooms { state_key: &str, ) -> Result> { if let Some(current_state_hash) = self.current_state_hash(room_id)? { - self.state_get(¤t_state_hash, event_type, state_key) + self.state_get(&room_id, ¤t_state_hash, event_type, state_key) } else { Ok(None) } @@ -593,7 +588,12 @@ impl Rooms { /// This adds all current state events (not including the incoming event) /// to `stateid_pduid` and adds the incoming event to `pduid_statehash`. /// The incoming event is the `pdu_id` passed to this method. - pub fn append_to_state(&self, new_pdu_id: &[u8], new_pdu: &PduEvent) -> Result { + pub fn append_to_state( + &self, + new_pdu_id: &[u8], + new_pdu: &PduEvent, + globals: &super::globals::Globals, + ) -> Result { let old_state = if let Some(old_state_hash) = self.roomid_statehash.get(new_pdu.room_id.as_bytes())? { // Store state for event. The state does not include the event itself. @@ -608,7 +608,7 @@ impl Rooms { self.stateid_pduid .scan_prefix(&prefix) .filter_map(|pdu| pdu.map_err(|e| error!("{}", e)).ok()) - // Chop the old state_hash out leaving behind the (EventType, StateKey) + // Chop the old state_hash out leaving behind the short key (u64) .map(|(k, v)| (k.subslice(prefix.len(), k.len() - prefix.len()), v)) .collect::>() } else { @@ -620,7 +620,23 @@ impl Rooms { let mut pdu_key = new_pdu.kind.as_ref().as_bytes().to_vec(); pdu_key.push(0xff); pdu_key.extend_from_slice(state_key.as_bytes()); - new_state.insert(pdu_key.into(), new_pdu_id.into()); + + let short = match self.statekey_short.get(&pdu_key)? { + Some(short) => utils::u64_from_bytes(&short) + .map_err(|_| Error::bad_database("Invalid short bytes in statekey_short."))?, + None => { + let short = globals.next_count()?; + self.statekey_short.insert(&pdu_key, &short.to_be_bytes())?; + short + } + }; + + let new_pdu_id_short = new_pdu_id + .splitn(2, |&b| b == 0xff) + .nth(1) + .ok_or_else(|| Error::bad_database("Invalid pduid in state."))?; + + new_state.insert((&short.to_be_bytes()).into(), new_pdu_id_short.into()); let new_state_hash = self.calculate_hash(&new_state.values().map(|b| &**b).collect::>())?; @@ -628,12 +644,10 @@ impl Rooms { let mut key = new_state_hash.to_vec(); key.push(0xff); - // TODO: we could avoid writing to the DB on every state event by keeping - // track of the delta and write that every so often - for (key_without_prefix, pdu_id) in new_state { + for (short, short_pdu_id) in new_state { let mut state_id = key.clone(); - state_id.extend_from_slice(&key_without_prefix); - self.stateid_pduid.insert(&state_id, &pdu_id)?; + state_id.extend_from_slice(&short); + self.stateid_pduid.insert(&state_id, &short_pdu_id)?; } self.roomid_statehash @@ -887,7 +901,7 @@ impl Rooms { // We append to state before appending the pdu, so we don't have a moment in time with the // pdu without it's state. This is okay because append_pdu can't fail. - self.append_to_state(&pdu_id, &pdu)?; + self.append_to_state(&pdu_id, &pdu, &globals)?; self.append_pdu( &pdu, diff --git a/src/database/sending.rs b/src/database/sending.rs index 7ce7d63..f21b154 100644 --- a/src/database/sending.rs +++ b/src/database/sending.rs @@ -1,21 +1,29 @@ -use std::{collections::HashMap, convert::TryFrom, time::SystemTime}; +use std::{ + collections::HashMap, + convert::TryFrom, + fmt::Debug, + sync::Arc, + time::{Duration, Instant, SystemTime}, +}; use crate::{appservice_server, server_server, utils, Error, PduEvent, Result}; use federation::transactions::send_transaction_message; use log::warn; use rocket::futures::stream::{FuturesUnordered, StreamExt}; use ruma::{ - api::{appservice, federation}, + api::{appservice, federation, OutgoingRequest}, ServerName, }; use sled::IVec; use tokio::select; +use tokio::sync::Semaphore; #[derive(Clone)] pub struct Sending { /// The state for a given state hash. pub(super) servernamepduids: sled::Tree, // ServernamePduId = (+)ServerName + PduId pub(super) servercurrentpdus: sled::Tree, // ServerCurrentPdus = (+)ServerName + PduId (pduid can be empty for reservation) + pub(super) maximum_requests: Arc, } impl Sending { @@ -40,35 +48,7 @@ impl Sending { for (server, pdu, is_appservice) in servercurrentpdus .iter() .filter_map(|r| r.ok()) - .map(|(key, _)| { - let mut parts = key.splitn(2, |&b| b == 0xff); - let server = parts.next().expect("splitn always returns one element"); - let pdu = parts.next().ok_or_else(|| { - Error::bad_database("Invalid bytes in servercurrentpdus.") - })?; - - let server = utils::string_from_bytes(&server).map_err(|_| { - Error::bad_database("Invalid server bytes in server_currenttransaction") - })?; - - // Appservices start with a plus - let (server, is_appservice) = if server.starts_with("+") { - (&server[1..], true) - } else { - (&*server, false) - }; - - Ok::<_, Error>(( - Box::::try_from(server).map_err(|_| { - Error::bad_database( - "Invalid server string in server_currenttransaction", - ) - })?, - IVec::from(pdu), - is_appservice, - )) - }) - .filter_map(|r| r.ok()) + .filter_map(|(key, _)| Self::parse_servercurrentpdus(key).ok()) .filter(|(_, pdu, _)| !pdu.is_empty()) // Skip reservation key .take(50) // This should not contain more than 50 anyway @@ -90,6 +70,8 @@ impl Sending { )); } + let mut last_failed_try: HashMap, (u32, Instant)> = HashMap::new(); + let mut subscriber = servernamepduids.watch_prefix(b""); loop { select! { @@ -140,9 +122,24 @@ impl Sending { // servercurrentpdus with the prefix should be empty now } } - Err((server, _is_appservice, e)) => { - warn!("Couldn't send transaction to {}: {}", server, e) - // TODO: exponential backoff + Err((server, is_appservice, e)) => { + warn!("Couldn't send transaction to {}: {}", server, e); + let mut prefix = if is_appservice { + "+".as_bytes().to_vec() + } else { + Vec::new() + }; + prefix.extend_from_slice(server.as_bytes()); + prefix.push(0xff); + last_failed_try.insert(server.clone(), match last_failed_try.get(&server) { + Some(last_failed) => { + (last_failed.0+1, Instant::now()) + }, + None => { + (1, Instant::now()) + } + }); + servercurrentpdus.remove(&prefix).unwrap(); } }; }, @@ -174,8 +171,19 @@ impl Sending { .ok() .map(|pdu_id| (server, is_appservice, pdu_id)) ) - // TODO: exponential backoff .filter(|(server, is_appservice, _)| { + if last_failed_try.get(server).map_or(false, |(tries, instant)| { + // Fail if a request has failed recently (exponential backoff) + let mut min_elapsed_duration = Duration::from_secs(60) * *tries * *tries; + if min_elapsed_duration > Duration::from_secs(60*60*24) { + min_elapsed_duration = Duration::from_secs(60*60*24); + } + + instant.elapsed() < min_elapsed_duration + }) { + return false; + } + let mut prefix = if *is_appservice { "+".as_bytes().to_vec() } else { @@ -308,4 +316,63 @@ impl Sending { .map_err(|e| (server, is_appservice, e)) } } + + fn parse_servercurrentpdus(key: IVec) -> Result<(Box, IVec, bool)> { + let mut parts = key.splitn(2, |&b| b == 0xff); + let server = parts.next().expect("splitn always returns one element"); + let pdu = parts + .next() + .ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?; + + let server = utils::string_from_bytes(&server).map_err(|_| { + Error::bad_database("Invalid server bytes in server_currenttransaction") + })?; + + // Appservices start with a plus + let (server, is_appservice) = if server.starts_with("+") { + (&server[1..], true) + } else { + (&*server, false) + }; + + Ok::<_, Error>(( + Box::::try_from(server).map_err(|_| { + Error::bad_database("Invalid server string in server_currenttransaction") + })?, + IVec::from(pdu), + is_appservice, + )) + } + + pub async fn send_federation_request( + &self, + globals: &crate::database::globals::Globals, + destination: Box, + request: T, + ) -> Result + where + T: Debug, + { + let permit = self.maximum_requests.acquire().await; + let response = server_server::send_request(globals, destination, request).await; + drop(permit); + + response + } + + pub async fn send_appservice_request( + &self, + globals: &crate::database::globals::Globals, + registration: serde_yaml::Value, + request: T, + ) -> Result + where + T: Debug, + { + let permit = self.maximum_requests.acquire().await; + let response = appservice_server::send_request(globals, registration, request).await; + drop(permit); + + response + } } diff --git a/src/error.rs b/src/error.rs index 7d4a751..d8e9d02 100644 --- a/src/error.rs +++ b/src/error.rs @@ -121,7 +121,7 @@ impl log::Log for ConduitLogger { fn log(&self, record: &log::Record<'_>) { let output = format!("{} - {}", record.level(), record.args()); - println!("{}", output); + eprintln!("{}", output); if self.enabled(record.metadata()) && record diff --git a/src/main.rs b/src/main.rs index 9574894..38a2ec9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -18,7 +18,7 @@ pub use pdu::PduEvent; pub use rocket::State; pub use ruma_wrapper::{ConduitResult, Ruma, RumaResponse}; -use rocket::{fairing::AdHoc, routes}; +use rocket::{catch, catchers, fairing::AdHoc, routes, Request}; fn setup_rocket() -> rocket::Rocket { // Force log level off, so we can use our own logger @@ -70,6 +70,7 @@ fn setup_rocket() -> rocket::Rocket { client_server::get_backup_key_sessions_route, client_server::get_backup_keys_route, client_server::set_read_marker_route, + client_server::set_receipt_route, client_server::create_typing_event_route, client_server::create_room_route, client_server::redact_event_route, @@ -134,6 +135,7 @@ fn setup_rocket() -> rocket::Rocket { server_server::get_profile_information_route, ], ) + .register(catchers![not_found_catcher]) .attach(AdHoc::on_attach("Config", |rocket| async { let data = Database::load_or_create(rocket.figment().extract().expect("config is valid")) @@ -157,3 +159,8 @@ fn setup_rocket() -> rocket::Rocket { async fn main() { setup_rocket().launch().await.unwrap(); } + +#[catch(404)] +fn not_found_catcher(_req: &'_ Request<'_>) -> String { + "404 Not Found".to_owned() +} diff --git a/src/server_server.rs b/src/server_server.rs index 7d12c54..0653959 100644 --- a/src/server_server.rs +++ b/src/server_server.rs @@ -490,7 +490,7 @@ pub async fn send_transaction_message_route<'a>( pdu_id.push(0xff); pdu_id.extend_from_slice(&count.to_be_bytes()); - db.rooms.append_to_state(&pdu_id, &pdu)?; + db.rooms.append_to_state(&pdu_id, &pdu, &db.globals)?; db.rooms.append_pdu( &pdu,