Merge branch 'rocksdb' into 'master'

Swappable database backend

See merge request famedly/conduit!98
next
Timo Kösters 2021-06-12 14:25:03 +00:00
commit 8c6bcc47bf
47 changed files with 1613 additions and 1047 deletions

106
Cargo.lock generated
View File

@ -103,6 +103,25 @@ version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "383d29d513d8764dcdc42ea295d979eb99c3c9f00607b3692cf68a431f7dca72" checksum = "383d29d513d8764dcdc42ea295d979eb99c3c9f00607b3692cf68a431f7dca72"
[[package]]
name = "bindgen"
version = "0.57.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fd4865004a46a0aafb2a0a5eb19d3c9fc46ee5f063a6cfc605c69ac9ecf5263d"
dependencies = [
"bitflags",
"cexpr",
"clang-sys",
"lazy_static",
"lazycell",
"peeking_take_while",
"proc-macro2",
"quote",
"regex",
"rustc-hash",
"shlex",
]
[[package]] [[package]]
name = "bitflags" name = "bitflags"
version = "1.2.1" version = "1.2.1"
@ -162,6 +181,15 @@ dependencies = [
"jobserver", "jobserver",
] ]
[[package]]
name = "cexpr"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f4aedb84272dbe89af497cf81375129abda4fc0a9e7c5d317498c15cc30c0d27"
dependencies = [
"nom",
]
[[package]] [[package]]
name = "cfg-if" name = "cfg-if"
version = "0.1.10" version = "0.1.10"
@ -187,6 +215,17 @@ dependencies = [
"winapi", "winapi",
] ]
[[package]]
name = "clang-sys"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "853eda514c284c2287f4bf20ae614f8781f40a81d32ecda6e91449304dfe077c"
dependencies = [
"glob",
"libc",
"libloading",
]
[[package]] [[package]]
name = "color_quant" name = "color_quant"
version = "1.1.0" version = "1.1.0"
@ -212,6 +251,7 @@ dependencies = [
"reqwest", "reqwest",
"ring", "ring",
"rocket", "rocket",
"rocksdb",
"ruma", "ruma",
"rust-argon2", "rust-argon2",
"rustls", "rustls",
@ -1008,12 +1048,40 @@ version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
[[package]]
name = "lazycell"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55"
[[package]] [[package]]
name = "libc" name = "libc"
version = "0.2.95" version = "0.2.95"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "789da6d93f1b866ffe175afc5322a4d76c038605a1c3319bb57b06967ca98a36" checksum = "789da6d93f1b866ffe175afc5322a4d76c038605a1c3319bb57b06967ca98a36"
[[package]]
name = "libloading"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6f84d96438c15fcd6c3f244c8fce01d1e2b9c6b5623e9c711dc9286d8fc92d6a"
dependencies = [
"cfg-if 1.0.0",
"winapi",
]
[[package]]
name = "librocksdb-sys"
version = "6.17.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5da125e1c0f22c7cae785982115523a0738728498547f415c9054cb17c7e89f9"
dependencies = [
"bindgen",
"cc",
"glob",
"libc",
]
[[package]] [[package]]
name = "linked-hash-map" name = "linked-hash-map"
version = "0.5.4" version = "0.5.4"
@ -1158,6 +1226,16 @@ dependencies = [
"version_check", "version_check",
] ]
[[package]]
name = "nom"
version = "5.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ffb4262d26ed83a1c0a33a38fe2bb15797329c85770da05e6b828ddb782627af"
dependencies = [
"memchr",
"version_check",
]
[[package]] [[package]]
name = "ntapi" name = "ntapi"
version = "0.3.6" version = "0.3.6"
@ -1339,6 +1417,12 @@ dependencies = [
"syn", "syn",
] ]
[[package]]
name = "peeking_take_while"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "19b17cddbe7ec3f8bc800887bab5e717348c95ea2ca0b1bf0837fb964dc67099"
[[package]] [[package]]
name = "pem" name = "pem"
version = "0.8.3" version = "0.8.3"
@ -1777,6 +1861,16 @@ dependencies = [
"uncased", "uncased",
] ]
[[package]]
name = "rocksdb"
version = "0.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c749134fda8bfc90d0de643d59bfc841dcb3ac8a1062e12b6754bd60235c48b3"
dependencies = [
"libc",
"librocksdb-sys",
]
[[package]] [[package]]
name = "ruma" name = "ruma"
version = "0.1.2" version = "0.1.2"
@ -2046,6 +2140,12 @@ dependencies = [
"crossbeam-utils", "crossbeam-utils",
] ]
[[package]]
name = "rustc-hash"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
[[package]] [[package]]
name = "rustc_version" name = "rustc_version"
version = "0.2.3" version = "0.2.3"
@ -2245,6 +2345,12 @@ dependencies = [
"lazy_static", "lazy_static",
] ]
[[package]]
name = "shlex"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7fdf1b9db47230893d76faad238fd6097fd6d6a9245cd7a4d90dbd639536bbd2"
[[package]] [[package]]
name = "signal-hook-registry" name = "signal-hook-registry"
version = "1.3.0" version = "1.3.0"

View File

