diff --git a/Cargo.toml b/Cargo.toml index eb43da5..e7ebadf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,8 +24,8 @@ ruma = { git = "https://github.com/ruma/ruma", rev = "b39537812c12caafcbf8b7bd74 # Used for long polling and federation sender, should be the same as rocket::tokio tokio = "1.2.0" # Used for storing data permanently -sled = { version = "0.34.6", features = ["compression", "no_metrics"] } -rocksdb = { version = "0.16.0", features = ["multi-threaded-cf"] } +sled = { version = "0.34.6", features = ["compression", "no_metrics"], optional = true } +rocksdb = { version = "0.16.0", features = ["multi-threaded-cf"], optional = true } #sled = { git = "https://github.com/spacejam/sled.git", rev = "e4640e0773595229f398438886f19bca6f7326a2", features = ["compression"] } # Used for the http request / response body type for Ruma endpoints used with reqwest @@ -75,7 +75,9 @@ opentelemetry-jaeger = "0.11.0" pretty_env_logger = "0.4.0" [features] -default = ["conduit_bin"] +default = ["conduit_bin", "backend_sled"] +backend_sled = ["sled"] +backend_rocksdb = ["rocksdb"] conduit_bin = [] # TODO: add rocket to this when it is optional [[bin]] diff --git a/src/client_server/membership.rs b/src/client_server/membership.rs index 92d7ace..a3f1389 100644 --- a/src/client_server/membership.rs +++ b/src/client_server/membership.rs @@ -621,7 +621,7 @@ async fn join_room_by_id_helper( &pdu, utils::to_canonical_object(&pdu).expect("Pdu is valid canonical object"), count, - pdu_id.into(), + &pdu_id, &[pdu.event_id.clone()], db, )?; diff --git a/src/database.rs b/src/database.rs index b5a25ea..e00bdcd 100644 --- a/src/database.rs +++ b/src/database.rs @@ -77,8 +77,12 @@ fn default_log() -> String { "info,state_res=warn,rocket=off,_=off,sled=off".to_owned() } +#[cfg(feature = "sled")] pub type Engine = abstraction::SledEngine; +#[cfg(feature = "rocksdb")] +pub type Engine = abstraction::RocksDbEngine; + pub struct Database { pub globals: globals::Globals, pub users: users::Users, diff --git a/src/database/abstraction.rs b/src/database/abstraction.rs index ad032fb..f81c9de 100644 --- a/src/database/abstraction.rs +++ b/src/database/abstraction.rs @@ -1,21 +1,19 @@ -use std::{ - collections::BTreeMap, - future::Future, - pin::Pin, - sync::{Arc, RwLock}, -}; - -use log::warn; -use rocksdb::{ - BoundColumnFamily, ColumnFamilyDescriptor, DBWithThreadMode, Direction, MultiThreaded, Options, -}; - use super::Config; use crate::{utils, Result}; +use log::warn; +use std::{future::Future, pin::Pin, sync::Arc}; +#[cfg(feature = "rocksdb")] +use std::{collections::BTreeMap, sync::RwLock}; + +#[cfg(feature = "sled")] pub struct SledEngine(sled::Db); +#[cfg(feature = "sled")] pub struct SledEngineTree(sled::Tree); -pub struct RocksDbEngine(rocksdb::DBWithThreadMode); + +#[cfg(feature = "rocksdb")] +pub struct RocksDbEngine(rocksdb::DBWithThreadMode); +#[cfg(feature = "rocksdb")] pub struct RocksDbEngineTree<'a> { db: Arc, name: &'a str, @@ -60,6 +58,7 @@ pub trait Tree: Send + Sync { } } +#[cfg(feature = "sled")] impl DatabaseEngine for SledEngine { fn open(config: &Config) -> Result> { Ok(Arc::new(SledEngine( @@ -76,6 +75,7 @@ impl DatabaseEngine for SledEngine { } } +#[cfg(feature = "sled")] impl Tree for SledEngineTree { fn get(&self, key: &[u8]) -> Result>> { Ok(self.0.get(key)?.map(|v| v.to_vec())) @@ -165,29 +165,42 @@ impl Tree for SledEngineTree { } } +#[cfg(feature = "rocksdb")] impl DatabaseEngine for RocksDbEngine { fn open(config: &Config) -> Result> { - let mut db_opts = Options::default(); + let mut db_opts = rocksdb::Options::default(); db_opts.create_if_missing(true); + db_opts.set_max_open_files(16); + db_opts.set_compaction_style(rocksdb::DBCompactionStyle::Level); + db_opts.set_compression_type(rocksdb::DBCompressionType::Snappy); + db_opts.set_target_file_size_base(256 << 20); + db_opts.set_write_buffer_size(256 << 20); - let cfs = DBWithThreadMode::::list_cf(&db_opts, &config.database_path) - .unwrap_or_default(); + let mut block_based_options = rocksdb::BlockBasedOptions::default(); + block_based_options.set_block_size(512 << 10); + db_opts.set_block_based_table_factory(&block_based_options); - let mut options = Options::default(); + let cfs = rocksdb::DBWithThreadMode::::list_cf( + &db_opts, + &config.database_path, + ) + .unwrap_or_default(); + + let mut options = rocksdb::Options::default(); options.set_merge_operator_associative("increment", utils::increment_rocksdb); - let db = DBWithThreadMode::::open_cf_descriptors( + let db = rocksdb::DBWithThreadMode::::open_cf_descriptors( &db_opts, &config.database_path, cfs.iter() - .map(|name| ColumnFamilyDescriptor::new(name, options.clone())), + .map(|name| rocksdb::ColumnFamilyDescriptor::new(name, options.clone())), )?; Ok(Arc::new(RocksDbEngine(db))) } fn open_tree(self: &Arc, name: &'static str) -> Result> { - let mut options = Options::default(); + let mut options = rocksdb::Options::default(); options.set_merge_operator_associative("increment", utils::increment_rocksdb); // Create if it doesn't exist @@ -201,12 +214,14 @@ impl DatabaseEngine for RocksDbEngine { } } +#[cfg(feature = "rocksdb")] impl RocksDbEngineTree<'_> { - fn cf(&self) -> BoundColumnFamily<'_> { + fn cf(&self) -> rocksdb::BoundColumnFamily<'_> { self.db.0.cf_handle(self.name).unwrap() } } +#[cfg(feature = "rocksdb")] impl Tree for RocksDbEngineTree<'_> { fn get(&self, key: &[u8]) -> Result>> { Ok(self.db.0.get_cf(self.cf(), key)?) @@ -260,15 +275,20 @@ impl Tree for RocksDbEngineTree<'_> { rocksdb::IteratorMode::From( from, if backwards { - Direction::Reverse + rocksdb::Direction::Reverse } else { - Direction::Forward + rocksdb::Direction::Forward }, ), )) } fn increment(&self, key: &[u8]) -> Result> { + let stats = rocksdb::perf::get_memory_usage_stats(Some(&[&self.db.0]), None).unwrap(); + dbg!(stats.mem_table_total); + dbg!(stats.mem_table_unflushed); + dbg!(stats.mem_table_readers_total); + dbg!(stats.cache_total); // TODO: atomic? let old = self.get(key)?; let new = utils::increment(old.as_deref()).unwrap(); @@ -285,7 +305,7 @@ impl Tree for RocksDbEngineTree<'_> { .0 .iterator_cf( self.cf(), - rocksdb::IteratorMode::From(&prefix, Direction::Forward), + rocksdb::IteratorMode::From(&prefix, rocksdb::Direction::Forward), ) .take_while(move |(k, _)| k.starts_with(&prefix)), ) diff --git a/src/database/rooms.rs b/src/database/rooms.rs index 0a8239d..736ff4d 100644 --- a/src/database/rooms.rs +++ b/src/database/rooms.rs @@ -19,8 +19,6 @@ use ruma::{ state_res::{self, Event, RoomVersion, StateMap}, uint, EventId, RoomAliasId, RoomId, RoomVersionId, ServerName, UserId, }; -use sled::IVec; - use std::{ collections::{BTreeMap, HashMap, HashSet}, convert::{TryFrom, TryInto}, @@ -34,7 +32,7 @@ use super::{abstraction::Tree, admin::AdminCommand, pusher}; /// /// This is created when a state group is added to the database by /// hashing the entire state. -pub type StateHashId = IVec; +pub type StateHashId = Vec; pub struct Rooms { pub edus: edus::RoomEdus, @@ -665,7 +663,7 @@ impl Rooms { pdu: &PduEvent, mut pdu_json: CanonicalJsonObject, count: u64, - pdu_id: IVec, + pdu_id: &[u8], leaves: &[EventId], db: &Database, ) -> Result<()> { @@ -713,14 +711,13 @@ impl Rooms { self.reset_notification_counts(&pdu.sender, &pdu.room_id)?; self.pduid_pdu.insert( - &pdu_id, + pdu_id, &serde_json::to_vec(&pdu_json).expect("CanonicalJsonObject is always a valid"), )?; // This also replaces the eventid of any outliers with the correct // pduid, removing the place holder. - self.eventid_pduid - .insert(pdu.event_id.as_bytes(), &*pdu_id)?; + self.eventid_pduid.insert(pdu.event_id.as_bytes(), pdu_id)?; // See if the event matches any known pushers for user in db @@ -1360,7 +1357,7 @@ impl Rooms { &pdu, pdu_json, count, - pdu_id.clone().into(), + &pdu_id, // Since this PDU references all pdu_leaves we can update the leaves // of the room &[pdu.event_id.clone()], diff --git a/src/database/sending.rs b/src/database/sending.rs index 77f6ed7..ecf0761 100644 --- a/src/database/sending.rs +++ b/src/database/sending.rs @@ -91,8 +91,6 @@ enum TransactionStatus { impl Sending { pub fn start_handler(&self, db: Arc, mut receiver: mpsc::UnboundedReceiver>) { - let db = db.clone(); - tokio::spawn(async move { let mut futures = FuturesUnordered::new(); diff --git a/src/error.rs b/src/error.rs index 10a48b7..4f363ff 100644 --- a/src/error.rs +++ b/src/error.rs @@ -23,11 +23,13 @@ pub type Result = std::result::Result; #[derive(Error, Debug)] pub enum Error { + #[cfg(feature = "sled")] #[error("There was a problem with the connection to the sled database.")] SledError { #[from] source: sled::Error, }, + #[cfg(feature = "rocksdb")] #[error("There was a problem with the connection to the rocksdb database: {source}")] RocksDbError { #[from] diff --git a/src/ruma_wrapper.rs b/src/ruma_wrapper.rs index ba2c37e..2912a57 100644 --- a/src/ruma_wrapper.rs +++ b/src/ruma_wrapper.rs @@ -1,14 +1,15 @@ -use crate::{Database, Error}; +use crate::Error; use ruma::{ api::OutgoingResponse, identifiers::{DeviceId, UserId}, - Outgoing, + signatures::CanonicalJsonValue, + Outgoing, ServerName, }; -use std::{ops::Deref, sync::Arc}; +use std::ops::Deref; #[cfg(feature = "conduit_bin")] use { - crate::server_server, + crate::{server_server, Database}, log::{debug, warn}, rocket::{ data::{self, ByteUnit, Data, FromData}, @@ -18,14 +19,11 @@ use { tokio::io::AsyncReadExt, Request, State, }, - ruma::{ - api::{AuthScheme, IncomingRequest}, - signatures::CanonicalJsonValue, - ServerName, - }, + ruma::api::{AuthScheme, IncomingRequest}, std::collections::BTreeMap, std::convert::TryFrom, std::io::Cursor, + std::sync::Arc, }; /// This struct converts rocket requests into ruma structs by converting them into http requests diff --git a/src/server_server.rs b/src/server_server.rs index 7a338dc..2a445c2 100644 --- a/src/server_server.rs +++ b/src/server_server.rs @@ -1681,7 +1681,7 @@ pub(crate) fn append_incoming_pdu( pdu, pdu_json, count, - pdu_id.clone().into(), + &pdu_id, &new_room_leaves.into_iter().collect::>(), &db, )?; diff --git a/src/utils.rs b/src/utils.rs index f59afb3..0c8fb5c 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,7 +1,6 @@ use argon2::{Config, Variant}; use cmp::Ordering; use rand::prelude::*; -use rocksdb::MergeOperands; use ruma::serde::{try_from_json_map, CanonicalJsonError, CanonicalJsonObject}; use std::{ cmp, @@ -16,10 +15,11 @@ pub fn millis_since_unix_epoch() -> u64 { .as_millis() as u64 } +#[cfg(feature = "rocksdb")] pub fn increment_rocksdb( _new_key: &[u8], old: Option<&[u8]>, - _operands: &mut MergeOperands, + _operands: &mut rocksdb::MergeOperands, ) -> Option> { increment(old) }