Use sled::Tree::prefix_search for deviceids
This commit is contained in:
		
							parent
							
								
									b508b4d1e7
								
							
						
					
					
						commit
						dba6c46667
					
				
					 5 changed files with 207 additions and 102 deletions
				
			
		
							
								
								
									
										1
									
								
								rustfmt.toml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								rustfmt.toml
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1 @@ | |||
| merge_imports = true | ||||
							
								
								
									
										131
									
								
								src/data.rs
									
									
									
									
									
								
							
							
						
						
									
										131
									
								
								src/data.rs
									
									
									
									
									
								
							|  | @ -1,134 +1,115 @@ | |||
| use crate::utils; | ||||
| use directories::ProjectDirs; | ||||
| use log::debug; | ||||
| use crate::{utils, Database}; | ||||
| use ruma_events::collections::all::Event; | ||||
| use ruma_identifiers::{EventId, RoomId, UserId}; | ||||
| use std::convert::TryInto; | ||||
| 
 | ||||
| const USERID_PASSWORD: &str = "userid_password"; | ||||
| const USERID_DEVICEIDS: &str = "userid_deviceids"; | ||||
| const DEVICEID_TOKEN: &str = "deviceid_token"; | ||||
| const TOKEN_USERID: &str = "token_userid"; | ||||
| 
 | ||||
| pub struct Data(sled::Db); | ||||
| pub struct Data { | ||||
|     hostname: String, | ||||
|     db: Database, | ||||
| } | ||||
| 
 | ||||
| impl Data { | ||||
|     /// Load an existing database or create a new one.
 | ||||
|     pub fn load_or_create() -> Self { | ||||
|         Data( | ||||
|             sled::open( | ||||
|                 ProjectDirs::from("xyz", "koesters", "matrixserver") | ||||
|                     .unwrap() | ||||
|                     .data_dir(), | ||||
|             ) | ||||
|             .unwrap(), | ||||
|         ) | ||||
|     pub fn load_or_create(hostname: &str) -> Self { | ||||
|         Self { | ||||
|             hostname: hostname.to_owned(), | ||||
|             db: Database::load_or_create(hostname), | ||||
|         } | ||||
| 
 | ||||
|     /// Set the hostname of the server. Warning: Hostname changes will likely break things.
 | ||||
|     pub fn set_hostname(&self, hostname: &str) { | ||||
|         self.0.insert("hostname", hostname).unwrap(); | ||||
|     } | ||||
| 
 | ||||
|     /// Get the hostname of the server.
 | ||||
|     pub fn hostname(&self) -> String { | ||||
|         utils::bytes_to_string(&self.0.get("hostname").unwrap().unwrap()) | ||||
|     pub fn hostname(&self) -> &str { | ||||
|         &self.hostname | ||||
|     } | ||||
| 
 | ||||
|     /// Check if a user has an account by looking for an assigned password.
 | ||||
|     pub fn user_exists(&self, user_id: &UserId) -> bool { | ||||
|         self.0 | ||||
|             .open_tree(USERID_PASSWORD) | ||||
|             .unwrap() | ||||
|         self.db | ||||
|             .userid_password | ||||
|             .contains_key(user_id.to_string()) | ||||
|             .unwrap() | ||||
|     } | ||||
| 
 | ||||
|     /// Create a new user account by assigning them a password.
 | ||||
|     pub fn user_add(&self, user_id: &UserId, password: Option<String>) { | ||||
|         self.0 | ||||
|             .open_tree(USERID_PASSWORD) | ||||
|             .unwrap() | ||||
|         self.db | ||||
|             .userid_password | ||||
|             .insert(user_id.to_string(), &*password.unwrap_or_default()) | ||||
|             .unwrap(); | ||||
|     } | ||||
| 
 | ||||
|     /// Find out which user an access token belongs to.
 | ||||
|     pub fn user_from_token(&self, token: &str) -> Option<UserId> { | ||||
|         self.0 | ||||
|             .open_tree(TOKEN_USERID) | ||||
|             .unwrap() | ||||
|         self.db | ||||
|             .token_userid | ||||
|             .get(token) | ||||
|             .unwrap() | ||||
|             .and_then(|bytes| (*utils::bytes_to_string(&bytes)).try_into().ok()) | ||||
|             .and_then(|bytes| (*utils::string_from_bytes(&bytes)).try_into().ok()) | ||||
|     } | ||||
| 
 | ||||
|     /// Checks if the given password is equal to the one in the database.
 | ||||
|     pub fn password_get(&self, user_id: &UserId) -> Option<String> { | ||||
|         self.0 | ||||
|             .open_tree(USERID_PASSWORD) | ||||
|             .unwrap() | ||||
|         self.db | ||||
|             .userid_password | ||||
|             .get(user_id.to_string()) | ||||
|             .unwrap() | ||||
|             .map(|bytes| utils::bytes_to_string(&bytes)) | ||||
|             .map(|bytes| utils::string_from_bytes(&bytes)) | ||||
|     } | ||||
| 
 | ||||
|     /// Add a new device to a user.
 | ||||
|     pub fn device_add(&self, user_id: &UserId, device_id: &str) { | ||||
|         self.0 | ||||
|             .open_tree(USERID_DEVICEIDS) | ||||
|             .unwrap() | ||||
|             .insert(user_id.to_string(), device_id) | ||||
|             .unwrap(); | ||||
|         if self | ||||
|             .db | ||||
|             .userid_deviceids | ||||
|             .get_iter(&user_id.to_string().as_bytes()) | ||||
|             .filter_map(|item| item.ok()) | ||||
|             .map(|(_key, value)| value) | ||||
|             .all(|device| device != device_id) | ||||
|         { | ||||
|             self.db | ||||
|                 .userid_deviceids | ||||
|                 .add(user_id.to_string().as_bytes(), device_id.into()); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     /// Replace the access token of one device.
 | ||||
|     pub fn token_replace(&self, user_id: &UserId, device_id: &String, token: String) { | ||||
|         // Make sure the device id belongs to the user
 | ||||
|         debug_assert!(self | ||||
|             .0 | ||||
|             .open_tree(USERID_DEVICEIDS) | ||||
|             .unwrap() | ||||
|             .get(&user_id.to_string()) // Does the user exist?
 | ||||
|             .unwrap() | ||||
|             .map(|bytes| utils::bytes_to_vec(&bytes)) | ||||
|             .filter(|devices| devices.contains(device_id)) // Does the user have that device?
 | ||||
|             .is_some()); | ||||
|             .db | ||||
|             .userid_deviceids | ||||
|             .get_iter(&user_id.to_string().as_bytes()) | ||||
|             .filter_map(|item| item.ok()) | ||||
|             .map(|(_key, value)| value) | ||||
|             .any(|device| device == device_id.as_bytes())); // Does the user have that device?
 | ||||
| 
 | ||||
|         // Remove old token
 | ||||
|         if let Some(old_token) = self | ||||
|             .0 | ||||
|             .open_tree(DEVICEID_TOKEN) | ||||
|             .unwrap() | ||||
|             .get(device_id) | ||||
|             .unwrap() | ||||
|         { | ||||
|             self.0 | ||||
|                 .open_tree(TOKEN_USERID) | ||||
|                 .unwrap() | ||||
|                 .remove(old_token) | ||||
|                 .unwrap(); | ||||
|             // It will be removed from DEVICEID_TOKEN by the insert later
 | ||||
|         if let Some(old_token) = self.db.deviceid_token.get(device_id).unwrap() { | ||||
|             self.db.token_userid.remove(old_token).unwrap(); | ||||
|             // It will be removed from deviceid_token by the insert later
 | ||||
|         } | ||||
| 
 | ||||
|         // Assign token to device_id
 | ||||
|         self.0 | ||||
|             .open_tree(DEVICEID_TOKEN) | ||||
|             .unwrap() | ||||
|             .insert(device_id, &*token) | ||||
|             .unwrap(); | ||||
|         self.db.deviceid_token.insert(device_id, &*token).unwrap(); | ||||
| 
 | ||||
|         // Assign token to user
 | ||||
|         self.0 | ||||
|             .open_tree(TOKEN_USERID) | ||||
|             .unwrap() | ||||
|         self.db | ||||
|             .token_userid | ||||
|             .insert(token, &*user_id.to_string()) | ||||
|             .unwrap(); | ||||
|     } | ||||
| 
 | ||||
|     /// Create a new room event.
 | ||||
|     pub fn event_add(&self, event: &Event, room_id: &RoomId, event_id: &EventId) { | ||||
|         debug!("{}", serde_json::to_string(event).unwrap()); | ||||
|         todo!(); | ||||
|     pub fn event_add(&self, room_id: &RoomId, event_id: &EventId, event: &Event) { | ||||
|         let mut key = room_id.to_string().as_bytes().to_vec(); | ||||
|         key.extend_from_slice(event_id.to_string().as_bytes()); | ||||
|         self.db | ||||
|             .roomid_eventid_event | ||||
|             .insert(&key, &*serde_json::to_string(event).unwrap()) | ||||
|             .unwrap(); | ||||
|     } | ||||
| 
 | ||||
|     pub fn debug(&self) { | ||||
|         self.db.debug(); | ||||
|     } | ||||
| } | ||||
|  |  | |||
							
								
								
									
										117
									
								
								src/database.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										117
									
								
								src/database.rs
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,117 @@ | |||
| use crate::utils; | ||||
| use directories::ProjectDirs; | ||||
| use sled::IVec; | ||||
| 
 | ||||
| pub struct MultiValue(sled::Tree); | ||||
| 
 | ||||
| impl MultiValue { | ||||
|     /// Get an iterator over all values.
 | ||||
|     pub fn iter_all(&self) -> sled::Iter { | ||||
|         self.0.iter() | ||||
|     } | ||||
| 
 | ||||
|     /// Get an iterator over all values of this id.
 | ||||
|     pub fn get_iter(&self, id: &[u8]) -> sled::Iter { | ||||
|         // Data keys start with d
 | ||||
|         let mut key = vec![b'd']; | ||||
|         key.extend_from_slice(id.as_ref()); | ||||
|         key.push(0xff); // Add delimiter so we don't find usernames starting with the same id
 | ||||
| 
 | ||||
|         self.0.scan_prefix(key) | ||||
|     } | ||||
| 
 | ||||
|     /// Add another value to the id.
 | ||||
|     pub fn add(&self, id: &[u8], value: IVec) { | ||||
|         // The new value will need a new index. We store the last used index in 'n' + id
 | ||||
|         let mut count_key: Vec<u8> = vec![b'n']; | ||||
|         count_key.extend_from_slice(id.as_ref()); | ||||
| 
 | ||||
|         // Increment the last index and use that
 | ||||
|         let index = self | ||||
|             .0 | ||||
|             .update_and_fetch(&count_key, utils::increment) | ||||
|             .unwrap() | ||||
|             .unwrap(); | ||||
| 
 | ||||
|         // Data keys start with d
 | ||||
|         let mut key = vec![b'd']; | ||||
|         key.extend_from_slice(id.as_ref()); | ||||
|         key.push(0xff); | ||||
|         key.extend_from_slice(&index); | ||||
| 
 | ||||
|         self.0.insert(key, value).unwrap(); | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| pub struct Database { | ||||
|     pub userid_password: sled::Tree, | ||||
|     pub userid_deviceids: MultiValue, | ||||
|     pub deviceid_token: sled::Tree, | ||||
|     pub token_userid: sled::Tree, | ||||
|     pub roomid_eventid_event: sled::Tree, | ||||
|     _db: sled::Db, | ||||
| } | ||||
| 
 | ||||
| impl Database { | ||||
|     /// Load an existing database or create a new one.
 | ||||
|     pub fn load_or_create(hostname: &str) -> Self { | ||||
|         let mut path = ProjectDirs::from("xyz", "koesters", "matrixserver") | ||||
|             .unwrap() | ||||
|             .data_dir() | ||||
|             .to_path_buf(); | ||||
|         path.push(hostname); | ||||
|         let db = sled::open(&path).unwrap(); | ||||
| 
 | ||||
|         Self { | ||||
|             userid_password: db.open_tree("userid_password").unwrap(), | ||||
|             userid_deviceids: MultiValue(db.open_tree("userid_deviceids").unwrap()), | ||||
|             deviceid_token: db.open_tree("deviceid_token").unwrap(), | ||||
|             token_userid: db.open_tree("token_userid").unwrap(), | ||||
|             roomid_eventid_event: db.open_tree("roomid_eventid_event").unwrap(), | ||||
|             _db: db, | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     pub fn debug(&self) { | ||||
|         println!("# UserId -> Password:"); | ||||
|         for (k, v) in self.userid_password.iter().map(|r| r.unwrap()) { | ||||
|             println!( | ||||
|                 "{} -> {}", | ||||
|                 String::from_utf8_lossy(&k), | ||||
|                 String::from_utf8_lossy(&v), | ||||
|             ); | ||||
|         } | ||||
|         println!("# UserId -> DeviceIds:"); | ||||
|         for (k, v) in self.userid_deviceids.iter_all().map(|r| r.unwrap()) { | ||||
|             println!( | ||||
|                 "{} -> {}", | ||||
|                 String::from_utf8_lossy(&k), | ||||
|                 String::from_utf8_lossy(&v), | ||||
|             ); | ||||
|         } | ||||
|         println!("# DeviceId -> Token:"); | ||||
|         for (k, v) in self.deviceid_token.iter().map(|r| r.unwrap()) { | ||||
|             println!( | ||||
|                 "{} -> {}", | ||||
|                 String::from_utf8_lossy(&k), | ||||
|                 String::from_utf8_lossy(&v), | ||||
|             ); | ||||
|         } | ||||
|         println!("# Token -> UserId:"); | ||||
|         for (k, v) in self.token_userid.iter().map(|r| r.unwrap()) { | ||||
|             println!( | ||||
|                 "{} -> {}", | ||||
|                 String::from_utf8_lossy(&k), | ||||
|                 String::from_utf8_lossy(&v), | ||||
|             ); | ||||
|         } | ||||
|         println!("# RoomId + EventId -> Event:"); | ||||
|         for (k, v) in self.roomid_eventid_event.iter().map(|r| r.unwrap()) { | ||||
|             println!( | ||||
|                 "{} -> {}", | ||||
|                 String::from_utf8_lossy(&k), | ||||
|                 String::from_utf8_lossy(&v), | ||||
|             ); | ||||
|         } | ||||
|     } | ||||
| } | ||||
							
								
								
									
										26
									
								
								src/main.rs
									
									
									
									
									
								
							
							
						
						
									
										26
									
								
								src/main.rs
									
									
									
									
									
								
							|  | @ -1,9 +1,12 @@ | |||
| #![feature(proc_macro_hygiene, decl_macro)] | ||||
| mod data; | ||||
| mod database; | ||||
| mod ruma_wrapper; | ||||
| mod utils; | ||||
| 
 | ||||
| pub use data::Data; | ||||
| pub use database::Database; | ||||
| 
 | ||||
| use log::debug; | ||||
| use rocket::{get, post, put, routes, State}; | ||||
| use ruma_client_api::{ | ||||
|  | @ -14,13 +17,14 @@ use ruma_client_api::{ | |||
|     }, | ||||
|     unversioned::get_supported_versions, | ||||
| }; | ||||
| use ruma_events::collections::all::Event; | ||||
| use ruma_events::room::message::MessageEvent; | ||||
| use ruma_events::{collections::all::Event, room::message::MessageEvent}; | ||||
| use ruma_identifiers::{EventId, UserId}; | ||||
| use ruma_wrapper::{MatrixResult, Ruma}; | ||||
| use serde_json::map::Map; | ||||
| use std::convert::TryFrom; | ||||
| use std::{collections::HashMap, convert::TryInto}; | ||||
| use std::{ | ||||
|     collections::HashMap, | ||||
|     convert::{TryFrom, TryInto}, | ||||
| }; | ||||
| 
 | ||||
| #[get("/_matrix/client/versions")] | ||||
| fn get_supported_versions_route() -> MatrixResult<get_supported_versions::Response> { | ||||
|  | @ -90,7 +94,7 @@ fn register_route( | |||
| 
 | ||||
|     MatrixResult(Ok(register::Response { | ||||
|         access_token: token, | ||||
|         home_server: data.hostname(), | ||||
|         home_server: data.hostname().to_owned(), | ||||
|         user_id, | ||||
|         device_id, | ||||
|     })) | ||||
|  | @ -153,7 +157,7 @@ fn login_route(data: State<Data>, body: Ruma<login::Request>) -> MatrixResult<lo | |||
|         .clone() | ||||
|         .unwrap_or("TODO:randomdeviceid".to_owned()); | ||||
| 
 | ||||
|     // Add device (TODO: We might not want to call it when using an existing device)
 | ||||
|     // Add device
 | ||||
|     data.device_add(&user_id, &device_id); | ||||
| 
 | ||||
|     // Generate a new token for the device
 | ||||
|  | @ -163,7 +167,7 @@ fn login_route(data: State<Data>, body: Ruma<login::Request>) -> MatrixResult<lo | |||
|     return MatrixResult(Ok(login::Response { | ||||
|         user_id, | ||||
|         access_token: token, | ||||
|         home_server: Some(data.hostname()), | ||||
|         home_server: Some(data.hostname().to_owned()), | ||||
|         device_id, | ||||
|         well_known: None, | ||||
|     })); | ||||
|  | @ -217,6 +221,8 @@ fn create_message_event_route( | |||
|     // Generate event id
 | ||||
|     let event_id = EventId::try_from("$TODOrandomeventid:localhost").unwrap(); | ||||
|     data.event_add( | ||||
|         &body.room_id, | ||||
|         &event_id, | ||||
|         &Event::RoomMessage(MessageEvent { | ||||
|             content: body.data.clone().into_result().unwrap(), | ||||
|             event_id: event_id.clone(), | ||||
|  | @ -225,8 +231,6 @@ fn create_message_event_route( | |||
|             sender: body.user_id.clone().expect("user is authenticated"), | ||||
|             unsigned: Map::default(), | ||||
|         }), | ||||
|         &body.room_id, | ||||
|         &event_id, | ||||
|     ); | ||||
| 
 | ||||
|     MatrixResult(Ok(create_message_event::Response { event_id })) | ||||
|  | @ -239,8 +243,8 @@ fn main() { | |||
|     } | ||||
|     pretty_env_logger::init(); | ||||
| 
 | ||||
|     let data = Data::load_or_create(); | ||||
|     data.set_hostname("localhost"); | ||||
|     let data = Data::load_or_create("localhost"); | ||||
|     data.debug(); | ||||
| 
 | ||||
|     rocket::ignite() | ||||
|         .mount( | ||||
|  |  | |||
							
								
								
									
										32
									
								
								src/utils.rs
									
									
									
									
									
								
							
							
						
						
									
										32
									
								
								src/utils.rs
									
									
									
									
									
								
							|  | @ -1,4 +1,7 @@ | |||
| use std::time::{SystemTime, UNIX_EPOCH}; | ||||
| use std::{ | ||||
|     convert::TryInto, | ||||
|     time::{SystemTime, UNIX_EPOCH}, | ||||
| }; | ||||
| 
 | ||||
| pub fn millis_since_unix_epoch() -> js_int::UInt { | ||||
|     (SystemTime::now() | ||||
|  | @ -8,20 +11,19 @@ pub fn millis_since_unix_epoch() -> js_int::UInt { | |||
|         .into() | ||||
| } | ||||
| 
 | ||||
| pub fn bytes_to_string(bytes: &[u8]) -> String { | ||||
|     String::from_utf8(bytes.to_vec()).expect("convert bytes to string") | ||||
| pub fn increment(old: Option<&[u8]>) -> Option<Vec<u8>> { | ||||
|     let number = match old { | ||||
|         Some(bytes) => { | ||||
|             let array: [u8; 8] = bytes.try_into().unwrap(); | ||||
|             let number = u64::from_be_bytes(array); | ||||
|             number + 1 | ||||
|         } | ||||
|         None => 0, | ||||
|     }; | ||||
| 
 | ||||
|     Some(number.to_be_bytes().to_vec()) | ||||
| } | ||||
| 
 | ||||
| pub fn vec_to_bytes(vec: Vec<String>) -> Vec<u8> { | ||||
|     vec.into_iter() | ||||
|         .map(|string| string.into_bytes()) | ||||
|         .collect::<Vec<Vec<u8>>>() | ||||
|         .join(&0) | ||||
| } | ||||
| 
 | ||||
| pub fn bytes_to_vec(bytes: &[u8]) -> Vec<String> { | ||||
|     bytes | ||||
|         .split(|&b| b == 0) | ||||
|         .map(|bytes_string| bytes_to_string(bytes_string)) | ||||
|         .collect::<Vec<String>>() | ||||
| pub fn string_from_bytes(bytes: &[u8]) -> String { | ||||
|     String::from_utf8(bytes.to_vec()).expect("bytes are valid utf8") | ||||
| } | ||||
|  |  | |||
		Loading…
	
		Reference in a new issue