Use sled::Tree::prefix_search for deviceids
parent
b508b4d1e7
commit
dba6c46667
|
@ -0,0 +1 @@
|
||||||
|
merge_imports = true
|
131
src/data.rs
131
src/data.rs
|
@ -1,134 +1,115 @@
|
||||||
use crate::utils;
|
use crate::{utils, Database};
|
||||||
use directories::ProjectDirs;
|
|
||||||
use log::debug;
|
|
||||||
use ruma_events::collections::all::Event;
|
use ruma_events::collections::all::Event;
|
||||||
use ruma_identifiers::{EventId, RoomId, UserId};
|
use ruma_identifiers::{EventId, RoomId, UserId};
|
||||||
use std::convert::TryInto;
|
use std::convert::TryInto;
|
||||||
|
|
||||||
const USERID_PASSWORD: &str = "userid_password";
|
pub struct Data {
|
||||||
const USERID_DEVICEIDS: &str = "userid_deviceids";
|
hostname: String,
|
||||||
const DEVICEID_TOKEN: &str = "deviceid_token";
|
db: Database,
|
||||||
const TOKEN_USERID: &str = "token_userid";
|
}
|
||||||
|
|
||||||
pub struct Data(sled::Db);
|
|
||||||
|
|
||||||
impl Data {
|
impl Data {
|
||||||
/// Load an existing database or create a new one.
|
/// Load an existing database or create a new one.
|
||||||
pub fn load_or_create() -> Self {
|
pub fn load_or_create(hostname: &str) -> Self {
|
||||||
Data(
|
Self {
|
||||||
sled::open(
|
hostname: hostname.to_owned(),
|
||||||
ProjectDirs::from("xyz", "koesters", "matrixserver")
|
db: Database::load_or_create(hostname),
|
||||||
.unwrap()
|
|
||||||
.data_dir(),
|
|
||||||
)
|
|
||||||
.unwrap(),
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Set the hostname of the server. Warning: Hostname changes will likely break things.
|
|
||||||
pub fn set_hostname(&self, hostname: &str) {
|
|
||||||
self.0.insert("hostname", hostname).unwrap();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get the hostname of the server.
|
/// Get the hostname of the server.
|
||||||
pub fn hostname(&self) -> String {
|
pub fn hostname(&self) -> &str {
|
||||||
utils::bytes_to_string(&self.0.get("hostname").unwrap().unwrap())
|
&self.hostname
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Check if a user has an account by looking for an assigned password.
|
/// Check if a user has an account by looking for an assigned password.
|
||||||
pub fn user_exists(&self, user_id: &UserId) -> bool {
|
pub fn user_exists(&self, user_id: &UserId) -> bool {
|
||||||
self.0
|
self.db
|
||||||
.open_tree(USERID_PASSWORD)
|
.userid_password
|
||||||
.unwrap()
|
|
||||||
.contains_key(user_id.to_string())
|
.contains_key(user_id.to_string())
|
||||||
.unwrap()
|
.unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create a new user account by assigning them a password.
|
/// Create a new user account by assigning them a password.
|
||||||
pub fn user_add(&self, user_id: &UserId, password: Option<String>) {
|
pub fn user_add(&self, user_id: &UserId, password: Option<String>) {
|
||||||
self.0
|
self.db
|
||||||
.open_tree(USERID_PASSWORD)
|
.userid_password
|
||||||
.unwrap()
|
|
||||||
.insert(user_id.to_string(), &*password.unwrap_or_default())
|
.insert(user_id.to_string(), &*password.unwrap_or_default())
|
||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Find out which user an access token belongs to.
|
/// Find out which user an access token belongs to.
|
||||||
pub fn user_from_token(&self, token: &str) -> Option<UserId> {
|
pub fn user_from_token(&self, token: &str) -> Option<UserId> {
|
||||||
self.0
|
self.db
|
||||||
.open_tree(TOKEN_USERID)
|
.token_userid
|
||||||
.unwrap()
|
|
||||||
.get(token)
|
.get(token)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.and_then(|bytes| (*utils::bytes_to_string(&bytes)).try_into().ok())
|
.and_then(|bytes| (*utils::string_from_bytes(&bytes)).try_into().ok())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Checks if the given password is equal to the one in the database.
|
/// Checks if the given password is equal to the one in the database.
|
||||||
pub fn password_get(&self, user_id: &UserId) -> Option<String> {
|
pub fn password_get(&self, user_id: &UserId) -> Option<String> {
|
||||||
self.0
|
self.db
|
||||||
.open_tree(USERID_PASSWORD)
|
.userid_password
|
||||||
.unwrap()
|
|
||||||
.get(user_id.to_string())
|
.get(user_id.to_string())
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.map(|bytes| utils::bytes_to_string(&bytes))
|
.map(|bytes| utils::string_from_bytes(&bytes))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Add a new device to a user.
|
/// Add a new device to a user.
|
||||||
pub fn device_add(&self, user_id: &UserId, device_id: &str) {
|
pub fn device_add(&self, user_id: &UserId, device_id: &str) {
|
||||||
self.0
|
if self
|
||||||
.open_tree(USERID_DEVICEIDS)
|
.db
|
||||||
.unwrap()
|
.userid_deviceids
|
||||||
.insert(user_id.to_string(), device_id)
|
.get_iter(&user_id.to_string().as_bytes())
|
||||||
.unwrap();
|
.filter_map(|item| item.ok())
|
||||||
|
.map(|(_key, value)| value)
|
||||||
|
.all(|device| device != device_id)
|
||||||
|
{
|
||||||
|
self.db
|
||||||
|
.userid_deviceids
|
||||||
|
.add(user_id.to_string().as_bytes(), device_id.into());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Replace the access token of one device.
|
/// Replace the access token of one device.
|
||||||
pub fn token_replace(&self, user_id: &UserId, device_id: &String, token: String) {
|
pub fn token_replace(&self, user_id: &UserId, device_id: &String, token: String) {
|
||||||
// Make sure the device id belongs to the user
|
// Make sure the device id belongs to the user
|
||||||
debug_assert!(self
|
debug_assert!(self
|
||||||
.0
|
.db
|
||||||
.open_tree(USERID_DEVICEIDS)
|
.userid_deviceids
|
||||||
.unwrap()
|
.get_iter(&user_id.to_string().as_bytes())
|
||||||
.get(&user_id.to_string()) // Does the user exist?
|
.filter_map(|item| item.ok())
|
||||||
.unwrap()
|
.map(|(_key, value)| value)
|
||||||
.map(|bytes| utils::bytes_to_vec(&bytes))
|
.any(|device| device == device_id.as_bytes())); // Does the user have that device?
|
||||||
.filter(|devices| devices.contains(device_id)) // Does the user have that device?
|
|
||||||
.is_some());
|
|
||||||
|
|
||||||
// Remove old token
|
// Remove old token
|
||||||
if let Some(old_token) = self
|
if let Some(old_token) = self.db.deviceid_token.get(device_id).unwrap() {
|
||||||
.0
|
self.db.token_userid.remove(old_token).unwrap();
|
||||||
.open_tree(DEVICEID_TOKEN)
|
// It will be removed from deviceid_token by the insert later
|
||||||
.unwrap()
|
|
||||||
.get(device_id)
|
|
||||||
.unwrap()
|
|
||||||
{
|
|
||||||
self.0
|
|
||||||
.open_tree(TOKEN_USERID)
|
|
||||||
.unwrap()
|
|
||||||
.remove(old_token)
|
|
||||||
.unwrap();
|
|
||||||
// It will be removed from DEVICEID_TOKEN by the insert later
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Assign token to device_id
|
// Assign token to device_id
|
||||||
self.0
|
self.db.deviceid_token.insert(device_id, &*token).unwrap();
|
||||||
.open_tree(DEVICEID_TOKEN)
|
|
||||||
.unwrap()
|
|
||||||
.insert(device_id, &*token)
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
// Assign token to user
|
// Assign token to user
|
||||||
self.0
|
self.db
|
||||||
.open_tree(TOKEN_USERID)
|
.token_userid
|
||||||
.unwrap()
|
|
||||||
.insert(token, &*user_id.to_string())
|
.insert(token, &*user_id.to_string())
|
||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create a new room event.
|
/// Create a new room event.
|
||||||
pub fn event_add(&self, event: &Event, room_id: &RoomId, event_id: &EventId) {
|
pub fn event_add(&self, room_id: &RoomId, event_id: &EventId, event: &Event) {
|
||||||
debug!("{}", serde_json::to_string(event).unwrap());
|
let mut key = room_id.to_string().as_bytes().to_vec();
|
||||||
todo!();
|
key.extend_from_slice(event_id.to_string().as_bytes());
|
||||||
|
self.db
|
||||||
|
.roomid_eventid_event
|
||||||
|
.insert(&key, &*serde_json::to_string(event).unwrap())
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn debug(&self) {
|
||||||
|
self.db.debug();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,117 @@
|
||||||
|
use crate::utils;
|
||||||
|
use directories::ProjectDirs;
|
||||||
|
use sled::IVec;
|
||||||
|
|
||||||
|
pub struct MultiValue(sled::Tree);
|
||||||
|
|
||||||
|
impl MultiValue {
|
||||||
|
/// Get an iterator over all values.
|
||||||
|
pub fn iter_all(&self) -> sled::Iter {
|
||||||
|
self.0.iter()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get an iterator over all values of this id.
|
||||||
|
pub fn get_iter(&self, id: &[u8]) -> sled::Iter {
|
||||||
|
// Data keys start with d
|
||||||
|
let mut key = vec![b'd'];
|
||||||
|
key.extend_from_slice(id.as_ref());
|
||||||
|
key.push(0xff); // Add delimiter so we don't find usernames starting with the same id
|
||||||
|
|
||||||
|
self.0.scan_prefix(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Add another value to the id.
|
||||||
|
pub fn add(&self, id: &[u8], value: IVec) {
|
||||||
|
// The new value will need a new index. We store the last used index in 'n' + id
|
||||||
|
let mut count_key: Vec<u8> = vec![b'n'];
|
||||||
|
count_key.extend_from_slice(id.as_ref());
|
||||||
|
|
||||||
|
// Increment the last index and use that
|
||||||
|
let index = self
|
||||||
|
.0
|
||||||
|
.update_and_fetch(&count_key, utils::increment)
|
||||||
|
.unwrap()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Data keys start with d
|
||||||
|
let mut key = vec![b'd'];
|
||||||
|
key.extend_from_slice(id.as_ref());
|
||||||
|
key.push(0xff);
|
||||||
|
key.extend_from_slice(&index);
|
||||||
|
|
||||||
|
self.0.insert(key, value).unwrap();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Database {
|
||||||
|
pub userid_password: sled::Tree,
|
||||||
|
pub userid_deviceids: MultiValue,
|
||||||
|
pub deviceid_token: sled::Tree,
|
||||||
|
pub token_userid: sled::Tree,
|
||||||
|
pub roomid_eventid_event: sled::Tree,
|
||||||
|
_db: sled::Db,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Database {
|
||||||
|
/// Load an existing database or create a new one.
|
||||||
|
pub fn load_or_create(hostname: &str) -> Self {
|
||||||
|
let mut path = ProjectDirs::from("xyz", "koesters", "matrixserver")
|
||||||
|
.unwrap()
|
||||||
|
.data_dir()
|
||||||
|
.to_path_buf();
|
||||||
|
path.push(hostname);
|
||||||
|
let db = sled::open(&path).unwrap();
|
||||||
|
|
||||||
|
Self {
|
||||||
|
userid_password: db.open_tree("userid_password").unwrap(),
|
||||||
|
userid_deviceids: MultiValue(db.open_tree("userid_deviceids").unwrap()),
|
||||||
|
deviceid_token: db.open_tree("deviceid_token").unwrap(),
|
||||||
|
token_userid: db.open_tree("token_userid").unwrap(),
|
||||||
|
roomid_eventid_event: db.open_tree("roomid_eventid_event").unwrap(),
|
||||||
|
_db: db,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn debug(&self) {
|
||||||
|
println!("# UserId -> Password:");
|
||||||
|
for (k, v) in self.userid_password.iter().map(|r| r.unwrap()) {
|
||||||
|
println!(
|
||||||
|
"{} -> {}",
|
||||||
|
String::from_utf8_lossy(&k),
|
||||||
|
String::from_utf8_lossy(&v),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
println!("# UserId -> DeviceIds:");
|
||||||
|
for (k, v) in self.userid_deviceids.iter_all().map(|r| r.unwrap()) {
|
||||||
|
println!(
|
||||||
|
"{} -> {}",
|
||||||
|
String::from_utf8_lossy(&k),
|
||||||
|
String::from_utf8_lossy(&v),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
println!("# DeviceId -> Token:");
|
||||||
|
for (k, v) in self.deviceid_token.iter().map(|r| r.unwrap()) {
|
||||||
|
println!(
|
||||||
|
"{} -> {}",
|
||||||
|
String::from_utf8_lossy(&k),
|
||||||
|
String::from_utf8_lossy(&v),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
println!("# Token -> UserId:");
|
||||||
|
for (k, v) in self.token_userid.iter().map(|r| r.unwrap()) {
|
||||||
|
println!(
|
||||||
|
"{} -> {}",
|
||||||
|
String::from_utf8_lossy(&k),
|
||||||
|
String::from_utf8_lossy(&v),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
println!("# RoomId + EventId -> Event:");
|
||||||
|
for (k, v) in self.roomid_eventid_event.iter().map(|r| r.unwrap()) {
|
||||||
|
println!(
|
||||||
|
"{} -> {}",
|
||||||
|
String::from_utf8_lossy(&k),
|
||||||
|
String::from_utf8_lossy(&v),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
26
src/main.rs
26
src/main.rs
|
@ -1,9 +1,12 @@
|
||||||
#![feature(proc_macro_hygiene, decl_macro)]
|
#![feature(proc_macro_hygiene, decl_macro)]
|
||||||
mod data;
|
mod data;
|
||||||
|
mod database;
|
||||||
mod ruma_wrapper;
|
mod ruma_wrapper;
|
||||||
mod utils;
|
mod utils;
|
||||||
|
|
||||||
pub use data::Data;
|
pub use data::Data;
|
||||||
|
pub use database::Database;
|
||||||
|
|
||||||
use log::debug;
|
use log::debug;
|
||||||
use rocket::{get, post, put, routes, State};
|
use rocket::{get, post, put, routes, State};
|
||||||
use ruma_client_api::{
|
use ruma_client_api::{
|
||||||
|
@ -14,13 +17,14 @@ use ruma_client_api::{
|
||||||
},
|
},
|
||||||
unversioned::get_supported_versions,
|
unversioned::get_supported_versions,
|
||||||
};
|
};
|
||||||
use ruma_events::collections::all::Event;
|
use ruma_events::{collections::all::Event, room::message::MessageEvent};
|
||||||
use ruma_events::room::message::MessageEvent;
|
|
||||||
use ruma_identifiers::{EventId, UserId};
|
use ruma_identifiers::{EventId, UserId};
|
||||||
use ruma_wrapper::{MatrixResult, Ruma};
|
use ruma_wrapper::{MatrixResult, Ruma};
|
||||||
use serde_json::map::Map;
|
use serde_json::map::Map;
|
||||||
use std::convert::TryFrom;
|
use std::{
|
||||||
use std::{collections::HashMap, convert::TryInto};
|
collections::HashMap,
|
||||||
|
convert::{TryFrom, TryInto},
|
||||||
|
};
|
||||||
|
|
||||||
#[get("/_matrix/client/versions")]
|
#[get("/_matrix/client/versions")]
|
||||||
fn get_supported_versions_route() -> MatrixResult<get_supported_versions::Response> {
|
fn get_supported_versions_route() -> MatrixResult<get_supported_versions::Response> {
|
||||||
|
@ -90,7 +94,7 @@ fn register_route(
|
||||||
|
|
||||||
MatrixResult(Ok(register::Response {
|
MatrixResult(Ok(register::Response {
|
||||||
access_token: token,
|
access_token: token,
|
||||||
home_server: data.hostname(),
|
home_server: data.hostname().to_owned(),
|
||||||
user_id,
|
user_id,
|
||||||
device_id,
|
device_id,
|
||||||
}))
|
}))
|
||||||
|
@ -153,7 +157,7 @@ fn login_route(data: State<Data>, body: Ruma<login::Request>) -> MatrixResult<lo
|
||||||
.clone()
|
.clone()
|
||||||
.unwrap_or("TODO:randomdeviceid".to_owned());
|
.unwrap_or("TODO:randomdeviceid".to_owned());
|
||||||
|
|
||||||
// Add device (TODO: We might not want to call it when using an existing device)
|
// Add device
|
||||||
data.device_add(&user_id, &device_id);
|
data.device_add(&user_id, &device_id);
|
||||||
|
|
||||||
// Generate a new token for the device
|
// Generate a new token for the device
|
||||||
|
@ -163,7 +167,7 @@ fn login_route(data: State<Data>, body: Ruma<login::Request>) -> MatrixResult<lo
|
||||||
return MatrixResult(Ok(login::Response {
|
return MatrixResult(Ok(login::Response {
|
||||||
user_id,
|
user_id,
|
||||||
access_token: token,
|
access_token: token,
|
||||||
home_server: Some(data.hostname()),
|
home_server: Some(data.hostname().to_owned()),
|
||||||
device_id,
|
device_id,
|
||||||
well_known: None,
|
well_known: None,
|
||||||
}));
|
}));
|
||||||
|
@ -217,6 +221,8 @@ fn create_message_event_route(
|
||||||
// Generate event id
|
// Generate event id
|
||||||
let event_id = EventId::try_from("$TODOrandomeventid:localhost").unwrap();
|
let event_id = EventId::try_from("$TODOrandomeventid:localhost").unwrap();
|
||||||
data.event_add(
|
data.event_add(
|
||||||
|
&body.room_id,
|
||||||
|
&event_id,
|
||||||
&Event::RoomMessage(MessageEvent {
|
&Event::RoomMessage(MessageEvent {
|
||||||
content: body.data.clone().into_result().unwrap(),
|
content: body.data.clone().into_result().unwrap(),
|
||||||
event_id: event_id.clone(),
|
event_id: event_id.clone(),
|
||||||
|
@ -225,8 +231,6 @@ fn create_message_event_route(
|
||||||
sender: body.user_id.clone().expect("user is authenticated"),
|
sender: body.user_id.clone().expect("user is authenticated"),
|
||||||
unsigned: Map::default(),
|
unsigned: Map::default(),
|
||||||
}),
|
}),
|
||||||
&body.room_id,
|
|
||||||
&event_id,
|
|
||||||
);
|
);
|
||||||
|
|
||||||
MatrixResult(Ok(create_message_event::Response { event_id }))
|
MatrixResult(Ok(create_message_event::Response { event_id }))
|
||||||
|
@ -239,8 +243,8 @@ fn main() {
|
||||||
}
|
}
|
||||||
pretty_env_logger::init();
|
pretty_env_logger::init();
|
||||||
|
|
||||||
let data = Data::load_or_create();
|
let data = Data::load_or_create("localhost");
|
||||||
data.set_hostname("localhost");
|
data.debug();
|
||||||
|
|
||||||
rocket::ignite()
|
rocket::ignite()
|
||||||
.mount(
|
.mount(
|
||||||
|
|
32
src/utils.rs
32
src/utils.rs
|
@ -1,4 +1,7 @@
|
||||||
use std::time::{SystemTime, UNIX_EPOCH};
|
use std::{
|
||||||
|
convert::TryInto,
|
||||||
|
time::{SystemTime, UNIX_EPOCH},
|
||||||
|
};
|
||||||
|
|
||||||
pub fn millis_since_unix_epoch() -> js_int::UInt {
|
pub fn millis_since_unix_epoch() -> js_int::UInt {
|
||||||
(SystemTime::now()
|
(SystemTime::now()
|
||||||
|
@ -8,20 +11,19 @@ pub fn millis_since_unix_epoch() -> js_int::UInt {
|
||||||
.into()
|
.into()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn bytes_to_string(bytes: &[u8]) -> String {
|
pub fn increment(old: Option<&[u8]>) -> Option<Vec<u8>> {
|
||||||
String::from_utf8(bytes.to_vec()).expect("convert bytes to string")
|
let number = match old {
|
||||||
|
Some(bytes) => {
|
||||||
|
let array: [u8; 8] = bytes.try_into().unwrap();
|
||||||
|
let number = u64::from_be_bytes(array);
|
||||||
|
number + 1
|
||||||
|
}
|
||||||
|
None => 0,
|
||||||
|
};
|
||||||
|
|
||||||
|
Some(number.to_be_bytes().to_vec())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn vec_to_bytes(vec: Vec<String>) -> Vec<u8> {
|
pub fn string_from_bytes(bytes: &[u8]) -> String {
|
||||||
vec.into_iter()
|
String::from_utf8(bytes.to_vec()).expect("bytes are valid utf8")
|
||||||
.map(|string| string.into_bytes())
|
|
||||||
.collect::<Vec<Vec<u8>>>()
|
|
||||||
.join(&0)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn bytes_to_vec(bytes: &[u8]) -> Vec<String> {
|
|
||||||
bytes
|
|
||||||
.split(|&b| b == 0)
|
|
||||||
.map(|bytes_string| bytes_to_string(bytes_string))
|
|
||||||
.collect::<Vec<String>>()
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue