diff --git a/src/client_server/directory.rs b/src/client_server/directory.rs index 7cab1a7..589aacd 100644 --- a/src/client_server/directory.rs +++ b/src/client_server/directory.rs @@ -1,3 +1,5 @@ +use std::convert::TryInto; + use crate::{database::DatabaseGuard, ConduitResult, Database, Error, Result, Ruma}; use ruma::{ api::{ @@ -21,7 +23,7 @@ use ruma::{ serde::Raw, ServerName, UInt, }; -use tracing::info; +use tracing::{info, warn}; #[cfg(feature = "conduit_bin")] use rocket::{get, post, put}; @@ -234,7 +236,15 @@ pub async fn get_public_rooms_filtered_helper( .name .map(|n| n.to_owned().into())) })?, - num_joined_members: (db.rooms.room_members(&room_id).count() as u32).into(), + num_joined_members: db + .rooms + .room_joined_count(&room_id)? + .unwrap_or_else(|| { + warn!("Room {} has no member count", room_id); + 0 + }) + .try_into() + .expect("user count should not be that big"), topic: db .rooms .room_state_get(&room_id, &EventType::RoomTopic, "")? diff --git a/src/database.rs b/src/database.rs index 4f3d332..2e7e60c 100644 --- a/src/database.rs +++ b/src/database.rs @@ -24,10 +24,11 @@ use rocket::{ request::{FromRequest, Request}, Shutdown, State, }; -use ruma::{DeviceId, ServerName, UserId}; +use ruma::{DeviceId, RoomId, ServerName, UserId}; use serde::{de::IgnoredAny, Deserialize}; use std::{ collections::{BTreeMap, HashMap}, + convert::TryFrom, fs::{self, remove_dir_all}, io::Write, ops::Deref, @@ -252,6 +253,7 @@ impl Database { serverroomids: builder.open_tree("serverroomids")?, userroomid_joined: builder.open_tree("userroomid_joined")?, roomuserid_joined: builder.open_tree("roomuserid_joined")?, + roomid_joinedcount: builder.open_tree("roomid_joinedcount")?, roomuseroncejoinedids: builder.open_tree("roomuseroncejoinedids")?, userroomid_invitestate: builder.open_tree("userroomid_invitestate")?, roomuserid_invitecount: builder.open_tree("roomuserid_invitecount")?, @@ -271,8 +273,8 @@ impl Database { eventid_outlierpdu: builder.open_tree("eventid_outlierpdu")?, referencedevents: builder.open_tree("referencedevents")?, - pdu_cache: Mutex::new(LruCache::new(1_000_000)), - auth_chain_cache: Mutex::new(LruCache::new(1_000_000)), + pdu_cache: Mutex::new(LruCache::new(0)), + auth_chain_cache: Mutex::new(LruCache::new(0)), }, account_data: account_data::AccountData { roomuserdataid_accountdata: builder.open_tree("roomuserdataid_accountdata")?, @@ -423,6 +425,20 @@ impl Database { println!("Migration: 4 -> 5 finished"); } + + if db.globals.database_version()? < 9 { // TODO update to 6 + // Set room member count + for (roomid, _) in db.rooms.roomid_shortstatehash.iter() { + let room_id = + RoomId::try_from(utils::string_from_bytes(&roomid).unwrap()).unwrap(); + + db.rooms.update_joined_count(&room_id)?; + } + + db.globals.bump_database_version(6)?; + + println!("Migration: 5 -> 6 finished"); + } } let guard = db.read().await; diff --git a/src/database/abstraction/sled.rs b/src/database/abstraction/sled.rs index d99ce26..35ba1b2 100644 --- a/src/database/abstraction/sled.rs +++ b/src/database/abstraction/sled.rs @@ -39,12 +39,21 @@ impl Tree for SledEngineTree { Ok(()) } + #[tracing::instrument(skip(self, iter))] + fn insert_batch<'a>(&self, iter: &mut dyn Iterator, Vec)>) -> Result<()> { + for (key, value) in iter { + self.0.insert(key, value)?; + } + + Ok(()) + } + fn remove(&self, key: &[u8]) -> Result<()> { self.0.remove(key)?; Ok(()) } - fn iter<'a>(&'a self) -> Box, Vec)> + Send + 'a> { + fn iter<'a>(&'a self) -> Box, Vec)> + 'a> { Box::new( self.0 .iter() @@ -62,7 +71,7 @@ impl Tree for SledEngineTree { &self, from: &[u8], backwards: bool, - ) -> Box, Vec)> + Send> { + ) -> Box, Vec)>> { let iter = if backwards { self.0.range(..=from) } else { @@ -95,7 +104,7 @@ impl Tree for SledEngineTree { fn scan_prefix<'a>( &'a self, prefix: Vec, - ) -> Box, Vec)> + Send + 'a> { + ) -> Box, Vec)> + 'a> { let iter = self .0 .scan_prefix(prefix) diff --git a/src/database/abstraction/sqlite.rs b/src/database/abstraction/sqlite.rs index 72fb5f7..0dbb261 100644 --- a/src/database/abstraction/sqlite.rs +++ b/src/database/abstraction/sqlite.rs @@ -55,7 +55,6 @@ impl Engine { conn.pragma_update(Some(Main), "journal_mode", &"WAL")?; conn.pragma_update(Some(Main), "synchronous", &"NORMAL")?; conn.pragma_update(Some(Main), "cache_size", &(-i64::from(cache_size_kb)))?; - conn.pragma_update(Some(Main), "wal_autocheckpoint", &0)?; Ok(conn) } diff --git a/src/database/rooms.rs b/src/database/rooms.rs index 549aa8c..10a6215 100644 --- a/src/database/rooms.rs +++ b/src/database/rooms.rs @@ -55,6 +55,7 @@ pub struct Rooms { pub(super) userroomid_joined: Arc, pub(super) roomuserid_joined: Arc, + pub(super) roomid_joinedcount: Arc, pub(super) roomuseroncejoinedids: Arc, pub(super) userroomid_invitestate: Arc, // InviteState = Vec> pub(super) roomuserid_invitecount: Arc, // InviteCount = Count @@ -1906,9 +1907,18 @@ impl Rooms { _ => {} } + self.update_joined_count(room_id)?; + Ok(()) } + pub fn update_joined_count(&self, room_id: &RoomId) -> Result<()> { + self.roomid_joinedcount.insert( + room_id.as_bytes(), + &(self.room_members(&room_id).count() as u64).to_be_bytes(), + ) + } + pub async fn leave_room( &self, user_id: &UserId, @@ -2370,6 +2380,17 @@ impl Rooms { }) } + pub fn room_joined_count(&self, room_id: &RoomId) -> Result> { + Ok(self + .roomid_joinedcount + .get(room_id.as_bytes())? + .map(|b| { + utils::u64_from_bytes(&b) + .map_err(|_| Error::bad_database("Invalid joinedcount in db.")) + }) + .transpose()?) + } + /// Returns an iterator over all User IDs who ever joined a room. #[tracing::instrument(skip(self))] pub fn room_useroncejoined<'a>( diff --git a/src/server_server.rs b/src/server_server.rs index 0e595d4..4255f12 100644 --- a/src/server_server.rs +++ b/src/server_server.rs @@ -1,6 +1,6 @@ use crate::{ client_server::{self, claim_keys_helper, get_keys_helper}, - database::{abstraction::sqlite::MILLI, DatabaseGuard}, + database::{DatabaseGuard}, utils, ConduitResult, Database, Error, PduEvent, Result, Ruma, }; use get_profile_information::v1::ProfileField; @@ -1736,20 +1736,11 @@ fn get_auth_chain(starting_events: Vec, db: &Database) -> Result MILLI { - println!("auth chain for {} took {:?}", &event_id, elapsed) - } - cache = db.rooms.auth_chain_cache(); - cache.insert(vec![event_id.clone()], auth_chain.clone()); - full_auth_chain.extend(auth_chain); }; - } cache.insert(starting_events, full_auth_chain.clone());