@ -24,7 +24,8 @@ ruma = { git = "https://github.com/ruma/ruma", rev = "b39537812c12caafcbf8b7bd74
# Used for long polling and federation sender, should be the same as rocket::tokio # Used for long polling and federation sender, should be the same as rocket::tokio
tokio = "1.2.0" tokio = "1.2.0"
# Used for storing data permanently # Used for storing data permanently
sled = { version = "0.34.6", features = ["compression", "no_metrics"] } sled = { version = "0.34.6", features = ["compression", "no_metrics"], optional = true }
rocksdb = { version = "0.16.0", features = ["multi-threaded-cf"], optional = true }
#sled = { git = "https://github.com/spacejam/sled.git", rev = "e4640e0773595229f398438886f19bca6f7326a2", features = ["compression"] } #sled = { git = "https://github.com/spacejam/sled.git", rev = "e4640e0773595229f398438886f19bca6f7326a2", features = ["compression"] }
# Used for the http request / response body type for Ruma endpoints used with reqwest # Used for the http request / response body type for Ruma endpoints used with reqwest
@ -74,7 +75,9 @@ opentelemetry-jaeger = "0.11.0"
pretty_env_logger = "0.4.0" pretty_env_logger = "0.4.0"
[features] [features]
default = ["conduit_bin"] default = ["conduit_bin", "backend_sled"]
backend_sled = ["sled"]
backend_rocksdb = ["rocksdb"]
conduit_bin = [] # TODO: add rocket to this when it is optional conduit_bin = [] # TODO: add rocket to this when it is optional
[[bin]] [[bin]]

View File

@ -1,4 +1,4 @@
use std::{collections::BTreeMap, convert::TryInto}; use std::{collections::BTreeMap, convert::TryInto, sync::Arc};
use super::{State, DEVICE_ID_LENGTH, SESSION_ID_LENGTH, TOKEN_LENGTH}; use super::{State, DEVICE_ID_LENGTH, SESSION_ID_LENGTH, TOKEN_LENGTH};
use crate::{pdu::PduBuilder, utils, ConduitResult, Database, Error, Ruma}; use crate::{pdu::PduBuilder, utils, ConduitResult, Database, Error, Ruma};
@ -42,7 +42,7 @@ const GUEST_NAME_LENGTH: usize = 10;
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn get_register_available_route( pub async fn get_register_available_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<get_username_availability::Request<'_>>, body: Ruma<get_username_availability::Request<'_>>,
) -> ConduitResult<get_username_availability::Response> { ) -> ConduitResult<get_username_availability::Response> {
// Validate user id // Validate user id
@ -85,7 +85,7 @@ pub async fn get_register_available_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn register_route( pub async fn register_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<register::Request<'_>>, body: Ruma<register::Request<'_>>,
) -> ConduitResult<register::Response> { ) -> ConduitResult<register::Response> {
if !db.globals.allow_registration() { if !db.globals.allow_registration() {
@ -227,7 +227,7 @@ pub async fn register_route(
)?; )?;
// If this is the first user on this server, create the admins room // If this is the first user on this server, create the admins room
if db.users.count() == 1 { if db.users.count()? == 1 {
// Create a user for the server // Create a user for the server
let conduit_user = UserId::parse_with_server_name("conduit", db.globals.server_name()) let conduit_user = UserId::parse_with_server_name("conduit", db.globals.server_name())
.expect("@conduit:server_name is valid"); .expect("@conduit:server_name is valid");
@ -506,7 +506,7 @@ pub async fn register_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn change_password_route( pub async fn change_password_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<change_password::Request<'_>>, body: Ruma<change_password::Request<'_>>,
) -> ConduitResult<change_password::Response> { ) -> ConduitResult<change_password::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -598,7 +598,7 @@ pub async fn whoami_route(body: Ruma<whoami::Request>) -> ConduitResult<whoami::
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn deactivate_route( pub async fn deactivate_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<deactivate::Request<'_>>, body: Ruma<deactivate::Request<'_>>,
) -> ConduitResult<deactivate::Response> { ) -> ConduitResult<deactivate::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");

View File

@ -1,3 +1,5 @@
use std::sync::Arc;
use super::State; use super::State;
use crate::{ConduitResult, Database, Error, Ruma}; use crate::{ConduitResult, Database, Error, Ruma};
use regex::Regex; use regex::Regex;
@ -22,7 +24,7 @@ use rocket::{delete, get, put};
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn create_alias_route( pub async fn create_alias_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<create_alias::Request<'_>>, body: Ruma<create_alias::Request<'_>>,
) -> ConduitResult<create_alias::Response> { ) -> ConduitResult<create_alias::Response> {
if db.rooms.id_from_alias(&body.room_alias)?.is_some() { if db.rooms.id_from_alias(&body.room_alias)?.is_some() {
@ -43,7 +45,7 @@ pub async fn create_alias_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn delete_alias_route( pub async fn delete_alias_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<delete_alias::Request<'_>>, body: Ruma<delete_alias::Request<'_>>,
) -> ConduitResult<delete_alias::Response> { ) -> ConduitResult<delete_alias::Response> {
db.rooms.set_alias(&body.room_alias, None, &db.globals)?; db.rooms.set_alias(&body.room_alias, None, &db.globals)?;
@ -59,7 +61,7 @@ pub async fn delete_alias_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn get_alias_route( pub async fn get_alias_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<get_alias::Request<'_>>, body: Ruma<get_alias::Request<'_>>,
) -> ConduitResult<get_alias::Response> { ) -> ConduitResult<get_alias::Response> {
get_alias_helper(&db, &body.room_alias).await get_alias_helper(&db, &body.room_alias).await
@ -86,7 +88,8 @@ pub async fn get_alias_helper(
match db.rooms.id_from_alias(&room_alias)? { match db.rooms.id_from_alias(&room_alias)? {
Some(r) => room_id = Some(r), Some(r) => room_id = Some(r),
None => { None => {
for (_id, registration) in db.appservice.iter_all().filter_map(|r| r.ok()) { let iter = db.appservice.iter_all()?;
for (_id, registration) in iter.filter_map(|r| r.ok()) {
let aliases = registration let aliases = registration
.get("namespaces") .get("namespaces")
.and_then(|ns| ns.get("aliases")) .and_then(|ns| ns.get("aliases"))

View File

@ -1,3 +1,5 @@
use std::sync::Arc;
use super::State; use super::State;
use crate::{ConduitResult, Database, Error, Ruma}; use crate::{ConduitResult, Database, Error, Ruma};
use ruma::api::client::{ use ruma::api::client::{
@ -19,7 +21,7 @@ use rocket::{delete, get, post, put};
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn create_backup_route( pub async fn create_backup_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<create_backup::Request>, body: Ruma<create_backup::Request>,
) -> ConduitResult<create_backup::Response> { ) -> ConduitResult<create_backup::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -38,7 +40,7 @@ pub async fn create_backup_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn update_backup_route( pub async fn update_backup_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<update_backup::Request<'_>>, body: Ruma<update_backup::Request<'_>>,
) -> ConduitResult<update_backup::Response> { ) -> ConduitResult<update_backup::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -56,7 +58,7 @@ pub async fn update_backup_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn get_latest_backup_route( pub async fn get_latest_backup_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<get_latest_backup::Request>, body: Ruma<get_latest_backup::Request>,
) -> ConduitResult<get_latest_backup::Response> { ) -> ConduitResult<get_latest_backup::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -84,7 +86,7 @@ pub async fn get_latest_backup_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn get_backup_route( pub async fn get_backup_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<get_backup::Request<'_>>, body: Ruma<get_backup::Request<'_>>,
) -> ConduitResult<get_backup::Response> { ) -> ConduitResult<get_backup::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -111,7 +113,7 @@ pub async fn get_backup_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn delete_backup_route( pub async fn delete_backup_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<delete_backup::Request<'_>>, body: Ruma<delete_backup::Request<'_>>,
) -> ConduitResult<delete_backup::Response> { ) -> ConduitResult<delete_backup::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -130,7 +132,7 @@ pub async fn delete_backup_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn add_backup_keys_route( pub async fn add_backup_keys_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<add_backup_keys::Request<'_>>, body: Ruma<add_backup_keys::Request<'_>>,
) -> ConduitResult<add_backup_keys::Response> { ) -> ConduitResult<add_backup_keys::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -164,7 +166,7 @@ pub async fn add_backup_keys_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn add_backup_key_sessions_route( pub async fn add_backup_key_sessions_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<add_backup_key_sessions::Request<'_>>, body: Ruma<add_backup_key_sessions::Request<'_>>,
) -> ConduitResult<add_backup_key_sessions::Response> { ) -> ConduitResult<add_backup_key_sessions::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -196,7 +198,7 @@ pub async fn add_backup_key_sessions_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn add_backup_key_session_route( pub async fn add_backup_key_session_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<add_backup_key_session::Request<'_>>, body: Ruma<add_backup_key_session::Request<'_>>,
) -> ConduitResult<add_backup_key_session::Response> { ) -> ConduitResult<add_backup_key_session::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -225,7 +227,7 @@ pub async fn add_backup_key_session_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn get_backup_keys_route( pub async fn get_backup_keys_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<get_backup_keys::Request<'_>>, body: Ruma<get_backup_keys::Request<'_>>,
) -> ConduitResult<get_backup_keys::Response> { ) -> ConduitResult<get_backup_keys::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -241,14 +243,14 @@ pub async fn get_backup_keys_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn get_backup_key_sessions_route( pub async fn get_backup_key_sessions_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<get_backup_key_sessions::Request<'_>>, body: Ruma<get_backup_key_sessions::Request<'_>>,
) -> ConduitResult<get_backup_key_sessions::Response> { ) -> ConduitResult<get_backup_key_sessions::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sessions = db let sessions = db
.key_backups .key_backups
.get_room(&sender_user, &body.version, &body.room_id); .get_room(&sender_user, &body.version, &body.room_id)?;
Ok(get_backup_key_sessions::Response { sessions }.into()) Ok(get_backup_key_sessions::Response { sessions }.into())
} }
@ -259,7 +261,7 @@ pub async fn get_backup_key_sessions_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn get_backup_key_session_route( pub async fn get_backup_key_session_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<get_backup_key_session::Request<'_>>, body: Ruma<get_backup_key_session::Request<'_>>,
) -> ConduitResult<get_backup_key_session::Response> { ) -> ConduitResult<get_backup_key_session::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -281,7 +283,7 @@ pub async fn get_backup_key_session_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn delete_backup_keys_route( pub async fn delete_backup_keys_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<delete_backup_keys::Request<'_>>, body: Ruma<delete_backup_keys::Request<'_>>,
) -> ConduitResult<delete_backup_keys::Response> { ) -> ConduitResult<delete_backup_keys::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -304,7 +306,7 @@ pub async fn delete_backup_keys_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn delete_backup_key_sessions_route( pub async fn delete_backup_key_sessions_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<delete_backup_key_sessions::Request<'_>>, body: Ruma<delete_backup_key_sessions::Request<'_>>,
) -> ConduitResult<delete_backup_key_sessions::Response> { ) -> ConduitResult<delete_backup_key_sessions::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -327,7 +329,7 @@ pub async fn delete_backup_key_sessions_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn delete_backup_key_session_route( pub async fn delete_backup_key_session_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<delete_backup_key_session::Request<'_>>, body: Ruma<delete_backup_key_session::Request<'_>>,
) -> ConduitResult<delete_backup_key_session::Response> { ) -> ConduitResult<delete_backup_key_session::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");

View File

@ -1,3 +1,5 @@
use std::sync::Arc;
use super::State; use super::State;
use crate::{ConduitResult, Database, Error, Ruma}; use crate::{ConduitResult, Database, Error, Ruma};
use ruma::{ use ruma::{
@ -23,7 +25,7 @@ use rocket::{get, put};
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn set_global_account_data_route( pub async fn set_global_account_data_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<set_global_account_data::Request<'_>>, body: Ruma<set_global_account_data::Request<'_>>,
) -> ConduitResult<set_global_account_data::Response> { ) -> ConduitResult<set_global_account_data::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -58,7 +60,7 @@ pub async fn set_global_account_data_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn set_room_account_data_route( pub async fn set_room_account_data_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<set_room_account_data::Request<'_>>, body: Ruma<set_room_account_data::Request<'_>>,
) -> ConduitResult<set_room_account_data::Response> { ) -> ConduitResult<set_room_account_data::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -90,7 +92,7 @@ pub async fn set_room_account_data_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn get_global_account_data_route( pub async fn get_global_account_data_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<get_global_account_data::Request<'_>>, body: Ruma<get_global_account_data::Request<'_>>,
) -> ConduitResult<get_global_account_data::Response> { ) -> ConduitResult<get_global_account_data::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -117,7 +119,7 @@ pub async fn get_global_account_data_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn get_room_account_data_route( pub async fn get_room_account_data_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<get_room_account_data::Request<'_>>, body: Ruma<get_room_account_data::Request<'_>>,
) -> ConduitResult<get_room_account_data::Response> { ) -> ConduitResult<get_room_account_data::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");

View File

@ -1,7 +1,7 @@
use super::State; use super::State;
use crate::{ConduitResult, Database, Error, Ruma}; use crate::{ConduitResult, Database, Error, Ruma};
use ruma::api::client::{error::ErrorKind, r0::context::get_context}; use ruma::api::client::{error::ErrorKind, r0::context::get_context};
use std::convert::TryFrom; use std::{convert::TryFrom, sync::Arc};
#[cfg(feature = "conduit_bin")] #[cfg(feature = "conduit_bin")]
use rocket::get; use rocket::get;
@ -12,7 +12,7 @@ use rocket::get;
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn get_context_route( pub async fn get_context_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<get_context::Request<'_>>, body: Ruma<get_context::Request<'_>>,
) -> ConduitResult<get_context::Response> { ) -> ConduitResult<get_context::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");

View File

@ -1,3 +1,5 @@
use std::sync::Arc;
use super::State; use super::State;
use crate::{utils, ConduitResult, Database, Error, Ruma}; use crate::{utils, ConduitResult, Database, Error, Ruma};
use ruma::api::client::{ use ruma::api::client::{
@ -18,7 +20,7 @@ use rocket::{delete, get, post, put};
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn get_devices_route( pub async fn get_devices_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<get_devices::Request>, body: Ruma<get_devices::Request>,
) -> ConduitResult<get_devices::Response> { ) -> ConduitResult<get_devices::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -38,7 +40,7 @@ pub async fn get_devices_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn get_device_route( pub async fn get_device_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<get_device::Request<'_>>, body: Ruma<get_device::Request<'_>>,
) -> ConduitResult<get_device::Response> { ) -> ConduitResult<get_device::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -57,7 +59,7 @@ pub async fn get_device_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn update_device_route( pub async fn update_device_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<update_device::Request<'_>>, body: Ruma<update_device::Request<'_>>,
) -> ConduitResult<update_device::Response> { ) -> ConduitResult<update_device::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -83,7 +85,7 @@ pub async fn update_device_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn delete_device_route( pub async fn delete_device_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<delete_device::Request<'_>>, body: Ruma<delete_device::Request<'_>>,
) -> ConduitResult<delete_device::Response> { ) -> ConduitResult<delete_device::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -137,7 +139,7 @@ pub async fn delete_device_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn delete_devices_route( pub async fn delete_devices_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<delete_devices::Request<'_>>, body: Ruma<delete_devices::Request<'_>>,
) -> ConduitResult<delete_devices::Response> { ) -> ConduitResult<delete_devices::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");

View File

@ -1,3 +1,5 @@
use std::sync::Arc;
use super::State; use super::State;
use crate::{ConduitResult, Database, Error, Result, Ruma}; use crate::{ConduitResult, Database, Error, Result, Ruma};
use log::info; use log::info;
@ -33,7 +35,7 @@ use rocket::{get, post, put};
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn get_public_rooms_filtered_route( pub async fn get_public_rooms_filtered_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<get_public_rooms_filtered::Request<'_>>, body: Ruma<get_public_rooms_filtered::Request<'_>>,
) -> ConduitResult<get_public_rooms_filtered::Response> { ) -> ConduitResult<get_public_rooms_filtered::Response> {
get_public_rooms_filtered_helper( get_public_rooms_filtered_helper(
@ -53,7 +55,7 @@ pub async fn get_public_rooms_filtered_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn get_public_rooms_route( pub async fn get_public_rooms_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<get_public_rooms::Request<'_>>, body: Ruma<get_public_rooms::Request<'_>>,
) -> ConduitResult<get_public_rooms::Response> { ) -> ConduitResult<get_public_rooms::Response> {
let response = get_public_rooms_filtered_helper( let response = get_public_rooms_filtered_helper(
@ -82,7 +84,7 @@ pub async fn get_public_rooms_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn set_room_visibility_route( pub async fn set_room_visibility_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<set_room_visibility::Request<'_>>, body: Ruma<set_room_visibility::Request<'_>>,
) -> ConduitResult<set_room_visibility::Response> { ) -> ConduitResult<set_room_visibility::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -112,7 +114,7 @@ pub async fn set_room_visibility_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn get_room_visibility_route( pub async fn get_room_visibility_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<get_room_visibility::Request<'_>>, body: Ruma<get_room_visibility::Request<'_>>,
) -> ConduitResult<get_room_visibility::Response> { ) -> ConduitResult<get_room_visibility::Response> {
Ok(get_room_visibility::Response { Ok(get_room_visibility::Response {

View File

@ -14,7 +14,10 @@ use ruma::{
encryption::UnsignedDeviceInfo, encryption::UnsignedDeviceInfo,
DeviceId, DeviceKeyAlgorithm, UserId, DeviceId, DeviceKeyAlgorithm, UserId,
}; };
use std::collections::{BTreeMap, HashSet}; use std::{
collections::{BTreeMap, HashSet},
sync::Arc,
};
#[cfg(feature = "conduit_bin")] #[cfg(feature = "conduit_bin")]
use rocket::{get, post}; use rocket::{get, post};
@ -25,7 +28,7 @@ use rocket::{get, post};
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn upload_keys_route( pub async fn upload_keys_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<upload_keys::Request>, body: Ruma<upload_keys::Request>,
) -> ConduitResult<upload_keys::Response> { ) -> ConduitResult<upload_keys::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -74,7 +77,7 @@ pub async fn upload_keys_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn get_keys_route( pub async fn get_keys_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<get_keys::Request<'_>>, body: Ruma<get_keys::Request<'_>>,
) -> ConduitResult<get_keys::Response> { ) -> ConduitResult<get_keys::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -95,7 +98,7 @@ pub async fn get_keys_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn claim_keys_route( pub async fn claim_keys_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<claim_keys::Request>, body: Ruma<claim_keys::Request>,
) -> ConduitResult<claim_keys::Response> { ) -> ConduitResult<claim_keys::Response> {
let response = claim_keys_helper(&body.one_time_keys, &db)?; let response = claim_keys_helper(&body.one_time_keys, &db)?;
@ -111,7 +114,7 @@ pub async fn claim_keys_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn upload_signing_keys_route( pub async fn upload_signing_keys_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<upload_signing_keys::Request<'_>>, body: Ruma<upload_signing_keys::Request<'_>>,
) -> ConduitResult<upload_signing_keys::Response> { ) -> ConduitResult<upload_signing_keys::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -174,7 +177,7 @@ pub async fn upload_signing_keys_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn upload_signatures_route( pub async fn upload_signatures_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<upload_signatures::Request>, body: Ruma<upload_signatures::Request>,
) -> ConduitResult<upload_signatures::Response> { ) -> ConduitResult<upload_signatures::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -235,7 +238,7 @@ pub async fn upload_signatures_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn get_key_changes_route( pub async fn get_key_changes_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<get_key_changes::Request<'_>>, body: Ruma<get_key_changes::Request<'_>>,
) -> ConduitResult<get_key_changes::Response> { ) -> ConduitResult<get_key_changes::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");

View File

@ -7,14 +7,14 @@ use ruma::api::client::{
#[cfg(feature = "conduit_bin")] #[cfg(feature = "conduit_bin")]
use rocket::{get, post}; use rocket::{get, post};
use std::convert::TryInto; use std::{convert::TryInto, sync::Arc};
const MXC_LENGTH: usize = 32; const MXC_LENGTH: usize = 32;
#[cfg_attr(feature = "conduit_bin", get("/_matrix/media/r0/config"))] #[cfg_attr(feature = "conduit_bin", get("/_matrix/media/r0/config"))]
#[tracing::instrument(skip(db))] #[tracing::instrument(skip(db))]
pub async fn get_media_config_route( pub async fn get_media_config_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
) -> ConduitResult<get_media_config::Response> { ) -> ConduitResult<get_media_config::Response> {
Ok(get_media_config::Response { Ok(get_media_config::Response {
upload_size: db.globals.max_request_size().into(), upload_size: db.globals.max_request_size().into(),
@ -28,7 +28,7 @@ pub async fn get_media_config_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn create_content_route( pub async fn create_content_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<create_content::Request<'_>>, body: Ruma<create_content::Request<'_>>,
) -> ConduitResult<create_content::Response> { ) -> ConduitResult<create_content::Response> {
let mxc = format!( let mxc = format!(
@ -36,16 +36,20 @@ pub async fn create_content_route(
db.globals.server_name(), db.globals.server_name(),
utils::random_string(MXC_LENGTH) utils::random_string(MXC_LENGTH)
); );
db.media.create(
mxc.clone(), db.media
&body .create(
.filename mxc.clone(),
.as_ref() &db.globals,
.map(|filename| "inline; filename=".to_owned() + filename) &body
.as_deref(), .filename
&body.content_type.as_deref(), .as_ref()
&body.file, .map(|filename| "inline; filename=".to_owned() + filename)
)?; .as_deref(),
&body.content_type.as_deref(),
&body.file,
)
.await?;
db.flush().await?; db.flush().await?;
@ -62,7 +66,7 @@ pub async fn create_content_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn get_content_route( pub async fn get_content_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<get_content::Request<'_>>, body: Ruma<get_content::Request<'_>>,
) -> ConduitResult<get_content::Response> { ) -> ConduitResult<get_content::Response> {
let mxc = format!("mxc://{}/{}", body.server_name, body.media_id); let mxc = format!("mxc://{}/{}", body.server_name, body.media_id);
@ -71,7 +75,7 @@ pub async fn get_content_route(
content_disposition, content_disposition,
content_type, content_type,
file, file,
}) = db.media.get(&mxc)? }) = db.media.get(&db.globals, &mxc).await?
{ {
Ok(get_content::Response { Ok(get_content::Response {
file, file,
@ -93,12 +97,15 @@ pub async fn get_content_route(
) )
.await?; .await?;
db.media.create( db.media
mxc, .create(
&get_content_response.content_disposition.as_deref(), mxc,
&get_content_response.content_type.as_deref(), &db.globals,
&get_content_response.file, &get_content_response.content_disposition.as_deref(),
)?; &get_content_response.content_type.as_deref(),
&get_content_response.file,
)
.await?;
Ok(get_content_response.into()) Ok(get_content_response.into())
} else { } else {
@ -112,22 +119,27 @@ pub async fn get_content_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn get_content_thumbnail_route( pub async fn get_content_thumbnail_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<get_content_thumbnail::Request<'_>>, body: Ruma<get_content_thumbnail::Request<'_>>,
) -> ConduitResult<get_content_thumbnail::Response> { ) -> ConduitResult<get_content_thumbnail::Response> {
let mxc = format!("mxc://{}/{}", body.server_name, body.media_id); let mxc = format!("mxc://{}/{}", body.server_name, body.media_id);
if let Some(FileMeta { if let Some(FileMeta {
content_type, file, .. content_type, file, ..
}) = db.media.get_thumbnail( }) = db
mxc.clone(), .media
body.width .get_thumbnail(
.try_into() mxc.clone(),
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Width is invalid."))?, &db.globals,
body.height body.width
.try_into() .try_into()
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Width is invalid."))?, .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Width is invalid."))?,
)? { body.height
.try_into()
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Width is invalid."))?,
)
.await?
{
Ok(get_content_thumbnail::Response { file, content_type }.into()) Ok(get_content_thumbnail::Response { file, content_type }.into())
} else if &*body.server_name != db.globals.server_name() && body.allow_remote { } else if &*body.server_name != db.globals.server_name() && body.allow_remote {
let get_thumbnail_response = db let get_thumbnail_response = db
@ -146,14 +158,17 @@ pub async fn get_content_thumbnail_route(
) )
.await?; .await?;
db.media.upload_thumbnail( db.media
mxc, .upload_thumbnail(
&None, mxc,
&get_thumbnail_response.content_type, &db.globals,
body.width.try_into().expect("all UInts are valid u32s"), &None,
body.height.try_into().expect("all UInts are valid u32s"), &get_thumbnail_response.content_type,
&get_thumbnail_response.file, body.width.try_into().expect("all UInts are valid u32s"),
)?; body.height.try_into().expect("all UInts are valid u32s"),
&get_thumbnail_response.file,
)
.await?;
Ok(get_thumbnail_response.into()) Ok(get_thumbnail_response.into())
} else { } else {

View File

@ -44,7 +44,7 @@ use rocket::{get, post};
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn join_room_by_id_route( pub async fn join_room_by_id_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<join_room_by_id::Request<'_>>, body: Ruma<join_room_by_id::Request<'_>>,
) -> ConduitResult<join_room_by_id::Response> { ) -> ConduitResult<join_room_by_id::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -81,7 +81,7 @@ pub async fn join_room_by_id_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn join_room_by_id_or_alias_route( pub async fn join_room_by_id_or_alias_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<join_room_by_id_or_alias::Request<'_>>, body: Ruma<join_room_by_id_or_alias::Request<'_>>,
) -> ConduitResult<join_room_by_id_or_alias::Response> { ) -> ConduitResult<join_room_by_id_or_alias::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -135,7 +135,7 @@ pub async fn join_room_by_id_or_alias_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn leave_room_route( pub async fn leave_room_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<leave_room::Request<'_>>, body: Ruma<leave_room::Request<'_>>,
) -> ConduitResult<leave_room::Response> { ) -> ConduitResult<leave_room::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -153,7 +153,7 @@ pub async fn leave_room_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn invite_user_route( pub async fn invite_user_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<invite_user::Request<'_>>, body: Ruma<invite_user::Request<'_>>,
) -> ConduitResult<invite_user::Response> { ) -> ConduitResult<invite_user::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -173,7 +173,7 @@ pub async fn invite_user_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn kick_user_route( pub async fn kick_user_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<kick_user::Request<'_>>, body: Ruma<kick_user::Request<'_>>,
) -> ConduitResult<kick_user::Response> { ) -> ConduitResult<kick_user::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -222,7 +222,7 @@ pub async fn kick_user_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn ban_user_route( pub async fn ban_user_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<ban_user::Request<'_>>, body: Ruma<ban_user::Request<'_>>,
) -> ConduitResult<ban_user::Response> { ) -> ConduitResult<ban_user::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -279,7 +279,7 @@ pub async fn ban_user_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn unban_user_route( pub async fn unban_user_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<unban_user::Request<'_>>, body: Ruma<unban_user::Request<'_>>,
) -> ConduitResult<unban_user::Response> { ) -> ConduitResult<unban_user::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -327,7 +327,7 @@ pub async fn unban_user_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn forget_room_route( pub async fn forget_room_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<forget_room::Request<'_>>, body: Ruma<forget_room::Request<'_>>,
) -> ConduitResult<forget_room::Response> { ) -> ConduitResult<forget_room::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -345,7 +345,7 @@ pub async fn forget_room_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn joined_rooms_route( pub async fn joined_rooms_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<joined_rooms::Request>, body: Ruma<joined_rooms::Request>,
) -> ConduitResult<joined_rooms::Response> { ) -> ConduitResult<joined_rooms::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -366,7 +366,7 @@ pub async fn joined_rooms_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn get_member_events_route( pub async fn get_member_events_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<get_member_events::Request<'_>>, body: Ruma<get_member_events::Request<'_>>,
) -> ConduitResult<get_member_events::Response> { ) -> ConduitResult<get_member_events::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -396,7 +396,7 @@ pub async fn get_member_events_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn joined_members_route( pub async fn joined_members_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<joined_members::Request<'_>>, body: Ruma<joined_members::Request<'_>>,
) -> ConduitResult<joined_members::Response> { ) -> ConduitResult<joined_members::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -621,7 +621,7 @@ async fn join_room_by_id_helper(
&pdu, &pdu,
utils::to_canonical_object(&pdu).expect("Pdu is valid canonical object"), utils::to_canonical_object(&pdu).expect("Pdu is valid canonical object"),
count, count,
pdu_id.into(), &pdu_id,
&[pdu.event_id.clone()], &[pdu.event_id.clone()],
db, db,
)?; )?;

View File

@ -11,6 +11,7 @@ use ruma::{
use std::{ use std::{
collections::BTreeMap, collections::BTreeMap,
convert::{TryFrom, TryInto}, convert::{TryFrom, TryInto},
sync::Arc,
}; };
#[cfg(feature = "conduit_bin")] #[cfg(feature = "conduit_bin")]
@ -22,7 +23,7 @@ use rocket::{get, put};
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn send_message_event_route( pub async fn send_message_event_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<send_message_event::Request<'_>>, body: Ruma<send_message_event::Request<'_>>,
) -> ConduitResult<send_message_event::Response> { ) -> ConduitResult<send_message_event::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -85,7 +86,7 @@ pub async fn send_message_event_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn get_message_events_route( pub async fn get_message_events_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<get_message_events::Request<'_>>, body: Ruma<get_message_events::Request<'_>>,
) -> ConduitResult<get_message_events::Response> { ) -> ConduitResult<get_message_events::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");

View File

@ -1,7 +1,7 @@
use super::State; use super::State;
use crate::{utils, ConduitResult, Database, Ruma}; use crate::{utils, ConduitResult, Database, Ruma};
use ruma::api::client::r0::presence::{get_presence, set_presence}; use ruma::api::client::r0::presence::{get_presence, set_presence};
use std::{convert::TryInto, time::Duration}; use std::{convert::TryInto, sync::Arc, time::Duration};
#[cfg(feature = "conduit_bin")] #[cfg(feature = "conduit_bin")]
use rocket::{get, put}; use rocket::{get, put};
@ -12,7 +12,7 @@ use rocket::{get, put};
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn set_presence_route( pub async fn set_presence_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<set_presence::Request<'_>>, body: Ruma<set_presence::Request<'_>>,
) -> ConduitResult<set_presence::Response> { ) -> ConduitResult<set_presence::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -53,7 +53,7 @@ pub async fn set_presence_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn get_presence_route( pub async fn get_presence_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<get_presence::Request<'_>>, body: Ruma<get_presence::Request<'_>>,
) -> ConduitResult<get_presence::Response> { ) -> ConduitResult<get_presence::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -62,7 +62,7 @@ pub async fn get_presence_route(
for room_id in db for room_id in db
.rooms .rooms
.get_shared_rooms(vec![sender_user.clone(), body.user_id.clone()]) .get_shared_rooms(vec![sender_user.clone(), body.user_id.clone()])?
{ {
let room_id = room_id?; let room_id = room_id?;

View File

@ -13,7 +13,7 @@ use ruma::{
#[cfg(feature = "conduit_bin")] #[cfg(feature = "conduit_bin")]
use rocket::{get, put}; use rocket::{get, put};
use std::convert::TryInto; use std::{convert::TryInto, sync::Arc};
#[cfg_attr( #[cfg_attr(
feature = "conduit_bin", feature = "conduit_bin",
@ -21,7 +21,7 @@ use std::convert::TryInto;
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn set_displayname_route( pub async fn set_displayname_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<set_display_name::Request<'_>>, body: Ruma<set_display_name::Request<'_>>,
) -> ConduitResult<set_display_name::Response> { ) -> ConduitResult<set_display_name::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -107,7 +107,7 @@ pub async fn set_displayname_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn get_displayname_route( pub async fn get_displayname_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<get_display_name::Request<'_>>, body: Ruma<get_display_name::Request<'_>>,
) -> ConduitResult<get_display_name::Response> { ) -> ConduitResult<get_display_name::Response> {
Ok(get_display_name::Response { Ok(get_display_name::Response {
@ -122,7 +122,7 @@ pub async fn get_displayname_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn set_avatar_url_route( pub async fn set_avatar_url_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<set_avatar_url::Request<'_>>, body: Ruma<set_avatar_url::Request<'_>>,
) -> ConduitResult<set_avatar_url::Response> { ) -> ConduitResult<set_avatar_url::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -208,7 +208,7 @@ pub async fn set_avatar_url_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn get_avatar_url_route( pub async fn get_avatar_url_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<get_avatar_url::Request<'_>>, body: Ruma<get_avatar_url::Request<'_>>,
) -> ConduitResult<get_avatar_url::Response> { ) -> ConduitResult<get_avatar_url::Response> {
Ok(get_avatar_url::Response { Ok(get_avatar_url::Response {
@ -223,7 +223,7 @@ pub async fn get_avatar_url_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn get_profile_route( pub async fn get_profile_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<get_profile::Request<'_>>, body: Ruma<get_profile::Request<'_>>,
) -> ConduitResult<get_profile::Response> { ) -> ConduitResult<get_profile::Response> {
if !db.users.exists(&body.user_id)? { if !db.users.exists(&body.user_id)? {

View File

@ -1,3 +1,5 @@
use std::sync::Arc;
use super::State; use super::State;
use crate::{ConduitResult, Database, Error, Ruma}; use crate::{ConduitResult, Database, Error, Ruma};
use ruma::{ use ruma::{
@ -22,7 +24,7 @@ use rocket::{delete, get, post, put};
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn get_pushrules_all_route( pub async fn get_pushrules_all_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<get_pushrules_all::Request>, body: Ruma<get_pushrules_all::Request>,
) -> ConduitResult<get_pushrules_all::Response> { ) -> ConduitResult<get_pushrules_all::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -47,7 +49,7 @@ pub async fn get_pushrules_all_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn get_pushrule_route( pub async fn get_pushrule_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<get_pushrule::Request<'_>>, body: Ruma<get_pushrule::Request<'_>>,
) -> ConduitResult<get_pushrule::Response> { ) -> ConduitResult<get_pushrule::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -101,7 +103,7 @@ pub async fn get_pushrule_route(
)] )]
#[tracing::instrument(skip(db, req))] #[tracing::instrument(skip(db, req))]
pub async fn set_pushrule_route( pub async fn set_pushrule_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
req: Ruma<set_pushrule::Request<'_>>, req: Ruma<set_pushrule::Request<'_>>,
) -> ConduitResult<set_pushrule::Response> { ) -> ConduitResult<set_pushrule::Response> {
let sender_user = req.sender_user.as_ref().expect("user is authenticated"); let sender_user = req.sender_user.as_ref().expect("user is authenticated");
@ -204,7 +206,7 @@ pub async fn set_pushrule_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn get_pushrule_actions_route( pub async fn get_pushrule_actions_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<get_pushrule_actions::Request<'_>>, body: Ruma<get_pushrule_actions::Request<'_>>,
) -> ConduitResult<get_pushrule_actions::Response> { ) -> ConduitResult<get_pushrule_actions::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -263,7 +265,7 @@ pub async fn get_pushrule_actions_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn set_pushrule_actions_route( pub async fn set_pushrule_actions_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<set_pushrule_actions::Request<'_>>, body: Ruma<set_pushrule_actions::Request<'_>>,
) -> ConduitResult<set_pushrule_actions::Response> { ) -> ConduitResult<set_pushrule_actions::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -337,7 +339,7 @@ pub async fn set_pushrule_actions_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn get_pushrule_enabled_route( pub async fn get_pushrule_enabled_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<get_pushrule_enabled::Request<'_>>, body: Ruma<get_pushrule_enabled::Request<'_>>,
) -> ConduitResult<get_pushrule_enabled::Response> { ) -> ConduitResult<get_pushrule_enabled::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -398,7 +400,7 @@ pub async fn get_pushrule_enabled_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn set_pushrule_enabled_route( pub async fn set_pushrule_enabled_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<set_pushrule_enabled::Request<'_>>, body: Ruma<set_pushrule_enabled::Request<'_>>,
) -> ConduitResult<set_pushrule_enabled::Response> { ) -> ConduitResult<set_pushrule_enabled::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -477,7 +479,7 @@ pub async fn set_pushrule_enabled_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn delete_pushrule_route( pub async fn delete_pushrule_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<delete_pushrule::Request<'_>>, body: Ruma<delete_pushrule::Request<'_>>,
) -> ConduitResult<delete_pushrule::Response> { ) -> ConduitResult<delete_pushrule::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -546,7 +548,7 @@ pub async fn delete_pushrule_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn get_pushers_route( pub async fn get_pushers_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<get_pushers::Request>, body: Ruma<get_pushers::Request>,
) -> ConduitResult<get_pushers::Response> { ) -> ConduitResult<get_pushers::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -563,7 +565,7 @@ pub async fn get_pushers_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn set_pushers_route( pub async fn set_pushers_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<set_pusher::Request>, body: Ruma<set_pusher::Request>,
) -> ConduitResult<set_pusher::Response> { ) -> ConduitResult<set_pusher::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");

View File

@ -12,7 +12,7 @@ use ruma::{
#[cfg(feature = "conduit_bin")] #[cfg(feature = "conduit_bin")]
use rocket::post; use rocket::post;
use std::collections::BTreeMap; use std::{collections::BTreeMap, sync::Arc};
#[cfg_attr( #[cfg_attr(
feature = "conduit_bin", feature = "conduit_bin",
@ -20,7 +20,7 @@ use std::collections::BTreeMap;
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn set_read_marker_route( pub async fn set_read_marker_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<set_read_marker::Request<'_>>, body: Ruma<set_read_marker::Request<'_>>,
) -> ConduitResult<set_read_marker::Response> { ) -> ConduitResult<set_read_marker::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -87,7 +87,7 @@ pub async fn set_read_marker_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn create_receipt_route( pub async fn create_receipt_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<create_receipt::Request<'_>>, body: Ruma<create_receipt::Request<'_>>,
) -> ConduitResult<create_receipt::Response> { ) -> ConduitResult<create_receipt::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");

View File

@ -4,6 +4,7 @@ use ruma::{
api::client::r0::redact::redact_event, api::client::r0::redact::redact_event,
events::{room::redaction, EventType}, events::{room::redaction, EventType},
}; };
use std::sync::Arc;
#[cfg(feature = "conduit_bin")] #[cfg(feature = "conduit_bin")]
use rocket::put; use rocket::put;
@ -14,7 +15,7 @@ use rocket::put;
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn redact_event_route( pub async fn redact_event_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<redact_event::Request<'_>>, body: Ruma<redact_event::Request<'_>>,
) -> ConduitResult<redact_event::Response> { ) -> ConduitResult<redact_event::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");

View File

@ -13,7 +13,7 @@ use ruma::{
serde::Raw, serde::Raw,
RoomAliasId, RoomId, RoomVersionId, RoomAliasId, RoomId, RoomVersionId,
}; };
use std::{cmp::max, collections::BTreeMap, convert::TryFrom}; use std::{cmp::max, collections::BTreeMap, convert::TryFrom, sync::Arc};
#[cfg(feature = "conduit_bin")] #[cfg(feature = "conduit_bin")]
use rocket::{get, post}; use rocket::{get, post};
@ -24,7 +24,7 @@ use rocket::{get, post};
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn create_room_route( pub async fn create_room_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<create_room::Request<'_>>, body: Ruma<create_room::Request<'_>>,
) -> ConduitResult<create_room::Response> { ) -> ConduitResult<create_room::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -304,7 +304,7 @@ pub async fn create_room_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn get_room_event_route( pub async fn get_room_event_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<get_room_event::Request<'_>>, body: Ruma<get_room_event::Request<'_>>,
) -> ConduitResult<get_room_event::Response> { ) -> ConduitResult<get_room_event::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -332,7 +332,7 @@ pub async fn get_room_event_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn upgrade_room_route( pub async fn upgrade_room_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<upgrade_room::Request<'_>>, body: Ruma<upgrade_room::Request<'_>>,
_room_id: String, _room_id: String,
) -> ConduitResult<upgrade_room::Response> { ) -> ConduitResult<upgrade_room::Response> {

View File

@ -1,6 +1,7 @@
use super::State; use super::State;
use crate::{ConduitResult, Database, Error, Ruma}; use crate::{ConduitResult, Database, Error, Ruma};
use ruma::api::client::{error::ErrorKind, r0::search::search_events}; use ruma::api::client::{error::ErrorKind, r0::search::search_events};
use std::sync::Arc;
#[cfg(feature = "conduit_bin")] #[cfg(feature = "conduit_bin")]
use rocket::post; use rocket::post;
@ -13,7 +14,7 @@ use std::collections::BTreeMap;
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn search_events_route( pub async fn search_events_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<search_events::Request<'_>>, body: Ruma<search_events::Request<'_>>,
) -> ConduitResult<search_events::Response> { ) -> ConduitResult<search_events::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");

View File

@ -1,3 +1,5 @@
use std::sync::Arc;
use super::{State, DEVICE_ID_LENGTH, TOKEN_LENGTH}; use super::{State, DEVICE_ID_LENGTH, TOKEN_LENGTH};
use crate::{utils, ConduitResult, Database, Error, Ruma}; use crate::{utils, ConduitResult, Database, Error, Ruma};
use log::info; use log::info;
@ -50,7 +52,7 @@ pub async fn get_login_types_route() -> ConduitResult<get_login_types::Response>
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn login_route( pub async fn login_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<login::Request<'_>>, body: Ruma<login::Request<'_>>,
) -> ConduitResult<login::Response> { ) -> ConduitResult<login::Response> {
// Validate login method // Validate login method
@ -167,7 +169,7 @@ pub async fn login_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn logout_route( pub async fn logout_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<logout::Request>, body: Ruma<logout::Request>,
) -> ConduitResult<logout::Response> { ) -> ConduitResult<logout::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -195,7 +197,7 @@ pub async fn logout_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn logout_all_route( pub async fn logout_all_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<logout_all::Request>, body: Ruma<logout_all::Request>,
) -> ConduitResult<logout_all::Response> { ) -> ConduitResult<logout_all::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");

View File

@ -1,3 +1,5 @@
use std::sync::Arc;
use super::State; use super::State;
use crate::{pdu::PduBuilder, ConduitResult, Database, Error, Result, Ruma}; use crate::{pdu::PduBuilder, ConduitResult, Database, Error, Result, Ruma};
use ruma::{ use ruma::{
@ -25,7 +27,7 @@ use rocket::{get, put};
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn send_state_event_for_key_route( pub async fn send_state_event_for_key_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<send_state_event::Request<'_>>, body: Ruma<send_state_event::Request<'_>>,
) -> ConduitResult<send_state_event::Response> { ) -> ConduitResult<send_state_event::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -51,7 +53,7 @@ pub async fn send_state_event_for_key_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn send_state_event_for_empty_key_route( pub async fn send_state_event_for_empty_key_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<send_state_event::Request<'_>>, body: Ruma<send_state_event::Request<'_>>,
) -> ConduitResult<send_state_event::Response> { ) -> ConduitResult<send_state_event::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -77,7 +79,7 @@ pub async fn send_state_event_for_empty_key_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn get_state_events_route( pub async fn get_state_events_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<get_state_events::Request<'_>>, body: Ruma<get_state_events::Request<'_>>,
) -> ConduitResult<get_state_events::Response> { ) -> ConduitResult<get_state_events::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -124,7 +126,7 @@ pub async fn get_state_events_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn get_state_events_for_key_route( pub async fn get_state_events_for_key_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<get_state_events_for_key::Request<'_>>, body: Ruma<get_state_events_for_key::Request<'_>>,
) -> ConduitResult<get_state_events_for_key::Response> { ) -> ConduitResult<get_state_events_for_key::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -175,7 +177,7 @@ pub async fn get_state_events_for_key_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn get_state_events_for_empty_key_route( pub async fn get_state_events_for_empty_key_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<get_state_events_for_key::Request<'_>>, body: Ruma<get_state_events_for_key::Request<'_>>,
) -> ConduitResult<get_state_events_for_key::Response> { ) -> ConduitResult<get_state_events_for_key::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");

View File

@ -1,5 +1,5 @@
use super::State; use super::State;
use crate::{ConduitResult, Database, Error, Ruma}; use crate::{ConduitResult, Database, Error, Result, Ruma};
use log::error; use log::error;
use ruma::{ use ruma::{
api::client::r0::sync::sync_events, api::client::r0::sync::sync_events,
@ -13,6 +13,7 @@ use rocket::{get, tokio};
use std::{ use std::{
collections::{hash_map, BTreeMap, HashMap, HashSet}, collections::{hash_map, BTreeMap, HashMap, HashSet},
convert::{TryFrom, TryInto}, convert::{TryFrom, TryInto},
sync::Arc,
time::Duration, time::Duration,
}; };
@ -33,7 +34,7 @@ use std::{
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn sync_events_route( pub async fn sync_events_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<sync_events::Request<'_>>, body: Ruma<sync_events::Request<'_>>,
) -> ConduitResult<sync_events::Response> { ) -> ConduitResult<sync_events::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -71,18 +72,23 @@ pub async fn sync_events_route(
let mut non_timeline_pdus = db let mut non_timeline_pdus = db
.rooms .rooms
.pdus_since(&sender_user, &room_id, since)? .pdus_until(&sender_user, &room_id, u64::MAX)
.filter_map(|r| { .filter_map(|r| {
// Filter out buggy events
if r.is_err() { if r.is_err() {
error!("Bad pdu in pdus_since: {:?}", r); error!("Bad pdu in pdus_since: {:?}", r);
} }
r.ok() r.ok()
}); // Filter out buggy events })
.take_while(|(pduid, _)| {
db.rooms
.pdu_count(pduid)
.map_or(false, |count| count > since)
});
// Take the last 10 events for the timeline // Take the last 10 events for the timeline
let timeline_pdus = non_timeline_pdus let timeline_pdus = non_timeline_pdus
.by_ref() .by_ref()
.rev()
.take(10) .take(10)
.collect::<Vec<_>>() .collect::<Vec<_>>()
.into_iter() .into_iter()
@ -226,7 +232,7 @@ pub async fn sync_events_route(
match (since_membership, current_membership) { match (since_membership, current_membership) {
(MembershipState::Leave, MembershipState::Join) => { (MembershipState::Leave, MembershipState::Join) => {
// A new user joined an encrypted room // A new user joined an encrypted room
if !share_encrypted_room(&db, &sender_user, &user_id, &room_id) { if !share_encrypted_room(&db, &sender_user, &user_id, &room_id)? {
device_list_updates.insert(user_id); device_list_updates.insert(user_id);
} }
} }
@ -257,6 +263,7 @@ pub async fn sync_events_route(
.filter(|user_id| { .filter(|user_id| {
// Only send keys if the sender doesn't share an encrypted room with the target already // Only send keys if the sender doesn't share an encrypted room with the target already
!share_encrypted_room(&db, sender_user, user_id, &room_id) !share_encrypted_room(&db, sender_user, user_id, &room_id)
.unwrap_or(false)
}), }),
); );
} }
@ -274,7 +281,7 @@ pub async fn sync_events_route(
for hero in db for hero in db
.rooms .rooms
.all_pdus(&sender_user, &room_id)? .all_pdus(&sender_user, &room_id)
.filter_map(|pdu| pdu.ok()) // Ignore all broken pdus .filter_map(|pdu| pdu.ok()) // Ignore all broken pdus
.filter(|(_, pdu)| pdu.kind == EventType::RoomMember) .filter(|(_, pdu)| pdu.kind == EventType::RoomMember)
.map(|(_, pdu)| { .map(|(_, pdu)| {
@ -411,7 +418,7 @@ pub async fn sync_events_route(
let mut edus = db let mut edus = db
.rooms .rooms
.edus .edus
.readreceipts_since(&room_id, since)? .readreceipts_since(&room_id, since)
.filter_map(|r| r.ok()) // Filter out buggy events .filter_map(|r| r.ok()) // Filter out buggy events
.map(|(_, _, v)| v) .map(|(_, _, v)| v)
.collect::<Vec<_>>(); .collect::<Vec<_>>();
@ -549,7 +556,7 @@ pub async fn sync_events_route(
for user_id in left_encrypted_users { for user_id in left_encrypted_users {
let still_share_encrypted_room = db let still_share_encrypted_room = db
.rooms .rooms
.get_shared_rooms(vec![sender_user.clone(), user_id.clone()]) .get_shared_rooms(vec![sender_user.clone(), user_id.clone()])?
.filter_map(|r| r.ok()) .filter_map(|r| r.ok())
.filter_map(|other_room_id| { .filter_map(|other_room_id| {
Some( Some(
@ -639,9 +646,10 @@ fn share_encrypted_room(
sender_user: &UserId, sender_user: &UserId,
user_id: &UserId, user_id: &UserId,
ignore_room: &RoomId, ignore_room: &RoomId,
) -> bool { ) -> Result<bool> {
db.rooms Ok(db
.get_shared_rooms(vec![sender_user.clone(), user_id.clone()]) .rooms
.get_shared_rooms(vec![sender_user.clone(), user_id.clone()])?
.filter_map(|r| r.ok()) .filter_map(|r| r.ok())
.filter(|room_id| room_id != ignore_room) .filter(|room_id| room_id != ignore_room)
.filter_map(|other_room_id| { .filter_map(|other_room_id| {
@ -652,5 +660,5 @@ fn share_encrypted_room(
.is_some(), .is_some(),
) )
}) })
.any(|encrypted| encrypted) .any(|encrypted| encrypted))
} }

View File

@ -4,7 +4,7 @@ use ruma::{
api::client::r0::tag::{create_tag, delete_tag, get_tags}, api::client::r0::tag::{create_tag, delete_tag, get_tags},
events::EventType, events::EventType,
}; };
use std::collections::BTreeMap; use std::{collections::BTreeMap, sync::Arc};
#[cfg(feature = "conduit_bin")] #[cfg(feature = "conduit_bin")]
use rocket::{delete, get, put}; use rocket::{delete, get, put};
@ -15,7 +15,7 @@ use rocket::{delete, get, put};
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn update_tag_route( pub async fn update_tag_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<create_tag::Request<'_>>, body: Ruma<create_tag::Request<'_>>,
) -> ConduitResult<create_tag::Response> { ) -> ConduitResult<create_tag::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -52,7 +52,7 @@ pub async fn update_tag_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn delete_tag_route( pub async fn delete_tag_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<delete_tag::Request<'_>>, body: Ruma<delete_tag::Request<'_>>,
) -> ConduitResult<delete_tag::Response> { ) -> ConduitResult<delete_tag::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -86,7 +86,7 @@ pub async fn delete_tag_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn get_tags_route( pub async fn get_tags_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<get_tags::Request<'_>>, body: Ruma<get_tags::Request<'_>>,
) -> ConduitResult<get_tags::Response> { ) -> ConduitResult<get_tags::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");

View File

@ -1,3 +1,5 @@
use std::sync::Arc;
use super::State; use super::State;
use crate::{ConduitResult, Database, Error, Ruma}; use crate::{ConduitResult, Database, Error, Ruma};
use ruma::api::client::{ use ruma::api::client::{
@ -14,7 +16,7 @@ use rocket::put;
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn send_event_to_device_route( pub async fn send_event_to_device_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<send_event_to_device::Request<'_>>, body: Ruma<send_event_to_device::Request<'_>>,
) -> ConduitResult<send_event_to_device::Response> { ) -> ConduitResult<send_event_to_device::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");

View File

@ -1,3 +1,5 @@
use std::sync::Arc;
use super::State; use super::State;
use crate::{utils, ConduitResult, Database, Ruma}; use crate::{utils, ConduitResult, Database, Ruma};
use create_typing_event::Typing; use create_typing_event::Typing;
@ -12,7 +14,7 @@ use rocket::put;
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub fn create_typing_event_route( pub fn create_typing_event_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<create_typing_event::Request<'_>>, body: Ruma<create_typing_event::Request<'_>>,
) -> ConduitResult<create_typing_event::Response> { ) -> ConduitResult<create_typing_event::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");

View File

@ -1,3 +1,5 @@
use std::sync::Arc;
use super::State; use super::State;
use crate::{ConduitResult, Database, Ruma}; use crate::{ConduitResult, Database, Ruma};
use ruma::api::client::r0::user_directory::search_users; use ruma::api::client::r0::user_directory::search_users;
@ -11,7 +13,7 @@ use rocket::post;
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn search_users_route( pub async fn search_users_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<search_users::Request<'_>>, body: Ruma<search_users::Request<'_>>,
) -> ConduitResult<search_users::Response> { ) -> ConduitResult<search_users::Response> {
let limit = u64::from(body.limit) as usize; let limit = u64::from(body.limit) as usize;

View File

@ -1,3 +1,5 @@
pub mod abstraction;
pub mod account_data; pub mod account_data;
pub mod admin; pub mod admin;
pub mod appservice; pub mod appservice;
@ -12,15 +14,16 @@ pub mod uiaa;
pub mod users; pub mod users;
use crate::{utils, Error, Result}; use crate::{utils, Error, Result};
use abstraction::DatabaseEngine;
use directories::ProjectDirs; use directories::ProjectDirs;
use futures::StreamExt; use log::error;
use log::{error, info}; use rocket::futures::{channel::mpsc, stream::FuturesUnordered, StreamExt};
use rocket::futures::{self, channel::mpsc};
use ruma::{DeviceId, ServerName, UserId}; use ruma::{DeviceId, ServerName, UserId};
use serde::Deserialize; use serde::Deserialize;
use std::{ use std::{
collections::HashMap, collections::HashMap,
fs::remove_dir_all, fs::{self, remove_dir_all},
io::Write,
sync::{Arc, RwLock}, sync::{Arc, RwLock},
}; };
use tokio::sync::Semaphore; use tokio::sync::Semaphore;
@ -74,7 +77,12 @@ fn default_log() -> String {
"info,state_res=warn,rocket=off,_=off,sled=off".to_owned() "info,state_res=warn,rocket=off,_=off,sled=off".to_owned()
} }
#[derive(Clone)] #[cfg(feature = "sled")]
pub type Engine = abstraction::SledEngine;
#[cfg(feature = "rocksdb")]
pub type Engine = abstraction::RocksDbEngine;
pub struct Database { pub struct Database {
pub globals: globals::Globals, pub globals: globals::Globals,
pub users: users::Users, pub users: users::Users,
@ -88,7 +96,6 @@ pub struct Database {
pub admin: admin::Admin, pub admin: admin::Admin,
pub appservice: appservice::Appservice, pub appservice: appservice::Appservice,
pub pusher: pusher::PushData, pub pusher: pusher::PushData,
pub _db: sled::Db,
} }
impl Database { impl Database {
@ -105,126 +112,126 @@ impl Database {
} }
/// Load an existing database or create a new one. /// Load an existing database or create a new one.
pub async fn load_or_create(config: Config) -> Result<Self> { pub async fn load_or_create(config: Config) -> Result<Arc<Self>> {
let db = sled::Config::default() let builder = Engine::open(&config)?;
.path(&config.database_path)
.cache_capacity(config.cache_capacity as u64)
.use_compression(true)
.open()?;
if config.max_request_size < 1024 { if config.max_request_size < 1024 {
eprintln!("ERROR: Max request size is less than 1KB. Please increase it."); eprintln!("ERROR: Max request size is less than 1KB. Please increase it.");
} }
let (admin_sender, admin_receiver) = mpsc::unbounded(); let (admin_sender, admin_receiver) = mpsc::unbounded();
let (sending_sender, sending_receiver) = mpsc::unbounded();
let db = Self { let db = Arc::new(Self {
users: users::Users { users: users::Users {
userid_password: db.open_tree("userid_password")?, userid_password: builder.open_tree("userid_password")?,
userid_displayname: db.open_tree("userid_displayname")?, userid_displayname: builder.open_tree("userid_displayname")?,
userid_avatarurl: db.open_tree("userid_avatarurl")?, userid_avatarurl: builder.open_tree("userid_avatarurl")?,
userdeviceid_token: db.open_tree("userdeviceid_token")?, userdeviceid_token: builder.open_tree("userdeviceid_token")?,
userdeviceid_metadata: db.open_tree("userdeviceid_metadata")?, userdeviceid_metadata: builder.open_tree("userdeviceid_metadata")?,
userid_devicelistversion: db.open_tree("userid_devicelistversion")?, userid_devicelistversion: builder.open_tree("userid_devicelistversion")?,
token_userdeviceid: db.open_tree("token_userdeviceid")?, token_userdeviceid: builder.open_tree("token_userdeviceid")?,
onetimekeyid_onetimekeys: db.open_tree("onetimekeyid_onetimekeys")?, onetimekeyid_onetimekeys: builder.open_tree("onetimekeyid_onetimekeys")?,
userid_lastonetimekeyupdate: db.open_tree("userid_lastonetimekeyupdate")?, userid_lastonetimekeyupdate: builder.open_tree("userid_lastonetimekeyupdate")?,
keychangeid_userid: db.open_tree("keychangeid_userid")?, keychangeid_userid: builder.open_tree("keychangeid_userid")?,
keyid_key: db.open_tree("keyid_key")?, keyid_key: builder.open_tree("keyid_key")?,
userid_masterkeyid: db.open_tree("userid_masterkeyid")?, userid_masterkeyid: builder.open_tree("userid_masterkeyid")?,
userid_selfsigningkeyid: db.open_tree("userid_selfsigningkeyid")?, userid_selfsigningkeyid: builder.open_tree("userid_selfsigningkeyid")?,
userid_usersigningkeyid: db.open_tree("userid_usersigningkeyid")?, userid_usersigningkeyid: builder.open_tree("userid_usersigningkeyid")?,
todeviceid_events: db.open_tree("todeviceid_events")?, todeviceid_events: builder.open_tree("todeviceid_events")?,
}, },
uiaa: uiaa::Uiaa { uiaa: uiaa::Uiaa {
userdevicesessionid_uiaainfo: db.open_tree("userdevicesessionid_uiaainfo")?, userdevicesessionid_uiaainfo: builder.open_tree("userdevicesessionid_uiaainfo")?,
userdevicesessionid_uiaarequest: db.open_tree("userdevicesessionid_uiaarequest")?, userdevicesessionid_uiaarequest: builder
.open_tree("userdevicesessionid_uiaarequest")?,
}, },
rooms: rooms::Rooms { rooms: rooms::Rooms {
edus: rooms::RoomEdus { edus: rooms::RoomEdus {
readreceiptid_readreceipt: db.open_tree("readreceiptid_readreceipt")?, readreceiptid_readreceipt: builder.open_tree("readreceiptid_readreceipt")?,
roomuserid_privateread: db.open_tree("roomuserid_privateread")?, // "Private" read receipt roomuserid_privateread: builder.open_tree("roomuserid_privateread")?, // "Private" read receipt
roomuserid_lastprivatereadupdate: db roomuserid_lastprivatereadupdate: builder
.open_tree("roomuserid_lastprivatereadupdate")?, .open_tree("roomuserid_lastprivatereadupdate")?,
typingid_userid: db.open_tree("typingid_userid")?, typingid_userid: builder.open_tree("typingid_userid")?,
roomid_lasttypingupdate: db.open_tree("roomid_lasttypingupdate")?, roomid_lasttypingupdate: builder.open_tree("roomid_lasttypingupdate")?,
presenceid_presence: db.open_tree("presenceid_presence")?, presenceid_presence: builder.open_tree("presenceid_presence")?,
userid_lastpresenceupdate: db.open_tree("userid_lastpresenceupdate")?, userid_lastpresenceupdate: builder.open_tree("userid_lastpresenceupdate")?,
}, },
pduid_pdu: db.open_tree("pduid_pdu")?, pduid_pdu: builder.open_tree("pduid_pdu")?,
eventid_pduid: db.open_tree("eventid_pduid")?, eventid_pduid: builder.open_tree("eventid_pduid")?,
roomid_pduleaves: db.open_tree("roomid_pduleaves")?, roomid_pduleaves: builder.open_tree("roomid_pduleaves")?,
alias_roomid: db.open_tree("alias_roomid")?, alias_roomid: builder.open_tree("alias_roomid")?,
aliasid_alias: db.open_tree("aliasid_alias")?, aliasid_alias: builder.open_tree("aliasid_alias")?,
publicroomids: db.open_tree("publicroomids")?, publicroomids: builder.open_tree("publicroomids")?,
tokenids: db.open_tree("tokenids")?, tokenids: builder.open_tree("tokenids")?,
roomserverids: db.open_tree("roomserverids")?, roomserverids: builder.open_tree("roomserverids")?,
serverroomids: db.open_tree("serverroomids")?, serverroomids: builder.open_tree("serverroomids")?,
userroomid_joined: db.open_tree("userroomid_joined")?, userroomid_joined: builder.open_tree("userroomid_joined")?,
roomuserid_joined: db.open_tree("roomuserid_joined")?, roomuserid_joined: builder.open_tree("roomuserid_joined")?,
roomuseroncejoinedids: db.open_tree("roomuseroncejoinedids")?, roomuseroncejoinedids: builder.open_tree("roomuseroncejoinedids")?,
userroomid_invitestate: db.open_tree("userroomid_invitestate")?, userroomid_invitestate: builder.open_tree("userroomid_invitestate")?,
roomuserid_invitecount: db.open_tree("roomuserid_invitecount")?, roomuserid_invitecount: builder.open_tree("roomuserid_invitecount")?,
userroomid_leftstate: db.open_tree("userroomid_leftstate")?, userroomid_leftstate: builder.open_tree("userroomid_leftstate")?,
roomuserid_leftcount: db.open_tree("roomuserid_leftcount")?, roomuserid_leftcount: builder.open_tree("roomuserid_leftcount")?,
userroomid_notificationcount: db.open_tree("userroomid_notificationcount")?, userroomid_notificationcount: builder.open_tree("userroomid_notificationcount")?,
userroomid_highlightcount: db.open_tree("userroomid_highlightcount")?, userroomid_highlightcount: builder.open_tree("userroomid_highlightcount")?,
statekey_shortstatekey: db.open_tree("statekey_shortstatekey")?, statekey_shortstatekey: builder.open_tree("statekey_shortstatekey")?,
stateid_shorteventid: db.open_tree("stateid_shorteventid")?, stateid_shorteventid: builder.open_tree("stateid_shorteventid")?,
eventid_shorteventid: db.open_tree("eventid_shorteventid")?, eventid_shorteventid: builder.open_tree("eventid_shorteventid")?,
shorteventid_eventid: db.open_tree("shorteventid_eventid")?, shorteventid_eventid: builder.open_tree("shorteventid_eventid")?,
shorteventid_shortstatehash: db.open_tree("shorteventid_shortstatehash")?, shorteventid_shortstatehash: builder.open_tree("shorteventid_shortstatehash")?,
roomid_shortstatehash: db.open_tree("roomid_shortstatehash")?, roomid_shortstatehash: builder.open_tree("roomid_shortstatehash")?,
statehash_shortstatehash: db.open_tree("statehash_shortstatehash")?, statehash_shortstatehash: builder.open_tree("statehash_shortstatehash")?,
eventid_outlierpdu: db.open_tree("eventid_outlierpdu")?, eventid_outlierpdu: builder.open_tree("eventid_outlierpdu")?,
prevevent_parent: db.open_tree("prevevent_parent")?, prevevent_parent: builder.open_tree("prevevent_parent")?,
}, },
account_data: account_data::AccountData { account_data: account_data::AccountData {
roomuserdataid_accountdata: db.open_tree("roomuserdataid_accountdata")?, roomuserdataid_accountdata: builder.open_tree("roomuserdataid_accountdata")?,
}, },
media: media::Media { media: media::Media {
mediaid_file: db.open_tree("mediaid_file")?, mediaid_file: builder.open_tree("mediaid_file")?,
}, },
key_backups: key_backups::KeyBackups { key_backups: key_backups::KeyBackups {
backupid_algorithm: db.open_tree("backupid_algorithm")?, backupid_algorithm: builder.open_tree("backupid_algorithm")?,
backupid_etag: db.open_tree("backupid_etag")?, backupid_etag: builder.open_tree("backupid_etag")?,
backupkeyid_backup: db.open_tree("backupkeyid_backup")?, backupkeyid_backup: builder.open_tree("backupkeyid_backup")?,
}, },
transaction_ids: transaction_ids::TransactionIds { transaction_ids: transaction_ids::TransactionIds {
userdevicetxnid_response: db.open_tree("userdevicetxnid_response")?, userdevicetxnid_response: builder.open_tree("userdevicetxnid_response")?,
}, },
sending: sending::Sending { sending: sending::Sending {
servername_educount: db.open_tree("servername_educount")?, servername_educount: builder.open_tree("servername_educount")?,
servernamepduids: db.open_tree("servernamepduids")?, servernamepduids: builder.open_tree("servernamepduids")?,
servercurrentevents: db.open_tree("servercurrentevents")?, servercurrentevents: builder.open_tree("servercurrentevents")?,
maximum_requests: Arc::new(Semaphore::new(config.max_concurrent_requests as usize)), maximum_requests: Arc::new(Semaphore::new(config.max_concurrent_requests as usize)),
sender: sending_sender,
}, },
admin: admin::Admin { admin: admin::Admin {
sender: admin_sender, sender: admin_sender,
}, },
appservice: appservice::Appservice { appservice: appservice::Appservice {
cached_registrations: Arc::new(RwLock::new(HashMap::new())), cached_registrations: Arc::new(RwLock::new(HashMap::new())),
id_appserviceregistrations: db.open_tree("id_appserviceregistrations")?, id_appserviceregistrations: builder.open_tree("id_appserviceregistrations")?,
},
pusher: pusher::PushData {
senderkey_pusher: builder.open_tree("senderkey_pusher")?,
}, },
pusher: pusher::PushData::new(&db)?,
globals: globals::Globals::load( globals: globals::Globals::load(
db.open_tree("global")?, builder.open_tree("global")?,
db.open_tree("server_signingkeys")?, builder.open_tree("server_signingkeys")?,
config, config,
)?, )?,
_db: db, });
};
// MIGRATIONS // MIGRATIONS
// TODO: database versions of new dbs should probably not be 0
if db.globals.database_version()? < 1 { if db.globals.database_version()? < 1 {
for roomserverid in db.rooms.roomserverids.iter().keys() { for (roomserverid, _) in db.rooms.roomserverids.iter() {
let roomserverid = roomserverid?;
let mut parts = roomserverid.split(|&b| b == 0xff); let mut parts = roomserverid.split(|&b| b == 0xff);
let room_id = parts.next().expect("split always returns one element"); let room_id = parts.next().expect("split always returns one element");
let servername = match parts.next() { let servername = match parts.next() {
@ -238,37 +245,55 @@ impl Database {
serverroomid.push(0xff); serverroomid.push(0xff);
serverroomid.extend_from_slice(room_id); serverroomid.extend_from_slice(room_id);
db.rooms.serverroomids.insert(serverroomid, &[])?; db.rooms.serverroomids.insert(&serverroomid, &[])?;
} }
db.globals.bump_database_version(1)?; db.globals.bump_database_version(1)?;
info!("Migration: 0 -> 1 finished"); println!("Migration: 0 -> 1 finished");
} }
if db.globals.database_version()? < 2 { if db.globals.database_version()? < 2 {
// We accidentally inserted hashed versions of "" into the db instead of just "" // We accidentally inserted hashed versions of "" into the db instead of just ""
for userid_password in db.users.userid_password.iter() { for (userid, password) in db.users.userid_password.iter() {
let (userid, password) = userid_password?;
let password = utils::string_from_bytes(&password); let password = utils::string_from_bytes(&password);
if password.map_or(false, |password| { let empty_hashed_password = password.map_or(false, |password| {
argon2::verify_encoded(&password, b"").unwrap_or(false) argon2::verify_encoded(&password, b"").unwrap_or(false)
}) { });
db.users.userid_password.insert(userid, b"")?;
if empty_hashed_password {
db.users.userid_password.insert(&userid, b"")?;
} }
} }
db.globals.bump_database_version(2)?; db.globals.bump_database_version(2)?;
info!("Migration: 1 -> 2 finished"); println!("Migration: 1 -> 2 finished");
} }
if db.globals.database_version()? < 3 {
// Move media to filesystem
for (key, content) in db.media.mediaid_file.iter() {
if content.len() == 0 {
continue;
}
let path = db.globals.get_media_file(&key);
let mut file = fs::File::create(path)?;
file.write_all(&content)?;
db.media.mediaid_file.insert(&key, &[])?;
}
db.globals.bump_database_version(3)?;
println!("Migration: 2 -> 3 finished");
}
// This data is probably outdated // This data is probably outdated
db.rooms.edus.presenceid_presence.clear()?; db.rooms.edus.presenceid_presence.clear()?;
db.admin.start_handler(db.clone(), admin_receiver); db.admin.start_handler(Arc::clone(&db), admin_receiver);
db.sending.start_handler(Arc::clone(&db), sending_receiver);
Ok(db) Ok(db)
} }
@ -282,7 +307,7 @@ impl Database {
userdeviceid_prefix.extend_from_slice(device_id.as_bytes()); userdeviceid_prefix.extend_from_slice(device_id.as_bytes());
userdeviceid_prefix.push(0xff); userdeviceid_prefix.push(0xff);
let mut futures = futures::stream::FuturesUnordered::new(); let mut futures = FuturesUnordered::new();
// Return when *any* user changed his key // Return when *any* user changed his key
// TODO: only send for user they share a room with // TODO: only send for user they share a room with

329
src/database/abstraction.rs Normal file
View File

@ -0,0 +1,329 @@
use super::Config;
use crate::{utils, Result};
use log::warn;
use std::{future::Future, pin::Pin, sync::Arc};
#[cfg(feature = "rocksdb")]
use std::{collections::BTreeMap, sync::RwLock};
#[cfg(feature = "sled")]
pub struct SledEngine(sled::Db);
#[cfg(feature = "sled")]
pub struct SledEngineTree(sled::Tree);
#[cfg(feature = "rocksdb")]
pub struct RocksDbEngine(rocksdb::DBWithThreadMode<rocksdb::MultiThreaded>);
#[cfg(feature = "rocksdb")]
pub struct RocksDbEngineTree<'a> {
db: Arc<RocksDbEngine>,
name: &'a str,
watchers: RwLock<BTreeMap<Vec<u8>, Vec<tokio::sync::oneshot::Sender<()>>>>,
}
pub trait DatabaseEngine: Sized {
fn open(config: &Config) -> Result<Arc<Self>>;
fn open_tree(self: &Arc<Self>, name: &'static str) -> Result<Arc<dyn Tree>>;
}
pub trait Tree: Send + Sync {
fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>>;
fn insert(&self, key: &[u8], value: &[u8]) -> Result<()>;
fn remove(&self, key: &[u8]) -> Result<()>;
fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = (Box<[u8]>, Box<[u8]>)> + Send + Sync + 'a>;
fn iter_from<'a>(
&'a self,
from: &[u8],
backwards: bool,
) -> Box<dyn Iterator<Item = (Box<[u8]>, Box<[u8]>)> + 'a>;
fn increment(&self, key: &[u8]) -> Result<Vec<u8>>;
fn scan_prefix<'a>(
&'a self,
prefix: Vec<u8>,
) -> Box<dyn Iterator<Item = (Box<[u8]>, Box<[u8]>)> + Send + 'a>;
fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>>;
fn clear(&self) -> Result<()> {
for (key, _) in self.iter() {
self.remove(&key)?;
}
Ok(())
}
}
#[cfg(feature = "sled")]
impl DatabaseEngine for SledEngine {
fn open(config: &Config) -> Result<Arc<Self>> {
Ok(Arc::new(SledEngine(
sled::Config::default()
.path(&config.database_path)
.cache_capacity(config.cache_capacity as u64)
.use_compression(true)
.open()?,
)))
}
fn open_tree(self: &Arc<Self>, name: &'static str) -> Result<Arc<dyn Tree>> {
Ok(Arc::new(SledEngineTree(self.0.open_tree(name)?)))
}
}
#[cfg(feature = "sled")]
impl Tree for SledEngineTree {
fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
Ok(self.0.get(key)?.map(|v| v.to_vec()))
}
fn insert(&self, key: &[u8], value: &[u8]) -> Result<()> {
self.0.insert(key, value)?;
Ok(())
}
fn remove(&self, key: &[u8]) -> Result<()> {
self.0.remove(key)?;
Ok(())
}
fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = (Box<[u8]>, Box<[u8]>)> + Send + Sync + 'a> {
Box::new(
self.0
.iter()
.filter_map(|r| {
if let Err(e) = &r {
warn!("Error: {}", e);
}
r.ok()
})
.map(|(k, v)| (k.to_vec().into(), v.to_vec().into())),
)
}
fn iter_from(
&self,
from: &[u8],
backwards: bool,
) -> Box<dyn Iterator<Item = (Box<[u8]>, Box<[u8]>)>> {
let iter = if backwards {
self.0.range(..from)
} else {
self.0.range(from..)
};
let iter = iter
.filter_map(|r| {
if let Err(e) = &r {
warn!("Error: {}", e);
}
r.ok()
})
.map(|(k, v)| (k.to_vec().into(), v.to_vec().into()));
if backwards {
Box::new(iter.rev())
} else {
Box::new(iter)
}
}
fn increment(&self, key: &[u8]) -> Result<Vec<u8>> {
Ok(self
.0
.update_and_fetch(key, utils::increment)
.map(|o| o.expect("increment always sets a value").to_vec())?)
}
fn scan_prefix<'a>(
&'a self,
prefix: Vec<u8>,
) -> Box<dyn Iterator<Item = (Box<[u8]>, Box<[u8]>)> + Send + 'a> {
let iter = self
.0
.scan_prefix(prefix)
.filter_map(|r| {
if let Err(e) = &r {
warn!("Error: {}", e);
}
r.ok()
})
.map(|(k, v)| (k.to_vec().into(), v.to_vec().into()));
Box::new(iter)
}
fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
let prefix = prefix.to_vec();
Box::pin(async move {
self.0.watch_prefix(prefix).await;
})
}
}
#[cfg(feature = "rocksdb")]
impl DatabaseEngine for RocksDbEngine {
fn open(config: &Config) -> Result<Arc<Self>> {
let mut db_opts = rocksdb::Options::default();
db_opts.create_if_missing(true);
db_opts.set_max_open_files(16);
db_opts.set_compaction_style(rocksdb::DBCompactionStyle::Level);
db_opts.set_compression_type(rocksdb::DBCompressionType::Snappy);
db_opts.set_target_file_size_base(256 << 20);
db_opts.set_write_buffer_size(256 << 20);
let mut block_based_options = rocksdb::BlockBasedOptions::default();
block_based_options.set_block_size(512 << 10);
db_opts.set_block_based_table_factory(&block_based_options);
let cfs = rocksdb::DBWithThreadMode::<rocksdb::MultiThreaded>::list_cf(
&db_opts,
&config.database_path,
)
.unwrap_or_default();
let mut options = rocksdb::Options::default();
options.set_merge_operator_associative("increment", utils::increment_rocksdb);
let db = rocksdb::DBWithThreadMode::<rocksdb::MultiThreaded>::open_cf_descriptors(
&db_opts,
&config.database_path,
cfs.iter()
.map(|name| rocksdb::ColumnFamilyDescriptor::new(name, options.clone())),
)?;
Ok(Arc::new(RocksDbEngine(db)))
}
fn open_tree(self: &Arc<Self>, name: &'static str) -> Result<Arc<dyn Tree>> {
let mut options = rocksdb::Options::default();
options.set_merge_operator_associative("increment", utils::increment_rocksdb);
// Create if it doesn't exist
let _ = self.0.create_cf(name, &options);
Ok(Arc::new(RocksDbEngineTree {
name,
db: Arc::clone(self),
watchers: RwLock::new(BTreeMap::new()),
}))
}
}
#[cfg(feature = "rocksdb")]
impl RocksDbEngineTree<'_> {
fn cf(&self) -> rocksdb::BoundColumnFamily<'_> {
self.db.0.cf_handle(self.name).unwrap()
}
}
#[cfg(feature = "rocksdb")]
impl Tree for RocksDbEngineTree<'_> {
fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
Ok(self.db.0.get_cf(self.cf(), key)?)
}
fn insert(&self, key: &[u8], value: &[u8]) -> Result<()> {
let watchers = self.watchers.read().unwrap();
let mut triggered = Vec::new();
for length in 0..=key.len() {
if watchers.contains_key(&key[..length]) {
triggered.push(&key[..length]);
}
}
drop(watchers);
if !triggered.is_empty() {
let mut watchers = self.watchers.write().unwrap();
for prefix in triggered {
if let Some(txs) = watchers.remove(prefix) {
for tx in txs {
let _ = tx.send(());
}
}
}
}
Ok(self.db.0.put_cf(self.cf(), key, value)?)
}
fn remove(&self, key: &[u8]) -> Result<()> {
Ok(self.db.0.delete_cf(self.cf(), key)?)
}
fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = (Box<[u8]>, Box<[u8]>)> + Send + Sync + 'a> {
Box::new(
self.db
.0
.iterator_cf(self.cf(), rocksdb::IteratorMode::Start),
)
}
fn iter_from<'a>(
&'a self,
from: &[u8],
backwards: bool,
) -> Box<dyn Iterator<Item = (Box<[u8]>, Box<[u8]>)> + 'a> {
Box::new(self.db.0.iterator_cf(
self.cf(),
rocksdb::IteratorMode::From(
from,
if backwards {
rocksdb::Direction::Reverse
} else {
rocksdb::Direction::Forward
},
),
))
}
fn increment(&self, key: &[u8]) -> Result<Vec<u8>> {
let stats = rocksdb::perf::get_memory_usage_stats(Some(&[&self.db.0]), None).unwrap();
dbg!(stats.mem_table_total);
dbg!(stats.mem_table_unflushed);
dbg!(stats.mem_table_readers_total);
dbg!(stats.cache_total);
// TODO: atomic?
let old = self.get(key)?;
let new = utils::increment(old.as_deref()).unwrap();
self.insert(key, &new)?;
Ok(new)
}
fn scan_prefix<'a>(
&'a self,
prefix: Vec<u8>,
) -> Box<dyn Iterator<Item = (Box<[u8]>, Box<[u8]>)> + Send + 'a> {
Box::new(
self.db
.0
.iterator_cf(
self.cf(),
rocksdb::IteratorMode::From(&prefix, rocksdb::Direction::Forward),
)
.take_while(move |(k, _)| k.starts_with(&prefix)),
)
}
fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
let (tx, rx) = tokio::sync::oneshot::channel();
self.watchers
.write()
.unwrap()
.entry(prefix.to_vec())
.or_default()
.push(tx);
Box::pin(async move {
// Tx is never destroyed
rx.await.unwrap();
})
}
}

View File

@ -6,12 +6,12 @@ use ruma::{
RoomId, UserId, RoomId, UserId,
}; };
use serde::{de::DeserializeOwned, Serialize}; use serde::{de::DeserializeOwned, Serialize};
use sled::IVec; use std::{collections::HashMap, convert::TryFrom, sync::Arc};
use std::{collections::HashMap, convert::TryFrom};
use super::abstraction::Tree;
#[derive(Clone)]
pub struct AccountData { pub struct AccountData {
pub(super) roomuserdataid_accountdata: sled::Tree, // RoomUserDataId = Room + User + Count + Type pub(super) roomuserdataid_accountdata: Arc<dyn Tree>, // RoomUserDataId = Room + User + Count + Type
} }
impl AccountData { impl AccountData {
@ -34,9 +34,8 @@ impl AccountData {
prefix.push(0xff); prefix.push(0xff);
// Remove old entry // Remove old entry
if let Some(previous) = self.find_event(room_id, user_id, &event_type) { if let Some((old_key, _)) = self.find_event(room_id, user_id, &event_type)? {
let (old_key, _) = previous?; self.roomuserdataid_accountdata.remove(&old_key)?;
self.roomuserdataid_accountdata.remove(old_key)?;
} }
let mut key = prefix; let mut key = prefix;
@ -52,8 +51,10 @@ impl AccountData {
)); ));
} }
self.roomuserdataid_accountdata self.roomuserdataid_accountdata.insert(
.insert(key, &*json.to_string())?; &key,
&serde_json::to_vec(&json).expect("to_vec always works on json values"),
)?;
Ok(()) Ok(())
} }
@ -65,9 +66,8 @@ impl AccountData {
user_id: &UserId, user_id: &UserId,
kind: EventType, kind: EventType,
) -> Result<Option<T>> { ) -> Result<Option<T>> {
self.find_event(room_id, user_id, &kind) self.find_event(room_id, user_id, &kind)?
.map(|r| { .map(|(_, v)| {
let (_, v) = r?;
serde_json::from_slice(&v).map_err(|_| Error::bad_database("could not deserialize")) serde_json::from_slice(&v).map_err(|_| Error::bad_database("could not deserialize"))
}) })
.transpose() .transpose()
@ -98,8 +98,7 @@ impl AccountData {
for r in self for r in self
.roomuserdataid_accountdata .roomuserdataid_accountdata
.range(&*first_possible..) .iter_from(&first_possible, false)
.filter_map(|r| r.ok())
.take_while(move |(k, _)| k.starts_with(&prefix)) .take_while(move |(k, _)| k.starts_with(&prefix))
.map(|(k, v)| { .map(|(k, v)| {
Ok::<_, Error>(( Ok::<_, Error>((
@ -128,7 +127,7 @@ impl AccountData {
room_id: Option<&RoomId>, room_id: Option<&RoomId>,
user_id: &UserId, user_id: &UserId,
kind: &EventType, kind: &EventType,
) -> Option<Result<(IVec, IVec)>> { ) -> Result<Option<(Box<[u8]>, Box<[u8]>)>> {
let mut prefix = room_id let mut prefix = room_id
.map(|r| r.to_string()) .map(|r| r.to_string())
.unwrap_or_default() .unwrap_or_default()
@ -137,23 +136,21 @@ impl AccountData {
prefix.push(0xff); prefix.push(0xff);
prefix.extend_from_slice(&user_id.as_bytes()); prefix.extend_from_slice(&user_id.as_bytes());
prefix.push(0xff); prefix.push(0xff);
let mut last_possible_key = prefix.clone();
last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes());
let kind = kind.clone(); let kind = kind.clone();
self.roomuserdataid_accountdata Ok(self
.scan_prefix(prefix) .roomuserdataid_accountdata
.rev() .iter_from(&last_possible_key, true)
.find(move |r| { .take_while(move |(k, _)| k.starts_with(&prefix))
r.as_ref() .find(move |(k, _)| {
.map(|(k, _)| { k.rsplit(|&b| b == 0xff)
k.rsplit(|&b| b == 0xff) .next()
.next() .map(|current_event_type| current_event_type == kind.as_ref().as_bytes())
.map(|current_event_type| {
current_event_type == kind.as_ref().as_bytes()
})
.unwrap_or(false)
})
.unwrap_or(false) .unwrap_or(false)
}) }))
.map(|r| Ok(r?))
} }
} }

View File

@ -1,6 +1,9 @@
use std::convert::{TryFrom, TryInto}; use std::{
convert::{TryFrom, TryInto},
sync::Arc,
};
use crate::pdu::PduBuilder; use crate::{pdu::PduBuilder, Database};
use log::warn; use log::warn;
use rocket::futures::{channel::mpsc, stream::StreamExt}; use rocket::futures::{channel::mpsc, stream::StreamExt};
use ruma::{ use ruma::{
@ -22,7 +25,7 @@ pub struct Admin {
impl Admin { impl Admin {
pub fn start_handler( pub fn start_handler(
&self, &self,
db: super::Database, db: Arc<Database>,
mut receiver: mpsc::UnboundedReceiver<AdminCommand>, mut receiver: mpsc::UnboundedReceiver<AdminCommand>,
) { ) {
tokio::spawn(async move { tokio::spawn(async move {
@ -73,14 +76,17 @@ impl Admin {
db.appservice.register_appservice(yaml).unwrap(); // TODO handle error db.appservice.register_appservice(yaml).unwrap(); // TODO handle error
} }
AdminCommand::ListAppservices => { AdminCommand::ListAppservices => {
let appservices = db.appservice.iter_ids().collect::<Vec<_>>(); if let Ok(appservices) = db.appservice.iter_ids().map(|ids| ids.collect::<Vec<_>>()) {
let count = appservices.len(); let count = appservices.len();
let output = format!( let output = format!(
"Appservices ({}): {}", "Appservices ({}): {}",
count, count,
appservices.into_iter().filter_map(|r| r.ok()).collect::<Vec<_>>().join(", ") appservices.into_iter().filter_map(|r| r.ok()).collect::<Vec<_>>().join(", ")
); );
send_message(message::MessageEventContent::text_plain(output)); send_message(message::MessageEventContent::text_plain(output));
} else {
send_message(message::MessageEventContent::text_plain("Failed to get appservices."));
}
} }
AdminCommand::SendMessage(message) => { AdminCommand::SendMessage(message) => {
send_message(message); send_message(message);
@ -93,6 +99,6 @@ impl Admin {
} }
pub fn send(&self, command: AdminCommand) { pub fn send(&self, command: AdminCommand) {
self.sender.unbounded_send(command).unwrap() self.sender.unbounded_send(command).unwrap();
} }
} }

View File

@ -4,18 +4,21 @@ use std::{
sync::{Arc, RwLock}, sync::{Arc, RwLock},
}; };
#[derive(Clone)] use super::abstraction::Tree;
pub struct Appservice { pub struct Appservice {
pub(super) cached_registrations: Arc<RwLock<HashMap<String, serde_yaml::Value>>>, pub(super) cached_registrations: Arc<RwLock<HashMap<String, serde_yaml::Value>>>,
pub(super) id_appserviceregistrations: sled::Tree, pub(super) id_appserviceregistrations: Arc<dyn Tree>,
} }
impl Appservice { impl Appservice {
pub fn register_appservice(&self, yaml: serde_yaml::Value) -> Result<()> { pub fn register_appservice(&self, yaml: serde_yaml::Value) -> Result<()> {
// TODO: Rumaify // TODO: Rumaify
let id = yaml.get("id").unwrap().as_str().unwrap(); let id = yaml.get("id").unwrap().as_str().unwrap();
self.id_appserviceregistrations self.id_appserviceregistrations.insert(
.insert(id, serde_yaml::to_string(&yaml).unwrap().as_bytes())?; id.as_bytes(),
serde_yaml::to_string(&yaml).unwrap().as_bytes(),
)?;
self.cached_registrations self.cached_registrations
.write() .write()
.unwrap() .unwrap()
@ -33,7 +36,7 @@ impl Appservice {
|| { || {
Ok(self Ok(self
.id_appserviceregistrations .id_appserviceregistrations
.get(id)? .get(id.as_bytes())?
.map(|bytes| { .map(|bytes| {
Ok::<_, Error>(serde_yaml::from_slice(&bytes).map_err(|_| { Ok::<_, Error>(serde_yaml::from_slice(&bytes).map_err(|_| {
Error::bad_database( Error::bad_database(
@ -47,21 +50,25 @@ impl Appservice {
) )
} }
pub fn iter_ids(&self) -> impl Iterator<Item = Result<String>> { pub fn iter_ids<'a>(
self.id_appserviceregistrations.iter().keys().map(|id| { &'a self,
Ok(utils::string_from_bytes(&id?).map_err(|_| { ) -> Result<impl Iterator<Item = Result<String>> + Send + Sync + 'a> {
Ok(self.id_appserviceregistrations.iter().map(|(id, _)| {
Ok(utils::string_from_bytes(&id).map_err(|_| {
Error::bad_database("Invalid id bytes in id_appserviceregistrations.") Error::bad_database("Invalid id bytes in id_appserviceregistrations.")
})?) })?)
}) }))
} }
pub fn iter_all(&self) -> impl Iterator<Item = Result<(String, serde_yaml::Value)>> + '_ { pub fn iter_all(
self.iter_ids().filter_map(|id| id.ok()).map(move |id| { &self,
) -> Result<impl Iterator<Item = Result<(String, serde_yaml::Value)>> + '_ + Send + Sync> {
Ok(self.iter_ids()?.filter_map(|id| id.ok()).map(move |id| {
Ok(( Ok((
id.clone(), id.clone(),
self.get_registration(&id)? self.get_registration(&id)?
.expect("iter_ids only returns appservices that exist"), .expect("iter_ids only returns appservices that exist"),
)) ))
}) }))
} }
} }

View File

@ -7,28 +7,31 @@ use ruma::{
use rustls::{ServerCertVerifier, WebPKIVerifier}; use rustls::{ServerCertVerifier, WebPKIVerifier};
use std::{ use std::{
collections::{BTreeMap, HashMap}, collections::{BTreeMap, HashMap},
fs,
path::PathBuf,
sync::{Arc, RwLock}, sync::{Arc, RwLock},
time::{Duration, Instant}, time::{Duration, Instant},
}; };
use tokio::sync::Semaphore; use tokio::sync::Semaphore;
use trust_dns_resolver::TokioAsyncResolver; use trust_dns_resolver::TokioAsyncResolver;
pub const COUNTER: &str = "c"; use super::abstraction::Tree;
pub const COUNTER: &[u8] = b"c";
type WellKnownMap = HashMap<Box<ServerName>, (String, String)>; type WellKnownMap = HashMap<Box<ServerName>, (String, String)>;
type TlsNameMap = HashMap<String, webpki::DNSName>; type TlsNameMap = HashMap<String, webpki::DNSName>;
type RateLimitState = (Instant, u32); // Time if last failed try, number of failed tries type RateLimitState = (Instant, u32); // Time if last failed try, number of failed tries
#[derive(Clone)]
pub struct Globals { pub struct Globals {
pub actual_destination_cache: Arc<RwLock<WellKnownMap>>, // actual_destination, host pub actual_destination_cache: Arc<RwLock<WellKnownMap>>, // actual_destination, host
pub tls_name_override: Arc<RwLock<TlsNameMap>>, pub tls_name_override: Arc<RwLock<TlsNameMap>>,
pub(super) globals: sled::Tree, pub(super) globals: Arc<dyn Tree>,
config: Config, config: Config,
keypair: Arc<ruma::signatures::Ed25519KeyPair>, keypair: Arc<ruma::signatures::Ed25519KeyPair>,
reqwest_client: reqwest::Client, reqwest_client: reqwest::Client,
dns_resolver: TokioAsyncResolver, dns_resolver: TokioAsyncResolver,
jwt_decoding_key: Option<jsonwebtoken::DecodingKey<'static>>, jwt_decoding_key: Option<jsonwebtoken::DecodingKey<'static>>,
pub(super) server_signingkeys: sled::Tree, pub(super) server_signingkeys: Arc<dyn Tree>,
pub bad_event_ratelimiter: Arc<RwLock<BTreeMap<EventId, RateLimitState>>>, pub bad_event_ratelimiter: Arc<RwLock<BTreeMap<EventId, RateLimitState>>>,
pub bad_signature_ratelimiter: Arc<RwLock<BTreeMap<Vec<String>, RateLimitState>>>, pub bad_signature_ratelimiter: Arc<RwLock<BTreeMap<Vec<String>, RateLimitState>>>,
pub servername_ratelimiter: Arc<RwLock<BTreeMap<Box<ServerName>, Arc<Semaphore>>>>, pub servername_ratelimiter: Arc<RwLock<BTreeMap<Box<ServerName>, Arc<Semaphore>>>>,
@ -69,15 +72,20 @@ impl ServerCertVerifier for MatrixServerVerifier {
impl Globals { impl Globals {
pub fn load( pub fn load(
globals: sled::Tree, globals: Arc<dyn Tree>,
server_signingkeys: sled::Tree, server_signingkeys: Arc<dyn Tree>,
config: Config, config: Config,
) -> Result<Self> { ) -> Result<Self> {
let bytes = &*globals let keypair_bytes = globals.get(b"keypair")?.map_or_else(
.update_and_fetch("keypair", utils::generate_keypair)? || {
.expect("utils::generate_keypair always returns Some"); let keypair = utils::generate_keypair();
globals.insert(b"keypair", &keypair)?;
Ok::<_, Error>(keypair)
},
|s| Ok(s.to_vec()),
)?;
let mut parts = bytes.splitn(2, |&b| b == 0xff); let mut parts = keypair_bytes.splitn(2, |&b| b == 0xff);
let keypair = utils::string_from_bytes( let keypair = utils::string_from_bytes(
// 1. version // 1. version
@ -102,7 +110,7 @@ impl Globals {
Ok(k) => k, Ok(k) => k,
Err(e) => { Err(e) => {
error!("Keypair invalid. Deleting..."); error!("Keypair invalid. Deleting...");
globals.remove("keypair")?; globals.remove(b"keypair")?;
return Err(e); return Err(e);
} }
}; };
@ -130,7 +138,7 @@ impl Globals {
.as_ref() .as_ref()
.map(|secret| jsonwebtoken::DecodingKey::from_secret(secret.as_bytes()).into_static()); .map(|secret| jsonwebtoken::DecodingKey::from_secret(secret.as_bytes()).into_static());
Ok(Self { let s = Self {
globals, globals,
config, config,
keypair: Arc::new(keypair), keypair: Arc::new(keypair),
@ -145,7 +153,11 @@ impl Globals {
bad_event_ratelimiter: Arc::new(RwLock::new(BTreeMap::new())), bad_event_ratelimiter: Arc::new(RwLock::new(BTreeMap::new())),
bad_signature_ratelimiter: Arc::new(RwLock::new(BTreeMap::new())), bad_signature_ratelimiter: Arc::new(RwLock::new(BTreeMap::new())),
servername_ratelimiter: Arc::new(RwLock::new(BTreeMap::new())), servername_ratelimiter: Arc::new(RwLock::new(BTreeMap::new())),
}) };
fs::create_dir_all(s.get_media_folder())?;
Ok(s)
} }
/// Returns this server's keypair. /// Returns this server's keypair.
@ -159,13 +171,8 @@ impl Globals {
} }
pub fn next_count(&self) -> Result<u64> { pub fn next_count(&self) -> Result<u64> {
Ok(utils::u64_from_bytes( Ok(utils::u64_from_bytes(&self.globals.increment(COUNTER)?)
&self .map_err(|_| Error::bad_database("Count has invalid bytes."))?)
.globals
.update_and_fetch(COUNTER, utils::increment)?
.expect("utils::increment will always put in a value"),
)
.map_err(|_| Error::bad_database("Count has invalid bytes."))?)
} }
pub fn current_count(&self) -> Result<u64> { pub fn current_count(&self) -> Result<u64> {
@ -211,21 +218,30 @@ impl Globals {
/// Remove the outdated keys and insert the new ones. /// Remove the outdated keys and insert the new ones.
/// ///
/// This doesn't actually check that the keys provided are newer than the old set. /// This doesn't actually check that the keys provided are newer than the old set.
pub fn add_signing_key(&self, origin: &ServerName, new_keys: &ServerSigningKeys) -> Result<()> { pub fn add_signing_key(&self, origin: &ServerName, new_keys: ServerSigningKeys) -> Result<()> {
self.server_signingkeys // Not atomic, but this is not critical
.update_and_fetch(origin.as_bytes(), |signingkeys| { let signingkeys = self.server_signingkeys.get(origin.as_bytes())?;
let mut keys = signingkeys
.and_then(|keys| serde_json::from_slice(keys).ok()) let mut keys = signingkeys
.unwrap_or_else(|| { .and_then(|keys| serde_json::from_slice(&keys).ok())
// Just insert "now", it doesn't matter .unwrap_or_else(|| {
ServerSigningKeys::new(origin.to_owned(), MilliSecondsSinceUnixEpoch::now()) // Just insert "now", it doesn't matter
}); ServerSigningKeys::new(origin.to_owned(), MilliSecondsSinceUnixEpoch::now())
keys.verify_keys });
.extend(new_keys.verify_keys.clone().into_iter());
keys.old_verify_keys let ServerSigningKeys {
.extend(new_keys.old_verify_keys.clone().into_iter()); verify_keys,
Some(serde_json::to_vec(&keys).expect("serversigningkeys can be serialized")) old_verify_keys,
})?; ..
} = new_keys;
keys.verify_keys.extend(verify_keys.into_iter());
keys.old_verify_keys.extend(old_verify_keys.into_iter());
self.server_signingkeys.insert(
origin.as_bytes(),
&serde_json::to_vec(&keys).expect("serversigningkeys can be serialized"),
)?;
Ok(()) Ok(())
} }
@ -254,14 +270,30 @@ impl Globals {
} }
pub fn database_version(&self) -> Result<u64> { pub fn database_version(&self) -> Result<u64> {
self.globals.get("version")?.map_or(Ok(0), |version| { self.globals.get(b"version")?.map_or(Ok(0), |version| {
utils::u64_from_bytes(&version) utils::u64_from_bytes(&version)
.map_err(|_| Error::bad_database("Database version id is invalid.")) .map_err(|_| Error::bad_database("Database version id is invalid."))
}) })
} }
pub fn bump_database_version(&self, new_version: u64) -> Result<()> { pub fn bump_database_version(&self, new_version: u64) -> Result<()> {
self.globals.insert("version", &new_version.to_be_bytes())?; self.globals
.insert(b"version", &new_version.to_be_bytes())?;
Ok(()) Ok(())
} }
pub fn get_media_folder(&self) -> PathBuf {
let mut r = PathBuf::new();
r.push(self.config.database_path.clone());
r.push("media");
r
}
pub fn get_media_file(&self, key: &[u8]) -> PathBuf {
let mut r = PathBuf::new();
r.push(self.config.database_path.clone());
r.push("media");
r.push(base64::encode_config(key, base64::URL_SAFE_NO_PAD));
r
}
} }

View File

@ -6,13 +6,14 @@ use ruma::{
}, },
RoomId, UserId, RoomId, UserId,
}; };
use std::{collections::BTreeMap, convert::TryFrom}; use std::{collections::BTreeMap, convert::TryFrom, sync::Arc};
use super::abstraction::Tree;
#[derive(Clone)]
pub struct KeyBackups { pub struct KeyBackups {
pub(super) backupid_algorithm: sled::Tree, // BackupId = UserId + Version(Count) pub(super) backupid_algorithm: Arc<dyn Tree>, // BackupId = UserId + Version(Count)
pub(super) backupid_etag: sled::Tree, // BackupId = UserId + Version(Count) pub(super) backupid_etag: Arc<dyn Tree>, // BackupId = UserId + Version(Count)
pub(super) backupkeyid_backup: sled::Tree, // BackupKeyId = UserId + Version + RoomId + SessionId pub(super) backupkeyid_backup: Arc<dyn Tree>, // BackupKeyId = UserId + Version + RoomId + SessionId
} }
impl KeyBackups { impl KeyBackups {
@ -30,8 +31,7 @@ impl KeyBackups {
self.backupid_algorithm.insert( self.backupid_algorithm.insert(
&key, &key,
&*serde_json::to_string(backup_metadata) &serde_json::to_vec(backup_metadata).expect("BackupAlgorithm::to_vec always works"),
.expect("BackupAlgorithm::to_string always works"),
)?; )?;
self.backupid_etag self.backupid_etag
.insert(&key, &globals.next_count()?.to_be_bytes())?; .insert(&key, &globals.next_count()?.to_be_bytes())?;
@ -48,13 +48,8 @@ impl KeyBackups {
key.push(0xff); key.push(0xff);
for outdated_key in self for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
.backupkeyid_backup self.backupkeyid_backup.remove(&outdated_key)?;
.scan_prefix(&key)
.keys()
.filter_map(|r| r.ok())
{
self.backupkeyid_backup.remove(outdated_key)?;
} }
Ok(()) Ok(())
@ -80,8 +75,9 @@ impl KeyBackups {
self.backupid_algorithm.insert( self.backupid_algorithm.insert(
&key, &key,
&*serde_json::to_string(backup_metadata) &serde_json::to_string(backup_metadata)
.expect("BackupAlgorithm::to_string always works"), .expect("BackupAlgorithm::to_string always works")
.as_bytes(),
)?; )?;
self.backupid_etag self.backupid_etag
.insert(&key, &globals.next_count()?.to_be_bytes())?; .insert(&key, &globals.next_count()?.to_be_bytes())?;
@ -91,11 +87,14 @@ impl KeyBackups {
pub fn get_latest_backup(&self, user_id: &UserId) -> Result<Option<(String, BackupAlgorithm)>> { pub fn get_latest_backup(&self, user_id: &UserId) -> Result<Option<(String, BackupAlgorithm)>> {
let mut prefix = user_id.as_bytes().to_vec(); let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xff); prefix.push(0xff);
let mut last_possible_key = prefix.clone();
last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes());
self.backupid_algorithm self.backupid_algorithm
.scan_prefix(&prefix) .iter_from(&last_possible_key, true)
.last() .take_while(move |(k, _)| k.starts_with(&prefix))
.map_or(Ok(None), |r| { .next()
let (key, value) = r?; .map_or(Ok(None), |(key, value)| {
let version = utils::string_from_bytes( let version = utils::string_from_bytes(
key.rsplit(|&b| b == 0xff) key.rsplit(|&b| b == 0xff)
.next() .next()
@ -117,10 +116,13 @@ impl KeyBackups {
key.push(0xff); key.push(0xff);
key.extend_from_slice(version.as_bytes()); key.extend_from_slice(version.as_bytes());
self.backupid_algorithm.get(key)?.map_or(Ok(None), |bytes| { self.backupid_algorithm
Ok(serde_json::from_slice(&bytes) .get(&key)?
.map_err(|_| Error::bad_database("Algorithm in backupid_algorithm is invalid."))?) .map_or(Ok(None), |bytes| {
}) Ok(serde_json::from_slice(&bytes).map_err(|_| {
Error::bad_database("Algorithm in backupid_algorithm is invalid.")
})?)
})
} }
pub fn add_key( pub fn add_key(
@ -153,7 +155,7 @@ impl KeyBackups {
self.backupkeyid_backup.insert( self.backupkeyid_backup.insert(
&key, &key,
&*serde_json::to_string(&key_data).expect("KeyBackupData::to_string always works"), &serde_json::to_vec(&key_data).expect("KeyBackupData::to_vec always works"),
)?; )?;
Ok(()) Ok(())
@ -164,7 +166,7 @@ impl KeyBackups {
prefix.push(0xff); prefix.push(0xff);
prefix.extend_from_slice(version.as_bytes()); prefix.extend_from_slice(version.as_bytes());
Ok(self.backupkeyid_backup.scan_prefix(&prefix).count()) Ok(self.backupkeyid_backup.scan_prefix(prefix).count())
} }
pub fn get_etag(&self, user_id: &UserId, version: &str) -> Result<String> { pub fn get_etag(&self, user_id: &UserId, version: &str) -> Result<String> {
@ -194,33 +196,37 @@ impl KeyBackups {
let mut rooms = BTreeMap::<RoomId, RoomKeyBackup>::new(); let mut rooms = BTreeMap::<RoomId, RoomKeyBackup>::new();
for result in self.backupkeyid_backup.scan_prefix(&prefix).map(|r| { for result in self
let (key, value) = r?; .backupkeyid_backup
let mut parts = key.rsplit(|&b| b == 0xff); .scan_prefix(prefix)
.map(|(key, value)| {
let mut parts = key.rsplit(|&b| b == 0xff);
let session_id = utils::string_from_bytes( let session_id =
&parts utils::string_from_bytes(&parts.next().ok_or_else(|| {
.next() Error::bad_database("backupkeyid_backup key is invalid.")
.ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?, })?)
) .map_err(|_| {
.map_err(|_| Error::bad_database("backupkeyid_backup session_id is invalid."))?; Error::bad_database("backupkeyid_backup session_id is invalid.")
})?;
let room_id = RoomId::try_from( let room_id = RoomId::try_from(
utils::string_from_bytes( utils::string_from_bytes(&parts.next().ok_or_else(|| {
&parts Error::bad_database("backupkeyid_backup key is invalid.")
.next() })?)
.ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?, .map_err(|_| Error::bad_database("backupkeyid_backup room_id is invalid."))?,
) )
.map_err(|_| Error::bad_database("backupkeyid_backup room_id is invalid."))?, .map_err(|_| {
) Error::bad_database("backupkeyid_backup room_id is invalid room id.")
.map_err(|_| Error::bad_database("backupkeyid_backup room_id is invalid room id."))?; })?;
let key_data = serde_json::from_slice(&value).map_err(|_| { let key_data = serde_json::from_slice(&value).map_err(|_| {
Error::bad_database("KeyBackupData in backupkeyid_backup is invalid.") Error::bad_database("KeyBackupData in backupkeyid_backup is invalid.")
})?; })?;
Ok::<_, Error>((room_id, session_id, key_data)) Ok::<_, Error>((room_id, session_id, key_data))
}) { })
{
let (room_id, session_id, key_data) = result?; let (room_id, session_id, key_data) = result?;
rooms rooms
.entry(room_id) .entry(room_id)
@ -239,7 +245,7 @@ impl KeyBackups {
user_id: &UserId, user_id: &UserId,
version: &str, version: &str,
room_id: &RoomId, room_id: &RoomId,
) -> BTreeMap<String, KeyBackupData> { ) -> Result<BTreeMap<String, KeyBackupData>> {
let mut prefix = user_id.as_bytes().to_vec(); let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xff); prefix.push(0xff);
prefix.extend_from_slice(version.as_bytes()); prefix.extend_from_slice(version.as_bytes());
@ -247,10 +253,10 @@ impl KeyBackups {
prefix.extend_from_slice(room_id.as_bytes()); prefix.extend_from_slice(room_id.as_bytes());
prefix.push(0xff); prefix.push(0xff);
self.backupkeyid_backup Ok(self
.scan_prefix(&prefix) .backupkeyid_backup
.map(|r| { .scan_prefix(prefix)
let (key, value) = r?; .map(|(key, value)| {
let mut parts = key.rsplit(|&b| b == 0xff); let mut parts = key.rsplit(|&b| b == 0xff);
let session_id = let session_id =
@ -268,7 +274,7 @@ impl KeyBackups {
Ok::<_, Error>((session_id, key_data)) Ok::<_, Error>((session_id, key_data))
}) })
.filter_map(|r| r.ok()) .filter_map(|r| r.ok())
.collect() .collect())
} }
pub fn get_session( pub fn get_session(
@ -302,13 +308,8 @@ impl KeyBackups {
key.extend_from_slice(&version.as_bytes()); key.extend_from_slice(&version.as_bytes());
key.push(0xff); key.push(0xff);
for outdated_key in self for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
.backupkeyid_backup self.backupkeyid_backup.remove(&outdated_key)?;
.scan_prefix(&key)
.keys()
.filter_map(|r| r.ok())
{
self.backupkeyid_backup.remove(outdated_key)?;
} }
Ok(()) Ok(())
@ -327,13 +328,8 @@ impl KeyBackups {
key.extend_from_slice(&room_id.as_bytes()); key.extend_from_slice(&room_id.as_bytes());
key.push(0xff); key.push(0xff);
for outdated_key in self for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
.backupkeyid_backup self.backupkeyid_backup.remove(&outdated_key)?;
.scan_prefix(&key)
.keys()
.filter_map(|r| r.ok())
{
self.backupkeyid_backup.remove(outdated_key)?;
} }
Ok(()) Ok(())
@ -354,13 +350,8 @@ impl KeyBackups {
key.push(0xff); key.push(0xff);
key.extend_from_slice(&session_id.as_bytes()); key.extend_from_slice(&session_id.as_bytes());
for outdated_key in self for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
.backupkeyid_backup self.backupkeyid_backup.remove(&outdated_key)?;
.scan_prefix(&key)
.keys()
.filter_map(|r| r.ok())
{
self.backupkeyid_backup.remove(outdated_key)?;
} }
Ok(()) Ok(())

View File

@ -1,7 +1,10 @@
use crate::database::globals::Globals;
use image::{imageops::FilterType, GenericImageView}; use image::{imageops::FilterType, GenericImageView};
use super::abstraction::Tree;
use crate::{utils, Error, Result}; use crate::{utils, Error, Result};
use std::mem; use std::{mem, sync::Arc};
use tokio::{fs::File, io::AsyncReadExt, io::AsyncWriteExt};
pub struct FileMeta { pub struct FileMeta {
pub content_disposition: Option<String>, pub content_disposition: Option<String>,
@ -9,16 +12,16 @@ pub struct FileMeta {
pub file: Vec<u8>, pub file: Vec<u8>,
} }
#[derive(Clone)]
pub struct Media { pub struct Media {
pub(super) mediaid_file: sled::Tree, // MediaId = MXC + WidthHeight + ContentDisposition + ContentType pub(super) mediaid_file: Arc<dyn Tree>, // MediaId = MXC + WidthHeight + ContentDisposition + ContentType
} }
impl Media { impl Media {
/// Uploads or replaces a file. /// Uploads a file.
pub fn create( pub async fn create(
&self, &self,
mxc: String, mxc: String,
globals: &Globals,
content_disposition: &Option<&str>, content_disposition: &Option<&str>,
content_type: &Option<&str>, content_type: &Option<&str>,
file: &[u8], file: &[u8],
@ -42,15 +45,19 @@ impl Media {
.unwrap_or_default(), .unwrap_or_default(),
); );
self.mediaid_file.insert(key, file)?; let path = globals.get_media_file(&key);
let mut f = File::create(path).await?;
f.write_all(file).await?;
self.mediaid_file.insert(&key, &[])?;
Ok(()) Ok(())
} }
/// Uploads or replaces a file thumbnail. /// Uploads or replaces a file thumbnail.
pub fn upload_thumbnail( pub async fn upload_thumbnail(
&self, &self,
mxc: String, mxc: String,
globals: &Globals,
content_disposition: &Option<String>, content_disposition: &Option<String>,
content_type: &Option<String>, content_type: &Option<String>,
width: u32, width: u32,
@ -76,21 +83,28 @@ impl Media {
.unwrap_or_default(), .unwrap_or_default(),
); );
self.mediaid_file.insert(key, file)?; let path = globals.get_media_file(&key);
let mut f = File::create(path).await?;
f.write_all(file).await?;
self.mediaid_file.insert(&key, &[])?;
Ok(()) Ok(())
} }
/// Downloads a file. /// Downloads a file.
pub fn get(&self, mxc: &str) -> Result<Option<FileMeta>> { pub async fn get(&self, globals: &Globals, mxc: &str) -> Result<Option<FileMeta>> {
let mut prefix = mxc.as_bytes().to_vec(); let mut prefix = mxc.as_bytes().to_vec();
prefix.push(0xff); prefix.push(0xff);
prefix.extend_from_slice(&0_u32.to_be_bytes()); // Width = 0 if it's not a thumbnail prefix.extend_from_slice(&0_u32.to_be_bytes()); // Width = 0 if it's not a thumbnail
prefix.extend_from_slice(&0_u32.to_be_bytes()); // Height = 0 if it's not a thumbnail prefix.extend_from_slice(&0_u32.to_be_bytes()); // Height = 0 if it's not a thumbnail
prefix.push(0xff); prefix.push(0xff);
if let Some(r) = self.mediaid_file.scan_prefix(&prefix).next() { let mut iter = self.mediaid_file.scan_prefix(prefix);
let (key, file) = r?; if let Some((key, _)) = iter.next() {
let path = globals.get_media_file(&key);
let mut file = Vec::new();
File::open(path).await?.read_to_end(&mut file).await?;
let mut parts = key.rsplit(|&b| b == 0xff); let mut parts = key.rsplit(|&b| b == 0xff);
let content_type = parts let content_type = parts
@ -121,7 +135,7 @@ impl Media {
Ok(Some(FileMeta { Ok(Some(FileMeta {
content_disposition, content_disposition,
content_type, content_type,
file: file.to_vec(), file,
})) }))
} else { } else {
Ok(None) Ok(None)
@ -151,7 +165,13 @@ impl Media {
/// - Server creates the thumbnail and sends it to the user /// - Server creates the thumbnail and sends it to the user
/// ///
/// For width,height <= 96 the server uses another thumbnailing algorithm which crops the image afterwards. /// For width,height <= 96 the server uses another thumbnailing algorithm which crops the image afterwards.
pub fn get_thumbnail(&self, mxc: String, width: u32, height: u32) -> Result<Option<FileMeta>> { pub async fn get_thumbnail(
&self,
mxc: String,
globals: &Globals,
width: u32,
height: u32,
) -> Result<Option<FileMeta>> {
let (width, height, crop) = self let (width, height, crop) = self
.thumbnail_properties(width, height) .thumbnail_properties(width, height)
.unwrap_or((0, 0, false)); // 0, 0 because that's the original file .unwrap_or((0, 0, false)); // 0, 0 because that's the original file
@ -169,9 +189,11 @@ impl Media {
original_prefix.extend_from_slice(&0_u32.to_be_bytes()); // Height = 0 if it's not a thumbnail original_prefix.extend_from_slice(&0_u32.to_be_bytes()); // Height = 0 if it's not a thumbnail
original_prefix.push(0xff); original_prefix.push(0xff);
if let Some(r) = self.mediaid_file.scan_prefix(&thumbnail_prefix).next() { if let Some((key, _)) = self.mediaid_file.scan_prefix(thumbnail_prefix).next() {
// Using saved thumbnail // Using saved thumbnail
let (key, file) = r?; let path = globals.get_media_file(&key);
let mut file = Vec::new();
File::open(path).await?.read_to_end(&mut file).await?;
let mut parts = key.rsplit(|&b| b == 0xff); let mut parts = key.rsplit(|&b| b == 0xff);
let content_type = parts let content_type = parts
@ -202,10 +224,12 @@ impl Media {
content_type, content_type,
file: file.to_vec(), file: file.to_vec(),
})) }))
} else if let Some(r) = self.mediaid_file.scan_prefix(&original_prefix).next() { } else if let Some((key, _)) = self.mediaid_file.scan_prefix(original_prefix).next() {
// Generate a thumbnail // Generate a thumbnail
let path = globals.get_media_file(&key);
let mut file = Vec::new();
File::open(path).await?.read_to_end(&mut file).await?;
let (key, file) = r?;
let mut parts = key.rsplit(|&b| b == 0xff); let mut parts = key.rsplit(|&b| b == 0xff);
let content_type = parts let content_type = parts
@ -302,7 +326,11 @@ impl Media {
widthheight, widthheight,
); );
self.mediaid_file.insert(thumbnail_key, &*thumbnail_bytes)?; let path = globals.get_media_file(&thumbnail_key);
let mut f = File::create(path).await?;
f.write_all(&thumbnail_bytes).await?;
self.mediaid_file.insert(&thumbnail_key, &[])?;
Ok(Some(FileMeta { Ok(Some(FileMeta {
content_disposition, content_disposition,

View File

@ -14,23 +14,17 @@ use ruma::{
push::{Action, PushConditionRoomCtx, PushFormat, Ruleset, Tweak}, push::{Action, PushConditionRoomCtx, PushFormat, Ruleset, Tweak},
uint, UInt, UserId, uint, UInt, UserId,
}; };
use sled::IVec;
use std::{convert::TryFrom, fmt::Debug, mem}; use std::{convert::TryFrom, fmt::Debug, mem, sync::Arc};
use super::abstraction::Tree;
#[derive(Debug, Clone)]
pub struct PushData { pub struct PushData {
/// UserId + pushkey -> Pusher /// UserId + pushkey -> Pusher
pub(super) senderkey_pusher: sled::Tree, pub(super) senderkey_pusher: Arc<dyn Tree>,
} }
impl PushData { impl PushData {
pub fn new(db: &sled::Db) -> Result<Self> {
Ok(Self {
senderkey_pusher: db.open_tree("senderkey_pusher")?,
})
}
pub fn set_pusher(&self, sender: &UserId, pusher: set_pusher::Pusher) -> Result<()> { pub fn set_pusher(&self, sender: &UserId, pusher: set_pusher::Pusher) -> Result<()> {
let mut key = sender.as_bytes().to_vec(); let mut key = sender.as_bytes().to_vec();
key.push(0xff); key.push(0xff);
@ -40,14 +34,14 @@ impl PushData {
if pusher.kind.is_none() { if pusher.kind.is_none() {
return self return self
.senderkey_pusher .senderkey_pusher
.remove(key) .remove(&key)
.map(|_| ()) .map(|_| ())
.map_err(Into::into); .map_err(Into::into);
} }
self.senderkey_pusher.insert( self.senderkey_pusher.insert(
key, &key,
&*serde_json::to_string(&pusher).expect("Pusher is valid JSON string"), &serde_json::to_vec(&pusher).expect("Pusher is valid JSON value"),
)?; )?;
Ok(()) Ok(())
@ -69,23 +63,21 @@ impl PushData {
self.senderkey_pusher self.senderkey_pusher
.scan_prefix(prefix) .scan_prefix(prefix)
.values() .map(|(_, push)| {
.map(|push| {
let push = push.map_err(|_| Error::bad_database("Invalid push bytes in db."))?;
Ok(serde_json::from_slice(&*push) Ok(serde_json::from_slice(&*push)
.map_err(|_| Error::bad_database("Invalid Pusher in db."))?) .map_err(|_| Error::bad_database("Invalid Pusher in db."))?)
}) })
.collect() .collect()
} }
pub fn get_pusher_senderkeys(&self, sender: &UserId) -> impl Iterator<Item = Result<IVec>> { pub fn get_pusher_senderkeys<'a>(
&'a self,
sender: &UserId,
) -> impl Iterator<Item = Box<[u8]>> + 'a {
let mut prefix = sender.as_bytes().to_vec(); let mut prefix = sender.as_bytes().to_vec();
prefix.push(0xff); prefix.push(0xff);
self.senderkey_pusher self.senderkey_pusher.scan_prefix(prefix).map(|(k, _)| k)
.scan_prefix(prefix)
.keys()
.map(|r| Ok(r?))
} }
} }

File diff suppressed because it is too large Load Diff

View File

@ -1,4 +1,4 @@
use crate::{utils, Error, Result}; use crate::{database::abstraction::Tree, utils, Error, Result};
use ruma::{ use ruma::{
events::{ events::{
presence::{PresenceEvent, PresenceEventContent}, presence::{PresenceEvent, PresenceEventContent},
@ -13,17 +13,17 @@ use std::{
collections::{HashMap, HashSet}, collections::{HashMap, HashSet},
convert::{TryFrom, TryInto}, convert::{TryFrom, TryInto},
mem, mem,
sync::Arc,
}; };
#[derive(Clone)]
pub struct RoomEdus { pub struct RoomEdus {
pub(in super::super) readreceiptid_readreceipt: sled::Tree, // ReadReceiptId = RoomId + Count + UserId pub(in super::super) readreceiptid_readreceipt: Arc<dyn Tree>, // ReadReceiptId = RoomId + Count + UserId
pub(in super::super) roomuserid_privateread: sled::Tree, // RoomUserId = Room + User, PrivateRead = Count pub(in super::super) roomuserid_privateread: Arc<dyn Tree>, // RoomUserId = Room + User, PrivateRead = Count
pub(in super::super) roomuserid_lastprivatereadupdate: sled::Tree, // LastPrivateReadUpdate = Count pub(in super::super) roomuserid_lastprivatereadupdate: Arc<dyn Tree>, // LastPrivateReadUpdate = Count
pub(in super::super) typingid_userid: sled::Tree, // TypingId = RoomId + TimeoutTime + Count pub(in super::super) typingid_userid: Arc<dyn Tree>, // TypingId = RoomId + TimeoutTime + Count
pub(in super::super) roomid_lasttypingupdate: sled::Tree, // LastRoomTypingUpdate = Count pub(in super::super) roomid_lasttypingupdate: Arc<dyn Tree>, // LastRoomTypingUpdate = Count
pub(in super::super) presenceid_presence: sled::Tree, // PresenceId = RoomId + Count + UserId pub(in super::super) presenceid_presence: Arc<dyn Tree>, // PresenceId = RoomId + Count + UserId
pub(in super::super) userid_lastpresenceupdate: sled::Tree, // LastPresenceUpdate = Count pub(in super::super) userid_lastpresenceupdate: Arc<dyn Tree>, // LastPresenceUpdate = Count
} }
impl RoomEdus { impl RoomEdus {
@ -38,15 +38,15 @@ impl RoomEdus {
let mut prefix = room_id.as_bytes().to_vec(); let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xff); prefix.push(0xff);
let mut last_possible_key = prefix.clone();
last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes());
// Remove old entry // Remove old entry
if let Some(old) = self if let Some((old, _)) = self
.readreceiptid_readreceipt .readreceiptid_readreceipt
.scan_prefix(&prefix) .iter_from(&last_possible_key, true)
.keys() .take_while(|(key, _)| key.starts_with(&prefix))
.rev() .find(|(key, _)| {
.filter_map(|r| r.ok())
.take_while(|key| key.starts_with(&prefix))
.find(|key| {
key.rsplit(|&b| b == 0xff) key.rsplit(|&b| b == 0xff)
.next() .next()
.expect("rsplit always returns an element") .expect("rsplit always returns an element")
@ -54,7 +54,7 @@ impl RoomEdus {
}) })
{ {
// This is the old room_latest // This is the old room_latest
self.readreceiptid_readreceipt.remove(old)?; self.readreceiptid_readreceipt.remove(&old)?;
} }
let mut room_latest_id = prefix; let mut room_latest_id = prefix;
@ -63,8 +63,8 @@ impl RoomEdus {
room_latest_id.extend_from_slice(&user_id.as_bytes()); room_latest_id.extend_from_slice(&user_id.as_bytes());
self.readreceiptid_readreceipt.insert( self.readreceiptid_readreceipt.insert(
room_latest_id, &room_latest_id,
&*serde_json::to_string(&event).expect("EduEvent::to_string always works"), &serde_json::to_vec(&event).expect("EduEvent::to_string always works"),
)?; )?;
Ok(()) Ok(())
@ -72,13 +72,12 @@ impl RoomEdus {
/// Returns an iterator over the most recent read_receipts in a room that happened after the event with id `since`. /// Returns an iterator over the most recent read_receipts in a room that happened after the event with id `since`.
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
pub fn readreceipts_since( pub fn readreceipts_since<'a>(
&self, &'a self,
room_id: &RoomId, room_id: &RoomId,
since: u64, since: u64,
) -> Result< ) -> impl Iterator<Item = Result<(UserId, u64, Raw<ruma::events::AnySyncEphemeralRoomEvent>)>> + 'a
impl Iterator<Item = Result<(UserId, u64, Raw<ruma::events::AnySyncEphemeralRoomEvent>)>>, {
> {
let mut prefix = room_id.as_bytes().to_vec(); let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xff); prefix.push(0xff);
let prefix2 = prefix.clone(); let prefix2 = prefix.clone();
@ -86,10 +85,8 @@ impl RoomEdus {
let mut first_possible_edu = prefix.clone(); let mut first_possible_edu = prefix.clone();
first_possible_edu.extend_from_slice(&(since + 1).to_be_bytes()); // +1 so we don't send the event at since first_possible_edu.extend_from_slice(&(since + 1).to_be_bytes()); // +1 so we don't send the event at since
Ok(self self.readreceiptid_readreceipt
.readreceiptid_readreceipt .iter_from(&first_possible_edu, false)
.range(&*first_possible_edu..)
.filter_map(|r| r.ok())
.take_while(move |(k, _)| k.starts_with(&prefix2)) .take_while(move |(k, _)| k.starts_with(&prefix2))
.map(move |(k, v)| { .map(move |(k, v)| {
let count = let count =
@ -115,7 +112,7 @@ impl RoomEdus {
serde_json::value::to_raw_value(&json).expect("json is valid raw value"), serde_json::value::to_raw_value(&json).expect("json is valid raw value"),
), ),
)) ))
})) })
} }
/// Sets a private read marker at `count`. /// Sets a private read marker at `count`.
@ -146,11 +143,13 @@ impl RoomEdus {
key.push(0xff); key.push(0xff);
key.extend_from_slice(&user_id.as_bytes()); key.extend_from_slice(&user_id.as_bytes());
self.roomuserid_privateread.get(key)?.map_or(Ok(None), |v| { self.roomuserid_privateread
Ok(Some(utils::u64_from_bytes(&v).map_err(|_| { .get(&key)?
Error::bad_database("Invalid private read marker bytes") .map_or(Ok(None), |v| {
})?)) Ok(Some(utils::u64_from_bytes(&v).map_err(|_| {
}) Error::bad_database("Invalid private read marker bytes")
})?))
})
} }
/// Returns the count of the last typing update in this room. /// Returns the count of the last typing update in this room.
@ -215,11 +214,10 @@ impl RoomEdus {
// Maybe there are multiple ones from calling roomtyping_add multiple times // Maybe there are multiple ones from calling roomtyping_add multiple times
for outdated_edu in self for outdated_edu in self
.typingid_userid .typingid_userid
.scan_prefix(&prefix) .scan_prefix(prefix)
.filter_map(|r| r.ok()) .filter(|(_, v)| &**v == user_id.as_bytes())
.filter(|(_, v)| v == user_id.as_bytes())
{ {
self.typingid_userid.remove(outdated_edu.0)?; self.typingid_userid.remove(&outdated_edu.0)?;
found_outdated = true; found_outdated = true;
} }
@ -247,10 +245,8 @@ impl RoomEdus {
// Find all outdated edus before inserting a new one // Find all outdated edus before inserting a new one
for outdated_edu in self for outdated_edu in self
.typingid_userid .typingid_userid
.scan_prefix(&prefix) .scan_prefix(prefix)
.keys() .map(|(key, _)| {
.map(|key| {
let key = key?;
Ok::<_, Error>(( Ok::<_, Error>((
key.clone(), key.clone(),
utils::u64_from_bytes( utils::u64_from_bytes(
@ -265,7 +261,7 @@ impl RoomEdus {
.take_while(|&(_, timestamp)| timestamp < current_timestamp) .take_while(|&(_, timestamp)| timestamp < current_timestamp)
{ {
// This is an outdated edu (time > timestamp) // This is an outdated edu (time > timestamp)
self.typingid_userid.remove(outdated_edu.0)?; self.typingid_userid.remove(&outdated_edu.0)?;
found_outdated = true; found_outdated = true;
} }
@ -309,10 +305,9 @@ impl RoomEdus {
for user_id in self for user_id in self
.typingid_userid .typingid_userid
.scan_prefix(prefix) .scan_prefix(prefix)
.values() .map(|(_, user_id)| {
.map(|user_id| {
Ok::<_, Error>( Ok::<_, Error>(
UserId::try_from(utils::string_from_bytes(&user_id?).map_err(|_| { UserId::try_from(utils::string_from_bytes(&user_id).map_err(|_| {
Error::bad_database("User ID in typingid_userid is invalid unicode.") Error::bad_database("User ID in typingid_userid is invalid unicode.")
})?) })?)
.map_err(|_| Error::bad_database("User ID in typingid_userid is invalid."))?, .map_err(|_| Error::bad_database("User ID in typingid_userid is invalid."))?,
@ -351,12 +346,12 @@ impl RoomEdus {
presence_id.extend_from_slice(&presence.sender.as_bytes()); presence_id.extend_from_slice(&presence.sender.as_bytes());
self.presenceid_presence.insert( self.presenceid_presence.insert(
presence_id, &presence_id,
&*serde_json::to_string(&presence).expect("PresenceEvent can be serialized"), &serde_json::to_vec(&presence).expect("PresenceEvent can be serialized"),
)?; )?;
self.userid_lastpresenceupdate.insert( self.userid_lastpresenceupdate.insert(
&user_id.as_bytes(), user_id.as_bytes(),
&utils::millis_since_unix_epoch().to_be_bytes(), &utils::millis_since_unix_epoch().to_be_bytes(),
)?; )?;
@ -403,7 +398,7 @@ impl RoomEdus {
presence_id.extend_from_slice(&user_id.as_bytes()); presence_id.extend_from_slice(&user_id.as_bytes());
self.presenceid_presence self.presenceid_presence
.get(presence_id)? .get(&presence_id)?
.map(|value| { .map(|value| {
let mut presence = serde_json::from_slice::<PresenceEvent>(&value) let mut presence = serde_json::from_slice::<PresenceEvent>(&value)
.map_err(|_| Error::bad_database("Invalid presence event in db."))?; .map_err(|_| Error::bad_database("Invalid presence event in db."))?;
@ -438,7 +433,6 @@ impl RoomEdus {
for (user_id_bytes, last_timestamp) in self for (user_id_bytes, last_timestamp) in self
.userid_lastpresenceupdate .userid_lastpresenceupdate
.iter() .iter()
.filter_map(|r| r.ok())
.filter_map(|(k, bytes)| { .filter_map(|(k, bytes)| {
Some(( Some((
k, k,
@ -468,8 +462,8 @@ impl RoomEdus {
presence_id.extend_from_slice(&user_id_bytes); presence_id.extend_from_slice(&user_id_bytes);
self.presenceid_presence.insert( self.presenceid_presence.insert(
presence_id, &presence_id,
&*serde_json::to_string(&PresenceEvent { &serde_json::to_vec(&PresenceEvent {
content: PresenceEventContent { content: PresenceEventContent {
avatar_url: None, avatar_url: None,
currently_active: None, currently_active: None,
@ -515,8 +509,7 @@ impl RoomEdus {
for (key, value) in self for (key, value) in self
.presenceid_presence .presenceid_presence
.range(&*first_possible_edu..) .iter_from(&*first_possible_edu, false)
.filter_map(|r| r.ok())
.take_while(|(key, _)| key.starts_with(&prefix)) .take_while(|(key, _)| key.starts_with(&prefix))
{ {
let user_id = UserId::try_from( let user_id = UserId::try_from(

View File

@ -12,7 +12,10 @@ use crate::{
use federation::transactions::send_transaction_message; use federation::transactions::send_transaction_message;
use log::{error, warn}; use log::{error, warn};
use ring::digest; use ring::digest;
use rocket::futures::stream::{FuturesUnordered, StreamExt}; use rocket::futures::{
channel::mpsc,
stream::{FuturesUnordered, StreamExt},
};
use ruma::{ use ruma::{
api::{ api::{
appservice, appservice,
@ -27,9 +30,10 @@ use ruma::{
receipt::ReceiptType, receipt::ReceiptType,
MilliSecondsSinceUnixEpoch, ServerName, UInt, UserId, MilliSecondsSinceUnixEpoch, ServerName, UInt, UserId,
}; };
use sled::IVec;
use tokio::{select, sync::Semaphore}; use tokio::{select, sync::Semaphore};
use super::abstraction::Tree;
#[derive(Clone, Debug, PartialEq, Eq, Hash)] #[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub enum OutgoingKind { pub enum OutgoingKind {
Appservice(Box<ServerName>), Appservice(Box<ServerName>),
@ -70,13 +74,13 @@ pub enum SendingEventType {
Edu(Vec<u8>), Edu(Vec<u8>),
} }
#[derive(Clone)]
pub struct Sending { pub struct Sending {
/// The state for a given state hash. /// The state for a given state hash.
pub(super) servername_educount: sled::Tree, // EduCount: Count of last EDU sync pub(super) servername_educount: Arc<dyn Tree>, // EduCount: Count of last EDU sync
pub(super) servernamepduids: sled::Tree, // ServernamePduId = (+ / $)SenderKey / ServerName / UserId + PduId pub(super) servernamepduids: Arc<dyn Tree>, // ServernamePduId = (+ / $)SenderKey / ServerName / UserId + PduId
pub(super) servercurrentevents: sled::Tree, // ServerCurrentEvents = (+ / $)ServerName / UserId + PduId / (*)EduEvent pub(super) servercurrentevents: Arc<dyn Tree>, // ServerCurrentEvents = (+ / $)ServerName / UserId + PduId / (*)EduEvent
pub(super) maximum_requests: Arc<Semaphore>, pub(super) maximum_requests: Arc<Semaphore>,
pub sender: mpsc::UnboundedSender<Vec<u8>>,
} }
enum TransactionStatus { enum TransactionStatus {
@ -86,28 +90,23 @@ enum TransactionStatus {
} }
impl Sending { impl Sending {
pub fn start_handler(&self, db: &Database) { pub fn start_handler(&self, db: Arc<Database>, mut receiver: mpsc::UnboundedReceiver<Vec<u8>>) {
let servernamepduids = self.servernamepduids.clone();
let servercurrentevents = self.servercurrentevents.clone();
let db = db.clone();
tokio::spawn(async move { tokio::spawn(async move {
let mut futures = FuturesUnordered::new(); let mut futures = FuturesUnordered::new();
// Retry requests we could not finish yet
let mut subscriber = servernamepduids.watch_prefix(b"");
let mut current_transaction_status = HashMap::<Vec<u8>, TransactionStatus>::new(); let mut current_transaction_status = HashMap::<Vec<u8>, TransactionStatus>::new();
// Retry requests we could not finish yet
let mut initial_transactions = HashMap::<OutgoingKind, Vec<SendingEventType>>::new(); let mut initial_transactions = HashMap::<OutgoingKind, Vec<SendingEventType>>::new();
for (key, outgoing_kind, event) in servercurrentevents for (key, outgoing_kind, event) in
.iter() db.sending
.filter_map(|r| r.ok()) .servercurrentevents
.filter_map(|(key, _)| { .iter()
Self::parse_servercurrentevent(&key) .filter_map(|(key, _)| {
.ok() Self::parse_servercurrentevent(&key)
.map(|(k, e)| (key, k, e)) .ok()
}) .map(|(k, e)| (key, k, e))
})
{ {
let entry = initial_transactions let entry = initial_transactions
.entry(outgoing_kind.clone()) .entry(outgoing_kind.clone())
@ -118,7 +117,7 @@ impl Sending {
"Dropping some current events: {:?} {:?} {:?}", "Dropping some current events: {:?} {:?} {:?}",
key, outgoing_kind, event key, outgoing_kind, event
); );
servercurrentevents.remove(key).unwrap(); db.sending.servercurrentevents.remove(&key).unwrap();
continue; continue;
} }
@ -137,20 +136,16 @@ impl Sending {
match response { match response {
Ok(outgoing_kind) => { Ok(outgoing_kind) => {
let prefix = outgoing_kind.get_prefix(); let prefix = outgoing_kind.get_prefix();
for key in servercurrentevents for (key, _) in db.sending.servercurrentevents
.scan_prefix(&prefix) .scan_prefix(prefix.clone())
.keys()
.filter_map(|r| r.ok())
{ {
servercurrentevents.remove(key).unwrap(); db.sending.servercurrentevents.remove(&key).unwrap();
} }
// Find events that have been added since starting the last request // Find events that have been added since starting the last request
let new_events = servernamepduids let new_events = db.sending.servernamepduids
.scan_prefix(&prefix) .scan_prefix(prefix.clone())
.keys() .map(|(k, _)| {
.filter_map(|r| r.ok())
.map(|k| {
SendingEventType::Pdu(k[prefix.len()..].to_vec()) SendingEventType::Pdu(k[prefix.len()..].to_vec())
}) })
.take(30) .take(30)
@ -166,8 +161,8 @@ impl Sending {
SendingEventType::Pdu(b) | SendingEventType::Pdu(b) |
SendingEventType::Edu(b) => { SendingEventType::Edu(b) => {
current_key.extend_from_slice(&b); current_key.extend_from_slice(&b);
servercurrentevents.insert(&current_key, &[]).unwrap(); db.sending.servercurrentevents.insert(&current_key, &[]).unwrap();
servernamepduids.remove(&current_key).unwrap(); db.sending.servernamepduids.remove(&current_key).unwrap();
} }
} }
} }
@ -195,18 +190,15 @@ impl Sending {
} }
}; };
}, },
Some(event) = &mut subscriber => { Some(key) = receiver.next() => {
// New sled version: if let Ok((outgoing_kind, event)) = Self::parse_servercurrentevent(&key) {
//for (_tree, key, value_opt) in &event { if let Ok(Some(events)) = Self::select_events(
// if value_opt.is_none() { &outgoing_kind,
// continue; vec![(event, key)],
// } &mut current_transaction_status,
&db
if let sled::Event::Insert { key, .. } = event { ) {
if let Ok((outgoing_kind, event)) = Self::parse_servercurrentevent(&key) { futures.push(Self::handle_events(outgoing_kind, events, &db));
if let Some(events) = Self::select_events(&outgoing_kind, vec![(event, key)], &mut current_transaction_status, &servercurrentevents, &servernamepduids, &db) {
futures.push(Self::handle_events(outgoing_kind, events, &db));
}
} }
} }
} }
@ -217,12 +209,10 @@ impl Sending {
fn select_events( fn select_events(
outgoing_kind: &OutgoingKind, outgoing_kind: &OutgoingKind,
new_events: Vec<(SendingEventType, IVec)>, // Events we want to send: event and full key new_events: Vec<(SendingEventType, Vec<u8>)>, // Events we want to send: event and full key
current_transaction_status: &mut HashMap<Vec<u8>, TransactionStatus>, current_transaction_status: &mut HashMap<Vec<u8>, TransactionStatus>,
servercurrentevents: &sled::Tree,
servernamepduids: &sled::Tree,
db: &Database, db: &Database,
) -> Option<Vec<SendingEventType>> { ) -> Result<Option<Vec<SendingEventType>>> {
let mut retry = false; let mut retry = false;
let mut allow = true; let mut allow = true;
@ -252,29 +242,25 @@ impl Sending {
.or_insert(TransactionStatus::Running); .or_insert(TransactionStatus::Running);
if !allow { if !allow {
return None; return Ok(None);
} }
let mut events = Vec::new(); let mut events = Vec::new();
if retry { if retry {
// We retry the previous transaction // We retry the previous transaction
for key in servercurrentevents for (key, _) in db.sending.servercurrentevents.scan_prefix(prefix) {
.scan_prefix(&prefix)
.keys()
.filter_map(|r| r.ok())
{
if let Ok((_, e)) = Self::parse_servercurrentevent(&key) { if let Ok((_, e)) = Self::parse_servercurrentevent(&key) {
events.push(e); events.push(e);
} }
} }
} else { } else {
for (e, full_key) in new_events { for (e, full_key) in new_events {
servercurrentevents.insert(&full_key, &[]).unwrap(); db.sending.servercurrentevents.insert(&full_key, &[])?;
// If it was a PDU we have to unqueue it // If it was a PDU we have to unqueue it
// TODO: don't try to unqueue EDUs // TODO: don't try to unqueue EDUs
servernamepduids.remove(&full_key).unwrap(); db.sending.servernamepduids.remove(&full_key)?;
events.push(e); events.push(e);
} }
@ -284,13 +270,12 @@ impl Sending {
events.extend_from_slice(&select_edus); events.extend_from_slice(&select_edus);
db.sending db.sending
.servername_educount .servername_educount
.insert(server_name.as_bytes(), &last_count.to_be_bytes()) .insert(server_name.as_bytes(), &last_count.to_be_bytes())?;
.unwrap();
} }
} }
} }
Some(events) Ok(Some(events))
} }
pub fn select_edus(db: &Database, server: &ServerName) -> Result<(Vec<SendingEventType>, u64)> { pub fn select_edus(db: &Database, server: &ServerName) -> Result<(Vec<SendingEventType>, u64)> {
@ -307,7 +292,7 @@ impl Sending {
let mut max_edu_count = since; let mut max_edu_count = since;
'outer: for room_id in db.rooms.server_rooms(server) { 'outer: for room_id in db.rooms.server_rooms(server) {
let room_id = room_id?; let room_id = room_id?;
for r in db.rooms.edus.readreceipts_since(&room_id, since)? { for r in db.rooms.edus.readreceipts_since(&room_id, since) {
let (user_id, count, read_receipt) = r?; let (user_id, count, read_receipt) = r?;
if count > max_edu_count { if count > max_edu_count {
@ -372,12 +357,13 @@ impl Sending {
} }
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
pub fn send_push_pdu(&self, pdu_id: &[u8], senderkey: IVec) -> Result<()> { pub fn send_push_pdu(&self, pdu_id: &[u8], senderkey: Box<[u8]>) -> Result<()> {
let mut key = b"$".to_vec(); let mut key = b"$".to_vec();
key.extend_from_slice(&senderkey); key.extend_from_slice(&senderkey);
key.push(0xff); key.push(0xff);
key.extend_from_slice(pdu_id); key.extend_from_slice(pdu_id);
self.servernamepduids.insert(key, b"")?; self.servernamepduids.insert(&key, b"")?;
self.sender.unbounded_send(key).unwrap();
Ok(()) Ok(())
} }
@ -387,7 +373,8 @@ impl Sending {
let mut key = server.as_bytes().to_vec(); let mut key = server.as_bytes().to_vec();
key.push(0xff); key.push(0xff);
key.extend_from_slice(pdu_id); key.extend_from_slice(pdu_id);
self.servernamepduids.insert(key, b"")?; self.servernamepduids.insert(&key, b"")?;
self.sender.unbounded_send(key).unwrap();
Ok(()) Ok(())
} }
@ -398,7 +385,8 @@ impl Sending {
key.extend_from_slice(appservice_id.as_bytes()); key.extend_from_slice(appservice_id.as_bytes());
key.push(0xff); key.push(0xff);
key.extend_from_slice(pdu_id); key.extend_from_slice(pdu_id);
self.servernamepduids.insert(key, b"")?; self.servernamepduids.insert(&key, b"")?;
self.sender.unbounded_send(key).unwrap();
Ok(()) Ok(())
} }
@ -641,7 +629,7 @@ impl Sending {
} }
} }
fn parse_servercurrentevent(key: &IVec) -> Result<(OutgoingKind, SendingEventType)> { fn parse_servercurrentevent(key: &[u8]) -> Result<(OutgoingKind, SendingEventType)> {
// Appservices start with a plus // Appservices start with a plus
Ok::<_, Error>(if key.starts_with(b"+") { Ok::<_, Error>(if key.starts_with(b"+") {
let mut parts = key[1..].splitn(2, |&b| b == 0xff); let mut parts = key[1..].splitn(2, |&b| b == 0xff);

View File

@ -1,10 +1,12 @@
use std::sync::Arc;
use crate::Result; use crate::Result;
use ruma::{DeviceId, UserId}; use ruma::{DeviceId, UserId};
use sled::IVec;
#[derive(Clone)] use super::abstraction::Tree;
pub struct TransactionIds { pub struct TransactionIds {
pub(super) userdevicetxnid_response: sled::Tree, // Response can be empty (/sendToDevice) or the event id (/send) pub(super) userdevicetxnid_response: Arc<dyn Tree>, // Response can be empty (/sendToDevice) or the event id (/send)
} }
impl TransactionIds { impl TransactionIds {
@ -21,7 +23,7 @@ impl TransactionIds {
key.push(0xff); key.push(0xff);
key.extend_from_slice(txn_id.as_bytes()); key.extend_from_slice(txn_id.as_bytes());
self.userdevicetxnid_response.insert(key, data)?; self.userdevicetxnid_response.insert(&key, data)?;
Ok(()) Ok(())
} }
@ -31,7 +33,7 @@ impl TransactionIds {
user_id: &UserId, user_id: &UserId,
device_id: Option<&DeviceId>, device_id: Option<&DeviceId>,
txn_id: &str, txn_id: &str,
) -> Result<Option<IVec>> { ) -> Result<Option<Vec<u8>>> {
let mut key = user_id.as_bytes().to_vec(); let mut key = user_id.as_bytes().to_vec();
key.push(0xff); key.push(0xff);
key.extend_from_slice(device_id.map(|d| d.as_bytes()).unwrap_or_default()); key.extend_from_slice(device_id.map(|d| d.as_bytes()).unwrap_or_default());
@ -39,6 +41,6 @@ impl TransactionIds {
key.extend_from_slice(txn_id.as_bytes()); key.extend_from_slice(txn_id.as_bytes());
// If there's no entry, this is a new transaction // If there's no entry, this is a new transaction
Ok(self.userdevicetxnid_response.get(key)?) Ok(self.userdevicetxnid_response.get(&key)?)
} }
} }

View File

@ -1,3 +1,5 @@
use std::sync::Arc;
use crate::{client_server::SESSION_ID_LENGTH, utils, Error, Result}; use crate::{client_server::SESSION_ID_LENGTH, utils, Error, Result};
use ruma::{ use ruma::{
api::client::{ api::client::{
@ -8,10 +10,11 @@ use ruma::{
DeviceId, UserId, DeviceId, UserId,
}; };
#[derive(Clone)] use super::abstraction::Tree;
pub struct Uiaa { pub struct Uiaa {
pub(super) userdevicesessionid_uiaainfo: sled::Tree, // User-interactive authentication pub(super) userdevicesessionid_uiaainfo: Arc<dyn Tree>, // User-interactive authentication
pub(super) userdevicesessionid_uiaarequest: sled::Tree, // UiaaRequest = canonical json value pub(super) userdevicesessionid_uiaarequest: Arc<dyn Tree>, // UiaaRequest = canonical json value
} }
impl Uiaa { impl Uiaa {
@ -185,7 +188,7 @@ impl Uiaa {
self.userdevicesessionid_uiaarequest.insert( self.userdevicesessionid_uiaarequest.insert(
&userdevicesessionid, &userdevicesessionid,
&*serde_json::to_string(request).expect("json value to string always works"), &serde_json::to_vec(request).expect("json value to vec always works"),
)?; )?;
Ok(()) Ok(())
@ -233,7 +236,7 @@ impl Uiaa {
if let Some(uiaainfo) = uiaainfo { if let Some(uiaainfo) = uiaainfo {
self.userdevicesessionid_uiaainfo.insert( self.userdevicesessionid_uiaainfo.insert(
&userdevicesessionid, &userdevicesessionid,
&*serde_json::to_string(&uiaainfo).expect("UiaaInfo::to_string always works"), &serde_json::to_vec(&uiaainfo).expect("UiaaInfo::to_vec always works"),
)?; )?;
} else { } else {
self.userdevicesessionid_uiaainfo self.userdevicesessionid_uiaainfo

View File

@ -7,40 +7,41 @@ use ruma::{
serde::Raw, serde::Raw,
DeviceId, DeviceKeyAlgorithm, DeviceKeyId, MilliSecondsSinceUnixEpoch, UInt, UserId, DeviceId, DeviceKeyAlgorithm, DeviceKeyId, MilliSecondsSinceUnixEpoch, UInt, UserId,
}; };
use std::{collections::BTreeMap, convert::TryFrom, mem}; use std::{collections::BTreeMap, convert::TryFrom, mem, sync::Arc};
use super::abstraction::Tree;
#[derive(Clone)]
pub struct Users { pub struct Users {
pub(super) userid_password: sled::Tree, pub(super) userid_password: Arc<dyn Tree>,
pub(super) userid_displayname: sled::Tree, pub(super) userid_displayname: Arc<dyn Tree>,
pub(super) userid_avatarurl: sled::Tree, pub(super) userid_avatarurl: Arc<dyn Tree>,
pub(super) userdeviceid_token: sled::Tree, pub(super) userdeviceid_token: Arc<dyn Tree>,
pub(super) userdeviceid_metadata: sled::Tree, // This is also used to check if a device exists pub(super) userdeviceid_metadata: Arc<dyn Tree>, // This is also used to check if a device exists
pub(super) userid_devicelistversion: sled::Tree, // DevicelistVersion = u64 pub(super) userid_devicelistversion: Arc<dyn Tree>, // DevicelistVersion = u64
pub(super) token_userdeviceid: sled::Tree, pub(super) token_userdeviceid: Arc<dyn Tree>,
pub(super) onetimekeyid_onetimekeys: sled::Tree, // OneTimeKeyId = UserId + DeviceKeyId pub(super) onetimekeyid_onetimekeys: Arc<dyn Tree>, // OneTimeKeyId = UserId + DeviceKeyId
pub(super) userid_lastonetimekeyupdate: sled::Tree, // LastOneTimeKeyUpdate = Count pub(super) userid_lastonetimekeyupdate: Arc<dyn Tree>, // LastOneTimeKeyUpdate = Count
pub(super) keychangeid_userid: sled::Tree, // KeyChangeId = UserId/RoomId + Count pub(super) keychangeid_userid: Arc<dyn Tree>, // KeyChangeId = UserId/RoomId + Count
pub(super) keyid_key: sled::Tree, // KeyId = UserId + KeyId (depends on key type) pub(super) keyid_key: Arc<dyn Tree>, // KeyId = UserId + KeyId (depends on key type)
pub(super) userid_masterkeyid: sled::Tree, pub(super) userid_masterkeyid: Arc<dyn Tree>,
pub(super) userid_selfsigningkeyid: sled::Tree, pub(super) userid_selfsigningkeyid: Arc<dyn Tree>,
pub(super) userid_usersigningkeyid: sled::Tree, pub(super) userid_usersigningkeyid: Arc<dyn Tree>,
pub(super) todeviceid_events: sled::Tree, // ToDeviceId = UserId + DeviceId + Count pub(super) todeviceid_events: Arc<dyn Tree>, // ToDeviceId = UserId + DeviceId + Count
} }
impl Users { impl Users {
/// Check if a user has an account on this homeserver. /// Check if a user has an account on this homeserver.
pub fn exists(&self, user_id: &UserId) -> Result<bool> { pub fn exists(&self, user_id: &UserId) -> Result<bool> {
Ok(self.userid_password.contains_key(user_id.to_string())?) Ok(self.userid_password.get(user_id.as_bytes())?.is_some())
} }
/// Check if account is deactivated /// Check if account is deactivated
pub fn is_deactivated(&self, user_id: &UserId) -> Result<bool> { pub fn is_deactivated(&self, user_id: &UserId) -> Result<bool> {
Ok(self Ok(self
.userid_password .userid_password
.get(user_id.to_string())? .get(user_id.as_bytes())?
.ok_or(Error::BadRequest( .ok_or(Error::BadRequest(
ErrorKind::InvalidParam, ErrorKind::InvalidParam,
"User does not exist.", "User does not exist.",
@ -55,14 +56,14 @@ impl Users {
} }
/// Returns the number of users registered on this server. /// Returns the number of users registered on this server.
pub fn count(&self) -> usize { pub fn count(&self) -> Result<usize> {
self.userid_password.iter().count() Ok(self.userid_password.iter().count())
} }
/// Find out which user an access token belongs to. /// Find out which user an access token belongs to.
pub fn find_from_token(&self, token: &str) -> Result<Option<(UserId, String)>> { pub fn find_from_token(&self, token: &str) -> Result<Option<(UserId, String)>> {
self.token_userdeviceid self.token_userdeviceid
.get(token)? .get(token.as_bytes())?
.map_or(Ok(None), |bytes| { .map_or(Ok(None), |bytes| {
let mut parts = bytes.split(|&b| b == 0xff); let mut parts = bytes.split(|&b| b == 0xff);
let user_bytes = parts.next().ok_or_else(|| { let user_bytes = parts.next().ok_or_else(|| {
@ -87,10 +88,10 @@ impl Users {
} }
/// Returns an iterator over all users on this homeserver. /// Returns an iterator over all users on this homeserver.
pub fn iter(&self) -> impl Iterator<Item = Result<UserId>> { pub fn iter<'a>(&'a self) -> impl Iterator<Item = Result<UserId>> + 'a {
self.userid_password.iter().keys().map(|bytes| { self.userid_password.iter().map(|(bytes, _)| {
Ok( Ok(
UserId::try_from(utils::string_from_bytes(&bytes?).map_err(|_| { UserId::try_from(utils::string_from_bytes(&bytes).map_err(|_| {
Error::bad_database("User ID in userid_password is invalid unicode.") Error::bad_database("User ID in userid_password is invalid unicode.")
})?) })?)
.map_err(|_| Error::bad_database("User ID in userid_password is invalid."))?, .map_err(|_| Error::bad_database("User ID in userid_password is invalid."))?,
@ -101,7 +102,7 @@ impl Users {
/// Returns the password hash for the given user. /// Returns the password hash for the given user.
pub fn password_hash(&self, user_id: &UserId) -> Result<Option<String>> { pub fn password_hash(&self, user_id: &UserId) -> Result<Option<String>> {
self.userid_password self.userid_password
.get(user_id.to_string())? .get(user_id.as_bytes())?
.map_or(Ok(None), |bytes| { .map_or(Ok(None), |bytes| {
Ok(Some(utils::string_from_bytes(&bytes).map_err(|_| { Ok(Some(utils::string_from_bytes(&bytes).map_err(|_| {
Error::bad_database("Password hash in db is not valid string.") Error::bad_database("Password hash in db is not valid string.")
@ -113,7 +114,8 @@ impl Users {
pub fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { pub fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> {
if let Some(password) = password { if let Some(password) = password {
if let Ok(hash) = utils::calculate_hash(&password) { if let Ok(hash) = utils::calculate_hash(&password) {
self.userid_password.insert(user_id.to_string(), &*hash)?; self.userid_password
.insert(user_id.as_bytes(), hash.as_bytes())?;
Ok(()) Ok(())
} else { } else {
Err(Error::BadRequest( Err(Error::BadRequest(
@ -122,7 +124,7 @@ impl Users {
)) ))
} }
} else { } else {
self.userid_password.insert(user_id.to_string(), "")?; self.userid_password.insert(user_id.as_bytes(), b"")?;
Ok(()) Ok(())
} }
} }
@ -130,7 +132,7 @@ impl Users {
/// Returns the displayname of a user on this homeserver. /// Returns the displayname of a user on this homeserver.
pub fn displayname(&self, user_id: &UserId) -> Result<Option<String>> { pub fn displayname(&self, user_id: &UserId) -> Result<Option<String>> {
self.userid_displayname self.userid_displayname
.get(user_id.to_string())? .get(user_id.as_bytes())?
.map_or(Ok(None), |bytes| { .map_or(Ok(None), |bytes| {
Ok(Some(utils::string_from_bytes(&bytes).map_err(|_| { Ok(Some(utils::string_from_bytes(&bytes).map_err(|_| {
Error::bad_database("Displayname in db is invalid.") Error::bad_database("Displayname in db is invalid.")
@ -142,9 +144,9 @@ impl Users {
pub fn set_displayname(&self, user_id: &UserId, displayname: Option<String>) -> Result<()> { pub fn set_displayname(&self, user_id: &UserId, displayname: Option<String>) -> Result<()> {
if let Some(displayname) = displayname { if let Some(displayname) = displayname {
self.userid_displayname self.userid_displayname
.insert(user_id.to_string(), &*displayname)?; .insert(user_id.as_bytes(), displayname.as_bytes())?;
} else { } else {
self.userid_displayname.remove(user_id.to_string())?; self.userid_displayname.remove(user_id.as_bytes())?;
} }
Ok(()) Ok(())
@ -153,7 +155,7 @@ impl Users {
/// Get a the avatar_url of a user. /// Get a the avatar_url of a user.
pub fn avatar_url(&self, user_id: &UserId) -> Result<Option<MxcUri>> { pub fn avatar_url(&self, user_id: &UserId) -> Result<Option<MxcUri>> {
self.userid_avatarurl self.userid_avatarurl
.get(user_id.to_string())? .get(user_id.as_bytes())?
.map(|bytes| { .map(|bytes| {
let s = utils::string_from_bytes(&bytes) let s = utils::string_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Avatar URL in db is invalid."))?; .map_err(|_| Error::bad_database("Avatar URL in db is invalid."))?;
@ -166,9 +168,9 @@ impl Users {
pub fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option<MxcUri>) -> Result<()> { pub fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option<MxcUri>) -> Result<()> {
if let Some(avatar_url) = avatar_url { if let Some(avatar_url) = avatar_url {
self.userid_avatarurl self.userid_avatarurl
.insert(user_id.to_string(), avatar_url.to_string().as_str())?; .insert(user_id.as_bytes(), avatar_url.to_string().as_bytes())?;
} else { } else {
self.userid_avatarurl.remove(user_id.to_string())?; self.userid_avatarurl.remove(user_id.as_bytes())?;
} }
Ok(()) Ok(())
@ -190,19 +192,17 @@ impl Users {
userdeviceid.extend_from_slice(device_id.as_bytes()); userdeviceid.extend_from_slice(device_id.as_bytes());
self.userid_devicelistversion self.userid_devicelistversion
.update_and_fetch(&user_id.as_bytes(), utils::increment)? .increment(user_id.as_bytes())?;
.expect("utils::increment will always put in a value");
self.userdeviceid_metadata.insert( self.userdeviceid_metadata.insert(
userdeviceid, &userdeviceid,
serde_json::to_string(&Device { &serde_json::to_vec(&Device {
device_id: device_id.into(), device_id: device_id.into(),
display_name: initial_device_display_name, display_name: initial_device_display_name,
last_seen_ip: None, // TODO last_seen_ip: None, // TODO
last_seen_ts: Some(MilliSecondsSinceUnixEpoch::now()), last_seen_ts: Some(MilliSecondsSinceUnixEpoch::now()),
}) })
.expect("Device::to_string never fails.") .expect("Device::to_string never fails."),
.as_bytes(),
)?; )?;
self.set_token(user_id, &device_id, token)?; self.set_token(user_id, &device_id, token)?;
@ -217,7 +217,8 @@ impl Users {
userdeviceid.extend_from_slice(device_id.as_bytes()); userdeviceid.extend_from_slice(device_id.as_bytes());
// Remove tokens // Remove tokens
if let Some(old_token) = self.userdeviceid_token.remove(&userdeviceid)? { if let Some(old_token) = self.userdeviceid_token.get(&userdeviceid)? {
self.userdeviceid_token.remove(&userdeviceid)?;
self.token_userdeviceid.remove(&old_token)?; self.token_userdeviceid.remove(&old_token)?;
} }
@ -225,15 +226,14 @@ impl Users {
let mut prefix = userdeviceid.clone(); let mut prefix = userdeviceid.clone();
prefix.push(0xff); prefix.push(0xff);
for key in self.todeviceid_events.scan_prefix(&prefix).keys() { for (key, _) in self.todeviceid_events.scan_prefix(prefix) {
self.todeviceid_events.remove(key?)?; self.todeviceid_events.remove(&key)?;
} }
// TODO: Remove onetimekeys // TODO: Remove onetimekeys
self.userid_devicelistversion self.userid_devicelistversion
.update_and_fetch(&user_id.as_bytes(), utils::increment)? .increment(user_id.as_bytes())?;
.expect("utils::increment will always put in a value");
self.userdeviceid_metadata.remove(&userdeviceid)?; self.userdeviceid_metadata.remove(&userdeviceid)?;
@ -241,16 +241,18 @@ impl Users {
} }
/// Returns an iterator over all device ids of this user. /// Returns an iterator over all device ids of this user.
pub fn all_device_ids(&self, user_id: &UserId) -> impl Iterator<Item = Result<Box<DeviceId>>> { pub fn all_device_ids<'a>(
&'a self,
user_id: &UserId,
) -> impl Iterator<Item = Result<Box<DeviceId>>> + 'a {
let mut prefix = user_id.as_bytes().to_vec(); let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xff); prefix.push(0xff);
// All devices have metadata // All devices have metadata
self.userdeviceid_metadata self.userdeviceid_metadata
.scan_prefix(prefix) .scan_prefix(prefix)
.keys() .map(|(bytes, _)| {
.map(|bytes| {
Ok(utils::string_from_bytes( Ok(utils::string_from_bytes(
&*bytes? &bytes
.rsplit(|&b| b == 0xff) .rsplit(|&b| b == 0xff)
.next() .next()
.ok_or_else(|| Error::bad_database("UserDevice ID in db is invalid."))?, .ok_or_else(|| Error::bad_database("UserDevice ID in db is invalid."))?,
@ -271,13 +273,15 @@ impl Users {
// Remove old token // Remove old token
if let Some(old_token) = self.userdeviceid_token.get(&userdeviceid)? { if let Some(old_token) = self.userdeviceid_token.get(&userdeviceid)? {
self.token_userdeviceid.remove(old_token)?; self.token_userdeviceid.remove(&old_token)?;
// It will be removed from userdeviceid_token by the insert later // It will be removed from userdeviceid_token by the insert later
} }
// Assign token to user device combination // Assign token to user device combination
self.userdeviceid_token.insert(&userdeviceid, &*token)?; self.userdeviceid_token
self.token_userdeviceid.insert(token, userdeviceid)?; .insert(&userdeviceid, token.as_bytes())?;
self.token_userdeviceid
.insert(token.as_bytes(), &userdeviceid)?;
Ok(()) Ok(())
} }
@ -309,8 +313,7 @@ impl Users {
self.onetimekeyid_onetimekeys.insert( self.onetimekeyid_onetimekeys.insert(
&key, &key,
&*serde_json::to_string(&one_time_key_value) &serde_json::to_vec(&one_time_key_value).expect("OneTimeKey::to_vec always works"),
.expect("OneTimeKey::to_string always works"),
)?; )?;
self.userid_lastonetimekeyupdate self.userid_lastonetimekeyupdate
@ -350,10 +353,9 @@ impl Users {
.insert(&user_id.as_bytes(), &globals.next_count()?.to_be_bytes())?; .insert(&user_id.as_bytes(), &globals.next_count()?.to_be_bytes())?;
self.onetimekeyid_onetimekeys self.onetimekeyid_onetimekeys
.scan_prefix(&prefix) .scan_prefix(prefix)
.next() .next()
.map(|r| { .map(|(key, value)| {
let (key, value) = r?;
self.onetimekeyid_onetimekeys.remove(&key)?; self.onetimekeyid_onetimekeys.remove(&key)?;
Ok(( Ok((
@ -383,21 +385,20 @@ impl Users {
let mut counts = BTreeMap::new(); let mut counts = BTreeMap::new();
for algorithm in self for algorithm in
.onetimekeyid_onetimekeys self.onetimekeyid_onetimekeys
.scan_prefix(&userdeviceid) .scan_prefix(userdeviceid)
.keys() .map(|(bytes, _)| {
.map(|bytes| { Ok::<_, Error>(
Ok::<_, Error>( serde_json::from_slice::<DeviceKeyId>(
serde_json::from_slice::<DeviceKeyId>( &*bytes.rsplit(|&b| b == 0xff).next().ok_or_else(|| {
&*bytes?.rsplit(|&b| b == 0xff).next().ok_or_else(|| { Error::bad_database("OneTimeKey ID in db is invalid.")
Error::bad_database("OneTimeKey ID in db is invalid.") })?,
})?, )
.map_err(|_| Error::bad_database("DeviceKeyId in db is invalid."))?
.algorithm(),
) )
.map_err(|_| Error::bad_database("DeviceKeyId in db is invalid."))? })
.algorithm(),
)
})
{ {
*counts.entry(algorithm?).or_default() += UInt::from(1_u32); *counts.entry(algorithm?).or_default() += UInt::from(1_u32);
} }
@ -419,7 +420,7 @@ impl Users {
self.keyid_key.insert( self.keyid_key.insert(
&userdeviceid, &userdeviceid,
&*serde_json::to_string(&device_keys).expect("DeviceKeys::to_string always works"), &serde_json::to_vec(&device_keys).expect("DeviceKeys::to_vec always works"),
)?; )?;
self.mark_device_key_update(user_id, rooms, globals)?; self.mark_device_key_update(user_id, rooms, globals)?;
@ -460,11 +461,11 @@ impl Users {
self.keyid_key.insert( self.keyid_key.insert(
&master_key_key, &master_key_key,
&*serde_json::to_string(&master_key).expect("CrossSigningKey::to_string always works"), &serde_json::to_vec(&master_key).expect("CrossSigningKey::to_vec always works"),
)?; )?;
self.userid_masterkeyid self.userid_masterkeyid
.insert(&*user_id.to_string(), master_key_key)?; .insert(user_id.as_bytes(), &master_key_key)?;
// Self-signing key // Self-signing key
if let Some(self_signing_key) = self_signing_key { if let Some(self_signing_key) = self_signing_key {
@ -486,12 +487,12 @@ impl Users {
self.keyid_key.insert( self.keyid_key.insert(
&self_signing_key_key, &self_signing_key_key,
&*serde_json::to_string(&self_signing_key) &serde_json::to_vec(&self_signing_key)
.expect("CrossSigningKey::to_string always works"), .expect("CrossSigningKey::to_vec always works"),
)?; )?;
self.userid_selfsigningkeyid self.userid_selfsigningkeyid
.insert(&*user_id.to_string(), self_signing_key_key)?; .insert(user_id.as_bytes(), &self_signing_key_key)?;
} }
// User-signing key // User-signing key
@ -514,12 +515,12 @@ impl Users {
self.keyid_key.insert( self.keyid_key.insert(
&user_signing_key_key, &user_signing_key_key,
&*serde_json::to_string(&user_signing_key) &serde_json::to_vec(&user_signing_key)
.expect("CrossSigningKey::to_string always works"), .expect("CrossSigningKey::to_vec always works"),
)?; )?;
self.userid_usersigningkeyid self.userid_usersigningkeyid
.insert(&*user_id.to_string(), user_signing_key_key)?; .insert(user_id.as_bytes(), &user_signing_key_key)?;
} }
self.mark_device_key_update(user_id, rooms, globals)?; self.mark_device_key_update(user_id, rooms, globals)?;
@ -561,8 +562,7 @@ impl Users {
self.keyid_key.insert( self.keyid_key.insert(
&key, &key,
&*serde_json::to_string(&cross_signing_key) &serde_json::to_vec(&cross_signing_key).expect("CrossSigningKey::to_vec always works"),
.expect("CrossSigningKey::to_string always works"),
)?; )?;
// TODO: Should we notify about this change? // TODO: Should we notify about this change?
@ -572,24 +572,20 @@ impl Users {
} }
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
pub fn keys_changed( pub fn keys_changed<'a>(
&self, &'a self,
user_or_room_id: &str, user_or_room_id: &str,
from: u64, from: u64,
to: Option<u64>, to: Option<u64>,
) -> impl Iterator<Item = Result<UserId>> { ) -> impl Iterator<Item = Result<UserId>> + 'a {
let mut prefix = user_or_room_id.as_bytes().to_vec(); let mut prefix = user_or_room_id.as_bytes().to_vec();
prefix.push(0xff); prefix.push(0xff);
let mut start = prefix.clone(); let mut start = prefix.clone();
start.extend_from_slice(&(from + 1).to_be_bytes()); start.extend_from_slice(&(from + 1).to_be_bytes());
let mut end = prefix.clone();
end.extend_from_slice(&to.unwrap_or(u64::MAX).to_be_bytes());
self.keychangeid_userid self.keychangeid_userid
.range(start..end) .iter_from(&start, false)
.filter_map(|r| r.ok())
.take_while(move |(k, _)| k.starts_with(&prefix)) .take_while(move |(k, _)| k.starts_with(&prefix))
.map(|(_, bytes)| { .map(|(_, bytes)| {
Ok( Ok(
@ -625,13 +621,13 @@ impl Users {
key.push(0xff); key.push(0xff);
key.extend_from_slice(&count); key.extend_from_slice(&count);
self.keychangeid_userid.insert(key, &*user_id.to_string())?; self.keychangeid_userid.insert(&key, user_id.as_bytes())?;
} }
let mut key = user_id.as_bytes().to_vec(); let mut key = user_id.as_bytes().to_vec();
key.push(0xff); key.push(0xff);
key.extend_from_slice(&count); key.extend_from_slice(&count);
self.keychangeid_userid.insert(key, &*user_id.to_string())?; self.keychangeid_userid.insert(&key, user_id.as_bytes())?;
Ok(()) Ok(())
} }
@ -645,7 +641,7 @@ impl Users {
key.push(0xff); key.push(0xff);
key.extend_from_slice(device_id.as_bytes()); key.extend_from_slice(device_id.as_bytes());
self.keyid_key.get(key)?.map_or(Ok(None), |bytes| { self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| {
Ok(Some(serde_json::from_slice(&bytes).map_err(|_| { Ok(Some(serde_json::from_slice(&bytes).map_err(|_| {
Error::bad_database("DeviceKeys in db are invalid.") Error::bad_database("DeviceKeys in db are invalid.")
})?)) })?))
@ -658,9 +654,9 @@ impl Users {
allowed_signatures: F, allowed_signatures: F,
) -> Result<Option<CrossSigningKey>> { ) -> Result<Option<CrossSigningKey>> {
self.userid_masterkeyid self.userid_masterkeyid
.get(user_id.to_string())? .get(user_id.as_bytes())?
.map_or(Ok(None), |key| { .map_or(Ok(None), |key| {
self.keyid_key.get(key)?.map_or(Ok(None), |bytes| { self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| {
let mut cross_signing_key = serde_json::from_slice::<CrossSigningKey>(&bytes) let mut cross_signing_key = serde_json::from_slice::<CrossSigningKey>(&bytes)
.map_err(|_| { .map_err(|_| {
Error::bad_database("CrossSigningKey in db is invalid.") Error::bad_database("CrossSigningKey in db is invalid.")
@ -685,9 +681,9 @@ impl Users {
allowed_signatures: F, allowed_signatures: F,
) -> Result<Option<CrossSigningKey>> { ) -> Result<Option<CrossSigningKey>> {
self.userid_selfsigningkeyid self.userid_selfsigningkeyid
.get(user_id.to_string())? .get(user_id.as_bytes())?
.map_or(Ok(None), |key| { .map_or(Ok(None), |key| {
self.keyid_key.get(key)?.map_or(Ok(None), |bytes| { self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| {
let mut cross_signing_key = serde_json::from_slice::<CrossSigningKey>(&bytes) let mut cross_signing_key = serde_json::from_slice::<CrossSigningKey>(&bytes)
.map_err(|_| { .map_err(|_| {
Error::bad_database("CrossSigningKey in db is invalid.") Error::bad_database("CrossSigningKey in db is invalid.")
@ -708,9 +704,9 @@ impl Users {
pub fn get_user_signing_key(&self, user_id: &UserId) -> Result<Option<CrossSigningKey>> { pub fn get_user_signing_key(&self, user_id: &UserId) -> Result<Option<CrossSigningKey>> {
self.userid_usersigningkeyid self.userid_usersigningkeyid
.get(user_id.to_string())? .get(user_id.as_bytes())?
.map_or(Ok(None), |key| { .map_or(Ok(None), |key| {
self.keyid_key.get(key)?.map_or(Ok(None), |bytes| { self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| {
Ok(Some(serde_json::from_slice(&bytes).map_err(|_| { Ok(Some(serde_json::from_slice(&bytes).map_err(|_| {
Error::bad_database("CrossSigningKey in db is invalid.") Error::bad_database("CrossSigningKey in db is invalid.")
})?)) })?))
@ -740,7 +736,7 @@ impl Users {
self.todeviceid_events.insert( self.todeviceid_events.insert(
&key, &key,
&*serde_json::to_string(&json).expect("Map::to_string always works"), &serde_json::to_vec(&json).expect("Map::to_vec always works"),
)?; )?;
Ok(()) Ok(())
@ -759,9 +755,9 @@ impl Users {
prefix.extend_from_slice(device_id.as_bytes()); prefix.extend_from_slice(device_id.as_bytes());
prefix.push(0xff); prefix.push(0xff);
for value in self.todeviceid_events.scan_prefix(&prefix).values() { for (_, value) in self.todeviceid_events.scan_prefix(prefix) {
events.push( events.push(
serde_json::from_slice(&*value?) serde_json::from_slice(&value)
.map_err(|_| Error::bad_database("Event in todeviceid_events is invalid."))?, .map_err(|_| Error::bad_database("Event in todeviceid_events is invalid."))?,
); );
} }
@ -786,10 +782,9 @@ impl Users {
for (key, _) in self for (key, _) in self
.todeviceid_events .todeviceid_events
.range(&*prefix..=&*last) .iter_from(&last, true)
.keys() .take_while(move |(k, _)| k.starts_with(&prefix))
.map(|key| { .map(|(key, _)| {
let key = key?;
Ok::<_, Error>(( Ok::<_, Error>((
key.clone(), key.clone(),
utils::u64_from_bytes(&key[key.len() - mem::size_of::<u64>()..key.len()]) utils::u64_from_bytes(&key[key.len() - mem::size_of::<u64>()..key.len()])
@ -799,7 +794,7 @@ impl Users {
.filter_map(|r| r.ok()) .filter_map(|r| r.ok())
.take_while(|&(_, count)| count <= until) .take_while(|&(_, count)| count <= until)
{ {
self.todeviceid_events.remove(key)?; self.todeviceid_events.remove(&key)?;
} }
Ok(()) Ok(())
@ -819,14 +814,11 @@ impl Users {
assert!(self.userdeviceid_metadata.get(&userdeviceid)?.is_some()); assert!(self.userdeviceid_metadata.get(&userdeviceid)?.is_some());
self.userid_devicelistversion self.userid_devicelistversion
.update_and_fetch(&user_id.as_bytes(), utils::increment)? .increment(user_id.as_bytes())?;
.expect("utils::increment will always put in a value");
self.userdeviceid_metadata.insert( self.userdeviceid_metadata.insert(
userdeviceid, &userdeviceid,
serde_json::to_string(device) &serde_json::to_vec(device).expect("Device::to_string always works"),
.expect("Device::to_string always works")
.as_bytes(),
)?; )?;
Ok(()) Ok(())
@ -861,15 +853,17 @@ impl Users {
}) })
} }
pub fn all_devices_metadata(&self, user_id: &UserId) -> impl Iterator<Item = Result<Device>> { pub fn all_devices_metadata<'a>(
&'a self,
user_id: &UserId,
) -> impl Iterator<Item = Result<Device>> + 'a {
let mut key = user_id.as_bytes().to_vec(); let mut key = user_id.as_bytes().to_vec();
key.push(0xff); key.push(0xff);
self.userdeviceid_metadata self.userdeviceid_metadata
.scan_prefix(key) .scan_prefix(key)
.values() .map(|(_, bytes)| {
.map(|bytes| { Ok(serde_json::from_slice::<Device>(&bytes).map_err(|_| {
Ok(serde_json::from_slice::<Device>(&bytes?).map_err(|_| {
Error::bad_database("Device in userdeviceid_metadata is invalid.") Error::bad_database("Device in userdeviceid_metadata is invalid.")
})?) })?)
}) })
@ -885,7 +879,7 @@ impl Users {
// Set the password to "" to indicate a deactivated account. Hashes will never result in an // Set the password to "" to indicate a deactivated account. Hashes will never result in an
// empty string, so the user will not be able to log in again. Systems like changing the // empty string, so the user will not be able to log in again. Systems like changing the
// password without logging in should check if the account is deactivated. // password without logging in should check if the account is deactivated.
self.userid_password.insert(user_id.to_string(), "")?; self.userid_password.insert(user_id.as_bytes(), &[])?;
// TODO: Unhook 3PID // TODO: Unhook 3PID
Ok(()) Ok(())

View File

@ -23,11 +23,18 @@ pub type Result<T> = std::result::Result<T, Error>;
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub enum Error { pub enum Error {
#[error("There was a problem with the connection to the database.")] #[cfg(feature = "sled")]
#[error("There was a problem with the connection to the sled database.")]
SledError { SledError {
#[from] #[from]
source: sled::Error, source: sled::Error,
}, },
#[cfg(feature = "rocksdb")]
#[error("There was a problem with the connection to the rocksdb database: {source}")]
RocksDbError {
#[from]
source: rocksdb::Error,
},
#[error("Could not generate an image.")] #[error("Could not generate an image.")]
ImageError { ImageError {
#[from] #[from]
@ -40,6 +47,11 @@ pub enum Error {
}, },
#[error("{0}")] #[error("{0}")]
FederationError(Box<ServerName>, RumaError), FederationError(Box<ServerName>, RumaError),
#[error("Could not do this io: {source}")]
IoError {
#[from]
source: std::io::Error,
},
#[error("{0}")] #[error("{0}")]
BadServerResponse(&'static str), BadServerResponse(&'static str),
#[error("{0}")] #[error("{0}")]

View File

@ -12,6 +12,8 @@ mod pdu;
mod ruma_wrapper; mod ruma_wrapper;
mod utils; mod utils;
use std::sync::Arc;
use database::Config; use database::Config;
pub use database::Database; pub use database::Database;
pub use error::{Error, Result}; pub use error::{Error, Result};
@ -31,7 +33,7 @@ use rocket::{
use tracing::span; use tracing::span;
use tracing_subscriber::{prelude::*, Registry}; use tracing_subscriber::{prelude::*, Registry};
fn setup_rocket(config: Figment, data: Database) -> rocket::Rocket<rocket::Build> { fn setup_rocket(config: Figment, data: Arc<Database>) -> rocket::Rocket<rocket::Build> {
rocket::custom(config) rocket::custom(config)
.manage(data) .manage(data)
.mount( .mount(
@ -197,8 +199,6 @@ async fn main() {
.await .await
.expect("config is valid"); .expect("config is valid");
db.sending.start_handler(&db);
if config.allow_jaeger { if config.allow_jaeger {
let (tracer, _uninstall) = opentelemetry_jaeger::new_pipeline() let (tracer, _uninstall) = opentelemetry_jaeger::new_pipeline()
.with_service_name("conduit") .with_service_name("conduit")

View File

@ -2,13 +2,14 @@ use crate::Error;
use ruma::{ use ruma::{
api::OutgoingResponse, api::OutgoingResponse,
identifiers::{DeviceId, UserId}, identifiers::{DeviceId, UserId},
Outgoing, signatures::CanonicalJsonValue,
Outgoing, ServerName,
}; };
use std::ops::Deref; use std::ops::Deref;
#[cfg(feature = "conduit_bin")] #[cfg(feature = "conduit_bin")]
use { use {
crate::server_server, crate::{server_server, Database},
log::{debug, warn}, log::{debug, warn},
rocket::{ rocket::{
data::{self, ByteUnit, Data, FromData}, data::{self, ByteUnit, Data, FromData},
@ -18,14 +19,11 @@ use {
tokio::io::AsyncReadExt, tokio::io::AsyncReadExt,
Request, State, Request, State,
}, },
ruma::{ ruma::api::{AuthScheme, IncomingRequest},
api::{AuthScheme, IncomingRequest},
signatures::CanonicalJsonValue,
ServerName,
},
std::collections::BTreeMap, std::collections::BTreeMap,
std::convert::TryFrom, std::convert::TryFrom,
std::io::Cursor, std::io::Cursor,
std::sync::Arc,
}; };
/// This struct converts rocket requests into ruma structs by converting them into http requests /// This struct converts rocket requests into ruma structs by converting them into http requests
@ -51,7 +49,7 @@ where
async fn from_data(request: &'a Request<'_>, data: Data) -> data::Outcome<Self, Self::Error> { async fn from_data(request: &'a Request<'_>, data: Data) -> data::Outcome<Self, Self::Error> {
let metadata = T::Incoming::METADATA; let metadata = T::Incoming::METADATA;
let db = request let db = request
.guard::<State<'_, crate::Database>>() .guard::<State<'_, Arc<Database>>>()
.await .await
.expect("database was loaded"); .expect("database was loaded");
@ -75,6 +73,7 @@ where
)) = db )) = db
.appservice .appservice
.iter_all() .iter_all()
.unwrap()
.filter_map(|r| r.ok()) .filter_map(|r| r.ok())
.find(|(_id, registration)| { .find(|(_id, registration)| {
registration registration

View File

@ -433,7 +433,7 @@ pub async fn request_well_known(
#[cfg_attr(feature = "conduit_bin", get("/_matrix/federation/v1/version"))] #[cfg_attr(feature = "conduit_bin", get("/_matrix/federation/v1/version"))]
#[tracing::instrument(skip(db))] #[tracing::instrument(skip(db))]
pub fn get_server_version_route( pub fn get_server_version_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
) -> ConduitResult<get_server_version::v1::Response> { ) -> ConduitResult<get_server_version::v1::Response> {
if !db.globals.allow_federation() { if !db.globals.allow_federation() {
return Err(Error::bad_config("Federation is disabled.")); return Err(Error::bad_config("Federation is disabled."));
@ -451,7 +451,7 @@ pub fn get_server_version_route(
// Response type for this endpoint is Json because we need to calculate a signature for the response // Response type for this endpoint is Json because we need to calculate a signature for the response
#[cfg_attr(feature = "conduit_bin", get("/_matrix/key/v2/server"))] #[cfg_attr(feature = "conduit_bin", get("/_matrix/key/v2/server"))]
#[tracing::instrument(skip(db))] #[tracing::instrument(skip(db))]
pub fn get_server_keys_route(db: State<'_, Database>) -> Json<String> { pub fn get_server_keys_route(db: State<'_, Arc<Database>>) -> Json<String> {
if !db.globals.allow_federation() { if !db.globals.allow_federation() {
// TODO: Use proper types // TODO: Use proper types
return Json("Federation is disabled.".to_owned()); return Json("Federation is disabled.".to_owned());
@ -498,7 +498,7 @@ pub fn get_server_keys_route(db: State<'_, Database>) -> Json<String> {
#[cfg_attr(feature = "conduit_bin", get("/_matrix/key/v2/server/<_>"))] #[cfg_attr(feature = "conduit_bin", get("/_matrix/key/v2/server/<_>"))]
#[tracing::instrument(skip(db))] #[tracing::instrument(skip(db))]
pub fn get_server_keys_deprecated_route(db: State<'_, Database>) -> Json<String> { pub fn get_server_keys_deprecated_route(db: State<'_, Arc<Database>>) -> Json<String> {
get_server_keys_route(db) get_server_keys_route(db)
} }
@ -508,7 +508,7 @@ pub fn get_server_keys_deprecated_route(db: State<'_, Database>) -> Json<String>
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn get_public_rooms_filtered_route( pub async fn get_public_rooms_filtered_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<get_public_rooms_filtered::v1::Request<'_>>, body: Ruma<get_public_rooms_filtered::v1::Request<'_>>,
) -> ConduitResult<get_public_rooms_filtered::v1::Response> { ) -> ConduitResult<get_public_rooms_filtered::v1::Response> {
if !db.globals.allow_federation() { if !db.globals.allow_federation() {
@ -556,7 +556,7 @@ pub async fn get_public_rooms_filtered_route(
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn get_public_rooms_route( pub async fn get_public_rooms_route(
db: State<'_, Database>, db: State<'_, Arc<Database>>,
body: Ruma<get_public_rooms::v1::Request<'_>>, body: Ruma<get_public_rooms::v1::Request<'_>>,
) -> ConduitResult<get_public_rooms::v1::Response> { ) -> ConduitResult<get_public_rooms::v1::Response> {
if !db.globals.allow_federation() { if !db.globals.allow_federation() {
@ -603,8 +603,8 @@ pub async fn get_public_rooms_route(
put("/_matrix/federation/v1/send/<_>", data = "<body>") put("/_matrix/federation/v1/send/<_>", data = "<body>")
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn send_transaction_message_route<'a>( pub async fn send_transaction_message_route(
db: State<'a, Database>, db: State<'_, Arc<Database>>,
body: Ruma<send_transaction_message::v1::Request<'_>>, body: Ruma<send_transaction_message::v1::Request<'_>>,
) -> ConduitResult<send_transaction_message::v1::Response> { ) -> ConduitResult<send_transaction_message::v1::Response> {
if !db.globals.allow_federation() { if !db.globals.allow_federation() {
@ -1585,7 +1585,7 @@ pub(crate) async fn fetch_signing_keys(
.await .await
{ {
db.globals db.globals
.add_signing_key(origin, &get_keys_response.server_key)?; .add_signing_key(origin, get_keys_response.server_key.clone())?;
result.extend( result.extend(
get_keys_response get_keys_response
@ -1628,7 +1628,7 @@ pub(crate) async fn fetch_signing_keys(
{ {
trace!("Got signing keys: {:?}", keys); trace!("Got signing keys: {:?}", keys);
for k in keys.server_keys { for k in keys.server_keys {
db.globals.add_signing_key(origin, &k)?; db.globals.add_signing_key(origin, k.clone())?;
result.extend( result.extend(
k.verify_keys k.verify_keys
.into_iter() .into_iter()
@ -1681,12 +1681,12 @@ pub(crate) fn append_incoming_pdu(
pdu, pdu,
pdu_json, pdu_json,
count, count,
pdu_id.clone().into(), &pdu_id,
&new_room_leaves.into_iter().collect::<Vec<_>>(), &new_room_leaves.into_iter().collect::<Vec<_>>(),
&db, &db,
)?; )?;
for appservice in db.appservice.iter_all().filter_map(|r| r.ok()) { for appservice in db.appservice.iter_all()?.filter_map(|r| r.ok()) {
if let Some(namespaces) = appservice.1.get("namespaces") { if let Some(namespaces) = appservice.1.get("namespaces") {
let users = namespaces let users = namespaces
.get("users") .get("users")
@ -1758,8 +1758,8 @@ pub(crate) fn append_incoming_pdu(
get("/_matrix/federation/v1/event/<_>", data = "<body>") get("/_matrix/federation/v1/event/<_>", data = "<body>")
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub fn get_event_route<'a>( pub fn get_event_route(
db: State<'a, Database>, db: State<'_, Arc<Database>>,
body: Ruma<get_event::v1::Request<'_>>, body: Ruma<get_event::v1::Request<'_>>,
) -> ConduitResult<get_event::v1::Response> { ) -> ConduitResult<get_event::v1::Response> {
if !db.globals.allow_federation() { if !db.globals.allow_federation() {
@ -1783,8 +1783,8 @@ pub fn get_event_route<'a>(
post("/_matrix/federation/v1/get_missing_events/<_>", data = "<body>") post("/_matrix/federation/v1/get_missing_events/<_>", data = "<body>")
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub fn get_missing_events_route<'a>( pub fn get_missing_events_route(
db: State<'a, Database>, db: State<'_, Arc<Database>>,
body: Ruma<get_missing_events::v1::Request<'_>>, body: Ruma<get_missing_events::v1::Request<'_>>,
) -> ConduitResult<get_missing_events::v1::Response> { ) -> ConduitResult<get_missing_events::v1::Response> {
if !db.globals.allow_federation() { if !db.globals.allow_federation() {
@ -1832,8 +1832,8 @@ pub fn get_missing_events_route<'a>(
get("/_matrix/federation/v1/state_ids/<_>", data = "<body>") get("/_matrix/federation/v1/state_ids/<_>", data = "<body>")
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub fn get_room_state_ids_route<'a>( pub fn get_room_state_ids_route(
db: State<'a, Database>, db: State<'_, Arc<Database>>,
body: Ruma<get_room_state_ids::v1::Request<'_>>, body: Ruma<get_room_state_ids::v1::Request<'_>>,
) -> ConduitResult<get_room_state_ids::v1::Response> { ) -> ConduitResult<get_room_state_ids::v1::Response> {
if !db.globals.allow_federation() { if !db.globals.allow_federation() {
@ -1884,8 +1884,8 @@ pub fn get_room_state_ids_route<'a>(
get("/_matrix/federation/v1/make_join/<_>/<_>", data = "<body>") get("/_matrix/federation/v1/make_join/<_>/<_>", data = "<body>")
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub fn create_join_event_template_route<'a>( pub fn create_join_event_template_route(
db: State<'a, Database>, db: State<'_, Arc<Database>>,
body: Ruma<create_join_event_template::v1::Request<'_>>, body: Ruma<create_join_event_template::v1::Request<'_>>,
) -> ConduitResult<create_join_event_template::v1::Response> { ) -> ConduitResult<create_join_event_template::v1::Response> {
if !db.globals.allow_federation() { if !db.globals.allow_federation() {
@ -2055,8 +2055,8 @@ pub fn create_join_event_template_route<'a>(
put("/_matrix/federation/v2/send_join/<_>/<_>", data = "<body>") put("/_matrix/federation/v2/send_join/<_>/<_>", data = "<body>")
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn create_join_event_route<'a>( pub async fn create_join_event_route(
db: State<'a, Database>, db: State<'_, Arc<Database>>,
body: Ruma<create_join_event::v2::Request<'_>>, body: Ruma<create_join_event::v2::Request<'_>>,
) -> ConduitResult<create_join_event::v2::Response> { ) -> ConduitResult<create_join_event::v2::Response> {
if !db.globals.allow_federation() { if !db.globals.allow_federation() {
@ -2171,8 +2171,8 @@ pub async fn create_join_event_route<'a>(
put("/_matrix/federation/v2/invite/<_>/<_>", data = "<body>") put("/_matrix/federation/v2/invite/<_>/<_>", data = "<body>")
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn create_invite_route<'a>( pub async fn create_invite_route(
db: State<'a, Database>, db: State<'_, Arc<Database>>,
body: Ruma<create_invite::v2::Request>, body: Ruma<create_invite::v2::Request>,
) -> ConduitResult<create_invite::v2::Response> { ) -> ConduitResult<create_invite::v2::Response> {
if !db.globals.allow_federation() { if !db.globals.allow_federation() {
@ -2276,8 +2276,8 @@ pub async fn create_invite_route<'a>(
get("/_matrix/federation/v1/user/devices/<_>", data = "<body>") get("/_matrix/federation/v1/user/devices/<_>", data = "<body>")
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub fn get_devices_route<'a>( pub fn get_devices_route(
db: State<'a, Database>, db: State<'_, Arc<Database>>,
body: Ruma<get_devices::v1::Request<'_>>, body: Ruma<get_devices::v1::Request<'_>>,
) -> ConduitResult<get_devices::v1::Response> { ) -> ConduitResult<get_devices::v1::Response> {
if !db.globals.allow_federation() { if !db.globals.allow_federation() {
@ -2316,8 +2316,8 @@ pub fn get_devices_route<'a>(
get("/_matrix/federation/v1/query/directory", data = "<body>") get("/_matrix/federation/v1/query/directory", data = "<body>")
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub fn get_room_information_route<'a>( pub fn get_room_information_route(
db: State<'a, Database>, db: State<'_, Arc<Database>>,
body: Ruma<get_room_information::v1::Request<'_>>, body: Ruma<get_room_information::v1::Request<'_>>,
) -> ConduitResult<get_room_information::v1::Response> { ) -> ConduitResult<get_room_information::v1::Response> {
if !db.globals.allow_federation() { if !db.globals.allow_federation() {
@ -2344,8 +2344,8 @@ pub fn get_room_information_route<'a>(
get("/_matrix/federation/v1/query/profile", data = "<body>") get("/_matrix/federation/v1/query/profile", data = "<body>")
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub fn get_profile_information_route<'a>( pub fn get_profile_information_route(
db: State<'a, Database>, db: State<'_, Arc<Database>>,
body: Ruma<get_profile_information::v1::Request<'_>>, body: Ruma<get_profile_information::v1::Request<'_>>,
) -> ConduitResult<get_profile_information::v1::Response> { ) -> ConduitResult<get_profile_information::v1::Response> {
if !db.globals.allow_federation() { if !db.globals.allow_federation() {
@ -2378,8 +2378,8 @@ pub fn get_profile_information_route<'a>(
post("/_matrix/federation/v1/user/keys/query", data = "<body>") post("/_matrix/federation/v1/user/keys/query", data = "<body>")
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub fn get_keys_route<'a>( pub fn get_keys_route(
db: State<'a, Database>, db: State<'_, Arc<Database>>,
body: Ruma<get_keys::v1::Request>, body: Ruma<get_keys::v1::Request>,
) -> ConduitResult<get_keys::v1::Response> { ) -> ConduitResult<get_keys::v1::Response> {
if !db.globals.allow_federation() { if !db.globals.allow_federation() {
@ -2406,8 +2406,8 @@ pub fn get_keys_route<'a>(
post("/_matrix/federation/v1/user/keys/claim", data = "<body>") post("/_matrix/federation/v1/user/keys/claim", data = "<body>")
)] )]
#[tracing::instrument(skip(db, body))] #[tracing::instrument(skip(db, body))]
pub async fn claim_keys_route<'a>( pub async fn claim_keys_route(
db: State<'a, Database>, db: State<'_, Arc<Database>>,
body: Ruma<claim_keys::v1::Request>, body: Ruma<claim_keys::v1::Request>,
) -> ConduitResult<claim_keys::v1::Response> { ) -> ConduitResult<claim_keys::v1::Response> {
if !db.globals.allow_federation() { if !db.globals.allow_federation() {

View File

@ -15,6 +15,15 @@ pub fn millis_since_unix_epoch() -> u64 {
.as_millis() as u64 .as_millis() as u64
} }
#[cfg(feature = "rocksdb")]
pub fn increment_rocksdb(
_new_key: &[u8],
old: Option<&[u8]>,
_operands: &mut rocksdb::MergeOperands,
) -> Option<Vec<u8>> {
increment(old)
}
pub fn increment(old: Option<&[u8]>) -> Option<Vec<u8>> { pub fn increment(old: Option<&[u8]>) -> Option<Vec<u8>> {
let number = match old.map(|bytes| bytes.try_into()) { let number = match old.map(|bytes| bytes.try_into()) {
Some(Ok(bytes)) => { Some(Ok(bytes)) => {
@ -27,16 +36,14 @@ pub fn increment(old: Option<&[u8]>) -> Option<Vec<u8>> {
Some(number.to_be_bytes().to_vec()) Some(number.to_be_bytes().to_vec())
} }
pub fn generate_keypair(old: Option<&[u8]>) -> Option<Vec<u8>> { pub fn generate_keypair() -> Vec<u8> {
Some(old.map(|s| s.to_vec()).unwrap_or_else(|| { let mut value = random_string(8).as_bytes().to_vec();
let mut value = random_string(8).as_bytes().to_vec(); value.push(0xff);
value.push(0xff); value.extend_from_slice(
value.extend_from_slice( &ruma::signatures::Ed25519KeyPair::generate()
&ruma::signatures::Ed25519KeyPair::generate() .expect("Ed25519KeyPair generation always works (?)"),
.expect("Ed25519KeyPair generation always works (?)"), );
); value
value
}))
} }
/// Parses the bytes into an u64. /// Parses the bytes into an u64.