diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 0000000..7d2cf54 --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1 @@ +merge_imports = true diff --git a/src/data.rs b/src/data.rs index 0fa24d4..b7b9845 100644 --- a/src/data.rs +++ b/src/data.rs @@ -1,134 +1,115 @@ -use crate::utils; -use directories::ProjectDirs; -use log::debug; +use crate::{utils, Database}; use ruma_events::collections::all::Event; use ruma_identifiers::{EventId, RoomId, UserId}; use std::convert::TryInto; -const USERID_PASSWORD: &str = "userid_password"; -const USERID_DEVICEIDS: &str = "userid_deviceids"; -const DEVICEID_TOKEN: &str = "deviceid_token"; -const TOKEN_USERID: &str = "token_userid"; - -pub struct Data(sled::Db); +pub struct Data { + hostname: String, + db: Database, +} impl Data { /// Load an existing database or create a new one. - pub fn load_or_create() -> Self { - Data( - sled::open( - ProjectDirs::from("xyz", "koesters", "matrixserver") - .unwrap() - .data_dir(), - ) - .unwrap(), - ) - } - - /// Set the hostname of the server. Warning: Hostname changes will likely break things. - pub fn set_hostname(&self, hostname: &str) { - self.0.insert("hostname", hostname).unwrap(); + pub fn load_or_create(hostname: &str) -> Self { + Self { + hostname: hostname.to_owned(), + db: Database::load_or_create(hostname), + } } /// Get the hostname of the server. - pub fn hostname(&self) -> String { - utils::bytes_to_string(&self.0.get("hostname").unwrap().unwrap()) + pub fn hostname(&self) -> &str { + &self.hostname } /// Check if a user has an account by looking for an assigned password. pub fn user_exists(&self, user_id: &UserId) -> bool { - self.0 - .open_tree(USERID_PASSWORD) - .unwrap() + self.db + .userid_password .contains_key(user_id.to_string()) .unwrap() } /// Create a new user account by assigning them a password. pub fn user_add(&self, user_id: &UserId, password: Option) { - self.0 - .open_tree(USERID_PASSWORD) - .unwrap() + self.db + .userid_password .insert(user_id.to_string(), &*password.unwrap_or_default()) .unwrap(); } /// Find out which user an access token belongs to. pub fn user_from_token(&self, token: &str) -> Option { - self.0 - .open_tree(TOKEN_USERID) - .unwrap() + self.db + .token_userid .get(token) .unwrap() - .and_then(|bytes| (*utils::bytes_to_string(&bytes)).try_into().ok()) + .and_then(|bytes| (*utils::string_from_bytes(&bytes)).try_into().ok()) } /// Checks if the given password is equal to the one in the database. pub fn password_get(&self, user_id: &UserId) -> Option { - self.0 - .open_tree(USERID_PASSWORD) - .unwrap() + self.db + .userid_password .get(user_id.to_string()) .unwrap() - .map(|bytes| utils::bytes_to_string(&bytes)) + .map(|bytes| utils::string_from_bytes(&bytes)) } /// Add a new device to a user. pub fn device_add(&self, user_id: &UserId, device_id: &str) { - self.0 - .open_tree(USERID_DEVICEIDS) - .unwrap() - .insert(user_id.to_string(), device_id) - .unwrap(); + if self + .db + .userid_deviceids + .get_iter(&user_id.to_string().as_bytes()) + .filter_map(|item| item.ok()) + .map(|(_key, value)| value) + .all(|device| device != device_id) + { + self.db + .userid_deviceids + .add(user_id.to_string().as_bytes(), device_id.into()); + } } /// Replace the access token of one device. pub fn token_replace(&self, user_id: &UserId, device_id: &String, token: String) { // Make sure the device id belongs to the user debug_assert!(self - .0 - .open_tree(USERID_DEVICEIDS) - .unwrap() - .get(&user_id.to_string()) // Does the user exist? - .unwrap() - .map(|bytes| utils::bytes_to_vec(&bytes)) - .filter(|devices| devices.contains(device_id)) // Does the user have that device? - .is_some()); + .db + .userid_deviceids + .get_iter(&user_id.to_string().as_bytes()) + .filter_map(|item| item.ok()) + .map(|(_key, value)| value) + .any(|device| device == device_id.as_bytes())); // Does the user have that device? // Remove old token - if let Some(old_token) = self - .0 - .open_tree(DEVICEID_TOKEN) - .unwrap() - .get(device_id) - .unwrap() - { - self.0 - .open_tree(TOKEN_USERID) - .unwrap() - .remove(old_token) - .unwrap(); - // It will be removed from DEVICEID_TOKEN by the insert later + if let Some(old_token) = self.db.deviceid_token.get(device_id).unwrap() { + self.db.token_userid.remove(old_token).unwrap(); + // It will be removed from deviceid_token by the insert later } // Assign token to device_id - self.0 - .open_tree(DEVICEID_TOKEN) - .unwrap() - .insert(device_id, &*token) - .unwrap(); + self.db.deviceid_token.insert(device_id, &*token).unwrap(); // Assign token to user - self.0 - .open_tree(TOKEN_USERID) - .unwrap() + self.db + .token_userid .insert(token, &*user_id.to_string()) .unwrap(); } /// Create a new room event. - pub fn event_add(&self, event: &Event, room_id: &RoomId, event_id: &EventId) { - debug!("{}", serde_json::to_string(event).unwrap()); - todo!(); + pub fn event_add(&self, room_id: &RoomId, event_id: &EventId, event: &Event) { + let mut key = room_id.to_string().as_bytes().to_vec(); + key.extend_from_slice(event_id.to_string().as_bytes()); + self.db + .roomid_eventid_event + .insert(&key, &*serde_json::to_string(event).unwrap()) + .unwrap(); + } + + pub fn debug(&self) { + self.db.debug(); } } diff --git a/src/database.rs b/src/database.rs new file mode 100644 index 0000000..34ed72b --- /dev/null +++ b/src/database.rs @@ -0,0 +1,117 @@ +use crate::utils; +use directories::ProjectDirs; +use sled::IVec; + +pub struct MultiValue(sled::Tree); + +impl MultiValue { + /// Get an iterator over all values. + pub fn iter_all(&self) -> sled::Iter { + self.0.iter() + } + + /// Get an iterator over all values of this id. + pub fn get_iter(&self, id: &[u8]) -> sled::Iter { + // Data keys start with d + let mut key = vec![b'd']; + key.extend_from_slice(id.as_ref()); + key.push(0xff); // Add delimiter so we don't find usernames starting with the same id + + self.0.scan_prefix(key) + } + + /// Add another value to the id. + pub fn add(&self, id: &[u8], value: IVec) { + // The new value will need a new index. We store the last used index in 'n' + id + let mut count_key: Vec = vec![b'n']; + count_key.extend_from_slice(id.as_ref()); + + // Increment the last index and use that + let index = self + .0 + .update_and_fetch(&count_key, utils::increment) + .unwrap() + .unwrap(); + + // Data keys start with d + let mut key = vec![b'd']; + key.extend_from_slice(id.as_ref()); + key.push(0xff); + key.extend_from_slice(&index); + + self.0.insert(key, value).unwrap(); + } +} + +pub struct Database { + pub userid_password: sled::Tree, + pub userid_deviceids: MultiValue, + pub deviceid_token: sled::Tree, + pub token_userid: sled::Tree, + pub roomid_eventid_event: sled::Tree, + _db: sled::Db, +} + +impl Database { + /// Load an existing database or create a new one. + pub fn load_or_create(hostname: &str) -> Self { + let mut path = ProjectDirs::from("xyz", "koesters", "matrixserver") + .unwrap() + .data_dir() + .to_path_buf(); + path.push(hostname); + let db = sled::open(&path).unwrap(); + + Self { + userid_password: db.open_tree("userid_password").unwrap(), + userid_deviceids: MultiValue(db.open_tree("userid_deviceids").unwrap()), + deviceid_token: db.open_tree("deviceid_token").unwrap(), + token_userid: db.open_tree("token_userid").unwrap(), + roomid_eventid_event: db.open_tree("roomid_eventid_event").unwrap(), + _db: db, + } + } + + pub fn debug(&self) { + println!("# UserId -> Password:"); + for (k, v) in self.userid_password.iter().map(|r| r.unwrap()) { + println!( + "{} -> {}", + String::from_utf8_lossy(&k), + String::from_utf8_lossy(&v), + ); + } + println!("# UserId -> DeviceIds:"); + for (k, v) in self.userid_deviceids.iter_all().map(|r| r.unwrap()) { + println!( + "{} -> {}", + String::from_utf8_lossy(&k), + String::from_utf8_lossy(&v), + ); + } + println!("# DeviceId -> Token:"); + for (k, v) in self.deviceid_token.iter().map(|r| r.unwrap()) { + println!( + "{} -> {}", + String::from_utf8_lossy(&k), + String::from_utf8_lossy(&v), + ); + } + println!("# Token -> UserId:"); + for (k, v) in self.token_userid.iter().map(|r| r.unwrap()) { + println!( + "{} -> {}", + String::from_utf8_lossy(&k), + String::from_utf8_lossy(&v), + ); + } + println!("# RoomId + EventId -> Event:"); + for (k, v) in self.roomid_eventid_event.iter().map(|r| r.unwrap()) { + println!( + "{} -> {}", + String::from_utf8_lossy(&k), + String::from_utf8_lossy(&v), + ); + } + } +} diff --git a/src/main.rs b/src/main.rs index 06f7ca3..cf1f37f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,9 +1,12 @@ #![feature(proc_macro_hygiene, decl_macro)] mod data; +mod database; mod ruma_wrapper; mod utils; pub use data::Data; +pub use database::Database; + use log::debug; use rocket::{get, post, put, routes, State}; use ruma_client_api::{ @@ -14,13 +17,14 @@ use ruma_client_api::{ }, unversioned::get_supported_versions, }; -use ruma_events::collections::all::Event; -use ruma_events::room::message::MessageEvent; +use ruma_events::{collections::all::Event, room::message::MessageEvent}; use ruma_identifiers::{EventId, UserId}; use ruma_wrapper::{MatrixResult, Ruma}; use serde_json::map::Map; -use std::convert::TryFrom; -use std::{collections::HashMap, convert::TryInto}; +use std::{ + collections::HashMap, + convert::{TryFrom, TryInto}, +}; #[get("/_matrix/client/versions")] fn get_supported_versions_route() -> MatrixResult { @@ -90,7 +94,7 @@ fn register_route( MatrixResult(Ok(register::Response { access_token: token, - home_server: data.hostname(), + home_server: data.hostname().to_owned(), user_id, device_id, })) @@ -153,7 +157,7 @@ fn login_route(data: State, body: Ruma) -> MatrixResult, body: Ruma) -> MatrixResult js_int::UInt { (SystemTime::now() @@ -8,20 +11,19 @@ pub fn millis_since_unix_epoch() -> js_int::UInt { .into() } -pub fn bytes_to_string(bytes: &[u8]) -> String { - String::from_utf8(bytes.to_vec()).expect("convert bytes to string") +pub fn increment(old: Option<&[u8]>) -> Option> { + let number = match old { + Some(bytes) => { + let array: [u8; 8] = bytes.try_into().unwrap(); + let number = u64::from_be_bytes(array); + number + 1 + } + None => 0, + }; + + Some(number.to_be_bytes().to_vec()) } -pub fn vec_to_bytes(vec: Vec) -> Vec { - vec.into_iter() - .map(|string| string.into_bytes()) - .collect::>>() - .join(&0) -} - -pub fn bytes_to_vec(bytes: &[u8]) -> Vec { - bytes - .split(|&b| b == 0) - .map(|bytes_string| bytes_to_string(bytes_string)) - .collect::>() +pub fn string_from_bytes(bytes: &[u8]) -> String { + String::from_utf8(bytes.to_vec()).expect("bytes are valid utf8") }