From 05821d6fd5bdf96b2e8615bb527cb07ec87f4c6e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timo=20K=C3=B6sters?= Date: Wed, 30 Jun 2021 09:52:01 +0200 Subject: [PATCH] improvement: pdu cache, /sync cache --- Cargo.lock | 1 + Cargo.toml | 1 + src/client_server/directory.rs | 216 ++++++++++++++++---------------- src/client_server/membership.rs | 23 ++-- src/client_server/profile.rs | 6 +- src/client_server/room.rs | 6 +- src/client_server/state.rs | 10 +- src/client_server/sync.rs | 172 +++++++++++++++++++++---- src/database.rs | 2 + src/database/abstraction.rs | 2 +- src/database/globals.rs | 19 ++- src/database/pusher.rs | 2 +- src/database/rooms.rs | 69 ++++++---- src/error.rs | 31 +++-- src/ruma_wrapper.rs | 77 +++++++----- src/server_server.rs | 30 ++--- 16 files changed, 424 insertions(+), 243 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c3d7408..c9bce96 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -243,6 +243,7 @@ dependencies = [ "image", "jsonwebtoken", "log", + "lru-cache", "opentelemetry", "opentelemetry-jaeger", "pretty_env_logger", diff --git a/Cargo.toml b/Cargo.toml index 96260ec..bb44918 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -73,6 +73,7 @@ tracing-subscriber = "0.2.16" tracing-opentelemetry = "0.11.0" opentelemetry-jaeger = "0.11.0" pretty_env_logger = "0.4.0" +lru-cache = "0.1.2" [features] default = ["conduit_bin", "backend_sled"] diff --git a/src/client_server/directory.rs b/src/client_server/directory.rs index be5501a..1b6b1d7 100644 --- a/src/client_server/directory.rs +++ b/src/client_server/directory.rs @@ -200,84 +200,84 @@ pub async fn get_public_rooms_filtered_helper( } } - let mut all_rooms = db - .rooms - .public_rooms() - .map(|room_id| { - let room_id = room_id?; + let mut all_rooms = + db.rooms + .public_rooms() + .map(|room_id| { + let room_id = room_id?; - let chunk = PublicRoomsChunk { - aliases: Vec::new(), - canonical_alias: db - .rooms - .room_state_get(&room_id, &EventType::RoomCanonicalAlias, "")? - .map_or(Ok::<_, Error>(None), |s| { - Ok( - serde_json::from_value::< + let chunk = PublicRoomsChunk { + aliases: Vec::new(), + canonical_alias: db + .rooms + .room_state_get(&room_id, &EventType::RoomCanonicalAlias, "")? + .map_or(Ok::<_, Error>(None), |s| { + Ok(serde_json::from_value::< Raw, - >(s.content) + >(s.content.clone()) .expect("from_value::> can never fail") .deserialize() .map_err(|_| { Error::bad_database("Invalid canonical alias event in database.") })? - .alias, - ) - })?, - name: db - .rooms - .room_state_get(&room_id, &EventType::RoomName, "")? - .map_or(Ok::<_, Error>(None), |s| { - Ok( - serde_json::from_value::>(s.content) - .expect("from_value::> can never fail") - .deserialize() - .map_err(|_| { - Error::bad_database("Invalid room name event in database.") - })? - .name() - .map(|n| n.to_owned()), - ) - })?, - num_joined_members: (db.rooms.room_members(&room_id).count() as u32).into(), - topic: db - .rooms - .room_state_get(&room_id, &EventType::RoomTopic, "")? - .map_or(Ok::<_, Error>(None), |s| { - Ok(Some( - serde_json::from_value::>(s.content) + .alias) + })?, + name: db + .rooms + .room_state_get(&room_id, &EventType::RoomName, "")? + .map_or(Ok::<_, Error>(None), |s| { + Ok(serde_json::from_value::>( + s.content.clone(), + ) + .expect("from_value::> can never fail") + .deserialize() + .map_err(|_| { + Error::bad_database("Invalid room name event in database.") + })? + .name() + .map(|n| n.to_owned())) + })?, + num_joined_members: (db.rooms.room_members(&room_id).count() as u32).into(), + topic: db + .rooms + .room_state_get(&room_id, &EventType::RoomTopic, "")? + .map_or(Ok::<_, Error>(None), |s| { + Ok(Some( + serde_json::from_value::>( + s.content.clone(), + ) .expect("from_value::> can never fail") .deserialize() .map_err(|_| { Error::bad_database("Invalid room topic event in database.") })? .topic, - )) - })?, - world_readable: db - .rooms - .room_state_get(&room_id, &EventType::RoomHistoryVisibility, "")? - .map_or(Ok::<_, Error>(false), |s| { - Ok(serde_json::from_value::< - Raw, - >(s.content) - .expect("from_value::> can never fail") - .deserialize() - .map_err(|_| { - Error::bad_database( - "Invalid room history visibility event in database.", - ) - })? - .history_visibility - == history_visibility::HistoryVisibility::WorldReadable) - })?, - guest_can_join: db - .rooms - .room_state_get(&room_id, &EventType::RoomGuestAccess, "")? - .map_or(Ok::<_, Error>(false), |s| { - Ok( + )) + })?, + world_readable: db + .rooms + .room_state_get(&room_id, &EventType::RoomHistoryVisibility, "")? + .map_or(Ok::<_, Error>(false), |s| { + Ok(serde_json::from_value::< + Raw, + >(s.content.clone()) + .expect("from_value::> can never fail") + .deserialize() + .map_err(|_| { + Error::bad_database( + "Invalid room history visibility event in database.", + ) + })? + .history_visibility + == history_visibility::HistoryVisibility::WorldReadable) + })?, + guest_can_join: db + .rooms + .room_state_get(&room_id, &EventType::RoomGuestAccess, "")? + .map_or(Ok::<_, Error>(false), |s| { + Ok( serde_json::from_value::>( - s.content, + s.content.clone(), ) .expect("from_value::> can never fail") .deserialize() @@ -287,61 +287,63 @@ pub async fn get_public_rooms_filtered_helper( .guest_access == guest_access::GuestAccess::CanJoin, ) - })?, - avatar_url: db - .rooms - .room_state_get(&room_id, &EventType::RoomAvatar, "")? - .map(|s| { - Ok::<_, Error>( - serde_json::from_value::>(s.content) + })?, + avatar_url: db + .rooms + .room_state_get(&room_id, &EventType::RoomAvatar, "")? + .map(|s| { + Ok::<_, Error>( + serde_json::from_value::>( + s.content.clone(), + ) .expect("from_value::> can never fail") .deserialize() .map_err(|_| { Error::bad_database("Invalid room avatar event in database.") })? .url, - ) - }) - .transpose()? - // url is now an Option so we must flatten - .flatten(), - room_id, - }; - Ok(chunk) - }) - .filter_map(|r: Result<_>| r.ok()) // Filter out buggy rooms - .filter(|chunk| { - if let Some(query) = filter - .generic_search_term - .as_ref() - .map(|q| q.to_lowercase()) - { - if let Some(name) = &chunk.name { - if name.to_lowercase().contains(&query) { - return true; + ) + }) + .transpose()? + // url is now an Option so we must flatten + .flatten(), + room_id, + }; + Ok(chunk) + }) + .filter_map(|r: Result<_>| r.ok()) // Filter out buggy rooms + .filter(|chunk| { + if let Some(query) = filter + .generic_search_term + .as_ref() + .map(|q| q.to_lowercase()) + { + if let Some(name) = &chunk.name { + if name.to_lowercase().contains(&query) { + return true; + } } - } - if let Some(topic) = &chunk.topic { - if topic.to_lowercase().contains(&query) { - return true; + if let Some(topic) = &chunk.topic { + if topic.to_lowercase().contains(&query) { + return true; + } } - } - if let Some(canonical_alias) = &chunk.canonical_alias { - if canonical_alias.as_str().to_lowercase().contains(&query) { - return true; + if let Some(canonical_alias) = &chunk.canonical_alias { + if canonical_alias.as_str().to_lowercase().contains(&query) { + return true; + } } - } - false - } else { - // No search term - true - } - }) - // We need to collect all, so we can sort by member count - .collect::>(); + false + } else { + // No search term + true + } + }) + // We need to collect all, so we can sort by member count + .collect::>(); all_rooms.sort_by(|l, r| r.num_joined_members.cmp(&l.num_joined_members)); diff --git a/src/client_server/membership.rs b/src/client_server/membership.rs index 2dfa077..87fead2 100644 --- a/src/client_server/membership.rs +++ b/src/client_server/membership.rs @@ -189,7 +189,8 @@ pub async fn kick_user_route( ErrorKind::BadState, "Cannot kick member that's not in the room.", ))? - .content, + .content + .clone(), ) .expect("Raw::from_value always works") .deserialize() @@ -245,11 +246,12 @@ pub async fn ban_user_route( third_party_invite: None, }), |event| { - let mut event = - serde_json::from_value::>(event.content) - .expect("Raw::from_value always works") - .deserialize() - .map_err(|_| Error::bad_database("Invalid member event in database."))?; + let mut event = serde_json::from_value::>( + event.content.clone(), + ) + .expect("Raw::from_value always works") + .deserialize() + .map_err(|_| Error::bad_database("Invalid member event in database."))?; event.membership = ruma::events::room::member::MembershipState::Ban; Ok(event) }, @@ -295,7 +297,8 @@ pub async fn unban_user_route( ErrorKind::BadState, "Cannot unban a user who is not banned.", ))? - .content, + .content + .clone(), ) .expect("from_value::> can never fail") .deserialize() @@ -753,7 +756,7 @@ pub async fn invite_helper( let create_prev_event = if prev_events.len() == 1 && Some(&prev_events[0]) == create_event.as_ref().map(|c| &c.event_id) { - create_event.map(Arc::new) + create_event } else { None }; @@ -792,10 +795,10 @@ pub async fn invite_helper( let mut unsigned = BTreeMap::new(); if let Some(prev_pdu) = db.rooms.room_state_get(room_id, &kind, &state_key)? { - unsigned.insert("prev_content".to_owned(), prev_pdu.content); + unsigned.insert("prev_content".to_owned(), prev_pdu.content.clone()); unsigned.insert( "prev_sender".to_owned(), - serde_json::to_value(prev_pdu.sender).expect("UserId::to_value always works"), + serde_json::to_value(&prev_pdu.sender).expect("UserId::to_value always works"), ); } diff --git a/src/client_server/profile.rs b/src/client_server/profile.rs index 32bb608..4e9a37b 100644 --- a/src/client_server/profile.rs +++ b/src/client_server/profile.rs @@ -53,7 +53,8 @@ pub async fn set_displayname_route( room.", ) })? - .content, + .content + .clone(), ) .expect("from_value::> can never fail") .deserialize() @@ -154,7 +155,8 @@ pub async fn set_avatar_url_route( room.", ) })? - .content, + .content + .clone(), ) .expect("from_value::> can never fail") .deserialize() diff --git a/src/client_server/room.rs b/src/client_server/room.rs index 3f91324..b33b550 100644 --- a/src/client_server/room.rs +++ b/src/client_server/room.rs @@ -362,7 +362,8 @@ pub async fn upgrade_room_route( db.rooms .room_state_get(&body.room_id, &EventType::RoomCreate, "")? .ok_or_else(|| Error::bad_database("Found room without m.room.create event."))? - .content, + .content + .clone(), ) .expect("Raw::from_value always works") .deserialize() @@ -463,7 +464,8 @@ pub async fn upgrade_room_route( db.rooms .room_state_get(&body.room_id, &EventType::RoomPowerLevels, "")? .ok_or_else(|| Error::bad_database("Found room without m.room.create event."))? - .content, + .content + .clone(), ) .expect("database contains invalid PDU") .deserialize() diff --git a/src/client_server/state.rs b/src/client_server/state.rs index c431ac0..be52834 100644 --- a/src/client_server/state.rs +++ b/src/client_server/state.rs @@ -92,7 +92,7 @@ pub async fn get_state_events_route( db.rooms .room_state_get(&body.room_id, &EventType::RoomHistoryVisibility, "")? .map(|event| { - serde_json::from_value::(event.content) + serde_json::from_value::(event.content.clone()) .map_err(|_| { Error::bad_database( "Invalid room history visibility event in database.", @@ -139,7 +139,7 @@ pub async fn get_state_events_for_key_route( db.rooms .room_state_get(&body.room_id, &EventType::RoomHistoryVisibility, "")? .map(|event| { - serde_json::from_value::(event.content) + serde_json::from_value::(event.content.clone()) .map_err(|_| { Error::bad_database( "Invalid room history visibility event in database.", @@ -165,7 +165,7 @@ pub async fn get_state_events_for_key_route( ))?; Ok(get_state_events_for_key::Response { - content: serde_json::from_value(event.content) + content: serde_json::from_value(event.content.clone()) .map_err(|_| Error::bad_database("Invalid event content in database"))?, } .into()) @@ -190,7 +190,7 @@ pub async fn get_state_events_for_empty_key_route( db.rooms .room_state_get(&body.room_id, &EventType::RoomHistoryVisibility, "")? .map(|event| { - serde_json::from_value::(event.content) + serde_json::from_value::(event.content.clone()) .map_err(|_| { Error::bad_database( "Invalid room history visibility event in database.", @@ -216,7 +216,7 @@ pub async fn get_state_events_for_empty_key_route( ))?; Ok(get_state_events_for_key::Response { - content: serde_json::from_value(event.content) + content: serde_json::from_value(event.content.clone()) .map_err(|_| Error::bad_database("Invalid event content in database"))?, } .into()) diff --git a/src/client_server/sync.rs b/src/client_server/sync.rs index 1c078e9..69511fa 100644 --- a/src/client_server/sync.rs +++ b/src/client_server/sync.rs @@ -1,21 +1,22 @@ use super::State; -use crate::{ConduitResult, Database, Error, Result, Ruma}; +use crate::{ConduitResult, Database, Error, Result, Ruma, RumaResponse}; use log::error; use ruma::{ - api::client::r0::sync::sync_events, + api::client::r0::{sync::sync_events, uiaa::UiaaResponse}, events::{room::member::MembershipState, AnySyncEphemeralRoomEvent, EventType}, serde::Raw, - RoomId, UserId, + DeviceId, RoomId, UserId, }; - -#[cfg(feature = "conduit_bin")] -use rocket::{get, tokio}; use std::{ - collections::{hash_map, BTreeMap, HashMap, HashSet}, + collections::{btree_map::Entry, hash_map, BTreeMap, HashMap, HashSet}, convert::{TryFrom, TryInto}, sync::Arc, time::Duration, }; +use tokio::sync::watch::Sender; + +#[cfg(feature = "conduit_bin")] +use rocket::{get, tokio}; /// # `GET /_matrix/client/r0/sync` /// @@ -36,21 +37,134 @@ use std::{ pub async fn sync_events_route( db: State<'_, Arc>, body: Ruma>, -) -> ConduitResult { +) -> std::result::Result, RumaResponse> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_ref().expect("user is authenticated"); + let mut rx = match db + .globals + .sync_receivers + .write() + .unwrap() + .entry((sender_user.clone(), sender_device.clone())) + { + Entry::Vacant(v) => { + let (tx, rx) = tokio::sync::watch::channel(None); + + tokio::spawn(sync_helper_wrapper( + Arc::clone(&db), + sender_user.clone(), + sender_device.clone(), + body.since.clone(), + body.full_state, + body.timeout, + tx, + )); + + v.insert((body.since.clone(), rx)).1.clone() + } + Entry::Occupied(mut o) => { + if o.get().0 != body.since { + let (tx, rx) = tokio::sync::watch::channel(None); + + tokio::spawn(sync_helper_wrapper( + Arc::clone(&db), + sender_user.clone(), + sender_device.clone(), + body.since.clone(), + body.full_state, + body.timeout, + tx, + )); + + o.insert((body.since.clone(), rx.clone())); + + rx + } else { + o.get().1.clone() + } + } + }; + + let we_have_to_wait = rx.borrow().is_none(); + if we_have_to_wait { + let _ = rx.changed().await; + } + + let result = match rx + .borrow() + .as_ref() + .expect("When sync channel changes it's always set to some") + { + Ok(response) => Ok(response.clone()), + Err(error) => Err(error.to_response()), + }; + + result +} + +pub async fn sync_helper_wrapper( + db: Arc, + sender_user: UserId, + sender_device: Box, + since: Option, + full_state: bool, + timeout: Option, + tx: Sender>>, +) { + let r = sync_helper( + Arc::clone(&db), + sender_user.clone(), + sender_device.clone(), + since.clone(), + full_state, + timeout, + ) + .await; + + if let Ok((_, caching_allowed)) = r { + if !caching_allowed { + match db + .globals + .sync_receivers + .write() + .unwrap() + .entry((sender_user, sender_device)) + { + Entry::Occupied(o) => { + // Only remove if the device didn't start a different /sync already + if o.get().0 == since { + o.remove(); + } + } + Entry::Vacant(_) => {} + } + } + } + + let _ = tx.send(Some(r.map(|(r, _)| r.into()))); +} + +async fn sync_helper( + db: Arc, + sender_user: UserId, + sender_device: Box, + since: Option, + full_state: bool, + timeout: Option, + // bool = caching allowed +) -> std::result::Result<(sync_events::Response, bool), Error> { // TODO: match body.set_presence { db.rooms.edus.ping_presence(&sender_user)?; // Setup watchers, so if there's no response, we can wait for them - let watcher = db.watch(sender_user, sender_device); + let watcher = db.watch(&sender_user, &sender_device); - let next_batch = db.globals.current_count()?.to_string(); + let next_batch = db.globals.current_count()?; + let next_batch_string = next_batch.to_string(); let mut joined_rooms = BTreeMap::new(); - let since = body - .since + let since = since .clone() .and_then(|string| string.parse().ok()) .unwrap_or(0); @@ -114,10 +228,11 @@ pub async fn sync_events_route( // since and the current room state, meaning there should be no updates. // The inner Option is None when there is an event, but there is no state hash associated // with it. This can happen for the RoomCreate event, so all updates should arrive. - let first_pdu_before_since = db.rooms.pdus_until(sender_user, &room_id, since).next(); + let first_pdu_before_since = db.rooms.pdus_until(&sender_user, &room_id, since).next(); + let pdus_after_since = db .rooms - .pdus_after(sender_user, &room_id, since) + .pdus_after(&sender_user, &room_id, since) .next() .is_some(); @@ -256,11 +371,11 @@ pub async fn sync_events_route( .flatten() .filter(|user_id| { // Don't send key updates from the sender to the sender - sender_user != user_id + &sender_user != user_id }) .filter(|user_id| { // Only send keys if the sender doesn't share an encrypted room with the target already - !share_encrypted_room(&db, sender_user, user_id, &room_id) + !share_encrypted_room(&db, &sender_user, user_id, &room_id) .unwrap_or(false) }), ); @@ -335,7 +450,7 @@ pub async fn sync_events_route( let state_events = if joined_since_last_sync { current_state - .into_iter() + .iter() .map(|(_, pdu)| pdu.to_sync_state_event()) .collect() } else { @@ -520,7 +635,7 @@ pub async fn sync_events_route( account_data: sync_events::RoomAccountData { events: Vec::new() }, timeline: sync_events::Timeline { limited: false, - prev_batch: Some(next_batch.clone()), + prev_batch: Some(next_batch_string.clone()), events: Vec::new(), }, state: sync_events::State { @@ -573,10 +688,10 @@ pub async fn sync_events_route( // Remove all to-device events the device received *last time* db.users - .remove_to_device_events(sender_user, sender_device, since)?; + .remove_to_device_events(&sender_user, &sender_device, since)?; let response = sync_events::Response { - next_batch, + next_batch: next_batch_string, rooms: sync_events::Rooms { leave: left_rooms, join: joined_rooms, @@ -604,20 +719,22 @@ pub async fn sync_events_route( changed: device_list_updates.into_iter().collect(), left: device_list_left.into_iter().collect(), }, - device_one_time_keys_count: if db.users.last_one_time_keys_update(sender_user)? > since + device_one_time_keys_count: if db.users.last_one_time_keys_update(&sender_user)? > since || since == 0 { - db.users.count_one_time_keys(sender_user, sender_device)? + db.users.count_one_time_keys(&sender_user, &sender_device)? } else { BTreeMap::new() }, to_device: sync_events::ToDevice { - events: db.users.get_to_device_events(sender_user, sender_device)?, + events: db + .users + .get_to_device_events(&sender_user, &sender_device)?, }, }; // TODO: Retry the endpoint instead of returning (waiting for #118) - if !body.full_state + if !full_state && response.rooms.is_empty() && response.presence.is_empty() && response.account_data.is_empty() @@ -627,14 +744,15 @@ pub async fn sync_events_route( { // Hang a few seconds so requests are not spammed // Stop hanging if new info arrives - let mut duration = body.timeout.unwrap_or_default(); + let mut duration = timeout.unwrap_or_default(); if duration.as_secs() > 30 { duration = Duration::from_secs(30); } let _ = tokio::time::timeout(duration, watcher).await; + Ok((response, false)) + } else { + Ok((response, since != next_batch)) // Only cache if we made progress } - - Ok(response.into()) } #[tracing::instrument(skip(db))] diff --git a/src/database.rs b/src/database.rs index 2846928..8968010 100644 --- a/src/database.rs +++ b/src/database.rs @@ -17,6 +17,7 @@ use crate::{utils, Error, Result}; use abstraction::DatabaseEngine; use directories::ProjectDirs; use log::error; +use lru_cache::LruCache; use rocket::futures::{channel::mpsc, stream::FuturesUnordered, StreamExt}; use ruma::{DeviceId, ServerName, UserId}; use serde::Deserialize; @@ -189,6 +190,7 @@ impl Database { eventid_outlierpdu: builder.open_tree("eventid_outlierpdu")?, prevevent_parent: builder.open_tree("prevevent_parent")?, + pdu_cache: RwLock::new(LruCache::new(1_000_000)), }, account_data: account_data::AccountData { roomuserdataid_accountdata: builder.open_tree("roomuserdataid_accountdata")?, diff --git a/src/database/abstraction.rs b/src/database/abstraction.rs index f81c9de..bf292eb 100644 --- a/src/database/abstraction.rs +++ b/src/database/abstraction.rs @@ -65,7 +65,7 @@ impl DatabaseEngine for SledEngine { sled::Config::default() .path(&config.database_path) .cache_capacity(config.cache_capacity as u64) - .use_compression(true) + .use_compression(false) .open()?, ))) } diff --git a/src/database/globals.rs b/src/database/globals.rs index 1ce87bd..4859ef4 100644 --- a/src/database/globals.rs +++ b/src/database/globals.rs @@ -1,8 +1,11 @@ -use crate::{database::Config, utils, Error, Result}; +use crate::{database::Config, utils, ConduitResult, Error, Result}; use log::{error, info}; use ruma::{ - api::federation::discovery::{ServerSigningKeys, VerifyKey}, - EventId, MilliSecondsSinceUnixEpoch, ServerName, ServerSigningKeyId, + api::{ + client::r0::sync::sync_events, + federation::discovery::{ServerSigningKeys, VerifyKey}, + }, + DeviceId, EventId, MilliSecondsSinceUnixEpoch, ServerName, ServerSigningKeyId, UserId, }; use rustls::{ServerCertVerifier, WebPKIVerifier}; use std::{ @@ -35,6 +38,15 @@ pub struct Globals { pub bad_event_ratelimiter: Arc>>, pub bad_signature_ratelimiter: Arc, RateLimitState>>>, pub servername_ratelimiter: Arc, Arc>>>, + pub sync_receivers: RwLock< + BTreeMap< + (UserId, Box), + ( + Option, + tokio::sync::watch::Receiver>>, + ), // since, rx + >, + >, } struct MatrixServerVerifier { @@ -153,6 +165,7 @@ impl Globals { bad_event_ratelimiter: Arc::new(RwLock::new(BTreeMap::new())), bad_signature_ratelimiter: Arc::new(RwLock::new(BTreeMap::new())), servername_ratelimiter: Arc::new(RwLock::new(BTreeMap::new())), + sync_receivers: RwLock::new(BTreeMap::new()), }; fs::create_dir_all(s.get_media_folder())?; diff --git a/src/database/pusher.rs b/src/database/pusher.rs index 358c3c9..a27bf2c 100644 --- a/src/database/pusher.rs +++ b/src/database/pusher.rs @@ -203,7 +203,7 @@ pub fn get_actions<'a>( .rooms .room_state_get(&pdu.room_id, &EventType::RoomPowerLevels, "")? .map(|ev| { - serde_json::from_value(ev.content) + serde_json::from_value(ev.content.clone()) .map_err(|_| Error::bad_database("invalid m.room.power_levels event")) }) .transpose()? diff --git a/src/database/rooms.rs b/src/database/rooms.rs index f19d4b9..e23b804 100644 --- a/src/database/rooms.rs +++ b/src/database/rooms.rs @@ -5,6 +5,7 @@ use member::MembershipState; use crate::{pdu::PduBuilder, utils, Database, Error, PduEvent, Result}; use log::{debug, error, warn}; +use lru_cache::LruCache; use regex::Regex; use ring::digest; use ruma::{ @@ -23,7 +24,7 @@ use std::{ collections::{BTreeMap, HashMap, HashSet}, convert::{TryFrom, TryInto}, mem, - sync::Arc, + sync::{Arc, RwLock}, }; use super::{abstraction::Tree, admin::AdminCommand, pusher}; @@ -81,6 +82,8 @@ pub struct Rooms { /// RoomId + EventId -> Parent PDU EventId. pub(super) prevevent_parent: Arc, + + pub(super) pdu_cache: RwLock>>, } impl Rooms { @@ -105,8 +108,8 @@ impl Rooms { pub fn state_full( &self, shortstatehash: u64, - ) -> Result> { - Ok(self + ) -> Result>> { + let state = self .stateid_shorteventid .scan_prefix(shortstatehash.to_be_bytes().to_vec()) .map(|(_, bytes)| self.shorteventid_eventid.get(&bytes).ok().flatten()) @@ -133,7 +136,9 @@ impl Rooms { )) }) .filter_map(|r| r.ok()) - .collect()) + .collect(); + + Ok(state) } /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). @@ -179,7 +184,7 @@ impl Rooms { shortstatehash: u64, event_type: &EventType, state_key: &str, - ) -> Result> { + ) -> Result>> { self.state_get_id(shortstatehash, event_type, state_key)? .map_or(Ok(None), |event_id| self.get_pdu(&event_id)) } @@ -234,7 +239,7 @@ impl Rooms { let mut events = StateMap::new(); for (event_type, state_key) in auth_events { if let Some(pdu) = self.room_state_get(room_id, &event_type, &state_key)? { - events.insert((event_type, state_key), Arc::new(pdu)); + events.insert((event_type, state_key), pdu); } else { // This is okay because when creating a new room some events were not created yet debug!( @@ -396,7 +401,7 @@ impl Rooms { pub fn room_state_full( &self, room_id: &RoomId, - ) -> Result> { + ) -> Result>> { if let Some(current_shortstatehash) = self.current_shortstatehash(room_id)? { self.state_full(current_shortstatehash) } else { @@ -426,7 +431,7 @@ impl Rooms { room_id: &RoomId, event_type: &EventType, state_key: &str, - ) -> Result> { + ) -> Result>> { if let Some(current_shortstatehash) = self.current_shortstatehash(room_id)? { self.state_get(current_shortstatehash, event_type, state_key) } else { @@ -514,21 +519,42 @@ impl Rooms { /// Returns the pdu. /// /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. - pub fn get_pdu(&self, event_id: &EventId) -> Result> { - self.eventid_pduid + pub fn get_pdu(&self, event_id: &EventId) -> Result>> { + if let Some(p) = self.pdu_cache.write().unwrap().get_mut(&event_id) { + return Ok(Some(Arc::clone(p))); + } + + if let Some(pdu) = self + .eventid_pduid .get(event_id.as_bytes())? .map_or_else::, _, _>( - || self.eventid_outlierpdu.get(event_id.as_bytes()), + || { + let r = self.eventid_outlierpdu.get(event_id.as_bytes()); + r + }, |pduid| { - Ok(Some(self.pduid_pdu.get(&pduid)?.ok_or_else(|| { + let r = Ok(Some(self.pduid_pdu.get(&pduid)?.ok_or_else(|| { Error::bad_database("Invalid pduid in eventid_pduid.") - })?)) + })?)); + r }, )? .map(|pdu| { - serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) + let r = serde_json::from_slice(&pdu) + .map_err(|_| Error::bad_database("Invalid PDU in db.")) + .map(Arc::new); + r }) - .transpose() + .transpose()? + { + self.pdu_cache + .write() + .unwrap() + .insert(event_id.clone(), Arc::clone(&pdu)); + Ok(Some(pdu)) + } else { + Ok(None) + } } /// Returns the pdu. @@ -663,7 +689,7 @@ impl Rooms { unsigned.insert( "prev_content".to_owned(), CanonicalJsonValue::Object( - utils::to_canonical_object(prev_state.content) + utils::to_canonical_object(prev_state.content.clone()) .expect("event is valid, we just created it"), ), ); @@ -1204,7 +1230,7 @@ impl Rooms { let create_prev_event = if prev_events.len() == 1 && Some(&prev_events[0]) == create_event.as_ref().map(|c| &c.event_id) { - create_event.map(Arc::new) + create_event } else { None }; @@ -1235,10 +1261,10 @@ impl Rooms { let mut unsigned = unsigned.unwrap_or_default(); if let Some(state_key) = &state_key { if let Some(prev_pdu) = self.room_state_get(&room_id, &event_type, &state_key)? { - unsigned.insert("prev_content".to_owned(), prev_pdu.content); + unsigned.insert("prev_content".to_owned(), prev_pdu.content.clone()); unsigned.insert( "prev_sender".to_owned(), - serde_json::to_value(prev_pdu.sender).expect("UserId::to_value always works"), + serde_json::to_value(&prev_pdu.sender).expect("UserId::to_value always works"), ); } } @@ -1583,7 +1609,7 @@ impl Rooms { .and_then(|create| { serde_json::from_value::< Raw, - >(create.content) + >(create.content.clone()) .expect("Raw::from_value always works") .deserialize() .ok() @@ -1764,7 +1790,8 @@ impl Rooms { ErrorKind::BadState, "Cannot leave a room you are not a member of.", ))? - .content, + .content + .clone(), ) .expect("from_value::> can never fail") .deserialize() diff --git a/src/error.rs b/src/error.rs index 4f363ff..501c77d 100644 --- a/src/error.rs +++ b/src/error.rs @@ -61,7 +61,6 @@ pub enum Error { BadDatabase(&'static str), #[error("uiaa")] Uiaa(UiaaInfo), - #[error("{0}: {1}")] BadRequest(ErrorKind, &'static str), #[error("{0}")] @@ -80,19 +79,16 @@ impl Error { } } -#[cfg(feature = "conduit_bin")] -impl<'r, 'o> Responder<'r, 'o> for Error -where - 'o: 'r, -{ - fn respond_to(self, r: &'r Request<'_>) -> response::Result<'o> { +impl Error { + pub fn to_response(&self) -> RumaResponse { if let Self::Uiaa(uiaainfo) = self { - return RumaResponse::from(UiaaResponse::AuthResponse(uiaainfo)).respond_to(r); + return RumaResponse(UiaaResponse::AuthResponse(uiaainfo.clone())); } - if let Self::FederationError(origin, mut error) = self { + if let Self::FederationError(origin, error) = self { + let mut error = error.clone(); error.message = format!("Answer from {}: {}", origin, error.message); - return RumaResponse::from(error).respond_to(r); + return RumaResponse(UiaaResponse::MatrixError(error)); } let message = format!("{}", self); @@ -119,11 +115,20 @@ where warn!("{}: {}", status_code, message); - RumaResponse::from(RumaError { + RumaResponse(UiaaResponse::MatrixError(RumaError { kind, message, status_code, - }) - .respond_to(r) + })) + } +} + +#[cfg(feature = "conduit_bin")] +impl<'r, 'o> Responder<'r, 'o> for Error +where + 'o: 'r, +{ + fn respond_to(self, r: &'r Request<'_>) -> response::Result<'o> { + self.to_response().respond_to(r) } } diff --git a/src/ruma_wrapper.rs b/src/ruma_wrapper.rs index 2912a57..8c22f79 100644 --- a/src/ruma_wrapper.rs +++ b/src/ruma_wrapper.rs @@ -1,6 +1,6 @@ use crate::Error; use ruma::{ - api::OutgoingResponse, + api::{client::r0::uiaa::UiaaResponse, OutgoingResponse}, identifiers::{DeviceId, UserId}, signatures::CanonicalJsonValue, Outgoing, ServerName, @@ -335,49 +335,60 @@ impl Deref for Ruma { /// This struct converts ruma responses into rocket http responses. pub type ConduitResult = std::result::Result, Error>; -pub struct RumaResponse(pub T); +pub fn response(response: RumaResponse) -> response::Result<'static> { + let http_response = response + .0 + .try_into_http_response::>() + .map_err(|_| Status::InternalServerError)?; -impl From for RumaResponse { + let mut response = rocket::response::Response::build(); + + let status = http_response.status(); + response.raw_status(status.into(), ""); + + for header in http_response.headers() { + response.raw_header(header.0.to_string(), header.1.to_str().unwrap().to_owned()); + } + + let http_body = http_response.into_body(); + + response.sized_body(http_body.len(), Cursor::new(http_body)); + + response.raw_header("Access-Control-Allow-Origin", "*"); + response.raw_header( + "Access-Control-Allow-Methods", + "GET, POST, PUT, DELETE, OPTIONS", + ); + response.raw_header( + "Access-Control-Allow-Headers", + "Origin, X-Requested-With, Content-Type, Accept, Authorization", + ); + response.raw_header("Access-Control-Max-Age", "86400"); + response.ok() +} + +#[derive(Clone)] +pub struct RumaResponse(pub T); + +impl From for RumaResponse { fn from(t: T) -> Self { Self(t) } } +impl From for RumaResponse { + fn from(t: Error) -> Self { + t.to_response() + } +} + #[cfg(feature = "conduit_bin")] impl<'r, 'o, T> Responder<'r, 'o> for RumaResponse where - T: Send + OutgoingResponse, 'o: 'r, + T: OutgoingResponse, { fn respond_to(self, _: &'r Request<'_>) -> response::Result<'o> { - let http_response = self - .0 - .try_into_http_response::>() - .map_err(|_| Status::InternalServerError)?; - - let mut response = rocket::response::Response::build(); - - let status = http_response.status(); - response.raw_status(status.into(), ""); - - for header in http_response.headers() { - response.raw_header(header.0.to_string(), header.1.to_str().unwrap().to_owned()); - } - - let http_body = http_response.into_body(); - - response.sized_body(http_body.len(), Cursor::new(http_body)); - - response.raw_header("Access-Control-Allow-Origin", "*"); - response.raw_header( - "Access-Control-Allow-Methods", - "GET, POST, PUT, DELETE, OPTIONS", - ); - response.raw_header( - "Access-Control-Allow-Headers", - "Origin, X-Requested-With, Content-Type, Accept, Authorization", - ); - response.raw_header("Access-Control-Max-Age", "86400"); - response.ok() + response(self) } } diff --git a/src/server_server.rs b/src/server_server.rs index 961cc9d..a9d8b8c 100644 --- a/src/server_server.rs +++ b/src/server_server.rs @@ -966,7 +966,7 @@ pub fn handle_incoming_pdu<'a>( auth_cache .get(&incoming_pdu.auth_events[0]) .cloned() - .filter(|maybe_create| **maybe_create == create_event) + .filter(|maybe_create| **maybe_create == *create_event) } else { None }; @@ -1181,15 +1181,12 @@ pub fn handle_incoming_pdu<'a>( let mut leaf_state = db .rooms .state_full(pdu_shortstatehash) - .map_err(|_| "Failed to ask db for room state.".to_owned())? - .into_iter() - .map(|(k, v)| (k, Arc::new(v))) - .collect::>(); + .map_err(|_| "Failed to ask db for room state.".to_owned())?; if let Some(state_key) = &leaf_pdu.state_key { // Now it's the state after let key = (leaf_pdu.kind.clone(), state_key.clone()); - leaf_state.insert(key, Arc::new(leaf_pdu)); + leaf_state.insert(key, leaf_pdu); } fork_states.insert(leaf_state); @@ -1209,10 +1206,7 @@ pub fn handle_incoming_pdu<'a>( let current_state = db .rooms .room_state_full(&room_id) - .map_err(|_| "Failed to load room state.".to_owned())? - .into_iter() - .map(|(k, v)| (k, Arc::new(v))) - .collect::>(); + .map_err(|_| "Failed to load room state.".to_owned())?; fork_states.insert(current_state.clone()); @@ -1424,7 +1418,7 @@ pub(crate) fn fetch_and_handle_events<'a>( auth_cache, ) .await?; - Arc::new(pdu) + pdu } None => { // d. Ask origin server over federation @@ -1838,7 +1832,7 @@ pub fn get_event_authorization_route( .difference(&auth_chain_ids) .cloned(), ); - auth_chain_ids.extend(pdu.auth_events.into_iter()); + auth_chain_ids.extend(pdu.auth_events.clone().into_iter()); let pdu_json = PduEvent::convert_to_outgoing_federation_event( db.rooms.get_pdu_json(&event_id)?.unwrap(), @@ -1901,7 +1895,7 @@ pub fn get_room_state_route( .difference(&auth_chain_ids) .cloned(), ); - auth_chain_ids.extend(pdu.auth_events.into_iter()); + auth_chain_ids.extend(pdu.auth_events.clone().into_iter()); let pdu_json = PduEvent::convert_to_outgoing_federation_event( db.rooms.get_pdu_json(&event_id)?.unwrap(), @@ -1954,7 +1948,7 @@ pub fn get_room_state_ids_route( .difference(&auth_chain_ids) .cloned(), ); - auth_chain_ids.extend(pdu.auth_events.into_iter()); + auth_chain_ids.extend(pdu.auth_events.clone().into_iter()); } else { warn!("Could not find pdu mentioned in auth events."); } @@ -2022,7 +2016,7 @@ pub fn create_join_event_template_route( let create_prev_event = if prev_events.len() == 1 && Some(&prev_events[0]) == create_event.as_ref().map(|c| &c.event_id) { - create_event.map(Arc::new) + create_event } else { None }; @@ -2066,10 +2060,10 @@ pub fn create_join_event_template_route( let mut unsigned = BTreeMap::new(); if let Some(prev_pdu) = db.rooms.room_state_get(&body.room_id, &kind, &state_key)? { - unsigned.insert("prev_content".to_owned(), prev_pdu.content); + unsigned.insert("prev_content".to_owned(), prev_pdu.content.clone()); unsigned.insert( "prev_sender".to_owned(), - serde_json::to_value(prev_pdu.sender).expect("UserId::to_value always works"), + serde_json::to_value(&prev_pdu.sender).expect("UserId::to_value always works"), ); } @@ -2220,7 +2214,7 @@ pub async fn create_join_event_route( .difference(&auth_chain_ids) .cloned(), ); - auth_chain_ids.extend(pdu.auth_events.into_iter()); + auth_chain_ids.extend(pdu.auth_events.clone().into_iter()); } else { warn!("Could not find pdu mentioned in auth events."); }