use super::{DatabaseEngine, Tree}; use crate::{database::Config, Result}; use parking_lot::{Mutex, MutexGuard, RwLock}; use rusqlite::{Connection, DatabaseName::Main, OptionalExtension}; use std::{ cell::RefCell, collections::HashMap, future::Future, path::{Path, PathBuf}, pin::Pin, sync::Arc, }; use thread_local::ThreadLocal; use tokio::sync::oneshot::Sender; use tracing::debug; thread_local! { static READ_CONNECTION: RefCell> = RefCell::new(None); static READ_CONNECTION_ITERATOR: RefCell> = RefCell::new(None); } struct PreparedStatementIterator<'a> { pub iterator: Box + 'a>, pub statement_ref: NonAliasingBox>, } impl Iterator for PreparedStatementIterator<'_> { type Item = TupleOfBytes; fn next(&mut self) -> Option { self.iterator.next() } } struct NonAliasingBox(*mut T); impl Drop for NonAliasingBox { fn drop(&mut self) { unsafe { Box::from_raw(self.0) }; } } pub struct Engine { writer: Mutex, read_conn_tls: ThreadLocal, read_iterator_conn_tls: ThreadLocal, path: PathBuf, cache_size_per_thread: u32, } impl Engine { fn prepare_conn(path: &Path, cache_size_kb: u32) -> Result { let conn = Connection::open(&path)?; conn.pragma_update(Some(Main), "page_size", &2048)?; conn.pragma_update(Some(Main), "journal_mode", &"WAL")?; conn.pragma_update(Some(Main), "synchronous", &"NORMAL")?; conn.pragma_update(Some(Main), "cache_size", &(-i64::from(cache_size_kb)))?; conn.pragma_update(Some(Main), "wal_autocheckpoint", &2000)?; Ok(conn) } fn write_lock(&self) -> MutexGuard<'_, Connection> { self.writer.lock() } fn read_lock<'a>(&'a self) -> &'a Connection { self.read_conn_tls .get_or(|| Self::prepare_conn(&self.path, self.cache_size_per_thread).unwrap()) } fn read_lock_iterator<'a>(&'a self) -> &'a Connection { self.read_iterator_conn_tls .get_or(|| Self::prepare_conn(&self.path, self.cache_size_per_thread).unwrap()) } pub fn flush_wal(self: &Arc) -> Result<()> { self.write_lock() .pragma_update(Some(Main), "wal_checkpoint", &"TRUNCATE")?; Ok(()) } } impl DatabaseEngine for Engine { fn open(config: &Config) -> Result> { let path = Path::new(&config.database_path).join("conduit.db"); // calculates cache-size per permanent connection // 1. convert MB to KiB // 2. divide by permanent connections + permanent iter connections + write connection // 3. round down to nearest integer let cache_size_per_thread: u32 = ((config.db_cache_capacity_mb * 1024.0) / ((num_cpus::get().max(1) * 2) + 1) as f64) as u32; let writer = Mutex::new(Self::prepare_conn(&path, cache_size_per_thread)?); let arc = Arc::new(Engine { writer, read_conn_tls: ThreadLocal::new(), read_iterator_conn_tls: ThreadLocal::new(), path, cache_size_per_thread, }); Ok(arc) } fn open_tree(self: &Arc, name: &str) -> Result> { self.write_lock().execute(&format!("CREATE TABLE IF NOT EXISTS {} ( \"key\" BLOB PRIMARY KEY, \"value\" BLOB NOT NULL )", name), [])?; Ok(Arc::new(SqliteTable { engine: Arc::clone(self), name: name.to_owned(), watchers: RwLock::new(HashMap::new()), })) } fn flush(self: &Arc) -> Result<()> { // we enabled PRAGMA synchronous=normal, so this should not be necessary Ok(()) } } pub struct SqliteTable { engine: Arc, name: String, watchers: RwLock, Vec>>>, } type TupleOfBytes = (Vec, Vec); impl SqliteTable { #[tracing::instrument(skip(self, guard, key))] fn get_with_guard(&self, guard: &Connection, key: &[u8]) -> Result>> { //dbg!(&self.name); Ok(guard .prepare(format!("SELECT value FROM {} WHERE key = ?", self.name).as_str())? .query_row([key], |row| row.get(0)) .optional()?) } #[tracing::instrument(skip(self, guard, key, value))] fn insert_with_guard(&self, guard: &Connection, key: &[u8], value: &[u8]) -> Result<()> { //dbg!(&self.name); guard.execute( format!( "INSERT OR REPLACE INTO {} (key, value) VALUES (?, ?)", self.name ) .as_str(), [key, value], )?; Ok(()) } pub fn iter_with_guard<'a>( &'a self, guard: &'a Connection, ) -> Box + 'a> { let statement = Box::leak(Box::new( guard .prepare(&format!( "SELECT key, value FROM {} ORDER BY key ASC", &self.name )) .unwrap(), )); let statement_ref = NonAliasingBox(statement); //let name = self.name.clone(); let iterator = Box::new( statement .query_map([], |row| Ok((row.get_unwrap(0), row.get_unwrap(1)))) .unwrap() .map(move |r| { //dbg!(&name); r.unwrap() }), ); Box::new(PreparedStatementIterator { iterator, statement_ref, }) } } impl Tree for SqliteTable { #[tracing::instrument(skip(self, key))] fn get(&self, key: &[u8]) -> Result>> { self.get_with_guard(&self.engine.read_lock(), key) } #[tracing::instrument(skip(self, key, value))] fn insert(&self, key: &[u8], value: &[u8]) -> Result<()> { let guard = self.engine.write_lock(); self.insert_with_guard(&guard, key, value)?; drop(guard); let watchers = self.watchers.read(); let mut triggered = Vec::new(); for length in 0..=key.len() { if watchers.contains_key(&key[..length]) { triggered.push(&key[..length]); } } drop(watchers); if !triggered.is_empty() { let mut watchers = self.watchers.write(); for prefix in triggered { if let Some(txs) = watchers.remove(prefix) { for tx in txs { let _ = tx.send(()); } } } }; Ok(()) } #[tracing::instrument(skip(self, iter))] fn insert_batch<'a>(&self, iter: &mut dyn Iterator, Vec)>) -> Result<()> { let guard = self.engine.write_lock(); guard.execute("BEGIN", [])?; for (key, value) in iter { self.insert_with_guard(&guard, &key, &value)?; } guard.execute("COMMIT", [])?; drop(guard); Ok(()) } #[tracing::instrument(skip(self, iter))] fn increment_batch<'a>(&self, iter: &mut dyn Iterator>) -> Result<()> { let guard = self.engine.write_lock(); guard.execute("BEGIN", [])?; for key in iter { let old = self.get_with_guard(&guard, &key)?; let new = crate::utils::increment(old.as_deref()) .expect("utils::increment always returns Some"); self.insert_with_guard(&guard, &key, &new)?; } guard.execute("COMMIT", [])?; drop(guard); Ok(()) } #[tracing::instrument(skip(self, key))] fn remove(&self, key: &[u8]) -> Result<()> { let guard = self.engine.write_lock(); guard.execute( format!("DELETE FROM {} WHERE key = ?", self.name).as_str(), [key], )?; Ok(()) } #[tracing::instrument(skip(self))] fn iter<'a>(&'a self) -> Box + 'a> { let guard = self.engine.read_lock_iterator(); self.iter_with_guard(&guard) } #[tracing::instrument(skip(self, from, backwards))] fn iter_from<'a>( &'a self, from: &[u8], backwards: bool, ) -> Box + 'a> { let guard = self.engine.read_lock_iterator(); let from = from.to_vec(); // TODO change interface? //let name = self.name.clone(); if backwards { let statement = Box::leak(Box::new( guard .prepare(&format!( "SELECT key, value FROM {} WHERE key <= ? ORDER BY key DESC", &self.name )) .unwrap(), )); let statement_ref = NonAliasingBox(statement); let iterator = Box::new( statement .query_map([from], |row| Ok((row.get_unwrap(0), row.get_unwrap(1)))) .unwrap() .map(move |r| { //dbg!(&name); r.unwrap() }), ); Box::new(PreparedStatementIterator { iterator, statement_ref, }) } else { let statement = Box::leak(Box::new( guard .prepare(&format!( "SELECT key, value FROM {} WHERE key >= ? ORDER BY key ASC", &self.name )) .unwrap(), )); let statement_ref = NonAliasingBox(statement); let iterator = Box::new( statement .query_map([from], |row| Ok((row.get_unwrap(0), row.get_unwrap(1)))) .unwrap() .map(move |r| { //dbg!(&name); r.unwrap() }), ); Box::new(PreparedStatementIterator { iterator, statement_ref, }) } } #[tracing::instrument(skip(self, key))] fn increment(&self, key: &[u8]) -> Result> { let guard = self.engine.write_lock(); let old = self.get_with_guard(&guard, key)?; let new = crate::utils::increment(old.as_deref()).expect("utils::increment always returns Some"); self.insert_with_guard(&guard, key, &new)?; Ok(new) } #[tracing::instrument(skip(self, prefix))] fn scan_prefix<'a>(&'a self, prefix: Vec) -> Box + 'a> { Box::new( self.iter_from(&prefix, false) .take_while(move |(key, _)| key.starts_with(&prefix)), ) } #[tracing::instrument(skip(self, prefix))] fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin + Send + 'a>> { let (tx, rx) = tokio::sync::oneshot::channel(); self.watchers .write() .entry(prefix.to_vec()) .or_default() .push(tx); Box::pin(async move { // Tx is never destroyed rx.await.unwrap(); }) } #[tracing::instrument(skip(self))] fn clear(&self) -> Result<()> { debug!("clear: running"); self.engine .write_lock() .execute(format!("DELETE FROM {}", self.name).as_str(), [])?; debug!("clear: ran"); Ok(()) } }