Merge branch 'rocksdb' into 'master'
Swappable database backend See merge request famedly/conduit!98
This commit is contained in:
		
						commit
						8c6bcc47bf
					
				
					 47 changed files with 1613 additions and 1047 deletions
				
			
		
							
								
								
									
										106
									
								
								Cargo.lock
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										106
									
								
								Cargo.lock
									
									
									
										generated
									
									
									
								
							|  | @ -103,6 +103,25 @@ version = "0.1.4" | ||||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
| checksum = "383d29d513d8764dcdc42ea295d979eb99c3c9f00607b3692cf68a431f7dca72" | checksum = "383d29d513d8764dcdc42ea295d979eb99c3c9f00607b3692cf68a431f7dca72" | ||||||
| 
 | 
 | ||||||
|  | [[package]] | ||||||
|  | name = "bindgen" | ||||||
|  | version = "0.57.0" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "fd4865004a46a0aafb2a0a5eb19d3c9fc46ee5f063a6cfc605c69ac9ecf5263d" | ||||||
|  | dependencies = [ | ||||||
|  |  "bitflags", | ||||||
|  |  "cexpr", | ||||||
|  |  "clang-sys", | ||||||
|  |  "lazy_static", | ||||||
|  |  "lazycell", | ||||||
|  |  "peeking_take_while", | ||||||
|  |  "proc-macro2", | ||||||
|  |  "quote", | ||||||
|  |  "regex", | ||||||
|  |  "rustc-hash", | ||||||
|  |  "shlex", | ||||||
|  | ] | ||||||
|  | 
 | ||||||
| [[package]] | [[package]] | ||||||
| name = "bitflags" | name = "bitflags" | ||||||
| version = "1.2.1" | version = "1.2.1" | ||||||
|  | @ -162,6 +181,15 @@ dependencies = [ | ||||||
|  "jobserver", |  "jobserver", | ||||||
| ] | ] | ||||||
| 
 | 
 | ||||||
|  | [[package]] | ||||||
|  | name = "cexpr" | ||||||
|  | version = "0.4.0" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "f4aedb84272dbe89af497cf81375129abda4fc0a9e7c5d317498c15cc30c0d27" | ||||||
|  | dependencies = [ | ||||||
|  |  "nom", | ||||||
|  | ] | ||||||
|  | 
 | ||||||
| [[package]] | [[package]] | ||||||
| name = "cfg-if" | name = "cfg-if" | ||||||
| version = "0.1.10" | version = "0.1.10" | ||||||
|  | @ -187,6 +215,17 @@ dependencies = [ | ||||||
|  "winapi", |  "winapi", | ||||||
| ] | ] | ||||||
| 
 | 
 | ||||||
|  | [[package]] | ||||||
|  | name = "clang-sys" | ||||||
|  | version = "1.2.0" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "853eda514c284c2287f4bf20ae614f8781f40a81d32ecda6e91449304dfe077c" | ||||||
|  | dependencies = [ | ||||||
|  |  "glob", | ||||||
|  |  "libc", | ||||||
|  |  "libloading", | ||||||
|  | ] | ||||||
|  | 
 | ||||||
| [[package]] | [[package]] | ||||||
| name = "color_quant" | name = "color_quant" | ||||||
| version = "1.1.0" | version = "1.1.0" | ||||||
|  | @ -212,6 +251,7 @@ dependencies = [ | ||||||
|  "reqwest", |  "reqwest", | ||||||
|  "ring", |  "ring", | ||||||
|  "rocket", |  "rocket", | ||||||
|  |  "rocksdb", | ||||||
|  "ruma", |  "ruma", | ||||||
|  "rust-argon2", |  "rust-argon2", | ||||||
|  "rustls", |  "rustls", | ||||||
|  | @ -1008,12 +1048,40 @@ version = "1.4.0" | ||||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
| checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" | checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" | ||||||
| 
 | 
 | ||||||
|  | [[package]] | ||||||
|  | name = "lazycell" | ||||||
|  | version = "1.3.0" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" | ||||||
|  | 
 | ||||||
| [[package]] | [[package]] | ||||||
| name = "libc" | name = "libc" | ||||||
| version = "0.2.95" | version = "0.2.95" | ||||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
| checksum = "789da6d93f1b866ffe175afc5322a4d76c038605a1c3319bb57b06967ca98a36" | checksum = "789da6d93f1b866ffe175afc5322a4d76c038605a1c3319bb57b06967ca98a36" | ||||||
| 
 | 
 | ||||||
|  | [[package]] | ||||||
|  | name = "libloading" | ||||||
|  | version = "0.7.0" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "6f84d96438c15fcd6c3f244c8fce01d1e2b9c6b5623e9c711dc9286d8fc92d6a" | ||||||
|  | dependencies = [ | ||||||
|  |  "cfg-if 1.0.0", | ||||||
|  |  "winapi", | ||||||
|  | ] | ||||||
|  | 
 | ||||||
|  | [[package]] | ||||||
|  | name = "librocksdb-sys" | ||||||
|  | version = "6.17.3" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "5da125e1c0f22c7cae785982115523a0738728498547f415c9054cb17c7e89f9" | ||||||
|  | dependencies = [ | ||||||
|  |  "bindgen", | ||||||
|  |  "cc", | ||||||
|  |  "glob", | ||||||
|  |  "libc", | ||||||
|  | ] | ||||||
|  | 
 | ||||||
| [[package]] | [[package]] | ||||||
| name = "linked-hash-map" | name = "linked-hash-map" | ||||||
| version = "0.5.4" | version = "0.5.4" | ||||||
|  | @ -1158,6 +1226,16 @@ dependencies = [ | ||||||
|  "version_check", |  "version_check", | ||||||
| ] | ] | ||||||
| 
 | 
 | ||||||
|  | [[package]] | ||||||
|  | name = "nom" | ||||||
|  | version = "5.1.2" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "ffb4262d26ed83a1c0a33a38fe2bb15797329c85770da05e6b828ddb782627af" | ||||||
|  | dependencies = [ | ||||||
|  |  "memchr", | ||||||
|  |  "version_check", | ||||||
|  | ] | ||||||
|  | 
 | ||||||
| [[package]] | [[package]] | ||||||
| name = "ntapi" | name = "ntapi" | ||||||
| version = "0.3.6" | version = "0.3.6" | ||||||
|  | @ -1339,6 +1417,12 @@ dependencies = [ | ||||||
|  "syn", |  "syn", | ||||||
| ] | ] | ||||||
| 
 | 
 | ||||||
|  | [[package]] | ||||||
|  | name = "peeking_take_while" | ||||||
|  | version = "0.1.2" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "19b17cddbe7ec3f8bc800887bab5e717348c95ea2ca0b1bf0837fb964dc67099" | ||||||
|  | 
 | ||||||
| [[package]] | [[package]] | ||||||
| name = "pem" | name = "pem" | ||||||
| version = "0.8.3" | version = "0.8.3" | ||||||
|  | @ -1777,6 +1861,16 @@ dependencies = [ | ||||||
|  "uncased", |  "uncased", | ||||||
| ] | ] | ||||||
| 
 | 
 | ||||||
|  | [[package]] | ||||||
|  | name = "rocksdb" | ||||||
|  | version = "0.16.0" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "c749134fda8bfc90d0de643d59bfc841dcb3ac8a1062e12b6754bd60235c48b3" | ||||||
|  | dependencies = [ | ||||||
|  |  "libc", | ||||||
|  |  "librocksdb-sys", | ||||||
|  | ] | ||||||
|  | 
 | ||||||
| [[package]] | [[package]] | ||||||
| name = "ruma" | name = "ruma" | ||||||
| version = "0.1.2" | version = "0.1.2" | ||||||
|  | @ -2046,6 +2140,12 @@ dependencies = [ | ||||||
|  "crossbeam-utils", |  "crossbeam-utils", | ||||||
| ] | ] | ||||||
| 
 | 
 | ||||||
|  | [[package]] | ||||||
|  | name = "rustc-hash" | ||||||
|  | version = "1.1.0" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" | ||||||
|  | 
 | ||||||
| [[package]] | [[package]] | ||||||
| name = "rustc_version" | name = "rustc_version" | ||||||
| version = "0.2.3" | version = "0.2.3" | ||||||
|  | @ -2245,6 +2345,12 @@ dependencies = [ | ||||||
|  "lazy_static", |  "lazy_static", | ||||||
| ] | ] | ||||||
| 
 | 
 | ||||||
|  | [[package]] | ||||||
|  | name = "shlex" | ||||||
|  | version = "0.1.1" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "7fdf1b9db47230893d76faad238fd6097fd6d6a9245cd7a4d90dbd639536bbd2" | ||||||
|  | 
 | ||||||
| [[package]] | [[package]] | ||||||
| name = "signal-hook-registry" | name = "signal-hook-registry" | ||||||
| version = "1.3.0" | version = "1.3.0" | ||||||
|  |  | ||||||
|  | @ -24,7 +24,8 @@ ruma = { git = "https://github.com/ruma/ruma", rev = "b39537812c12caafcbf8b7bd74 | ||||||
| # Used for long polling and federation sender, should be the same as rocket::tokio | # Used for long polling and federation sender, should be the same as rocket::tokio | ||||||
| tokio = "1.2.0" | tokio = "1.2.0" | ||||||
| # Used for storing data permanently | # Used for storing data permanently | ||||||
| sled = { version = "0.34.6", features = ["compression", "no_metrics"] } | sled = { version = "0.34.6", features = ["compression", "no_metrics"], optional = true } | ||||||
|  | rocksdb = { version = "0.16.0", features = ["multi-threaded-cf"], optional = true } | ||||||
| #sled = { git = "https://github.com/spacejam/sled.git", rev = "e4640e0773595229f398438886f19bca6f7326a2", features = ["compression"] } | #sled = { git = "https://github.com/spacejam/sled.git", rev = "e4640e0773595229f398438886f19bca6f7326a2", features = ["compression"] } | ||||||
| 
 | 
 | ||||||
| # Used for the http request / response body type for Ruma endpoints used with reqwest | # Used for the http request / response body type for Ruma endpoints used with reqwest | ||||||
|  | @ -74,7 +75,9 @@ opentelemetry-jaeger = "0.11.0" | ||||||
| pretty_env_logger = "0.4.0" | pretty_env_logger = "0.4.0" | ||||||
| 
 | 
 | ||||||
| [features] | [features] | ||||||
| default = ["conduit_bin"] | default = ["conduit_bin", "backend_sled"] | ||||||
|  | backend_sled = ["sled"] | ||||||
|  | backend_rocksdb = ["rocksdb"] | ||||||
| conduit_bin = [] # TODO: add rocket to this when it is optional | conduit_bin = [] # TODO: add rocket to this when it is optional | ||||||
| 
 | 
 | ||||||
| [[bin]] | [[bin]] | ||||||
|  |  | ||||||
|  | @ -1,4 +1,4 @@ | ||||||
| use std::{collections::BTreeMap, convert::TryInto}; | use std::{collections::BTreeMap, convert::TryInto, sync::Arc}; | ||||||
| 
 | 
 | ||||||
| use super::{State, DEVICE_ID_LENGTH, SESSION_ID_LENGTH, TOKEN_LENGTH}; | use super::{State, DEVICE_ID_LENGTH, SESSION_ID_LENGTH, TOKEN_LENGTH}; | ||||||
| use crate::{pdu::PduBuilder, utils, ConduitResult, Database, Error, Ruma}; | use crate::{pdu::PduBuilder, utils, ConduitResult, Database, Error, Ruma}; | ||||||
|  | @ -42,7 +42,7 @@ const GUEST_NAME_LENGTH: usize = 10; | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn get_register_available_route( | pub async fn get_register_available_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<get_username_availability::Request<'_>>, |     body: Ruma<get_username_availability::Request<'_>>, | ||||||
| ) -> ConduitResult<get_username_availability::Response> { | ) -> ConduitResult<get_username_availability::Response> { | ||||||
|     // Validate user id
 |     // Validate user id
 | ||||||
|  | @ -85,7 +85,7 @@ pub async fn get_register_available_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn register_route( | pub async fn register_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<register::Request<'_>>, |     body: Ruma<register::Request<'_>>, | ||||||
| ) -> ConduitResult<register::Response> { | ) -> ConduitResult<register::Response> { | ||||||
|     if !db.globals.allow_registration() { |     if !db.globals.allow_registration() { | ||||||
|  | @ -227,7 +227,7 @@ pub async fn register_route( | ||||||
|     )?; |     )?; | ||||||
| 
 | 
 | ||||||
|     // If this is the first user on this server, create the admins room
 |     // If this is the first user on this server, create the admins room
 | ||||||
|     if db.users.count() == 1 { |     if db.users.count()? == 1 { | ||||||
|         // Create a user for the server
 |         // Create a user for the server
 | ||||||
|         let conduit_user = UserId::parse_with_server_name("conduit", db.globals.server_name()) |         let conduit_user = UserId::parse_with_server_name("conduit", db.globals.server_name()) | ||||||
|             .expect("@conduit:server_name is valid"); |             .expect("@conduit:server_name is valid"); | ||||||
|  | @ -506,7 +506,7 @@ pub async fn register_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn change_password_route( | pub async fn change_password_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<change_password::Request<'_>>, |     body: Ruma<change_password::Request<'_>>, | ||||||
| ) -> ConduitResult<change_password::Response> { | ) -> ConduitResult<change_password::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  | @ -598,7 +598,7 @@ pub async fn whoami_route(body: Ruma<whoami::Request>) -> ConduitResult<whoami:: | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn deactivate_route( | pub async fn deactivate_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<deactivate::Request<'_>>, |     body: Ruma<deactivate::Request<'_>>, | ||||||
| ) -> ConduitResult<deactivate::Response> { | ) -> ConduitResult<deactivate::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  |  | ||||||
|  | @ -1,3 +1,5 @@ | ||||||
|  | use std::sync::Arc; | ||||||
|  | 
 | ||||||
| use super::State; | use super::State; | ||||||
| use crate::{ConduitResult, Database, Error, Ruma}; | use crate::{ConduitResult, Database, Error, Ruma}; | ||||||
| use regex::Regex; | use regex::Regex; | ||||||
|  | @ -22,7 +24,7 @@ use rocket::{delete, get, put}; | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn create_alias_route( | pub async fn create_alias_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<create_alias::Request<'_>>, |     body: Ruma<create_alias::Request<'_>>, | ||||||
| ) -> ConduitResult<create_alias::Response> { | ) -> ConduitResult<create_alias::Response> { | ||||||
|     if db.rooms.id_from_alias(&body.room_alias)?.is_some() { |     if db.rooms.id_from_alias(&body.room_alias)?.is_some() { | ||||||
|  | @ -43,7 +45,7 @@ pub async fn create_alias_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn delete_alias_route( | pub async fn delete_alias_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<delete_alias::Request<'_>>, |     body: Ruma<delete_alias::Request<'_>>, | ||||||
| ) -> ConduitResult<delete_alias::Response> { | ) -> ConduitResult<delete_alias::Response> { | ||||||
|     db.rooms.set_alias(&body.room_alias, None, &db.globals)?; |     db.rooms.set_alias(&body.room_alias, None, &db.globals)?; | ||||||
|  | @ -59,7 +61,7 @@ pub async fn delete_alias_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn get_alias_route( | pub async fn get_alias_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<get_alias::Request<'_>>, |     body: Ruma<get_alias::Request<'_>>, | ||||||
| ) -> ConduitResult<get_alias::Response> { | ) -> ConduitResult<get_alias::Response> { | ||||||
|     get_alias_helper(&db, &body.room_alias).await |     get_alias_helper(&db, &body.room_alias).await | ||||||
|  | @ -86,7 +88,8 @@ pub async fn get_alias_helper( | ||||||
|     match db.rooms.id_from_alias(&room_alias)? { |     match db.rooms.id_from_alias(&room_alias)? { | ||||||
|         Some(r) => room_id = Some(r), |         Some(r) => room_id = Some(r), | ||||||
|         None => { |         None => { | ||||||
|             for (_id, registration) in db.appservice.iter_all().filter_map(|r| r.ok()) { |             let iter = db.appservice.iter_all()?; | ||||||
|  |             for (_id, registration) in iter.filter_map(|r| r.ok()) { | ||||||
|                 let aliases = registration |                 let aliases = registration | ||||||
|                     .get("namespaces") |                     .get("namespaces") | ||||||
|                     .and_then(|ns| ns.get("aliases")) |                     .and_then(|ns| ns.get("aliases")) | ||||||
|  |  | ||||||
|  | @ -1,3 +1,5 @@ | ||||||
|  | use std::sync::Arc; | ||||||
|  | 
 | ||||||
| use super::State; | use super::State; | ||||||
| use crate::{ConduitResult, Database, Error, Ruma}; | use crate::{ConduitResult, Database, Error, Ruma}; | ||||||
| use ruma::api::client::{ | use ruma::api::client::{ | ||||||
|  | @ -19,7 +21,7 @@ use rocket::{delete, get, post, put}; | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn create_backup_route( | pub async fn create_backup_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<create_backup::Request>, |     body: Ruma<create_backup::Request>, | ||||||
| ) -> ConduitResult<create_backup::Response> { | ) -> ConduitResult<create_backup::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  | @ -38,7 +40,7 @@ pub async fn create_backup_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn update_backup_route( | pub async fn update_backup_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<update_backup::Request<'_>>, |     body: Ruma<update_backup::Request<'_>>, | ||||||
| ) -> ConduitResult<update_backup::Response> { | ) -> ConduitResult<update_backup::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  | @ -56,7 +58,7 @@ pub async fn update_backup_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn get_latest_backup_route( | pub async fn get_latest_backup_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<get_latest_backup::Request>, |     body: Ruma<get_latest_backup::Request>, | ||||||
| ) -> ConduitResult<get_latest_backup::Response> { | ) -> ConduitResult<get_latest_backup::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  | @ -84,7 +86,7 @@ pub async fn get_latest_backup_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn get_backup_route( | pub async fn get_backup_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<get_backup::Request<'_>>, |     body: Ruma<get_backup::Request<'_>>, | ||||||
| ) -> ConduitResult<get_backup::Response> { | ) -> ConduitResult<get_backup::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  | @ -111,7 +113,7 @@ pub async fn get_backup_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn delete_backup_route( | pub async fn delete_backup_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<delete_backup::Request<'_>>, |     body: Ruma<delete_backup::Request<'_>>, | ||||||
| ) -> ConduitResult<delete_backup::Response> { | ) -> ConduitResult<delete_backup::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  | @ -130,7 +132,7 @@ pub async fn delete_backup_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn add_backup_keys_route( | pub async fn add_backup_keys_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<add_backup_keys::Request<'_>>, |     body: Ruma<add_backup_keys::Request<'_>>, | ||||||
| ) -> ConduitResult<add_backup_keys::Response> { | ) -> ConduitResult<add_backup_keys::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  | @ -164,7 +166,7 @@ pub async fn add_backup_keys_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn add_backup_key_sessions_route( | pub async fn add_backup_key_sessions_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<add_backup_key_sessions::Request<'_>>, |     body: Ruma<add_backup_key_sessions::Request<'_>>, | ||||||
| ) -> ConduitResult<add_backup_key_sessions::Response> { | ) -> ConduitResult<add_backup_key_sessions::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  | @ -196,7 +198,7 @@ pub async fn add_backup_key_sessions_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn add_backup_key_session_route( | pub async fn add_backup_key_session_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<add_backup_key_session::Request<'_>>, |     body: Ruma<add_backup_key_session::Request<'_>>, | ||||||
| ) -> ConduitResult<add_backup_key_session::Response> { | ) -> ConduitResult<add_backup_key_session::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  | @ -225,7 +227,7 @@ pub async fn add_backup_key_session_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn get_backup_keys_route( | pub async fn get_backup_keys_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<get_backup_keys::Request<'_>>, |     body: Ruma<get_backup_keys::Request<'_>>, | ||||||
| ) -> ConduitResult<get_backup_keys::Response> { | ) -> ConduitResult<get_backup_keys::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  | @ -241,14 +243,14 @@ pub async fn get_backup_keys_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn get_backup_key_sessions_route( | pub async fn get_backup_key_sessions_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<get_backup_key_sessions::Request<'_>>, |     body: Ruma<get_backup_key_sessions::Request<'_>>, | ||||||
| ) -> ConduitResult<get_backup_key_sessions::Response> { | ) -> ConduitResult<get_backup_key_sessions::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
| 
 | 
 | ||||||
|     let sessions = db |     let sessions = db | ||||||
|         .key_backups |         .key_backups | ||||||
|         .get_room(&sender_user, &body.version, &body.room_id); |         .get_room(&sender_user, &body.version, &body.room_id)?; | ||||||
| 
 | 
 | ||||||
|     Ok(get_backup_key_sessions::Response { sessions }.into()) |     Ok(get_backup_key_sessions::Response { sessions }.into()) | ||||||
| } | } | ||||||
|  | @ -259,7 +261,7 @@ pub async fn get_backup_key_sessions_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn get_backup_key_session_route( | pub async fn get_backup_key_session_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<get_backup_key_session::Request<'_>>, |     body: Ruma<get_backup_key_session::Request<'_>>, | ||||||
| ) -> ConduitResult<get_backup_key_session::Response> { | ) -> ConduitResult<get_backup_key_session::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  | @ -281,7 +283,7 @@ pub async fn get_backup_key_session_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn delete_backup_keys_route( | pub async fn delete_backup_keys_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<delete_backup_keys::Request<'_>>, |     body: Ruma<delete_backup_keys::Request<'_>>, | ||||||
| ) -> ConduitResult<delete_backup_keys::Response> { | ) -> ConduitResult<delete_backup_keys::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  | @ -304,7 +306,7 @@ pub async fn delete_backup_keys_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn delete_backup_key_sessions_route( | pub async fn delete_backup_key_sessions_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<delete_backup_key_sessions::Request<'_>>, |     body: Ruma<delete_backup_key_sessions::Request<'_>>, | ||||||
| ) -> ConduitResult<delete_backup_key_sessions::Response> { | ) -> ConduitResult<delete_backup_key_sessions::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  | @ -327,7 +329,7 @@ pub async fn delete_backup_key_sessions_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn delete_backup_key_session_route( | pub async fn delete_backup_key_session_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<delete_backup_key_session::Request<'_>>, |     body: Ruma<delete_backup_key_session::Request<'_>>, | ||||||
| ) -> ConduitResult<delete_backup_key_session::Response> { | ) -> ConduitResult<delete_backup_key_session::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  |  | ||||||
|  | @ -1,3 +1,5 @@ | ||||||
|  | use std::sync::Arc; | ||||||
|  | 
 | ||||||
| use super::State; | use super::State; | ||||||
| use crate::{ConduitResult, Database, Error, Ruma}; | use crate::{ConduitResult, Database, Error, Ruma}; | ||||||
| use ruma::{ | use ruma::{ | ||||||
|  | @ -23,7 +25,7 @@ use rocket::{get, put}; | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn set_global_account_data_route( | pub async fn set_global_account_data_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<set_global_account_data::Request<'_>>, |     body: Ruma<set_global_account_data::Request<'_>>, | ||||||
| ) -> ConduitResult<set_global_account_data::Response> { | ) -> ConduitResult<set_global_account_data::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  | @ -58,7 +60,7 @@ pub async fn set_global_account_data_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn set_room_account_data_route( | pub async fn set_room_account_data_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<set_room_account_data::Request<'_>>, |     body: Ruma<set_room_account_data::Request<'_>>, | ||||||
| ) -> ConduitResult<set_room_account_data::Response> { | ) -> ConduitResult<set_room_account_data::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  | @ -90,7 +92,7 @@ pub async fn set_room_account_data_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn get_global_account_data_route( | pub async fn get_global_account_data_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<get_global_account_data::Request<'_>>, |     body: Ruma<get_global_account_data::Request<'_>>, | ||||||
| ) -> ConduitResult<get_global_account_data::Response> { | ) -> ConduitResult<get_global_account_data::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  | @ -117,7 +119,7 @@ pub async fn get_global_account_data_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn get_room_account_data_route( | pub async fn get_room_account_data_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<get_room_account_data::Request<'_>>, |     body: Ruma<get_room_account_data::Request<'_>>, | ||||||
| ) -> ConduitResult<get_room_account_data::Response> { | ) -> ConduitResult<get_room_account_data::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  |  | ||||||
|  | @ -1,7 +1,7 @@ | ||||||
| use super::State; | use super::State; | ||||||
| use crate::{ConduitResult, Database, Error, Ruma}; | use crate::{ConduitResult, Database, Error, Ruma}; | ||||||
| use ruma::api::client::{error::ErrorKind, r0::context::get_context}; | use ruma::api::client::{error::ErrorKind, r0::context::get_context}; | ||||||
| use std::convert::TryFrom; | use std::{convert::TryFrom, sync::Arc}; | ||||||
| 
 | 
 | ||||||
| #[cfg(feature = "conduit_bin")] | #[cfg(feature = "conduit_bin")] | ||||||
| use rocket::get; | use rocket::get; | ||||||
|  | @ -12,7 +12,7 @@ use rocket::get; | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn get_context_route( | pub async fn get_context_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<get_context::Request<'_>>, |     body: Ruma<get_context::Request<'_>>, | ||||||
| ) -> ConduitResult<get_context::Response> { | ) -> ConduitResult<get_context::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  |  | ||||||
|  | @ -1,3 +1,5 @@ | ||||||
|  | use std::sync::Arc; | ||||||
|  | 
 | ||||||
| use super::State; | use super::State; | ||||||
| use crate::{utils, ConduitResult, Database, Error, Ruma}; | use crate::{utils, ConduitResult, Database, Error, Ruma}; | ||||||
| use ruma::api::client::{ | use ruma::api::client::{ | ||||||
|  | @ -18,7 +20,7 @@ use rocket::{delete, get, post, put}; | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn get_devices_route( | pub async fn get_devices_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<get_devices::Request>, |     body: Ruma<get_devices::Request>, | ||||||
| ) -> ConduitResult<get_devices::Response> { | ) -> ConduitResult<get_devices::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  | @ -38,7 +40,7 @@ pub async fn get_devices_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn get_device_route( | pub async fn get_device_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<get_device::Request<'_>>, |     body: Ruma<get_device::Request<'_>>, | ||||||
| ) -> ConduitResult<get_device::Response> { | ) -> ConduitResult<get_device::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  | @ -57,7 +59,7 @@ pub async fn get_device_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn update_device_route( | pub async fn update_device_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<update_device::Request<'_>>, |     body: Ruma<update_device::Request<'_>>, | ||||||
| ) -> ConduitResult<update_device::Response> { | ) -> ConduitResult<update_device::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  | @ -83,7 +85,7 @@ pub async fn update_device_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn delete_device_route( | pub async fn delete_device_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<delete_device::Request<'_>>, |     body: Ruma<delete_device::Request<'_>>, | ||||||
| ) -> ConduitResult<delete_device::Response> { | ) -> ConduitResult<delete_device::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  | @ -137,7 +139,7 @@ pub async fn delete_device_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn delete_devices_route( | pub async fn delete_devices_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<delete_devices::Request<'_>>, |     body: Ruma<delete_devices::Request<'_>>, | ||||||
| ) -> ConduitResult<delete_devices::Response> { | ) -> ConduitResult<delete_devices::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  |  | ||||||
|  | @ -1,3 +1,5 @@ | ||||||
|  | use std::sync::Arc; | ||||||
|  | 
 | ||||||
| use super::State; | use super::State; | ||||||
| use crate::{ConduitResult, Database, Error, Result, Ruma}; | use crate::{ConduitResult, Database, Error, Result, Ruma}; | ||||||
| use log::info; | use log::info; | ||||||
|  | @ -33,7 +35,7 @@ use rocket::{get, post, put}; | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn get_public_rooms_filtered_route( | pub async fn get_public_rooms_filtered_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<get_public_rooms_filtered::Request<'_>>, |     body: Ruma<get_public_rooms_filtered::Request<'_>>, | ||||||
| ) -> ConduitResult<get_public_rooms_filtered::Response> { | ) -> ConduitResult<get_public_rooms_filtered::Response> { | ||||||
|     get_public_rooms_filtered_helper( |     get_public_rooms_filtered_helper( | ||||||
|  | @ -53,7 +55,7 @@ pub async fn get_public_rooms_filtered_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn get_public_rooms_route( | pub async fn get_public_rooms_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<get_public_rooms::Request<'_>>, |     body: Ruma<get_public_rooms::Request<'_>>, | ||||||
| ) -> ConduitResult<get_public_rooms::Response> { | ) -> ConduitResult<get_public_rooms::Response> { | ||||||
|     let response = get_public_rooms_filtered_helper( |     let response = get_public_rooms_filtered_helper( | ||||||
|  | @ -82,7 +84,7 @@ pub async fn get_public_rooms_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn set_room_visibility_route( | pub async fn set_room_visibility_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<set_room_visibility::Request<'_>>, |     body: Ruma<set_room_visibility::Request<'_>>, | ||||||
| ) -> ConduitResult<set_room_visibility::Response> { | ) -> ConduitResult<set_room_visibility::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  | @ -112,7 +114,7 @@ pub async fn set_room_visibility_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn get_room_visibility_route( | pub async fn get_room_visibility_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<get_room_visibility::Request<'_>>, |     body: Ruma<get_room_visibility::Request<'_>>, | ||||||
| ) -> ConduitResult<get_room_visibility::Response> { | ) -> ConduitResult<get_room_visibility::Response> { | ||||||
|     Ok(get_room_visibility::Response { |     Ok(get_room_visibility::Response { | ||||||
|  |  | ||||||
|  | @ -14,7 +14,10 @@ use ruma::{ | ||||||
|     encryption::UnsignedDeviceInfo, |     encryption::UnsignedDeviceInfo, | ||||||
|     DeviceId, DeviceKeyAlgorithm, UserId, |     DeviceId, DeviceKeyAlgorithm, UserId, | ||||||
| }; | }; | ||||||
| use std::collections::{BTreeMap, HashSet}; | use std::{ | ||||||
|  |     collections::{BTreeMap, HashSet}, | ||||||
|  |     sync::Arc, | ||||||
|  | }; | ||||||
| 
 | 
 | ||||||
| #[cfg(feature = "conduit_bin")] | #[cfg(feature = "conduit_bin")] | ||||||
| use rocket::{get, post}; | use rocket::{get, post}; | ||||||
|  | @ -25,7 +28,7 @@ use rocket::{get, post}; | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn upload_keys_route( | pub async fn upload_keys_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<upload_keys::Request>, |     body: Ruma<upload_keys::Request>, | ||||||
| ) -> ConduitResult<upload_keys::Response> { | ) -> ConduitResult<upload_keys::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  | @ -74,7 +77,7 @@ pub async fn upload_keys_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn get_keys_route( | pub async fn get_keys_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<get_keys::Request<'_>>, |     body: Ruma<get_keys::Request<'_>>, | ||||||
| ) -> ConduitResult<get_keys::Response> { | ) -> ConduitResult<get_keys::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  | @ -95,7 +98,7 @@ pub async fn get_keys_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn claim_keys_route( | pub async fn claim_keys_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<claim_keys::Request>, |     body: Ruma<claim_keys::Request>, | ||||||
| ) -> ConduitResult<claim_keys::Response> { | ) -> ConduitResult<claim_keys::Response> { | ||||||
|     let response = claim_keys_helper(&body.one_time_keys, &db)?; |     let response = claim_keys_helper(&body.one_time_keys, &db)?; | ||||||
|  | @ -111,7 +114,7 @@ pub async fn claim_keys_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn upload_signing_keys_route( | pub async fn upload_signing_keys_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<upload_signing_keys::Request<'_>>, |     body: Ruma<upload_signing_keys::Request<'_>>, | ||||||
| ) -> ConduitResult<upload_signing_keys::Response> { | ) -> ConduitResult<upload_signing_keys::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  | @ -174,7 +177,7 @@ pub async fn upload_signing_keys_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn upload_signatures_route( | pub async fn upload_signatures_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<upload_signatures::Request>, |     body: Ruma<upload_signatures::Request>, | ||||||
| ) -> ConduitResult<upload_signatures::Response> { | ) -> ConduitResult<upload_signatures::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  | @ -235,7 +238,7 @@ pub async fn upload_signatures_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn get_key_changes_route( | pub async fn get_key_changes_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<get_key_changes::Request<'_>>, |     body: Ruma<get_key_changes::Request<'_>>, | ||||||
| ) -> ConduitResult<get_key_changes::Response> { | ) -> ConduitResult<get_key_changes::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  |  | ||||||
|  | @ -7,14 +7,14 @@ use ruma::api::client::{ | ||||||
| 
 | 
 | ||||||
| #[cfg(feature = "conduit_bin")] | #[cfg(feature = "conduit_bin")] | ||||||
| use rocket::{get, post}; | use rocket::{get, post}; | ||||||
| use std::convert::TryInto; | use std::{convert::TryInto, sync::Arc}; | ||||||
| 
 | 
 | ||||||
| const MXC_LENGTH: usize = 32; | const MXC_LENGTH: usize = 32; | ||||||
| 
 | 
 | ||||||
| #[cfg_attr(feature = "conduit_bin", get("/_matrix/media/r0/config"))] | #[cfg_attr(feature = "conduit_bin", get("/_matrix/media/r0/config"))] | ||||||
| #[tracing::instrument(skip(db))] | #[tracing::instrument(skip(db))] | ||||||
| pub async fn get_media_config_route( | pub async fn get_media_config_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
| ) -> ConduitResult<get_media_config::Response> { | ) -> ConduitResult<get_media_config::Response> { | ||||||
|     Ok(get_media_config::Response { |     Ok(get_media_config::Response { | ||||||
|         upload_size: db.globals.max_request_size().into(), |         upload_size: db.globals.max_request_size().into(), | ||||||
|  | @ -28,7 +28,7 @@ pub async fn get_media_config_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn create_content_route( | pub async fn create_content_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<create_content::Request<'_>>, |     body: Ruma<create_content::Request<'_>>, | ||||||
| ) -> ConduitResult<create_content::Response> { | ) -> ConduitResult<create_content::Response> { | ||||||
|     let mxc = format!( |     let mxc = format!( | ||||||
|  | @ -36,16 +36,20 @@ pub async fn create_content_route( | ||||||
|         db.globals.server_name(), |         db.globals.server_name(), | ||||||
|         utils::random_string(MXC_LENGTH) |         utils::random_string(MXC_LENGTH) | ||||||
|     ); |     ); | ||||||
|     db.media.create( | 
 | ||||||
|         mxc.clone(), |     db.media | ||||||
|         &body |         .create( | ||||||
|             .filename |             mxc.clone(), | ||||||
|             .as_ref() |             &db.globals, | ||||||
|             .map(|filename| "inline; filename=".to_owned() + filename) |             &body | ||||||
|             .as_deref(), |                 .filename | ||||||
|         &body.content_type.as_deref(), |                 .as_ref() | ||||||
|         &body.file, |                 .map(|filename| "inline; filename=".to_owned() + filename) | ||||||
|     )?; |                 .as_deref(), | ||||||
|  |             &body.content_type.as_deref(), | ||||||
|  |             &body.file, | ||||||
|  |         ) | ||||||
|  |         .await?; | ||||||
| 
 | 
 | ||||||
|     db.flush().await?; |     db.flush().await?; | ||||||
| 
 | 
 | ||||||
|  | @ -62,7 +66,7 @@ pub async fn create_content_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn get_content_route( | pub async fn get_content_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<get_content::Request<'_>>, |     body: Ruma<get_content::Request<'_>>, | ||||||
| ) -> ConduitResult<get_content::Response> { | ) -> ConduitResult<get_content::Response> { | ||||||
|     let mxc = format!("mxc://{}/{}", body.server_name, body.media_id); |     let mxc = format!("mxc://{}/{}", body.server_name, body.media_id); | ||||||
|  | @ -71,7 +75,7 @@ pub async fn get_content_route( | ||||||
|         content_disposition, |         content_disposition, | ||||||
|         content_type, |         content_type, | ||||||
|         file, |         file, | ||||||
|     }) = db.media.get(&mxc)? |     }) = db.media.get(&db.globals, &mxc).await? | ||||||
|     { |     { | ||||||
|         Ok(get_content::Response { |         Ok(get_content::Response { | ||||||
|             file, |             file, | ||||||
|  | @ -93,12 +97,15 @@ pub async fn get_content_route( | ||||||
|             ) |             ) | ||||||
|             .await?; |             .await?; | ||||||
| 
 | 
 | ||||||
|         db.media.create( |         db.media | ||||||
|             mxc, |             .create( | ||||||
|             &get_content_response.content_disposition.as_deref(), |                 mxc, | ||||||
|             &get_content_response.content_type.as_deref(), |                 &db.globals, | ||||||
|             &get_content_response.file, |                 &get_content_response.content_disposition.as_deref(), | ||||||
|         )?; |                 &get_content_response.content_type.as_deref(), | ||||||
|  |                 &get_content_response.file, | ||||||
|  |             ) | ||||||
|  |             .await?; | ||||||
| 
 | 
 | ||||||
|         Ok(get_content_response.into()) |         Ok(get_content_response.into()) | ||||||
|     } else { |     } else { | ||||||
|  | @ -112,22 +119,27 @@ pub async fn get_content_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn get_content_thumbnail_route( | pub async fn get_content_thumbnail_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<get_content_thumbnail::Request<'_>>, |     body: Ruma<get_content_thumbnail::Request<'_>>, | ||||||
| ) -> ConduitResult<get_content_thumbnail::Response> { | ) -> ConduitResult<get_content_thumbnail::Response> { | ||||||
|     let mxc = format!("mxc://{}/{}", body.server_name, body.media_id); |     let mxc = format!("mxc://{}/{}", body.server_name, body.media_id); | ||||||
| 
 | 
 | ||||||
|     if let Some(FileMeta { |     if let Some(FileMeta { | ||||||
|         content_type, file, .. |         content_type, file, .. | ||||||
|     }) = db.media.get_thumbnail( |     }) = db | ||||||
|         mxc.clone(), |         .media | ||||||
|         body.width |         .get_thumbnail( | ||||||
|             .try_into() |             mxc.clone(), | ||||||
|             .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Width is invalid."))?, |             &db.globals, | ||||||
|         body.height |             body.width | ||||||
|             .try_into() |                 .try_into() | ||||||
|             .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Width is invalid."))?, |                 .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Width is invalid."))?, | ||||||
|     )? { |             body.height | ||||||
|  |                 .try_into() | ||||||
|  |                 .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Width is invalid."))?, | ||||||
|  |         ) | ||||||
|  |         .await? | ||||||
|  |     { | ||||||
|         Ok(get_content_thumbnail::Response { file, content_type }.into()) |         Ok(get_content_thumbnail::Response { file, content_type }.into()) | ||||||
|     } else if &*body.server_name != db.globals.server_name() && body.allow_remote { |     } else if &*body.server_name != db.globals.server_name() && body.allow_remote { | ||||||
|         let get_thumbnail_response = db |         let get_thumbnail_response = db | ||||||
|  | @ -146,14 +158,17 @@ pub async fn get_content_thumbnail_route( | ||||||
|             ) |             ) | ||||||
|             .await?; |             .await?; | ||||||
| 
 | 
 | ||||||
|         db.media.upload_thumbnail( |         db.media | ||||||
|             mxc, |             .upload_thumbnail( | ||||||
|             &None, |                 mxc, | ||||||
|             &get_thumbnail_response.content_type, |                 &db.globals, | ||||||
|             body.width.try_into().expect("all UInts are valid u32s"), |                 &None, | ||||||
|             body.height.try_into().expect("all UInts are valid u32s"), |                 &get_thumbnail_response.content_type, | ||||||
|             &get_thumbnail_response.file, |                 body.width.try_into().expect("all UInts are valid u32s"), | ||||||
|         )?; |                 body.height.try_into().expect("all UInts are valid u32s"), | ||||||
|  |                 &get_thumbnail_response.file, | ||||||
|  |             ) | ||||||
|  |             .await?; | ||||||
| 
 | 
 | ||||||
|         Ok(get_thumbnail_response.into()) |         Ok(get_thumbnail_response.into()) | ||||||
|     } else { |     } else { | ||||||
|  |  | ||||||
|  | @ -44,7 +44,7 @@ use rocket::{get, post}; | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn join_room_by_id_route( | pub async fn join_room_by_id_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<join_room_by_id::Request<'_>>, |     body: Ruma<join_room_by_id::Request<'_>>, | ||||||
| ) -> ConduitResult<join_room_by_id::Response> { | ) -> ConduitResult<join_room_by_id::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  | @ -81,7 +81,7 @@ pub async fn join_room_by_id_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn join_room_by_id_or_alias_route( | pub async fn join_room_by_id_or_alias_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<join_room_by_id_or_alias::Request<'_>>, |     body: Ruma<join_room_by_id_or_alias::Request<'_>>, | ||||||
| ) -> ConduitResult<join_room_by_id_or_alias::Response> { | ) -> ConduitResult<join_room_by_id_or_alias::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  | @ -135,7 +135,7 @@ pub async fn join_room_by_id_or_alias_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn leave_room_route( | pub async fn leave_room_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<leave_room::Request<'_>>, |     body: Ruma<leave_room::Request<'_>>, | ||||||
| ) -> ConduitResult<leave_room::Response> { | ) -> ConduitResult<leave_room::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  | @ -153,7 +153,7 @@ pub async fn leave_room_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn invite_user_route( | pub async fn invite_user_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<invite_user::Request<'_>>, |     body: Ruma<invite_user::Request<'_>>, | ||||||
| ) -> ConduitResult<invite_user::Response> { | ) -> ConduitResult<invite_user::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  | @ -173,7 +173,7 @@ pub async fn invite_user_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn kick_user_route( | pub async fn kick_user_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<kick_user::Request<'_>>, |     body: Ruma<kick_user::Request<'_>>, | ||||||
| ) -> ConduitResult<kick_user::Response> { | ) -> ConduitResult<kick_user::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  | @ -222,7 +222,7 @@ pub async fn kick_user_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn ban_user_route( | pub async fn ban_user_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<ban_user::Request<'_>>, |     body: Ruma<ban_user::Request<'_>>, | ||||||
| ) -> ConduitResult<ban_user::Response> { | ) -> ConduitResult<ban_user::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  | @ -279,7 +279,7 @@ pub async fn ban_user_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn unban_user_route( | pub async fn unban_user_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<unban_user::Request<'_>>, |     body: Ruma<unban_user::Request<'_>>, | ||||||
| ) -> ConduitResult<unban_user::Response> { | ) -> ConduitResult<unban_user::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  | @ -327,7 +327,7 @@ pub async fn unban_user_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn forget_room_route( | pub async fn forget_room_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<forget_room::Request<'_>>, |     body: Ruma<forget_room::Request<'_>>, | ||||||
| ) -> ConduitResult<forget_room::Response> { | ) -> ConduitResult<forget_room::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  | @ -345,7 +345,7 @@ pub async fn forget_room_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn joined_rooms_route( | pub async fn joined_rooms_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<joined_rooms::Request>, |     body: Ruma<joined_rooms::Request>, | ||||||
| ) -> ConduitResult<joined_rooms::Response> { | ) -> ConduitResult<joined_rooms::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  | @ -366,7 +366,7 @@ pub async fn joined_rooms_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn get_member_events_route( | pub async fn get_member_events_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<get_member_events::Request<'_>>, |     body: Ruma<get_member_events::Request<'_>>, | ||||||
| ) -> ConduitResult<get_member_events::Response> { | ) -> ConduitResult<get_member_events::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  | @ -396,7 +396,7 @@ pub async fn get_member_events_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn joined_members_route( | pub async fn joined_members_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<joined_members::Request<'_>>, |     body: Ruma<joined_members::Request<'_>>, | ||||||
| ) -> ConduitResult<joined_members::Response> { | ) -> ConduitResult<joined_members::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  | @ -621,7 +621,7 @@ async fn join_room_by_id_helper( | ||||||
|             &pdu, |             &pdu, | ||||||
|             utils::to_canonical_object(&pdu).expect("Pdu is valid canonical object"), |             utils::to_canonical_object(&pdu).expect("Pdu is valid canonical object"), | ||||||
|             count, |             count, | ||||||
|             pdu_id.into(), |             &pdu_id, | ||||||
|             &[pdu.event_id.clone()], |             &[pdu.event_id.clone()], | ||||||
|             db, |             db, | ||||||
|         )?; |         )?; | ||||||
|  |  | ||||||
|  | @ -11,6 +11,7 @@ use ruma::{ | ||||||
| use std::{ | use std::{ | ||||||
|     collections::BTreeMap, |     collections::BTreeMap, | ||||||
|     convert::{TryFrom, TryInto}, |     convert::{TryFrom, TryInto}, | ||||||
|  |     sync::Arc, | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| #[cfg(feature = "conduit_bin")] | #[cfg(feature = "conduit_bin")] | ||||||
|  | @ -22,7 +23,7 @@ use rocket::{get, put}; | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn send_message_event_route( | pub async fn send_message_event_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<send_message_event::Request<'_>>, |     body: Ruma<send_message_event::Request<'_>>, | ||||||
| ) -> ConduitResult<send_message_event::Response> { | ) -> ConduitResult<send_message_event::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  | @ -85,7 +86,7 @@ pub async fn send_message_event_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn get_message_events_route( | pub async fn get_message_events_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<get_message_events::Request<'_>>, |     body: Ruma<get_message_events::Request<'_>>, | ||||||
| ) -> ConduitResult<get_message_events::Response> { | ) -> ConduitResult<get_message_events::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  |  | ||||||
|  | @ -1,7 +1,7 @@ | ||||||
| use super::State; | use super::State; | ||||||
| use crate::{utils, ConduitResult, Database, Ruma}; | use crate::{utils, ConduitResult, Database, Ruma}; | ||||||
| use ruma::api::client::r0::presence::{get_presence, set_presence}; | use ruma::api::client::r0::presence::{get_presence, set_presence}; | ||||||
| use std::{convert::TryInto, time::Duration}; | use std::{convert::TryInto, sync::Arc, time::Duration}; | ||||||
| 
 | 
 | ||||||
| #[cfg(feature = "conduit_bin")] | #[cfg(feature = "conduit_bin")] | ||||||
| use rocket::{get, put}; | use rocket::{get, put}; | ||||||
|  | @ -12,7 +12,7 @@ use rocket::{get, put}; | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn set_presence_route( | pub async fn set_presence_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<set_presence::Request<'_>>, |     body: Ruma<set_presence::Request<'_>>, | ||||||
| ) -> ConduitResult<set_presence::Response> { | ) -> ConduitResult<set_presence::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  | @ -53,7 +53,7 @@ pub async fn set_presence_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn get_presence_route( | pub async fn get_presence_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<get_presence::Request<'_>>, |     body: Ruma<get_presence::Request<'_>>, | ||||||
| ) -> ConduitResult<get_presence::Response> { | ) -> ConduitResult<get_presence::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  | @ -62,7 +62,7 @@ pub async fn get_presence_route( | ||||||
| 
 | 
 | ||||||
|     for room_id in db |     for room_id in db | ||||||
|         .rooms |         .rooms | ||||||
|         .get_shared_rooms(vec![sender_user.clone(), body.user_id.clone()]) |         .get_shared_rooms(vec![sender_user.clone(), body.user_id.clone()])? | ||||||
|     { |     { | ||||||
|         let room_id = room_id?; |         let room_id = room_id?; | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -13,7 +13,7 @@ use ruma::{ | ||||||
| 
 | 
 | ||||||
| #[cfg(feature = "conduit_bin")] | #[cfg(feature = "conduit_bin")] | ||||||
| use rocket::{get, put}; | use rocket::{get, put}; | ||||||
| use std::convert::TryInto; | use std::{convert::TryInto, sync::Arc}; | ||||||
| 
 | 
 | ||||||
| #[cfg_attr(
 | #[cfg_attr(
 | ||||||
|     feature = "conduit_bin", |     feature = "conduit_bin", | ||||||
|  | @ -21,7 +21,7 @@ use std::convert::TryInto; | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn set_displayname_route( | pub async fn set_displayname_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<set_display_name::Request<'_>>, |     body: Ruma<set_display_name::Request<'_>>, | ||||||
| ) -> ConduitResult<set_display_name::Response> { | ) -> ConduitResult<set_display_name::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  | @ -107,7 +107,7 @@ pub async fn set_displayname_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn get_displayname_route( | pub async fn get_displayname_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<get_display_name::Request<'_>>, |     body: Ruma<get_display_name::Request<'_>>, | ||||||
| ) -> ConduitResult<get_display_name::Response> { | ) -> ConduitResult<get_display_name::Response> { | ||||||
|     Ok(get_display_name::Response { |     Ok(get_display_name::Response { | ||||||
|  | @ -122,7 +122,7 @@ pub async fn get_displayname_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn set_avatar_url_route( | pub async fn set_avatar_url_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<set_avatar_url::Request<'_>>, |     body: Ruma<set_avatar_url::Request<'_>>, | ||||||
| ) -> ConduitResult<set_avatar_url::Response> { | ) -> ConduitResult<set_avatar_url::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  | @ -208,7 +208,7 @@ pub async fn set_avatar_url_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn get_avatar_url_route( | pub async fn get_avatar_url_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<get_avatar_url::Request<'_>>, |     body: Ruma<get_avatar_url::Request<'_>>, | ||||||
| ) -> ConduitResult<get_avatar_url::Response> { | ) -> ConduitResult<get_avatar_url::Response> { | ||||||
|     Ok(get_avatar_url::Response { |     Ok(get_avatar_url::Response { | ||||||
|  | @ -223,7 +223,7 @@ pub async fn get_avatar_url_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn get_profile_route( | pub async fn get_profile_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<get_profile::Request<'_>>, |     body: Ruma<get_profile::Request<'_>>, | ||||||
| ) -> ConduitResult<get_profile::Response> { | ) -> ConduitResult<get_profile::Response> { | ||||||
|     if !db.users.exists(&body.user_id)? { |     if !db.users.exists(&body.user_id)? { | ||||||
|  |  | ||||||
|  | @ -1,3 +1,5 @@ | ||||||
|  | use std::sync::Arc; | ||||||
|  | 
 | ||||||
| use super::State; | use super::State; | ||||||
| use crate::{ConduitResult, Database, Error, Ruma}; | use crate::{ConduitResult, Database, Error, Ruma}; | ||||||
| use ruma::{ | use ruma::{ | ||||||
|  | @ -22,7 +24,7 @@ use rocket::{delete, get, post, put}; | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn get_pushrules_all_route( | pub async fn get_pushrules_all_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<get_pushrules_all::Request>, |     body: Ruma<get_pushrules_all::Request>, | ||||||
| ) -> ConduitResult<get_pushrules_all::Response> { | ) -> ConduitResult<get_pushrules_all::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  | @ -47,7 +49,7 @@ pub async fn get_pushrules_all_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn get_pushrule_route( | pub async fn get_pushrule_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<get_pushrule::Request<'_>>, |     body: Ruma<get_pushrule::Request<'_>>, | ||||||
| ) -> ConduitResult<get_pushrule::Response> { | ) -> ConduitResult<get_pushrule::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  | @ -101,7 +103,7 @@ pub async fn get_pushrule_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, req))] | #[tracing::instrument(skip(db, req))] | ||||||
| pub async fn set_pushrule_route( | pub async fn set_pushrule_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     req: Ruma<set_pushrule::Request<'_>>, |     req: Ruma<set_pushrule::Request<'_>>, | ||||||
| ) -> ConduitResult<set_pushrule::Response> { | ) -> ConduitResult<set_pushrule::Response> { | ||||||
|     let sender_user = req.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = req.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  | @ -204,7 +206,7 @@ pub async fn set_pushrule_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn get_pushrule_actions_route( | pub async fn get_pushrule_actions_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<get_pushrule_actions::Request<'_>>, |     body: Ruma<get_pushrule_actions::Request<'_>>, | ||||||
| ) -> ConduitResult<get_pushrule_actions::Response> { | ) -> ConduitResult<get_pushrule_actions::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  | @ -263,7 +265,7 @@ pub async fn get_pushrule_actions_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn set_pushrule_actions_route( | pub async fn set_pushrule_actions_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<set_pushrule_actions::Request<'_>>, |     body: Ruma<set_pushrule_actions::Request<'_>>, | ||||||
| ) -> ConduitResult<set_pushrule_actions::Response> { | ) -> ConduitResult<set_pushrule_actions::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  | @ -337,7 +339,7 @@ pub async fn set_pushrule_actions_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn get_pushrule_enabled_route( | pub async fn get_pushrule_enabled_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<get_pushrule_enabled::Request<'_>>, |     body: Ruma<get_pushrule_enabled::Request<'_>>, | ||||||
| ) -> ConduitResult<get_pushrule_enabled::Response> { | ) -> ConduitResult<get_pushrule_enabled::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  | @ -398,7 +400,7 @@ pub async fn get_pushrule_enabled_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn set_pushrule_enabled_route( | pub async fn set_pushrule_enabled_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<set_pushrule_enabled::Request<'_>>, |     body: Ruma<set_pushrule_enabled::Request<'_>>, | ||||||
| ) -> ConduitResult<set_pushrule_enabled::Response> { | ) -> ConduitResult<set_pushrule_enabled::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  | @ -477,7 +479,7 @@ pub async fn set_pushrule_enabled_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn delete_pushrule_route( | pub async fn delete_pushrule_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<delete_pushrule::Request<'_>>, |     body: Ruma<delete_pushrule::Request<'_>>, | ||||||
| ) -> ConduitResult<delete_pushrule::Response> { | ) -> ConduitResult<delete_pushrule::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  | @ -546,7 +548,7 @@ pub async fn delete_pushrule_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn get_pushers_route( | pub async fn get_pushers_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<get_pushers::Request>, |     body: Ruma<get_pushers::Request>, | ||||||
| ) -> ConduitResult<get_pushers::Response> { | ) -> ConduitResult<get_pushers::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  | @ -563,7 +565,7 @@ pub async fn get_pushers_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn set_pushers_route( | pub async fn set_pushers_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<set_pusher::Request>, |     body: Ruma<set_pusher::Request>, | ||||||
| ) -> ConduitResult<set_pusher::Response> { | ) -> ConduitResult<set_pusher::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  |  | ||||||
|  | @ -12,7 +12,7 @@ use ruma::{ | ||||||
| 
 | 
 | ||||||
| #[cfg(feature = "conduit_bin")] | #[cfg(feature = "conduit_bin")] | ||||||
| use rocket::post; | use rocket::post; | ||||||
| use std::collections::BTreeMap; | use std::{collections::BTreeMap, sync::Arc}; | ||||||
| 
 | 
 | ||||||
| #[cfg_attr(
 | #[cfg_attr(
 | ||||||
|     feature = "conduit_bin", |     feature = "conduit_bin", | ||||||
|  | @ -20,7 +20,7 @@ use std::collections::BTreeMap; | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn set_read_marker_route( | pub async fn set_read_marker_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<set_read_marker::Request<'_>>, |     body: Ruma<set_read_marker::Request<'_>>, | ||||||
| ) -> ConduitResult<set_read_marker::Response> { | ) -> ConduitResult<set_read_marker::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  | @ -87,7 +87,7 @@ pub async fn set_read_marker_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn create_receipt_route( | pub async fn create_receipt_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<create_receipt::Request<'_>>, |     body: Ruma<create_receipt::Request<'_>>, | ||||||
| ) -> ConduitResult<create_receipt::Response> { | ) -> ConduitResult<create_receipt::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  |  | ||||||
|  | @ -4,6 +4,7 @@ use ruma::{ | ||||||
|     api::client::r0::redact::redact_event, |     api::client::r0::redact::redact_event, | ||||||
|     events::{room::redaction, EventType}, |     events::{room::redaction, EventType}, | ||||||
| }; | }; | ||||||
|  | use std::sync::Arc; | ||||||
| 
 | 
 | ||||||
| #[cfg(feature = "conduit_bin")] | #[cfg(feature = "conduit_bin")] | ||||||
| use rocket::put; | use rocket::put; | ||||||
|  | @ -14,7 +15,7 @@ use rocket::put; | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn redact_event_route( | pub async fn redact_event_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<redact_event::Request<'_>>, |     body: Ruma<redact_event::Request<'_>>, | ||||||
| ) -> ConduitResult<redact_event::Response> { | ) -> ConduitResult<redact_event::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  |  | ||||||
|  | @ -13,7 +13,7 @@ use ruma::{ | ||||||
|     serde::Raw, |     serde::Raw, | ||||||
|     RoomAliasId, RoomId, RoomVersionId, |     RoomAliasId, RoomId, RoomVersionId, | ||||||
| }; | }; | ||||||
| use std::{cmp::max, collections::BTreeMap, convert::TryFrom}; | use std::{cmp::max, collections::BTreeMap, convert::TryFrom, sync::Arc}; | ||||||
| 
 | 
 | ||||||
| #[cfg(feature = "conduit_bin")] | #[cfg(feature = "conduit_bin")] | ||||||
| use rocket::{get, post}; | use rocket::{get, post}; | ||||||
|  | @ -24,7 +24,7 @@ use rocket::{get, post}; | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn create_room_route( | pub async fn create_room_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<create_room::Request<'_>>, |     body: Ruma<create_room::Request<'_>>, | ||||||
| ) -> ConduitResult<create_room::Response> { | ) -> ConduitResult<create_room::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  | @ -304,7 +304,7 @@ pub async fn create_room_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn get_room_event_route( | pub async fn get_room_event_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<get_room_event::Request<'_>>, |     body: Ruma<get_room_event::Request<'_>>, | ||||||
| ) -> ConduitResult<get_room_event::Response> { | ) -> ConduitResult<get_room_event::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  | @ -332,7 +332,7 @@ pub async fn get_room_event_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn upgrade_room_route( | pub async fn upgrade_room_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<upgrade_room::Request<'_>>, |     body: Ruma<upgrade_room::Request<'_>>, | ||||||
|     _room_id: String, |     _room_id: String, | ||||||
| ) -> ConduitResult<upgrade_room::Response> { | ) -> ConduitResult<upgrade_room::Response> { | ||||||
|  |  | ||||||
|  | @ -1,6 +1,7 @@ | ||||||
| use super::State; | use super::State; | ||||||
| use crate::{ConduitResult, Database, Error, Ruma}; | use crate::{ConduitResult, Database, Error, Ruma}; | ||||||
| use ruma::api::client::{error::ErrorKind, r0::search::search_events}; | use ruma::api::client::{error::ErrorKind, r0::search::search_events}; | ||||||
|  | use std::sync::Arc; | ||||||
| 
 | 
 | ||||||
| #[cfg(feature = "conduit_bin")] | #[cfg(feature = "conduit_bin")] | ||||||
| use rocket::post; | use rocket::post; | ||||||
|  | @ -13,7 +14,7 @@ use std::collections::BTreeMap; | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn search_events_route( | pub async fn search_events_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<search_events::Request<'_>>, |     body: Ruma<search_events::Request<'_>>, | ||||||
| ) -> ConduitResult<search_events::Response> { | ) -> ConduitResult<search_events::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  |  | ||||||
|  | @ -1,3 +1,5 @@ | ||||||
|  | use std::sync::Arc; | ||||||
|  | 
 | ||||||
| use super::{State, DEVICE_ID_LENGTH, TOKEN_LENGTH}; | use super::{State, DEVICE_ID_LENGTH, TOKEN_LENGTH}; | ||||||
| use crate::{utils, ConduitResult, Database, Error, Ruma}; | use crate::{utils, ConduitResult, Database, Error, Ruma}; | ||||||
| use log::info; | use log::info; | ||||||
|  | @ -50,7 +52,7 @@ pub async fn get_login_types_route() -> ConduitResult<get_login_types::Response> | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn login_route( | pub async fn login_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<login::Request<'_>>, |     body: Ruma<login::Request<'_>>, | ||||||
| ) -> ConduitResult<login::Response> { | ) -> ConduitResult<login::Response> { | ||||||
|     // Validate login method
 |     // Validate login method
 | ||||||
|  | @ -167,7 +169,7 @@ pub async fn login_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn logout_route( | pub async fn logout_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<logout::Request>, |     body: Ruma<logout::Request>, | ||||||
| ) -> ConduitResult<logout::Response> { | ) -> ConduitResult<logout::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  | @ -195,7 +197,7 @@ pub async fn logout_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn logout_all_route( | pub async fn logout_all_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<logout_all::Request>, |     body: Ruma<logout_all::Request>, | ||||||
| ) -> ConduitResult<logout_all::Response> { | ) -> ConduitResult<logout_all::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  |  | ||||||
|  | @ -1,3 +1,5 @@ | ||||||
|  | use std::sync::Arc; | ||||||
|  | 
 | ||||||
| use super::State; | use super::State; | ||||||
| use crate::{pdu::PduBuilder, ConduitResult, Database, Error, Result, Ruma}; | use crate::{pdu::PduBuilder, ConduitResult, Database, Error, Result, Ruma}; | ||||||
| use ruma::{ | use ruma::{ | ||||||
|  | @ -25,7 +27,7 @@ use rocket::{get, put}; | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn send_state_event_for_key_route( | pub async fn send_state_event_for_key_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<send_state_event::Request<'_>>, |     body: Ruma<send_state_event::Request<'_>>, | ||||||
| ) -> ConduitResult<send_state_event::Response> { | ) -> ConduitResult<send_state_event::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  | @ -51,7 +53,7 @@ pub async fn send_state_event_for_key_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn send_state_event_for_empty_key_route( | pub async fn send_state_event_for_empty_key_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<send_state_event::Request<'_>>, |     body: Ruma<send_state_event::Request<'_>>, | ||||||
| ) -> ConduitResult<send_state_event::Response> { | ) -> ConduitResult<send_state_event::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  | @ -77,7 +79,7 @@ pub async fn send_state_event_for_empty_key_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn get_state_events_route( | pub async fn get_state_events_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<get_state_events::Request<'_>>, |     body: Ruma<get_state_events::Request<'_>>, | ||||||
| ) -> ConduitResult<get_state_events::Response> { | ) -> ConduitResult<get_state_events::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  | @ -124,7 +126,7 @@ pub async fn get_state_events_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn get_state_events_for_key_route( | pub async fn get_state_events_for_key_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<get_state_events_for_key::Request<'_>>, |     body: Ruma<get_state_events_for_key::Request<'_>>, | ||||||
| ) -> ConduitResult<get_state_events_for_key::Response> { | ) -> ConduitResult<get_state_events_for_key::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  | @ -175,7 +177,7 @@ pub async fn get_state_events_for_key_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn get_state_events_for_empty_key_route( | pub async fn get_state_events_for_empty_key_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<get_state_events_for_key::Request<'_>>, |     body: Ruma<get_state_events_for_key::Request<'_>>, | ||||||
| ) -> ConduitResult<get_state_events_for_key::Response> { | ) -> ConduitResult<get_state_events_for_key::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  |  | ||||||
|  | @ -1,5 +1,5 @@ | ||||||
| use super::State; | use super::State; | ||||||
| use crate::{ConduitResult, Database, Error, Ruma}; | use crate::{ConduitResult, Database, Error, Result, Ruma}; | ||||||
| use log::error; | use log::error; | ||||||
| use ruma::{ | use ruma::{ | ||||||
|     api::client::r0::sync::sync_events, |     api::client::r0::sync::sync_events, | ||||||
|  | @ -13,6 +13,7 @@ use rocket::{get, tokio}; | ||||||
| use std::{ | use std::{ | ||||||
|     collections::{hash_map, BTreeMap, HashMap, HashSet}, |     collections::{hash_map, BTreeMap, HashMap, HashSet}, | ||||||
|     convert::{TryFrom, TryInto}, |     convert::{TryFrom, TryInto}, | ||||||
|  |     sync::Arc, | ||||||
|     time::Duration, |     time::Duration, | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
|  | @ -33,7 +34,7 @@ use std::{ | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn sync_events_route( | pub async fn sync_events_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<sync_events::Request<'_>>, |     body: Ruma<sync_events::Request<'_>>, | ||||||
| ) -> ConduitResult<sync_events::Response> { | ) -> ConduitResult<sync_events::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  | @ -71,18 +72,23 @@ pub async fn sync_events_route( | ||||||
| 
 | 
 | ||||||
|         let mut non_timeline_pdus = db |         let mut non_timeline_pdus = db | ||||||
|             .rooms |             .rooms | ||||||
|             .pdus_since(&sender_user, &room_id, since)? |             .pdus_until(&sender_user, &room_id, u64::MAX) | ||||||
|             .filter_map(|r| { |             .filter_map(|r| { | ||||||
|  |                 // Filter out buggy events
 | ||||||
|                 if r.is_err() { |                 if r.is_err() { | ||||||
|                     error!("Bad pdu in pdus_since: {:?}", r); |                     error!("Bad pdu in pdus_since: {:?}", r); | ||||||
|                 } |                 } | ||||||
|                 r.ok() |                 r.ok() | ||||||
|             }); // Filter out buggy events
 |             }) | ||||||
|  |             .take_while(|(pduid, _)| { | ||||||
|  |                 db.rooms | ||||||
|  |                     .pdu_count(pduid) | ||||||
|  |                     .map_or(false, |count| count > since) | ||||||
|  |             }); | ||||||
| 
 | 
 | ||||||
|         // Take the last 10 events for the timeline
 |         // Take the last 10 events for the timeline
 | ||||||
|         let timeline_pdus = non_timeline_pdus |         let timeline_pdus = non_timeline_pdus | ||||||
|             .by_ref() |             .by_ref() | ||||||
|             .rev() |  | ||||||
|             .take(10) |             .take(10) | ||||||
|             .collect::<Vec<_>>() |             .collect::<Vec<_>>() | ||||||
|             .into_iter() |             .into_iter() | ||||||
|  | @ -226,7 +232,7 @@ pub async fn sync_events_route( | ||||||
|                     match (since_membership, current_membership) { |                     match (since_membership, current_membership) { | ||||||
|                         (MembershipState::Leave, MembershipState::Join) => { |                         (MembershipState::Leave, MembershipState::Join) => { | ||||||
|                             // A new user joined an encrypted room
 |                             // A new user joined an encrypted room
 | ||||||
|                             if !share_encrypted_room(&db, &sender_user, &user_id, &room_id) { |                             if !share_encrypted_room(&db, &sender_user, &user_id, &room_id)? { | ||||||
|                                 device_list_updates.insert(user_id); |                                 device_list_updates.insert(user_id); | ||||||
|                             } |                             } | ||||||
|                         } |                         } | ||||||
|  | @ -257,6 +263,7 @@ pub async fn sync_events_route( | ||||||
|                         .filter(|user_id| { |                         .filter(|user_id| { | ||||||
|                             // Only send keys if the sender doesn't share an encrypted room with the target already
 |                             // Only send keys if the sender doesn't share an encrypted room with the target already
 | ||||||
|                             !share_encrypted_room(&db, sender_user, user_id, &room_id) |                             !share_encrypted_room(&db, sender_user, user_id, &room_id) | ||||||
|  |                                 .unwrap_or(false) | ||||||
|                         }), |                         }), | ||||||
|                 ); |                 ); | ||||||
|             } |             } | ||||||
|  | @ -274,7 +281,7 @@ pub async fn sync_events_route( | ||||||
| 
 | 
 | ||||||
|                     for hero in db |                     for hero in db | ||||||
|                         .rooms |                         .rooms | ||||||
|                         .all_pdus(&sender_user, &room_id)? |                         .all_pdus(&sender_user, &room_id) | ||||||
|                         .filter_map(|pdu| pdu.ok()) // Ignore all broken pdus
 |                         .filter_map(|pdu| pdu.ok()) // Ignore all broken pdus
 | ||||||
|                         .filter(|(_, pdu)| pdu.kind == EventType::RoomMember) |                         .filter(|(_, pdu)| pdu.kind == EventType::RoomMember) | ||||||
|                         .map(|(_, pdu)| { |                         .map(|(_, pdu)| { | ||||||
|  | @ -411,7 +418,7 @@ pub async fn sync_events_route( | ||||||
|         let mut edus = db |         let mut edus = db | ||||||
|             .rooms |             .rooms | ||||||
|             .edus |             .edus | ||||||
|             .readreceipts_since(&room_id, since)? |             .readreceipts_since(&room_id, since) | ||||||
|             .filter_map(|r| r.ok()) // Filter out buggy events
 |             .filter_map(|r| r.ok()) // Filter out buggy events
 | ||||||
|             .map(|(_, _, v)| v) |             .map(|(_, _, v)| v) | ||||||
|             .collect::<Vec<_>>(); |             .collect::<Vec<_>>(); | ||||||
|  | @ -549,7 +556,7 @@ pub async fn sync_events_route( | ||||||
|     for user_id in left_encrypted_users { |     for user_id in left_encrypted_users { | ||||||
|         let still_share_encrypted_room = db |         let still_share_encrypted_room = db | ||||||
|             .rooms |             .rooms | ||||||
|             .get_shared_rooms(vec![sender_user.clone(), user_id.clone()]) |             .get_shared_rooms(vec![sender_user.clone(), user_id.clone()])? | ||||||
|             .filter_map(|r| r.ok()) |             .filter_map(|r| r.ok()) | ||||||
|             .filter_map(|other_room_id| { |             .filter_map(|other_room_id| { | ||||||
|                 Some( |                 Some( | ||||||
|  | @ -639,9 +646,10 @@ fn share_encrypted_room( | ||||||
|     sender_user: &UserId, |     sender_user: &UserId, | ||||||
|     user_id: &UserId, |     user_id: &UserId, | ||||||
|     ignore_room: &RoomId, |     ignore_room: &RoomId, | ||||||
| ) -> bool { | ) -> Result<bool> { | ||||||
|     db.rooms |     Ok(db | ||||||
|         .get_shared_rooms(vec![sender_user.clone(), user_id.clone()]) |         .rooms | ||||||
|  |         .get_shared_rooms(vec![sender_user.clone(), user_id.clone()])? | ||||||
|         .filter_map(|r| r.ok()) |         .filter_map(|r| r.ok()) | ||||||
|         .filter(|room_id| room_id != ignore_room) |         .filter(|room_id| room_id != ignore_room) | ||||||
|         .filter_map(|other_room_id| { |         .filter_map(|other_room_id| { | ||||||
|  | @ -652,5 +660,5 @@ fn share_encrypted_room( | ||||||
|                     .is_some(), |                     .is_some(), | ||||||
|             ) |             ) | ||||||
|         }) |         }) | ||||||
|         .any(|encrypted| encrypted) |         .any(|encrypted| encrypted)) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -4,7 +4,7 @@ use ruma::{ | ||||||
|     api::client::r0::tag::{create_tag, delete_tag, get_tags}, |     api::client::r0::tag::{create_tag, delete_tag, get_tags}, | ||||||
|     events::EventType, |     events::EventType, | ||||||
| }; | }; | ||||||
| use std::collections::BTreeMap; | use std::{collections::BTreeMap, sync::Arc}; | ||||||
| 
 | 
 | ||||||
| #[cfg(feature = "conduit_bin")] | #[cfg(feature = "conduit_bin")] | ||||||
| use rocket::{delete, get, put}; | use rocket::{delete, get, put}; | ||||||
|  | @ -15,7 +15,7 @@ use rocket::{delete, get, put}; | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn update_tag_route( | pub async fn update_tag_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<create_tag::Request<'_>>, |     body: Ruma<create_tag::Request<'_>>, | ||||||
| ) -> ConduitResult<create_tag::Response> { | ) -> ConduitResult<create_tag::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  | @ -52,7 +52,7 @@ pub async fn update_tag_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn delete_tag_route( | pub async fn delete_tag_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<delete_tag::Request<'_>>, |     body: Ruma<delete_tag::Request<'_>>, | ||||||
| ) -> ConduitResult<delete_tag::Response> { | ) -> ConduitResult<delete_tag::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  | @ -86,7 +86,7 @@ pub async fn delete_tag_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn get_tags_route( | pub async fn get_tags_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<get_tags::Request<'_>>, |     body: Ruma<get_tags::Request<'_>>, | ||||||
| ) -> ConduitResult<get_tags::Response> { | ) -> ConduitResult<get_tags::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  |  | ||||||
|  | @ -1,3 +1,5 @@ | ||||||
|  | use std::sync::Arc; | ||||||
|  | 
 | ||||||
| use super::State; | use super::State; | ||||||
| use crate::{ConduitResult, Database, Error, Ruma}; | use crate::{ConduitResult, Database, Error, Ruma}; | ||||||
| use ruma::api::client::{ | use ruma::api::client::{ | ||||||
|  | @ -14,7 +16,7 @@ use rocket::put; | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn send_event_to_device_route( | pub async fn send_event_to_device_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<send_event_to_device::Request<'_>>, |     body: Ruma<send_event_to_device::Request<'_>>, | ||||||
| ) -> ConduitResult<send_event_to_device::Response> { | ) -> ConduitResult<send_event_to_device::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  |  | ||||||
|  | @ -1,3 +1,5 @@ | ||||||
|  | use std::sync::Arc; | ||||||
|  | 
 | ||||||
| use super::State; | use super::State; | ||||||
| use crate::{utils, ConduitResult, Database, Ruma}; | use crate::{utils, ConduitResult, Database, Ruma}; | ||||||
| use create_typing_event::Typing; | use create_typing_event::Typing; | ||||||
|  | @ -12,7 +14,7 @@ use rocket::put; | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub fn create_typing_event_route( | pub fn create_typing_event_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<create_typing_event::Request<'_>>, |     body: Ruma<create_typing_event::Request<'_>>, | ||||||
| ) -> ConduitResult<create_typing_event::Response> { | ) -> ConduitResult<create_typing_event::Response> { | ||||||
|     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); |     let sender_user = body.sender_user.as_ref().expect("user is authenticated"); | ||||||
|  |  | ||||||
|  | @ -1,3 +1,5 @@ | ||||||
|  | use std::sync::Arc; | ||||||
|  | 
 | ||||||
| use super::State; | use super::State; | ||||||
| use crate::{ConduitResult, Database, Ruma}; | use crate::{ConduitResult, Database, Ruma}; | ||||||
| use ruma::api::client::r0::user_directory::search_users; | use ruma::api::client::r0::user_directory::search_users; | ||||||
|  | @ -11,7 +13,7 @@ use rocket::post; | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn search_users_route( | pub async fn search_users_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<search_users::Request<'_>>, |     body: Ruma<search_users::Request<'_>>, | ||||||
| ) -> ConduitResult<search_users::Response> { | ) -> ConduitResult<search_users::Response> { | ||||||
|     let limit = u64::from(body.limit) as usize; |     let limit = u64::from(body.limit) as usize; | ||||||
|  |  | ||||||
							
								
								
									
										209
									
								
								src/database.rs
									
									
									
									
									
								
							
							
						
						
									
										209
									
								
								src/database.rs
									
									
									
									
									
								
							|  | @ -1,3 +1,5 @@ | ||||||
|  | pub mod abstraction; | ||||||
|  | 
 | ||||||
| pub mod account_data; | pub mod account_data; | ||||||
| pub mod admin; | pub mod admin; | ||||||
| pub mod appservice; | pub mod appservice; | ||||||
|  | @ -12,15 +14,16 @@ pub mod uiaa; | ||||||
| pub mod users; | pub mod users; | ||||||
| 
 | 
 | ||||||
| use crate::{utils, Error, Result}; | use crate::{utils, Error, Result}; | ||||||
|  | use abstraction::DatabaseEngine; | ||||||
| use directories::ProjectDirs; | use directories::ProjectDirs; | ||||||
| use futures::StreamExt; | use log::error; | ||||||
| use log::{error, info}; | use rocket::futures::{channel::mpsc, stream::FuturesUnordered, StreamExt}; | ||||||
| use rocket::futures::{self, channel::mpsc}; |  | ||||||
| use ruma::{DeviceId, ServerName, UserId}; | use ruma::{DeviceId, ServerName, UserId}; | ||||||
| use serde::Deserialize; | use serde::Deserialize; | ||||||
| use std::{ | use std::{ | ||||||
|     collections::HashMap, |     collections::HashMap, | ||||||
|     fs::remove_dir_all, |     fs::{self, remove_dir_all}, | ||||||
|  |     io::Write, | ||||||
|     sync::{Arc, RwLock}, |     sync::{Arc, RwLock}, | ||||||
| }; | }; | ||||||
| use tokio::sync::Semaphore; | use tokio::sync::Semaphore; | ||||||
|  | @ -74,7 +77,12 @@ fn default_log() -> String { | ||||||
|     "info,state_res=warn,rocket=off,_=off,sled=off".to_owned() |     "info,state_res=warn,rocket=off,_=off,sled=off".to_owned() | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| #[derive(Clone)] | #[cfg(feature = "sled")] | ||||||
|  | pub type Engine = abstraction::SledEngine; | ||||||
|  | 
 | ||||||
|  | #[cfg(feature = "rocksdb")] | ||||||
|  | pub type Engine = abstraction::RocksDbEngine; | ||||||
|  | 
 | ||||||
| pub struct Database { | pub struct Database { | ||||||
|     pub globals: globals::Globals, |     pub globals: globals::Globals, | ||||||
|     pub users: users::Users, |     pub users: users::Users, | ||||||
|  | @ -88,7 +96,6 @@ pub struct Database { | ||||||
|     pub admin: admin::Admin, |     pub admin: admin::Admin, | ||||||
|     pub appservice: appservice::Appservice, |     pub appservice: appservice::Appservice, | ||||||
|     pub pusher: pusher::PushData, |     pub pusher: pusher::PushData, | ||||||
|     pub _db: sled::Db, |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| impl Database { | impl Database { | ||||||
|  | @ -105,126 +112,126 @@ impl Database { | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     /// Load an existing database or create a new one.
 |     /// Load an existing database or create a new one.
 | ||||||
|     pub async fn load_or_create(config: Config) -> Result<Self> { |     pub async fn load_or_create(config: Config) -> Result<Arc<Self>> { | ||||||
|         let db = sled::Config::default() |         let builder = Engine::open(&config)?; | ||||||
|             .path(&config.database_path) |  | ||||||
|             .cache_capacity(config.cache_capacity as u64) |  | ||||||
|             .use_compression(true) |  | ||||||
|             .open()?; |  | ||||||
| 
 | 
 | ||||||
|         if config.max_request_size < 1024 { |         if config.max_request_size < 1024 { | ||||||
|             eprintln!("ERROR: Max request size is less than 1KB. Please increase it."); |             eprintln!("ERROR: Max request size is less than 1KB. Please increase it."); | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         let (admin_sender, admin_receiver) = mpsc::unbounded(); |         let (admin_sender, admin_receiver) = mpsc::unbounded(); | ||||||
|  |         let (sending_sender, sending_receiver) = mpsc::unbounded(); | ||||||
| 
 | 
 | ||||||
|         let db = Self { |         let db = Arc::new(Self { | ||||||
|             users: users::Users { |             users: users::Users { | ||||||
|                 userid_password: db.open_tree("userid_password")?, |                 userid_password: builder.open_tree("userid_password")?, | ||||||
|                 userid_displayname: db.open_tree("userid_displayname")?, |                 userid_displayname: builder.open_tree("userid_displayname")?, | ||||||
|                 userid_avatarurl: db.open_tree("userid_avatarurl")?, |                 userid_avatarurl: builder.open_tree("userid_avatarurl")?, | ||||||
|                 userdeviceid_token: db.open_tree("userdeviceid_token")?, |                 userdeviceid_token: builder.open_tree("userdeviceid_token")?, | ||||||
|                 userdeviceid_metadata: db.open_tree("userdeviceid_metadata")?, |                 userdeviceid_metadata: builder.open_tree("userdeviceid_metadata")?, | ||||||
|                 userid_devicelistversion: db.open_tree("userid_devicelistversion")?, |                 userid_devicelistversion: builder.open_tree("userid_devicelistversion")?, | ||||||
|                 token_userdeviceid: db.open_tree("token_userdeviceid")?, |                 token_userdeviceid: builder.open_tree("token_userdeviceid")?, | ||||||
|                 onetimekeyid_onetimekeys: db.open_tree("onetimekeyid_onetimekeys")?, |                 onetimekeyid_onetimekeys: builder.open_tree("onetimekeyid_onetimekeys")?, | ||||||
|                 userid_lastonetimekeyupdate: db.open_tree("userid_lastonetimekeyupdate")?, |                 userid_lastonetimekeyupdate: builder.open_tree("userid_lastonetimekeyupdate")?, | ||||||
|                 keychangeid_userid: db.open_tree("keychangeid_userid")?, |                 keychangeid_userid: builder.open_tree("keychangeid_userid")?, | ||||||
|                 keyid_key: db.open_tree("keyid_key")?, |                 keyid_key: builder.open_tree("keyid_key")?, | ||||||
|                 userid_masterkeyid: db.open_tree("userid_masterkeyid")?, |                 userid_masterkeyid: builder.open_tree("userid_masterkeyid")?, | ||||||
|                 userid_selfsigningkeyid: db.open_tree("userid_selfsigningkeyid")?, |                 userid_selfsigningkeyid: builder.open_tree("userid_selfsigningkeyid")?, | ||||||
|                 userid_usersigningkeyid: db.open_tree("userid_usersigningkeyid")?, |                 userid_usersigningkeyid: builder.open_tree("userid_usersigningkeyid")?, | ||||||
|                 todeviceid_events: db.open_tree("todeviceid_events")?, |                 todeviceid_events: builder.open_tree("todeviceid_events")?, | ||||||
|             }, |             }, | ||||||
|             uiaa: uiaa::Uiaa { |             uiaa: uiaa::Uiaa { | ||||||
|                 userdevicesessionid_uiaainfo: db.open_tree("userdevicesessionid_uiaainfo")?, |                 userdevicesessionid_uiaainfo: builder.open_tree("userdevicesessionid_uiaainfo")?, | ||||||
|                 userdevicesessionid_uiaarequest: db.open_tree("userdevicesessionid_uiaarequest")?, |                 userdevicesessionid_uiaarequest: builder | ||||||
|  |                     .open_tree("userdevicesessionid_uiaarequest")?, | ||||||
|             }, |             }, | ||||||
|             rooms: rooms::Rooms { |             rooms: rooms::Rooms { | ||||||
|                 edus: rooms::RoomEdus { |                 edus: rooms::RoomEdus { | ||||||
|                     readreceiptid_readreceipt: db.open_tree("readreceiptid_readreceipt")?, |                     readreceiptid_readreceipt: builder.open_tree("readreceiptid_readreceipt")?, | ||||||
|                     roomuserid_privateread: db.open_tree("roomuserid_privateread")?, // "Private" read receipt
 |                     roomuserid_privateread: builder.open_tree("roomuserid_privateread")?, // "Private" read receipt
 | ||||||
|                     roomuserid_lastprivatereadupdate: db |                     roomuserid_lastprivatereadupdate: builder | ||||||
|                         .open_tree("roomuserid_lastprivatereadupdate")?, |                         .open_tree("roomuserid_lastprivatereadupdate")?, | ||||||
|                     typingid_userid: db.open_tree("typingid_userid")?, |                     typingid_userid: builder.open_tree("typingid_userid")?, | ||||||
|                     roomid_lasttypingupdate: db.open_tree("roomid_lasttypingupdate")?, |                     roomid_lasttypingupdate: builder.open_tree("roomid_lasttypingupdate")?, | ||||||
|                     presenceid_presence: db.open_tree("presenceid_presence")?, |                     presenceid_presence: builder.open_tree("presenceid_presence")?, | ||||||
|                     userid_lastpresenceupdate: db.open_tree("userid_lastpresenceupdate")?, |                     userid_lastpresenceupdate: builder.open_tree("userid_lastpresenceupdate")?, | ||||||
|                 }, |                 }, | ||||||
|                 pduid_pdu: db.open_tree("pduid_pdu")?, |                 pduid_pdu: builder.open_tree("pduid_pdu")?, | ||||||
|                 eventid_pduid: db.open_tree("eventid_pduid")?, |                 eventid_pduid: builder.open_tree("eventid_pduid")?, | ||||||
|                 roomid_pduleaves: db.open_tree("roomid_pduleaves")?, |                 roomid_pduleaves: builder.open_tree("roomid_pduleaves")?, | ||||||
| 
 | 
 | ||||||
|                 alias_roomid: db.open_tree("alias_roomid")?, |                 alias_roomid: builder.open_tree("alias_roomid")?, | ||||||
|                 aliasid_alias: db.open_tree("aliasid_alias")?, |                 aliasid_alias: builder.open_tree("aliasid_alias")?, | ||||||
|                 publicroomids: db.open_tree("publicroomids")?, |                 publicroomids: builder.open_tree("publicroomids")?, | ||||||
| 
 | 
 | ||||||
|                 tokenids: db.open_tree("tokenids")?, |                 tokenids: builder.open_tree("tokenids")?, | ||||||
| 
 | 
 | ||||||
|                 roomserverids: db.open_tree("roomserverids")?, |                 roomserverids: builder.open_tree("roomserverids")?, | ||||||
|                 serverroomids: db.open_tree("serverroomids")?, |                 serverroomids: builder.open_tree("serverroomids")?, | ||||||
|                 userroomid_joined: db.open_tree("userroomid_joined")?, |                 userroomid_joined: builder.open_tree("userroomid_joined")?, | ||||||
|                 roomuserid_joined: db.open_tree("roomuserid_joined")?, |                 roomuserid_joined: builder.open_tree("roomuserid_joined")?, | ||||||
|                 roomuseroncejoinedids: db.open_tree("roomuseroncejoinedids")?, |                 roomuseroncejoinedids: builder.open_tree("roomuseroncejoinedids")?, | ||||||
|                 userroomid_invitestate: db.open_tree("userroomid_invitestate")?, |                 userroomid_invitestate: builder.open_tree("userroomid_invitestate")?, | ||||||
|                 roomuserid_invitecount: db.open_tree("roomuserid_invitecount")?, |                 roomuserid_invitecount: builder.open_tree("roomuserid_invitecount")?, | ||||||
|                 userroomid_leftstate: db.open_tree("userroomid_leftstate")?, |                 userroomid_leftstate: builder.open_tree("userroomid_leftstate")?, | ||||||
|                 roomuserid_leftcount: db.open_tree("roomuserid_leftcount")?, |                 roomuserid_leftcount: builder.open_tree("roomuserid_leftcount")?, | ||||||
| 
 | 
 | ||||||
|                 userroomid_notificationcount: db.open_tree("userroomid_notificationcount")?, |                 userroomid_notificationcount: builder.open_tree("userroomid_notificationcount")?, | ||||||
|                 userroomid_highlightcount: db.open_tree("userroomid_highlightcount")?, |                 userroomid_highlightcount: builder.open_tree("userroomid_highlightcount")?, | ||||||
| 
 | 
 | ||||||
|                 statekey_shortstatekey: db.open_tree("statekey_shortstatekey")?, |                 statekey_shortstatekey: builder.open_tree("statekey_shortstatekey")?, | ||||||
|                 stateid_shorteventid: db.open_tree("stateid_shorteventid")?, |                 stateid_shorteventid: builder.open_tree("stateid_shorteventid")?, | ||||||
|                 eventid_shorteventid: db.open_tree("eventid_shorteventid")?, |                 eventid_shorteventid: builder.open_tree("eventid_shorteventid")?, | ||||||
|                 shorteventid_eventid: db.open_tree("shorteventid_eventid")?, |                 shorteventid_eventid: builder.open_tree("shorteventid_eventid")?, | ||||||
|                 shorteventid_shortstatehash: db.open_tree("shorteventid_shortstatehash")?, |                 shorteventid_shortstatehash: builder.open_tree("shorteventid_shortstatehash")?, | ||||||
|                 roomid_shortstatehash: db.open_tree("roomid_shortstatehash")?, |                 roomid_shortstatehash: builder.open_tree("roomid_shortstatehash")?, | ||||||
|                 statehash_shortstatehash: db.open_tree("statehash_shortstatehash")?, |                 statehash_shortstatehash: builder.open_tree("statehash_shortstatehash")?, | ||||||
| 
 | 
 | ||||||
|                 eventid_outlierpdu: db.open_tree("eventid_outlierpdu")?, |                 eventid_outlierpdu: builder.open_tree("eventid_outlierpdu")?, | ||||||
|                 prevevent_parent: db.open_tree("prevevent_parent")?, |                 prevevent_parent: builder.open_tree("prevevent_parent")?, | ||||||
|             }, |             }, | ||||||
|             account_data: account_data::AccountData { |             account_data: account_data::AccountData { | ||||||
|                 roomuserdataid_accountdata: db.open_tree("roomuserdataid_accountdata")?, |                 roomuserdataid_accountdata: builder.open_tree("roomuserdataid_accountdata")?, | ||||||
|             }, |             }, | ||||||
|             media: media::Media { |             media: media::Media { | ||||||
|                 mediaid_file: db.open_tree("mediaid_file")?, |                 mediaid_file: builder.open_tree("mediaid_file")?, | ||||||
|             }, |             }, | ||||||
|             key_backups: key_backups::KeyBackups { |             key_backups: key_backups::KeyBackups { | ||||||
|                 backupid_algorithm: db.open_tree("backupid_algorithm")?, |                 backupid_algorithm: builder.open_tree("backupid_algorithm")?, | ||||||
|                 backupid_etag: db.open_tree("backupid_etag")?, |                 backupid_etag: builder.open_tree("backupid_etag")?, | ||||||
|                 backupkeyid_backup: db.open_tree("backupkeyid_backup")?, |                 backupkeyid_backup: builder.open_tree("backupkeyid_backup")?, | ||||||
|             }, |             }, | ||||||
|             transaction_ids: transaction_ids::TransactionIds { |             transaction_ids: transaction_ids::TransactionIds { | ||||||
|                 userdevicetxnid_response: db.open_tree("userdevicetxnid_response")?, |                 userdevicetxnid_response: builder.open_tree("userdevicetxnid_response")?, | ||||||
|             }, |             }, | ||||||
|             sending: sending::Sending { |             sending: sending::Sending { | ||||||
|                 servername_educount: db.open_tree("servername_educount")?, |                 servername_educount: builder.open_tree("servername_educount")?, | ||||||
|                 servernamepduids: db.open_tree("servernamepduids")?, |                 servernamepduids: builder.open_tree("servernamepduids")?, | ||||||
|                 servercurrentevents: db.open_tree("servercurrentevents")?, |                 servercurrentevents: builder.open_tree("servercurrentevents")?, | ||||||
|                 maximum_requests: Arc::new(Semaphore::new(config.max_concurrent_requests as usize)), |                 maximum_requests: Arc::new(Semaphore::new(config.max_concurrent_requests as usize)), | ||||||
|  |                 sender: sending_sender, | ||||||
|             }, |             }, | ||||||
|             admin: admin::Admin { |             admin: admin::Admin { | ||||||
|                 sender: admin_sender, |                 sender: admin_sender, | ||||||
|             }, |             }, | ||||||
|             appservice: appservice::Appservice { |             appservice: appservice::Appservice { | ||||||
|                 cached_registrations: Arc::new(RwLock::new(HashMap::new())), |                 cached_registrations: Arc::new(RwLock::new(HashMap::new())), | ||||||
|                 id_appserviceregistrations: db.open_tree("id_appserviceregistrations")?, |                 id_appserviceregistrations: builder.open_tree("id_appserviceregistrations")?, | ||||||
|  |             }, | ||||||
|  |             pusher: pusher::PushData { | ||||||
|  |                 senderkey_pusher: builder.open_tree("senderkey_pusher")?, | ||||||
|             }, |             }, | ||||||
|             pusher: pusher::PushData::new(&db)?, |  | ||||||
|             globals: globals::Globals::load( |             globals: globals::Globals::load( | ||||||
|                 db.open_tree("global")?, |                 builder.open_tree("global")?, | ||||||
|                 db.open_tree("server_signingkeys")?, |                 builder.open_tree("server_signingkeys")?, | ||||||
|                 config, |                 config, | ||||||
|             )?, |             )?, | ||||||
|             _db: db, |         }); | ||||||
|         }; |  | ||||||
| 
 | 
 | ||||||
|         // MIGRATIONS
 |         // MIGRATIONS
 | ||||||
|  |         // TODO: database versions of new dbs should probably not be 0
 | ||||||
|         if db.globals.database_version()? < 1 { |         if db.globals.database_version()? < 1 { | ||||||
|             for roomserverid in db.rooms.roomserverids.iter().keys() { |             for (roomserverid, _) in db.rooms.roomserverids.iter() { | ||||||
|                 let roomserverid = roomserverid?; |  | ||||||
|                 let mut parts = roomserverid.split(|&b| b == 0xff); |                 let mut parts = roomserverid.split(|&b| b == 0xff); | ||||||
|                 let room_id = parts.next().expect("split always returns one element"); |                 let room_id = parts.next().expect("split always returns one element"); | ||||||
|                 let servername = match parts.next() { |                 let servername = match parts.next() { | ||||||
|  | @ -238,37 +245,55 @@ impl Database { | ||||||
|                 serverroomid.push(0xff); |                 serverroomid.push(0xff); | ||||||
|                 serverroomid.extend_from_slice(room_id); |                 serverroomid.extend_from_slice(room_id); | ||||||
| 
 | 
 | ||||||
|                 db.rooms.serverroomids.insert(serverroomid, &[])?; |                 db.rooms.serverroomids.insert(&serverroomid, &[])?; | ||||||
|             } |             } | ||||||
| 
 | 
 | ||||||
|             db.globals.bump_database_version(1)?; |             db.globals.bump_database_version(1)?; | ||||||
| 
 | 
 | ||||||
|             info!("Migration: 0 -> 1 finished"); |             println!("Migration: 0 -> 1 finished"); | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         if db.globals.database_version()? < 2 { |         if db.globals.database_version()? < 2 { | ||||||
|             // We accidentally inserted hashed versions of "" into the db instead of just ""
 |             // We accidentally inserted hashed versions of "" into the db instead of just ""
 | ||||||
|             for userid_password in db.users.userid_password.iter() { |             for (userid, password) in db.users.userid_password.iter() { | ||||||
|                 let (userid, password) = userid_password?; |  | ||||||
| 
 |  | ||||||
|                 let password = utils::string_from_bytes(&password); |                 let password = utils::string_from_bytes(&password); | ||||||
| 
 | 
 | ||||||
|                 if password.map_or(false, |password| { |                 let empty_hashed_password = password.map_or(false, |password| { | ||||||
|                     argon2::verify_encoded(&password, b"").unwrap_or(false) |                     argon2::verify_encoded(&password, b"").unwrap_or(false) | ||||||
|                 }) { |                 }); | ||||||
|                     db.users.userid_password.insert(userid, b"")?; | 
 | ||||||
|  |                 if empty_hashed_password { | ||||||
|  |                     db.users.userid_password.insert(&userid, b"")?; | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
| 
 | 
 | ||||||
|             db.globals.bump_database_version(2)?; |             db.globals.bump_database_version(2)?; | ||||||
| 
 | 
 | ||||||
|             info!("Migration: 1 -> 2 finished"); |             println!("Migration: 1 -> 2 finished"); | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|  |         if db.globals.database_version()? < 3 { | ||||||
|  |             // Move media to filesystem
 | ||||||
|  |             for (key, content) in db.media.mediaid_file.iter() { | ||||||
|  |                 if content.len() == 0 { | ||||||
|  |                     continue; | ||||||
|  |                 } | ||||||
|  | 
 | ||||||
|  |                 let path = db.globals.get_media_file(&key); | ||||||
|  |                 let mut file = fs::File::create(path)?; | ||||||
|  |                 file.write_all(&content)?; | ||||||
|  |                 db.media.mediaid_file.insert(&key, &[])?; | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             db.globals.bump_database_version(3)?; | ||||||
|  | 
 | ||||||
|  |             println!("Migration: 2 -> 3 finished"); | ||||||
|  |         } | ||||||
|         // This data is probably outdated
 |         // This data is probably outdated
 | ||||||
|         db.rooms.edus.presenceid_presence.clear()?; |         db.rooms.edus.presenceid_presence.clear()?; | ||||||
| 
 | 
 | ||||||
|         db.admin.start_handler(db.clone(), admin_receiver); |         db.admin.start_handler(Arc::clone(&db), admin_receiver); | ||||||
|  |         db.sending.start_handler(Arc::clone(&db), sending_receiver); | ||||||
| 
 | 
 | ||||||
|         Ok(db) |         Ok(db) | ||||||
|     } |     } | ||||||
|  | @ -282,7 +307,7 @@ impl Database { | ||||||
|         userdeviceid_prefix.extend_from_slice(device_id.as_bytes()); |         userdeviceid_prefix.extend_from_slice(device_id.as_bytes()); | ||||||
|         userdeviceid_prefix.push(0xff); |         userdeviceid_prefix.push(0xff); | ||||||
| 
 | 
 | ||||||
|         let mut futures = futures::stream::FuturesUnordered::new(); |         let mut futures = FuturesUnordered::new(); | ||||||
| 
 | 
 | ||||||
|         // Return when *any* user changed his key
 |         // Return when *any* user changed his key
 | ||||||
|         // TODO: only send for user they share a room with
 |         // TODO: only send for user they share a room with
 | ||||||
|  |  | ||||||
							
								
								
									
										329
									
								
								src/database/abstraction.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										329
									
								
								src/database/abstraction.rs
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,329 @@ | ||||||
|  | use super::Config; | ||||||
|  | use crate::{utils, Result}; | ||||||
|  | use log::warn; | ||||||
|  | use std::{future::Future, pin::Pin, sync::Arc}; | ||||||
|  | 
 | ||||||
|  | #[cfg(feature = "rocksdb")] | ||||||
|  | use std::{collections::BTreeMap, sync::RwLock}; | ||||||
|  | 
 | ||||||
|  | #[cfg(feature = "sled")] | ||||||
|  | pub struct SledEngine(sled::Db); | ||||||
|  | #[cfg(feature = "sled")] | ||||||
|  | pub struct SledEngineTree(sled::Tree); | ||||||
|  | 
 | ||||||
|  | #[cfg(feature = "rocksdb")] | ||||||
|  | pub struct RocksDbEngine(rocksdb::DBWithThreadMode<rocksdb::MultiThreaded>); | ||||||
|  | #[cfg(feature = "rocksdb")] | ||||||
|  | pub struct RocksDbEngineTree<'a> { | ||||||
|  |     db: Arc<RocksDbEngine>, | ||||||
|  |     name: &'a str, | ||||||
|  |     watchers: RwLock<BTreeMap<Vec<u8>, Vec<tokio::sync::oneshot::Sender<()>>>>, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | pub trait DatabaseEngine: Sized { | ||||||
|  |     fn open(config: &Config) -> Result<Arc<Self>>; | ||||||
|  |     fn open_tree(self: &Arc<Self>, name: &'static str) -> Result<Arc<dyn Tree>>; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | pub trait Tree: Send + Sync { | ||||||
|  |     fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>>; | ||||||
|  | 
 | ||||||
|  |     fn insert(&self, key: &[u8], value: &[u8]) -> Result<()>; | ||||||
|  | 
 | ||||||
|  |     fn remove(&self, key: &[u8]) -> Result<()>; | ||||||
|  | 
 | ||||||
|  |     fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = (Box<[u8]>, Box<[u8]>)> + Send + Sync + 'a>; | ||||||
|  | 
 | ||||||
|  |     fn iter_from<'a>( | ||||||
|  |         &'a self, | ||||||
|  |         from: &[u8], | ||||||
|  |         backwards: bool, | ||||||
|  |     ) -> Box<dyn Iterator<Item = (Box<[u8]>, Box<[u8]>)> + 'a>; | ||||||
|  | 
 | ||||||
|  |     fn increment(&self, key: &[u8]) -> Result<Vec<u8>>; | ||||||
|  | 
 | ||||||
|  |     fn scan_prefix<'a>( | ||||||
|  |         &'a self, | ||||||
|  |         prefix: Vec<u8>, | ||||||
|  |     ) -> Box<dyn Iterator<Item = (Box<[u8]>, Box<[u8]>)> + Send + 'a>; | ||||||
|  | 
 | ||||||
|  |     fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>>; | ||||||
|  | 
 | ||||||
|  |     fn clear(&self) -> Result<()> { | ||||||
|  |         for (key, _) in self.iter() { | ||||||
|  |             self.remove(&key)?; | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         Ok(()) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | #[cfg(feature = "sled")] | ||||||
|  | impl DatabaseEngine for SledEngine { | ||||||
|  |     fn open(config: &Config) -> Result<Arc<Self>> { | ||||||
|  |         Ok(Arc::new(SledEngine( | ||||||
|  |             sled::Config::default() | ||||||
|  |                 .path(&config.database_path) | ||||||
|  |                 .cache_capacity(config.cache_capacity as u64) | ||||||
|  |                 .use_compression(true) | ||||||
|  |                 .open()?, | ||||||
|  |         ))) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn open_tree(self: &Arc<Self>, name: &'static str) -> Result<Arc<dyn Tree>> { | ||||||
|  |         Ok(Arc::new(SledEngineTree(self.0.open_tree(name)?))) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | #[cfg(feature = "sled")] | ||||||
|  | impl Tree for SledEngineTree { | ||||||
|  |     fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>> { | ||||||
|  |         Ok(self.0.get(key)?.map(|v| v.to_vec())) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn insert(&self, key: &[u8], value: &[u8]) -> Result<()> { | ||||||
|  |         self.0.insert(key, value)?; | ||||||
|  |         Ok(()) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn remove(&self, key: &[u8]) -> Result<()> { | ||||||
|  |         self.0.remove(key)?; | ||||||
|  |         Ok(()) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = (Box<[u8]>, Box<[u8]>)> + Send + Sync + 'a> { | ||||||
|  |         Box::new( | ||||||
|  |             self.0 | ||||||
|  |                 .iter() | ||||||
|  |                 .filter_map(|r| { | ||||||
|  |                     if let Err(e) = &r { | ||||||
|  |                         warn!("Error: {}", e); | ||||||
|  |                     } | ||||||
|  |                     r.ok() | ||||||
|  |                 }) | ||||||
|  |                 .map(|(k, v)| (k.to_vec().into(), v.to_vec().into())), | ||||||
|  |         ) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn iter_from( | ||||||
|  |         &self, | ||||||
|  |         from: &[u8], | ||||||
|  |         backwards: bool, | ||||||
|  |     ) -> Box<dyn Iterator<Item = (Box<[u8]>, Box<[u8]>)>> { | ||||||
|  |         let iter = if backwards { | ||||||
|  |             self.0.range(..from) | ||||||
|  |         } else { | ||||||
|  |             self.0.range(from..) | ||||||
|  |         }; | ||||||
|  | 
 | ||||||
|  |         let iter = iter | ||||||
|  |             .filter_map(|r| { | ||||||
|  |                 if let Err(e) = &r { | ||||||
|  |                     warn!("Error: {}", e); | ||||||
|  |                 } | ||||||
|  |                 r.ok() | ||||||
|  |             }) | ||||||
|  |             .map(|(k, v)| (k.to_vec().into(), v.to_vec().into())); | ||||||
|  | 
 | ||||||
|  |         if backwards { | ||||||
|  |             Box::new(iter.rev()) | ||||||
|  |         } else { | ||||||
|  |             Box::new(iter) | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn increment(&self, key: &[u8]) -> Result<Vec<u8>> { | ||||||
|  |         Ok(self | ||||||
|  |             .0 | ||||||
|  |             .update_and_fetch(key, utils::increment) | ||||||
|  |             .map(|o| o.expect("increment always sets a value").to_vec())?) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn scan_prefix<'a>( | ||||||
|  |         &'a self, | ||||||
|  |         prefix: Vec<u8>, | ||||||
|  |     ) -> Box<dyn Iterator<Item = (Box<[u8]>, Box<[u8]>)> + Send + 'a> { | ||||||
|  |         let iter = self | ||||||
|  |             .0 | ||||||
|  |             .scan_prefix(prefix) | ||||||
|  |             .filter_map(|r| { | ||||||
|  |                 if let Err(e) = &r { | ||||||
|  |                     warn!("Error: {}", e); | ||||||
|  |                 } | ||||||
|  |                 r.ok() | ||||||
|  |             }) | ||||||
|  |             .map(|(k, v)| (k.to_vec().into(), v.to_vec().into())); | ||||||
|  | 
 | ||||||
|  |         Box::new(iter) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> { | ||||||
|  |         let prefix = prefix.to_vec(); | ||||||
|  |         Box::pin(async move { | ||||||
|  |             self.0.watch_prefix(prefix).await; | ||||||
|  |         }) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | #[cfg(feature = "rocksdb")] | ||||||
|  | impl DatabaseEngine for RocksDbEngine { | ||||||
|  |     fn open(config: &Config) -> Result<Arc<Self>> { | ||||||
|  |         let mut db_opts = rocksdb::Options::default(); | ||||||
|  |         db_opts.create_if_missing(true); | ||||||
|  |         db_opts.set_max_open_files(16); | ||||||
|  |         db_opts.set_compaction_style(rocksdb::DBCompactionStyle::Level); | ||||||
|  |         db_opts.set_compression_type(rocksdb::DBCompressionType::Snappy); | ||||||
|  |         db_opts.set_target_file_size_base(256 << 20); | ||||||
|  |         db_opts.set_write_buffer_size(256 << 20); | ||||||
|  | 
 | ||||||
|  |         let mut block_based_options = rocksdb::BlockBasedOptions::default(); | ||||||
|  |         block_based_options.set_block_size(512 << 10); | ||||||
|  |         db_opts.set_block_based_table_factory(&block_based_options); | ||||||
|  | 
 | ||||||
|  |         let cfs = rocksdb::DBWithThreadMode::<rocksdb::MultiThreaded>::list_cf( | ||||||
|  |             &db_opts, | ||||||
|  |             &config.database_path, | ||||||
|  |         ) | ||||||
|  |         .unwrap_or_default(); | ||||||
|  | 
 | ||||||
|  |         let mut options = rocksdb::Options::default(); | ||||||
|  |         options.set_merge_operator_associative("increment", utils::increment_rocksdb); | ||||||
|  | 
 | ||||||
|  |         let db = rocksdb::DBWithThreadMode::<rocksdb::MultiThreaded>::open_cf_descriptors( | ||||||
|  |             &db_opts, | ||||||
|  |             &config.database_path, | ||||||
|  |             cfs.iter() | ||||||
|  |                 .map(|name| rocksdb::ColumnFamilyDescriptor::new(name, options.clone())), | ||||||
|  |         )?; | ||||||
|  | 
 | ||||||
|  |         Ok(Arc::new(RocksDbEngine(db))) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn open_tree(self: &Arc<Self>, name: &'static str) -> Result<Arc<dyn Tree>> { | ||||||
|  |         let mut options = rocksdb::Options::default(); | ||||||
|  |         options.set_merge_operator_associative("increment", utils::increment_rocksdb); | ||||||
|  | 
 | ||||||
|  |         // Create if it doesn't exist
 | ||||||
|  |         let _ = self.0.create_cf(name, &options); | ||||||
|  | 
 | ||||||
|  |         Ok(Arc::new(RocksDbEngineTree { | ||||||
|  |             name, | ||||||
|  |             db: Arc::clone(self), | ||||||
|  |             watchers: RwLock::new(BTreeMap::new()), | ||||||
|  |         })) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | #[cfg(feature = "rocksdb")] | ||||||
|  | impl RocksDbEngineTree<'_> { | ||||||
|  |     fn cf(&self) -> rocksdb::BoundColumnFamily<'_> { | ||||||
|  |         self.db.0.cf_handle(self.name).unwrap() | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | #[cfg(feature = "rocksdb")] | ||||||
|  | impl Tree for RocksDbEngineTree<'_> { | ||||||
|  |     fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>> { | ||||||
|  |         Ok(self.db.0.get_cf(self.cf(), key)?) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn insert(&self, key: &[u8], value: &[u8]) -> Result<()> { | ||||||
|  |         let watchers = self.watchers.read().unwrap(); | ||||||
|  |         let mut triggered = Vec::new(); | ||||||
|  | 
 | ||||||
|  |         for length in 0..=key.len() { | ||||||
|  |             if watchers.contains_key(&key[..length]) { | ||||||
|  |                 triggered.push(&key[..length]); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         drop(watchers); | ||||||
|  | 
 | ||||||
|  |         if !triggered.is_empty() { | ||||||
|  |             let mut watchers = self.watchers.write().unwrap(); | ||||||
|  |             for prefix in triggered { | ||||||
|  |                 if let Some(txs) = watchers.remove(prefix) { | ||||||
|  |                     for tx in txs { | ||||||
|  |                         let _ = tx.send(()); | ||||||
|  |                     } | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         Ok(self.db.0.put_cf(self.cf(), key, value)?) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn remove(&self, key: &[u8]) -> Result<()> { | ||||||
|  |         Ok(self.db.0.delete_cf(self.cf(), key)?) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = (Box<[u8]>, Box<[u8]>)> + Send + Sync + 'a> { | ||||||
|  |         Box::new( | ||||||
|  |             self.db | ||||||
|  |                 .0 | ||||||
|  |                 .iterator_cf(self.cf(), rocksdb::IteratorMode::Start), | ||||||
|  |         ) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn iter_from<'a>( | ||||||
|  |         &'a self, | ||||||
|  |         from: &[u8], | ||||||
|  |         backwards: bool, | ||||||
|  |     ) -> Box<dyn Iterator<Item = (Box<[u8]>, Box<[u8]>)> + 'a> { | ||||||
|  |         Box::new(self.db.0.iterator_cf( | ||||||
|  |             self.cf(), | ||||||
|  |             rocksdb::IteratorMode::From( | ||||||
|  |                 from, | ||||||
|  |                 if backwards { | ||||||
|  |                     rocksdb::Direction::Reverse | ||||||
|  |                 } else { | ||||||
|  |                     rocksdb::Direction::Forward | ||||||
|  |                 }, | ||||||
|  |             ), | ||||||
|  |         )) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn increment(&self, key: &[u8]) -> Result<Vec<u8>> { | ||||||
|  |         let stats = rocksdb::perf::get_memory_usage_stats(Some(&[&self.db.0]), None).unwrap(); | ||||||
|  |         dbg!(stats.mem_table_total); | ||||||
|  |         dbg!(stats.mem_table_unflushed); | ||||||
|  |         dbg!(stats.mem_table_readers_total); | ||||||
|  |         dbg!(stats.cache_total); | ||||||
|  |         // TODO: atomic?
 | ||||||
|  |         let old = self.get(key)?; | ||||||
|  |         let new = utils::increment(old.as_deref()).unwrap(); | ||||||
|  |         self.insert(key, &new)?; | ||||||
|  |         Ok(new) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn scan_prefix<'a>( | ||||||
|  |         &'a self, | ||||||
|  |         prefix: Vec<u8>, | ||||||
|  |     ) -> Box<dyn Iterator<Item = (Box<[u8]>, Box<[u8]>)> + Send + 'a> { | ||||||
|  |         Box::new( | ||||||
|  |             self.db | ||||||
|  |                 .0 | ||||||
|  |                 .iterator_cf( | ||||||
|  |                     self.cf(), | ||||||
|  |                     rocksdb::IteratorMode::From(&prefix, rocksdb::Direction::Forward), | ||||||
|  |                 ) | ||||||
|  |                 .take_while(move |(k, _)| k.starts_with(&prefix)), | ||||||
|  |         ) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> { | ||||||
|  |         let (tx, rx) = tokio::sync::oneshot::channel(); | ||||||
|  | 
 | ||||||
|  |         self.watchers | ||||||
|  |             .write() | ||||||
|  |             .unwrap() | ||||||
|  |             .entry(prefix.to_vec()) | ||||||
|  |             .or_default() | ||||||
|  |             .push(tx); | ||||||
|  | 
 | ||||||
|  |         Box::pin(async move { | ||||||
|  |             // Tx is never destroyed
 | ||||||
|  |             rx.await.unwrap(); | ||||||
|  |         }) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | @ -6,12 +6,12 @@ use ruma::{ | ||||||
|     RoomId, UserId, |     RoomId, UserId, | ||||||
| }; | }; | ||||||
| use serde::{de::DeserializeOwned, Serialize}; | use serde::{de::DeserializeOwned, Serialize}; | ||||||
| use sled::IVec; | use std::{collections::HashMap, convert::TryFrom, sync::Arc}; | ||||||
| use std::{collections::HashMap, convert::TryFrom}; | 
 | ||||||
|  | use super::abstraction::Tree; | ||||||
| 
 | 
 | ||||||
| #[derive(Clone)] |  | ||||||
| pub struct AccountData { | pub struct AccountData { | ||||||
|     pub(super) roomuserdataid_accountdata: sled::Tree, // RoomUserDataId = Room + User + Count + Type
 |     pub(super) roomuserdataid_accountdata: Arc<dyn Tree>, // RoomUserDataId = Room + User + Count + Type
 | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| impl AccountData { | impl AccountData { | ||||||
|  | @ -34,9 +34,8 @@ impl AccountData { | ||||||
|         prefix.push(0xff); |         prefix.push(0xff); | ||||||
| 
 | 
 | ||||||
|         // Remove old entry
 |         // Remove old entry
 | ||||||
|         if let Some(previous) = self.find_event(room_id, user_id, &event_type) { |         if let Some((old_key, _)) = self.find_event(room_id, user_id, &event_type)? { | ||||||
|             let (old_key, _) = previous?; |             self.roomuserdataid_accountdata.remove(&old_key)?; | ||||||
|             self.roomuserdataid_accountdata.remove(old_key)?; |  | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         let mut key = prefix; |         let mut key = prefix; | ||||||
|  | @ -52,8 +51,10 @@ impl AccountData { | ||||||
|             )); |             )); | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         self.roomuserdataid_accountdata |         self.roomuserdataid_accountdata.insert( | ||||||
|             .insert(key, &*json.to_string())?; |             &key, | ||||||
|  |             &serde_json::to_vec(&json).expect("to_vec always works on json values"), | ||||||
|  |         )?; | ||||||
| 
 | 
 | ||||||
|         Ok(()) |         Ok(()) | ||||||
|     } |     } | ||||||
|  | @ -65,9 +66,8 @@ impl AccountData { | ||||||
|         user_id: &UserId, |         user_id: &UserId, | ||||||
|         kind: EventType, |         kind: EventType, | ||||||
|     ) -> Result<Option<T>> { |     ) -> Result<Option<T>> { | ||||||
|         self.find_event(room_id, user_id, &kind) |         self.find_event(room_id, user_id, &kind)? | ||||||
|             .map(|r| { |             .map(|(_, v)| { | ||||||
|                 let (_, v) = r?; |  | ||||||
|                 serde_json::from_slice(&v).map_err(|_| Error::bad_database("could not deserialize")) |                 serde_json::from_slice(&v).map_err(|_| Error::bad_database("could not deserialize")) | ||||||
|             }) |             }) | ||||||
|             .transpose() |             .transpose() | ||||||
|  | @ -98,8 +98,7 @@ impl AccountData { | ||||||
| 
 | 
 | ||||||
|         for r in self |         for r in self | ||||||
|             .roomuserdataid_accountdata |             .roomuserdataid_accountdata | ||||||
|             .range(&*first_possible..) |             .iter_from(&first_possible, false) | ||||||
|             .filter_map(|r| r.ok()) |  | ||||||
|             .take_while(move |(k, _)| k.starts_with(&prefix)) |             .take_while(move |(k, _)| k.starts_with(&prefix)) | ||||||
|             .map(|(k, v)| { |             .map(|(k, v)| { | ||||||
|                 Ok::<_, Error>(( |                 Ok::<_, Error>(( | ||||||
|  | @ -128,7 +127,7 @@ impl AccountData { | ||||||
|         room_id: Option<&RoomId>, |         room_id: Option<&RoomId>, | ||||||
|         user_id: &UserId, |         user_id: &UserId, | ||||||
|         kind: &EventType, |         kind: &EventType, | ||||||
|     ) -> Option<Result<(IVec, IVec)>> { |     ) -> Result<Option<(Box<[u8]>, Box<[u8]>)>> { | ||||||
|         let mut prefix = room_id |         let mut prefix = room_id | ||||||
|             .map(|r| r.to_string()) |             .map(|r| r.to_string()) | ||||||
|             .unwrap_or_default() |             .unwrap_or_default() | ||||||
|  | @ -137,23 +136,21 @@ impl AccountData { | ||||||
|         prefix.push(0xff); |         prefix.push(0xff); | ||||||
|         prefix.extend_from_slice(&user_id.as_bytes()); |         prefix.extend_from_slice(&user_id.as_bytes()); | ||||||
|         prefix.push(0xff); |         prefix.push(0xff); | ||||||
|  | 
 | ||||||
|  |         let mut last_possible_key = prefix.clone(); | ||||||
|  |         last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); | ||||||
|  | 
 | ||||||
|         let kind = kind.clone(); |         let kind = kind.clone(); | ||||||
| 
 | 
 | ||||||
|         self.roomuserdataid_accountdata |         Ok(self | ||||||
|             .scan_prefix(prefix) |             .roomuserdataid_accountdata | ||||||
|             .rev() |             .iter_from(&last_possible_key, true) | ||||||
|             .find(move |r| { |             .take_while(move |(k, _)| k.starts_with(&prefix)) | ||||||
|                 r.as_ref() |             .find(move |(k, _)| { | ||||||
|                     .map(|(k, _)| { |                 k.rsplit(|&b| b == 0xff) | ||||||
|                         k.rsplit(|&b| b == 0xff) |                     .next() | ||||||
|                             .next() |                     .map(|current_event_type| current_event_type == kind.as_ref().as_bytes()) | ||||||
|                             .map(|current_event_type| { |  | ||||||
|                                 current_event_type == kind.as_ref().as_bytes() |  | ||||||
|                             }) |  | ||||||
|                             .unwrap_or(false) |  | ||||||
|                     }) |  | ||||||
|                     .unwrap_or(false) |                     .unwrap_or(false) | ||||||
|             }) |             })) | ||||||
|             .map(|r| Ok(r?)) |  | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -1,6 +1,9 @@ | ||||||
| use std::convert::{TryFrom, TryInto}; | use std::{ | ||||||
|  |     convert::{TryFrom, TryInto}, | ||||||
|  |     sync::Arc, | ||||||
|  | }; | ||||||
| 
 | 
 | ||||||
| use crate::pdu::PduBuilder; | use crate::{pdu::PduBuilder, Database}; | ||||||
| use log::warn; | use log::warn; | ||||||
| use rocket::futures::{channel::mpsc, stream::StreamExt}; | use rocket::futures::{channel::mpsc, stream::StreamExt}; | ||||||
| use ruma::{ | use ruma::{ | ||||||
|  | @ -22,7 +25,7 @@ pub struct Admin { | ||||||
| impl Admin { | impl Admin { | ||||||
|     pub fn start_handler( |     pub fn start_handler( | ||||||
|         &self, |         &self, | ||||||
|         db: super::Database, |         db: Arc<Database>, | ||||||
|         mut receiver: mpsc::UnboundedReceiver<AdminCommand>, |         mut receiver: mpsc::UnboundedReceiver<AdminCommand>, | ||||||
|     ) { |     ) { | ||||||
|         tokio::spawn(async move { |         tokio::spawn(async move { | ||||||
|  | @ -73,14 +76,17 @@ impl Admin { | ||||||
|                                 db.appservice.register_appservice(yaml).unwrap(); // TODO handle error
 |                                 db.appservice.register_appservice(yaml).unwrap(); // TODO handle error
 | ||||||
|                             } |                             } | ||||||
|                             AdminCommand::ListAppservices => { |                             AdminCommand::ListAppservices => { | ||||||
|                                 let appservices = db.appservice.iter_ids().collect::<Vec<_>>(); |                                 if let Ok(appservices) = db.appservice.iter_ids().map(|ids| ids.collect::<Vec<_>>()) { | ||||||
|                                 let count = appservices.len(); |                                     let count = appservices.len(); | ||||||
|                                 let output = format!( |                                     let output = format!( | ||||||
|                                     "Appservices ({}): {}", |                                         "Appservices ({}): {}", | ||||||
|                                     count, |                                         count, | ||||||
|                                     appservices.into_iter().filter_map(|r| r.ok()).collect::<Vec<_>>().join(", ") |                                         appservices.into_iter().filter_map(|r| r.ok()).collect::<Vec<_>>().join(", ") | ||||||
|                                 ); |                                     ); | ||||||
|                                 send_message(message::MessageEventContent::text_plain(output)); |                                     send_message(message::MessageEventContent::text_plain(output)); | ||||||
|  |                                 } else { | ||||||
|  |                                     send_message(message::MessageEventContent::text_plain("Failed to get appservices.")); | ||||||
|  |                                 } | ||||||
|                             } |                             } | ||||||
|                             AdminCommand::SendMessage(message) => { |                             AdminCommand::SendMessage(message) => { | ||||||
|                                 send_message(message); |                                 send_message(message); | ||||||
|  | @ -93,6 +99,6 @@ impl Admin { | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     pub fn send(&self, command: AdminCommand) { |     pub fn send(&self, command: AdminCommand) { | ||||||
|         self.sender.unbounded_send(command).unwrap() |         self.sender.unbounded_send(command).unwrap(); | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -4,18 +4,21 @@ use std::{ | ||||||
|     sync::{Arc, RwLock}, |     sync::{Arc, RwLock}, | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| #[derive(Clone)] | use super::abstraction::Tree; | ||||||
|  | 
 | ||||||
| pub struct Appservice { | pub struct Appservice { | ||||||
|     pub(super) cached_registrations: Arc<RwLock<HashMap<String, serde_yaml::Value>>>, |     pub(super) cached_registrations: Arc<RwLock<HashMap<String, serde_yaml::Value>>>, | ||||||
|     pub(super) id_appserviceregistrations: sled::Tree, |     pub(super) id_appserviceregistrations: Arc<dyn Tree>, | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| impl Appservice { | impl Appservice { | ||||||
|     pub fn register_appservice(&self, yaml: serde_yaml::Value) -> Result<()> { |     pub fn register_appservice(&self, yaml: serde_yaml::Value) -> Result<()> { | ||||||
|         // TODO: Rumaify
 |         // TODO: Rumaify
 | ||||||
|         let id = yaml.get("id").unwrap().as_str().unwrap(); |         let id = yaml.get("id").unwrap().as_str().unwrap(); | ||||||
|         self.id_appserviceregistrations |         self.id_appserviceregistrations.insert( | ||||||
|             .insert(id, serde_yaml::to_string(&yaml).unwrap().as_bytes())?; |             id.as_bytes(), | ||||||
|  |             serde_yaml::to_string(&yaml).unwrap().as_bytes(), | ||||||
|  |         )?; | ||||||
|         self.cached_registrations |         self.cached_registrations | ||||||
|             .write() |             .write() | ||||||
|             .unwrap() |             .unwrap() | ||||||
|  | @ -33,7 +36,7 @@ impl Appservice { | ||||||
|                 || { |                 || { | ||||||
|                     Ok(self |                     Ok(self | ||||||
|                         .id_appserviceregistrations |                         .id_appserviceregistrations | ||||||
|                         .get(id)? |                         .get(id.as_bytes())? | ||||||
|                         .map(|bytes| { |                         .map(|bytes| { | ||||||
|                             Ok::<_, Error>(serde_yaml::from_slice(&bytes).map_err(|_| { |                             Ok::<_, Error>(serde_yaml::from_slice(&bytes).map_err(|_| { | ||||||
|                                 Error::bad_database( |                                 Error::bad_database( | ||||||
|  | @ -47,21 +50,25 @@ impl Appservice { | ||||||
|             ) |             ) | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     pub fn iter_ids(&self) -> impl Iterator<Item = Result<String>> { |     pub fn iter_ids<'a>( | ||||||
|         self.id_appserviceregistrations.iter().keys().map(|id| { |         &'a self, | ||||||
|             Ok(utils::string_from_bytes(&id?).map_err(|_| { |     ) -> Result<impl Iterator<Item = Result<String>> + Send + Sync + 'a> { | ||||||
|  |         Ok(self.id_appserviceregistrations.iter().map(|(id, _)| { | ||||||
|  |             Ok(utils::string_from_bytes(&id).map_err(|_| { | ||||||
|                 Error::bad_database("Invalid id bytes in id_appserviceregistrations.") |                 Error::bad_database("Invalid id bytes in id_appserviceregistrations.") | ||||||
|             })?) |             })?) | ||||||
|         }) |         })) | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     pub fn iter_all(&self) -> impl Iterator<Item = Result<(String, serde_yaml::Value)>> + '_ { |     pub fn iter_all( | ||||||
|         self.iter_ids().filter_map(|id| id.ok()).map(move |id| { |         &self, | ||||||
|  |     ) -> Result<impl Iterator<Item = Result<(String, serde_yaml::Value)>> + '_ + Send + Sync> { | ||||||
|  |         Ok(self.iter_ids()?.filter_map(|id| id.ok()).map(move |id| { | ||||||
|             Ok(( |             Ok(( | ||||||
|                 id.clone(), |                 id.clone(), | ||||||
|                 self.get_registration(&id)? |                 self.get_registration(&id)? | ||||||
|                     .expect("iter_ids only returns appservices that exist"), |                     .expect("iter_ids only returns appservices that exist"), | ||||||
|             )) |             )) | ||||||
|         }) |         })) | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -7,28 +7,31 @@ use ruma::{ | ||||||
| use rustls::{ServerCertVerifier, WebPKIVerifier}; | use rustls::{ServerCertVerifier, WebPKIVerifier}; | ||||||
| use std::{ | use std::{ | ||||||
|     collections::{BTreeMap, HashMap}, |     collections::{BTreeMap, HashMap}, | ||||||
|  |     fs, | ||||||
|  |     path::PathBuf, | ||||||
|     sync::{Arc, RwLock}, |     sync::{Arc, RwLock}, | ||||||
|     time::{Duration, Instant}, |     time::{Duration, Instant}, | ||||||
| }; | }; | ||||||
| use tokio::sync::Semaphore; | use tokio::sync::Semaphore; | ||||||
| use trust_dns_resolver::TokioAsyncResolver; | use trust_dns_resolver::TokioAsyncResolver; | ||||||
| 
 | 
 | ||||||
| pub const COUNTER: &str = "c"; | use super::abstraction::Tree; | ||||||
|  | 
 | ||||||
|  | pub const COUNTER: &[u8] = b"c"; | ||||||
| 
 | 
 | ||||||
| type WellKnownMap = HashMap<Box<ServerName>, (String, String)>; | type WellKnownMap = HashMap<Box<ServerName>, (String, String)>; | ||||||
| type TlsNameMap = HashMap<String, webpki::DNSName>; | type TlsNameMap = HashMap<String, webpki::DNSName>; | ||||||
| type RateLimitState = (Instant, u32); // Time if last failed try, number of failed tries
 | type RateLimitState = (Instant, u32); // Time if last failed try, number of failed tries
 | ||||||
| #[derive(Clone)] |  | ||||||
| pub struct Globals { | pub struct Globals { | ||||||
|     pub actual_destination_cache: Arc<RwLock<WellKnownMap>>, // actual_destination, host
 |     pub actual_destination_cache: Arc<RwLock<WellKnownMap>>, // actual_destination, host
 | ||||||
|     pub tls_name_override: Arc<RwLock<TlsNameMap>>, |     pub tls_name_override: Arc<RwLock<TlsNameMap>>, | ||||||
|     pub(super) globals: sled::Tree, |     pub(super) globals: Arc<dyn Tree>, | ||||||
|     config: Config, |     config: Config, | ||||||
|     keypair: Arc<ruma::signatures::Ed25519KeyPair>, |     keypair: Arc<ruma::signatures::Ed25519KeyPair>, | ||||||
|     reqwest_client: reqwest::Client, |     reqwest_client: reqwest::Client, | ||||||
|     dns_resolver: TokioAsyncResolver, |     dns_resolver: TokioAsyncResolver, | ||||||
|     jwt_decoding_key: Option<jsonwebtoken::DecodingKey<'static>>, |     jwt_decoding_key: Option<jsonwebtoken::DecodingKey<'static>>, | ||||||
|     pub(super) server_signingkeys: sled::Tree, |     pub(super) server_signingkeys: Arc<dyn Tree>, | ||||||
|     pub bad_event_ratelimiter: Arc<RwLock<BTreeMap<EventId, RateLimitState>>>, |     pub bad_event_ratelimiter: Arc<RwLock<BTreeMap<EventId, RateLimitState>>>, | ||||||
|     pub bad_signature_ratelimiter: Arc<RwLock<BTreeMap<Vec<String>, RateLimitState>>>, |     pub bad_signature_ratelimiter: Arc<RwLock<BTreeMap<Vec<String>, RateLimitState>>>, | ||||||
|     pub servername_ratelimiter: Arc<RwLock<BTreeMap<Box<ServerName>, Arc<Semaphore>>>>, |     pub servername_ratelimiter: Arc<RwLock<BTreeMap<Box<ServerName>, Arc<Semaphore>>>>, | ||||||
|  | @ -69,15 +72,20 @@ impl ServerCertVerifier for MatrixServerVerifier { | ||||||
| 
 | 
 | ||||||
| impl Globals { | impl Globals { | ||||||
|     pub fn load( |     pub fn load( | ||||||
|         globals: sled::Tree, |         globals: Arc<dyn Tree>, | ||||||
|         server_signingkeys: sled::Tree, |         server_signingkeys: Arc<dyn Tree>, | ||||||
|         config: Config, |         config: Config, | ||||||
|     ) -> Result<Self> { |     ) -> Result<Self> { | ||||||
|         let bytes = &*globals |         let keypair_bytes = globals.get(b"keypair")?.map_or_else( | ||||||
|             .update_and_fetch("keypair", utils::generate_keypair)? |             || { | ||||||
|             .expect("utils::generate_keypair always returns Some"); |                 let keypair = utils::generate_keypair(); | ||||||
|  |                 globals.insert(b"keypair", &keypair)?; | ||||||
|  |                 Ok::<_, Error>(keypair) | ||||||
|  |             }, | ||||||
|  |             |s| Ok(s.to_vec()), | ||||||
|  |         )?; | ||||||
| 
 | 
 | ||||||
|         let mut parts = bytes.splitn(2, |&b| b == 0xff); |         let mut parts = keypair_bytes.splitn(2, |&b| b == 0xff); | ||||||
| 
 | 
 | ||||||
|         let keypair = utils::string_from_bytes( |         let keypair = utils::string_from_bytes( | ||||||
|             // 1. version
 |             // 1. version
 | ||||||
|  | @ -102,7 +110,7 @@ impl Globals { | ||||||
|             Ok(k) => k, |             Ok(k) => k, | ||||||
|             Err(e) => { |             Err(e) => { | ||||||
|                 error!("Keypair invalid. Deleting..."); |                 error!("Keypair invalid. Deleting..."); | ||||||
|                 globals.remove("keypair")?; |                 globals.remove(b"keypair")?; | ||||||
|                 return Err(e); |                 return Err(e); | ||||||
|             } |             } | ||||||
|         }; |         }; | ||||||
|  | @ -130,7 +138,7 @@ impl Globals { | ||||||
|             .as_ref() |             .as_ref() | ||||||
|             .map(|secret| jsonwebtoken::DecodingKey::from_secret(secret.as_bytes()).into_static()); |             .map(|secret| jsonwebtoken::DecodingKey::from_secret(secret.as_bytes()).into_static()); | ||||||
| 
 | 
 | ||||||
|         Ok(Self { |         let s = Self { | ||||||
|             globals, |             globals, | ||||||
|             config, |             config, | ||||||
|             keypair: Arc::new(keypair), |             keypair: Arc::new(keypair), | ||||||
|  | @ -145,7 +153,11 @@ impl Globals { | ||||||
|             bad_event_ratelimiter: Arc::new(RwLock::new(BTreeMap::new())), |             bad_event_ratelimiter: Arc::new(RwLock::new(BTreeMap::new())), | ||||||
|             bad_signature_ratelimiter: Arc::new(RwLock::new(BTreeMap::new())), |             bad_signature_ratelimiter: Arc::new(RwLock::new(BTreeMap::new())), | ||||||
|             servername_ratelimiter: Arc::new(RwLock::new(BTreeMap::new())), |             servername_ratelimiter: Arc::new(RwLock::new(BTreeMap::new())), | ||||||
|         }) |         }; | ||||||
|  | 
 | ||||||
|  |         fs::create_dir_all(s.get_media_folder())?; | ||||||
|  | 
 | ||||||
|  |         Ok(s) | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     /// Returns this server's keypair.
 |     /// Returns this server's keypair.
 | ||||||
|  | @ -159,13 +171,8 @@ impl Globals { | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     pub fn next_count(&self) -> Result<u64> { |     pub fn next_count(&self) -> Result<u64> { | ||||||
|         Ok(utils::u64_from_bytes( |         Ok(utils::u64_from_bytes(&self.globals.increment(COUNTER)?) | ||||||
|             &self |             .map_err(|_| Error::bad_database("Count has invalid bytes."))?) | ||||||
|                 .globals |  | ||||||
|                 .update_and_fetch(COUNTER, utils::increment)? |  | ||||||
|                 .expect("utils::increment will always put in a value"), |  | ||||||
|         ) |  | ||||||
|         .map_err(|_| Error::bad_database("Count has invalid bytes."))?) |  | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     pub fn current_count(&self) -> Result<u64> { |     pub fn current_count(&self) -> Result<u64> { | ||||||
|  | @ -211,21 +218,30 @@ impl Globals { | ||||||
|     /// Remove the outdated keys and insert the new ones.
 |     /// Remove the outdated keys and insert the new ones.
 | ||||||
|     ///
 |     ///
 | ||||||
|     /// This doesn't actually check that the keys provided are newer than the old set.
 |     /// This doesn't actually check that the keys provided are newer than the old set.
 | ||||||
|     pub fn add_signing_key(&self, origin: &ServerName, new_keys: &ServerSigningKeys) -> Result<()> { |     pub fn add_signing_key(&self, origin: &ServerName, new_keys: ServerSigningKeys) -> Result<()> { | ||||||
|         self.server_signingkeys |         // Not atomic, but this is not critical
 | ||||||
|             .update_and_fetch(origin.as_bytes(), |signingkeys| { |         let signingkeys = self.server_signingkeys.get(origin.as_bytes())?; | ||||||
|                 let mut keys = signingkeys | 
 | ||||||
|                     .and_then(|keys| serde_json::from_slice(keys).ok()) |         let mut keys = signingkeys | ||||||
|                     .unwrap_or_else(|| { |             .and_then(|keys| serde_json::from_slice(&keys).ok()) | ||||||
|                         // Just insert "now", it doesn't matter
 |             .unwrap_or_else(|| { | ||||||
|                         ServerSigningKeys::new(origin.to_owned(), MilliSecondsSinceUnixEpoch::now()) |                 // Just insert "now", it doesn't matter
 | ||||||
|                     }); |                 ServerSigningKeys::new(origin.to_owned(), MilliSecondsSinceUnixEpoch::now()) | ||||||
|                 keys.verify_keys |             }); | ||||||
|                     .extend(new_keys.verify_keys.clone().into_iter()); | 
 | ||||||
|                 keys.old_verify_keys |         let ServerSigningKeys { | ||||||
|                     .extend(new_keys.old_verify_keys.clone().into_iter()); |             verify_keys, | ||||||
|                 Some(serde_json::to_vec(&keys).expect("serversigningkeys can be serialized")) |             old_verify_keys, | ||||||
|             })?; |             .. | ||||||
|  |         } = new_keys; | ||||||
|  | 
 | ||||||
|  |         keys.verify_keys.extend(verify_keys.into_iter()); | ||||||
|  |         keys.old_verify_keys.extend(old_verify_keys.into_iter()); | ||||||
|  | 
 | ||||||
|  |         self.server_signingkeys.insert( | ||||||
|  |             origin.as_bytes(), | ||||||
|  |             &serde_json::to_vec(&keys).expect("serversigningkeys can be serialized"), | ||||||
|  |         )?; | ||||||
| 
 | 
 | ||||||
|         Ok(()) |         Ok(()) | ||||||
|     } |     } | ||||||
|  | @ -254,14 +270,30 @@ impl Globals { | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     pub fn database_version(&self) -> Result<u64> { |     pub fn database_version(&self) -> Result<u64> { | ||||||
|         self.globals.get("version")?.map_or(Ok(0), |version| { |         self.globals.get(b"version")?.map_or(Ok(0), |version| { | ||||||
|             utils::u64_from_bytes(&version) |             utils::u64_from_bytes(&version) | ||||||
|                 .map_err(|_| Error::bad_database("Database version id is invalid.")) |                 .map_err(|_| Error::bad_database("Database version id is invalid.")) | ||||||
|         }) |         }) | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     pub fn bump_database_version(&self, new_version: u64) -> Result<()> { |     pub fn bump_database_version(&self, new_version: u64) -> Result<()> { | ||||||
|         self.globals.insert("version", &new_version.to_be_bytes())?; |         self.globals | ||||||
|  |             .insert(b"version", &new_version.to_be_bytes())?; | ||||||
|         Ok(()) |         Ok(()) | ||||||
|     } |     } | ||||||
|  | 
 | ||||||
|  |     pub fn get_media_folder(&self) -> PathBuf { | ||||||
|  |         let mut r = PathBuf::new(); | ||||||
|  |         r.push(self.config.database_path.clone()); | ||||||
|  |         r.push("media"); | ||||||
|  |         r | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn get_media_file(&self, key: &[u8]) -> PathBuf { | ||||||
|  |         let mut r = PathBuf::new(); | ||||||
|  |         r.push(self.config.database_path.clone()); | ||||||
|  |         r.push("media"); | ||||||
|  |         r.push(base64::encode_config(key, base64::URL_SAFE_NO_PAD)); | ||||||
|  |         r | ||||||
|  |     } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -6,13 +6,14 @@ use ruma::{ | ||||||
|     }, |     }, | ||||||
|     RoomId, UserId, |     RoomId, UserId, | ||||||
| }; | }; | ||||||
| use std::{collections::BTreeMap, convert::TryFrom}; | use std::{collections::BTreeMap, convert::TryFrom, sync::Arc}; | ||||||
|  | 
 | ||||||
|  | use super::abstraction::Tree; | ||||||
| 
 | 
 | ||||||
| #[derive(Clone)] |  | ||||||
| pub struct KeyBackups { | pub struct KeyBackups { | ||||||
|     pub(super) backupid_algorithm: sled::Tree, // BackupId = UserId + Version(Count)
 |     pub(super) backupid_algorithm: Arc<dyn Tree>, // BackupId = UserId + Version(Count)
 | ||||||
|     pub(super) backupid_etag: sled::Tree,      // BackupId = UserId + Version(Count)
 |     pub(super) backupid_etag: Arc<dyn Tree>,      // BackupId = UserId + Version(Count)
 | ||||||
|     pub(super) backupkeyid_backup: sled::Tree, // BackupKeyId = UserId + Version + RoomId + SessionId
 |     pub(super) backupkeyid_backup: Arc<dyn Tree>, // BackupKeyId = UserId + Version + RoomId + SessionId
 | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| impl KeyBackups { | impl KeyBackups { | ||||||
|  | @ -30,8 +31,7 @@ impl KeyBackups { | ||||||
| 
 | 
 | ||||||
|         self.backupid_algorithm.insert( |         self.backupid_algorithm.insert( | ||||||
|             &key, |             &key, | ||||||
|             &*serde_json::to_string(backup_metadata) |             &serde_json::to_vec(backup_metadata).expect("BackupAlgorithm::to_vec always works"), | ||||||
|                 .expect("BackupAlgorithm::to_string always works"), |  | ||||||
|         )?; |         )?; | ||||||
|         self.backupid_etag |         self.backupid_etag | ||||||
|             .insert(&key, &globals.next_count()?.to_be_bytes())?; |             .insert(&key, &globals.next_count()?.to_be_bytes())?; | ||||||
|  | @ -48,13 +48,8 @@ impl KeyBackups { | ||||||
| 
 | 
 | ||||||
|         key.push(0xff); |         key.push(0xff); | ||||||
| 
 | 
 | ||||||
|         for outdated_key in self |         for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) { | ||||||
|             .backupkeyid_backup |             self.backupkeyid_backup.remove(&outdated_key)?; | ||||||
|             .scan_prefix(&key) |  | ||||||
|             .keys() |  | ||||||
|             .filter_map(|r| r.ok()) |  | ||||||
|         { |  | ||||||
|             self.backupkeyid_backup.remove(outdated_key)?; |  | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         Ok(()) |         Ok(()) | ||||||
|  | @ -80,8 +75,9 @@ impl KeyBackups { | ||||||
| 
 | 
 | ||||||
|         self.backupid_algorithm.insert( |         self.backupid_algorithm.insert( | ||||||
|             &key, |             &key, | ||||||
|             &*serde_json::to_string(backup_metadata) |             &serde_json::to_string(backup_metadata) | ||||||
|                 .expect("BackupAlgorithm::to_string always works"), |                 .expect("BackupAlgorithm::to_string always works") | ||||||
|  |                 .as_bytes(), | ||||||
|         )?; |         )?; | ||||||
|         self.backupid_etag |         self.backupid_etag | ||||||
|             .insert(&key, &globals.next_count()?.to_be_bytes())?; |             .insert(&key, &globals.next_count()?.to_be_bytes())?; | ||||||
|  | @ -91,11 +87,14 @@ impl KeyBackups { | ||||||
|     pub fn get_latest_backup(&self, user_id: &UserId) -> Result<Option<(String, BackupAlgorithm)>> { |     pub fn get_latest_backup(&self, user_id: &UserId) -> Result<Option<(String, BackupAlgorithm)>> { | ||||||
|         let mut prefix = user_id.as_bytes().to_vec(); |         let mut prefix = user_id.as_bytes().to_vec(); | ||||||
|         prefix.push(0xff); |         prefix.push(0xff); | ||||||
|  |         let mut last_possible_key = prefix.clone(); | ||||||
|  |         last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); | ||||||
|  | 
 | ||||||
|         self.backupid_algorithm |         self.backupid_algorithm | ||||||
|             .scan_prefix(&prefix) |             .iter_from(&last_possible_key, true) | ||||||
|             .last() |             .take_while(move |(k, _)| k.starts_with(&prefix)) | ||||||
|             .map_or(Ok(None), |r| { |             .next() | ||||||
|                 let (key, value) = r?; |             .map_or(Ok(None), |(key, value)| { | ||||||
|                 let version = utils::string_from_bytes( |                 let version = utils::string_from_bytes( | ||||||
|                     key.rsplit(|&b| b == 0xff) |                     key.rsplit(|&b| b == 0xff) | ||||||
|                         .next() |                         .next() | ||||||
|  | @ -117,10 +116,13 @@ impl KeyBackups { | ||||||
|         key.push(0xff); |         key.push(0xff); | ||||||
|         key.extend_from_slice(version.as_bytes()); |         key.extend_from_slice(version.as_bytes()); | ||||||
| 
 | 
 | ||||||
|         self.backupid_algorithm.get(key)?.map_or(Ok(None), |bytes| { |         self.backupid_algorithm | ||||||
|             Ok(serde_json::from_slice(&bytes) |             .get(&key)? | ||||||
|                 .map_err(|_| Error::bad_database("Algorithm in backupid_algorithm is invalid."))?) |             .map_or(Ok(None), |bytes| { | ||||||
|         }) |                 Ok(serde_json::from_slice(&bytes).map_err(|_| { | ||||||
|  |                     Error::bad_database("Algorithm in backupid_algorithm is invalid.") | ||||||
|  |                 })?) | ||||||
|  |             }) | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     pub fn add_key( |     pub fn add_key( | ||||||
|  | @ -153,7 +155,7 @@ impl KeyBackups { | ||||||
| 
 | 
 | ||||||
|         self.backupkeyid_backup.insert( |         self.backupkeyid_backup.insert( | ||||||
|             &key, |             &key, | ||||||
|             &*serde_json::to_string(&key_data).expect("KeyBackupData::to_string always works"), |             &serde_json::to_vec(&key_data).expect("KeyBackupData::to_vec always works"), | ||||||
|         )?; |         )?; | ||||||
| 
 | 
 | ||||||
|         Ok(()) |         Ok(()) | ||||||
|  | @ -164,7 +166,7 @@ impl KeyBackups { | ||||||
|         prefix.push(0xff); |         prefix.push(0xff); | ||||||
|         prefix.extend_from_slice(version.as_bytes()); |         prefix.extend_from_slice(version.as_bytes()); | ||||||
| 
 | 
 | ||||||
|         Ok(self.backupkeyid_backup.scan_prefix(&prefix).count()) |         Ok(self.backupkeyid_backup.scan_prefix(prefix).count()) | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     pub fn get_etag(&self, user_id: &UserId, version: &str) -> Result<String> { |     pub fn get_etag(&self, user_id: &UserId, version: &str) -> Result<String> { | ||||||
|  | @ -194,33 +196,37 @@ impl KeyBackups { | ||||||
| 
 | 
 | ||||||
|         let mut rooms = BTreeMap::<RoomId, RoomKeyBackup>::new(); |         let mut rooms = BTreeMap::<RoomId, RoomKeyBackup>::new(); | ||||||
| 
 | 
 | ||||||
|         for result in self.backupkeyid_backup.scan_prefix(&prefix).map(|r| { |         for result in self | ||||||
|             let (key, value) = r?; |             .backupkeyid_backup | ||||||
|             let mut parts = key.rsplit(|&b| b == 0xff); |             .scan_prefix(prefix) | ||||||
|  |             .map(|(key, value)| { | ||||||
|  |                 let mut parts = key.rsplit(|&b| b == 0xff); | ||||||
| 
 | 
 | ||||||
|             let session_id = utils::string_from_bytes( |                 let session_id = | ||||||
|                 &parts |                     utils::string_from_bytes(&parts.next().ok_or_else(|| { | ||||||
|                     .next() |                         Error::bad_database("backupkeyid_backup key is invalid.") | ||||||
|                     .ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?, |                     })?) | ||||||
|             ) |                     .map_err(|_| { | ||||||
|             .map_err(|_| Error::bad_database("backupkeyid_backup session_id is invalid."))?; |                         Error::bad_database("backupkeyid_backup session_id is invalid.") | ||||||
|  |                     })?; | ||||||
| 
 | 
 | ||||||
|             let room_id = RoomId::try_from( |                 let room_id = RoomId::try_from( | ||||||
|                 utils::string_from_bytes( |                     utils::string_from_bytes(&parts.next().ok_or_else(|| { | ||||||
|                     &parts |                         Error::bad_database("backupkeyid_backup key is invalid.") | ||||||
|                         .next() |                     })?) | ||||||
|                         .ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?, |                     .map_err(|_| Error::bad_database("backupkeyid_backup room_id is invalid."))?, | ||||||
|                 ) |                 ) | ||||||
|                 .map_err(|_| Error::bad_database("backupkeyid_backup room_id is invalid."))?, |                 .map_err(|_| { | ||||||
|             ) |                     Error::bad_database("backupkeyid_backup room_id is invalid room id.") | ||||||
|             .map_err(|_| Error::bad_database("backupkeyid_backup room_id is invalid room id."))?; |                 })?; | ||||||
| 
 | 
 | ||||||
|             let key_data = serde_json::from_slice(&value).map_err(|_| { |                 let key_data = serde_json::from_slice(&value).map_err(|_| { | ||||||
|                 Error::bad_database("KeyBackupData in backupkeyid_backup is invalid.") |                     Error::bad_database("KeyBackupData in backupkeyid_backup is invalid.") | ||||||
|             })?; |                 })?; | ||||||
| 
 | 
 | ||||||
|             Ok::<_, Error>((room_id, session_id, key_data)) |                 Ok::<_, Error>((room_id, session_id, key_data)) | ||||||
|         }) { |             }) | ||||||
|  |         { | ||||||
|             let (room_id, session_id, key_data) = result?; |             let (room_id, session_id, key_data) = result?; | ||||||
|             rooms |             rooms | ||||||
|                 .entry(room_id) |                 .entry(room_id) | ||||||
|  | @ -239,7 +245,7 @@ impl KeyBackups { | ||||||
|         user_id: &UserId, |         user_id: &UserId, | ||||||
|         version: &str, |         version: &str, | ||||||
|         room_id: &RoomId, |         room_id: &RoomId, | ||||||
|     ) -> BTreeMap<String, KeyBackupData> { |     ) -> Result<BTreeMap<String, KeyBackupData>> { | ||||||
|         let mut prefix = user_id.as_bytes().to_vec(); |         let mut prefix = user_id.as_bytes().to_vec(); | ||||||
|         prefix.push(0xff); |         prefix.push(0xff); | ||||||
|         prefix.extend_from_slice(version.as_bytes()); |         prefix.extend_from_slice(version.as_bytes()); | ||||||
|  | @ -247,10 +253,10 @@ impl KeyBackups { | ||||||
|         prefix.extend_from_slice(room_id.as_bytes()); |         prefix.extend_from_slice(room_id.as_bytes()); | ||||||
|         prefix.push(0xff); |         prefix.push(0xff); | ||||||
| 
 | 
 | ||||||
|         self.backupkeyid_backup |         Ok(self | ||||||
|             .scan_prefix(&prefix) |             .backupkeyid_backup | ||||||
|             .map(|r| { |             .scan_prefix(prefix) | ||||||
|                 let (key, value) = r?; |             .map(|(key, value)| { | ||||||
|                 let mut parts = key.rsplit(|&b| b == 0xff); |                 let mut parts = key.rsplit(|&b| b == 0xff); | ||||||
| 
 | 
 | ||||||
|                 let session_id = |                 let session_id = | ||||||
|  | @ -268,7 +274,7 @@ impl KeyBackups { | ||||||
|                 Ok::<_, Error>((session_id, key_data)) |                 Ok::<_, Error>((session_id, key_data)) | ||||||
|             }) |             }) | ||||||
|             .filter_map(|r| r.ok()) |             .filter_map(|r| r.ok()) | ||||||
|             .collect() |             .collect()) | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     pub fn get_session( |     pub fn get_session( | ||||||
|  | @ -302,13 +308,8 @@ impl KeyBackups { | ||||||
|         key.extend_from_slice(&version.as_bytes()); |         key.extend_from_slice(&version.as_bytes()); | ||||||
|         key.push(0xff); |         key.push(0xff); | ||||||
| 
 | 
 | ||||||
|         for outdated_key in self |         for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) { | ||||||
|             .backupkeyid_backup |             self.backupkeyid_backup.remove(&outdated_key)?; | ||||||
|             .scan_prefix(&key) |  | ||||||
|             .keys() |  | ||||||
|             .filter_map(|r| r.ok()) |  | ||||||
|         { |  | ||||||
|             self.backupkeyid_backup.remove(outdated_key)?; |  | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         Ok(()) |         Ok(()) | ||||||
|  | @ -327,13 +328,8 @@ impl KeyBackups { | ||||||
|         key.extend_from_slice(&room_id.as_bytes()); |         key.extend_from_slice(&room_id.as_bytes()); | ||||||
|         key.push(0xff); |         key.push(0xff); | ||||||
| 
 | 
 | ||||||
|         for outdated_key in self |         for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) { | ||||||
|             .backupkeyid_backup |             self.backupkeyid_backup.remove(&outdated_key)?; | ||||||
|             .scan_prefix(&key) |  | ||||||
|             .keys() |  | ||||||
|             .filter_map(|r| r.ok()) |  | ||||||
|         { |  | ||||||
|             self.backupkeyid_backup.remove(outdated_key)?; |  | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         Ok(()) |         Ok(()) | ||||||
|  | @ -354,13 +350,8 @@ impl KeyBackups { | ||||||
|         key.push(0xff); |         key.push(0xff); | ||||||
|         key.extend_from_slice(&session_id.as_bytes()); |         key.extend_from_slice(&session_id.as_bytes()); | ||||||
| 
 | 
 | ||||||
|         for outdated_key in self |         for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) { | ||||||
|             .backupkeyid_backup |             self.backupkeyid_backup.remove(&outdated_key)?; | ||||||
|             .scan_prefix(&key) |  | ||||||
|             .keys() |  | ||||||
|             .filter_map(|r| r.ok()) |  | ||||||
|         { |  | ||||||
|             self.backupkeyid_backup.remove(outdated_key)?; |  | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         Ok(()) |         Ok(()) | ||||||
|  |  | ||||||
|  | @ -1,7 +1,10 @@ | ||||||
|  | use crate::database::globals::Globals; | ||||||
| use image::{imageops::FilterType, GenericImageView}; | use image::{imageops::FilterType, GenericImageView}; | ||||||
| 
 | 
 | ||||||
|  | use super::abstraction::Tree; | ||||||
| use crate::{utils, Error, Result}; | use crate::{utils, Error, Result}; | ||||||
| use std::mem; | use std::{mem, sync::Arc}; | ||||||
|  | use tokio::{fs::File, io::AsyncReadExt, io::AsyncWriteExt}; | ||||||
| 
 | 
 | ||||||
| pub struct FileMeta { | pub struct FileMeta { | ||||||
|     pub content_disposition: Option<String>, |     pub content_disposition: Option<String>, | ||||||
|  | @ -9,16 +12,16 @@ pub struct FileMeta { | ||||||
|     pub file: Vec<u8>, |     pub file: Vec<u8>, | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| #[derive(Clone)] |  | ||||||
| pub struct Media { | pub struct Media { | ||||||
|     pub(super) mediaid_file: sled::Tree, // MediaId = MXC + WidthHeight + ContentDisposition + ContentType
 |     pub(super) mediaid_file: Arc<dyn Tree>, // MediaId = MXC + WidthHeight + ContentDisposition + ContentType
 | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| impl Media { | impl Media { | ||||||
|     /// Uploads or replaces a file.
 |     /// Uploads a file.
 | ||||||
|     pub fn create( |     pub async fn create( | ||||||
|         &self, |         &self, | ||||||
|         mxc: String, |         mxc: String, | ||||||
|  |         globals: &Globals, | ||||||
|         content_disposition: &Option<&str>, |         content_disposition: &Option<&str>, | ||||||
|         content_type: &Option<&str>, |         content_type: &Option<&str>, | ||||||
|         file: &[u8], |         file: &[u8], | ||||||
|  | @ -42,15 +45,19 @@ impl Media { | ||||||
|                 .unwrap_or_default(), |                 .unwrap_or_default(), | ||||||
|         ); |         ); | ||||||
| 
 | 
 | ||||||
|         self.mediaid_file.insert(key, file)?; |         let path = globals.get_media_file(&key); | ||||||
|  |         let mut f = File::create(path).await?; | ||||||
|  |         f.write_all(file).await?; | ||||||
| 
 | 
 | ||||||
|  |         self.mediaid_file.insert(&key, &[])?; | ||||||
|         Ok(()) |         Ok(()) | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     /// Uploads or replaces a file thumbnail.
 |     /// Uploads or replaces a file thumbnail.
 | ||||||
|     pub fn upload_thumbnail( |     pub async fn upload_thumbnail( | ||||||
|         &self, |         &self, | ||||||
|         mxc: String, |         mxc: String, | ||||||
|  |         globals: &Globals, | ||||||
|         content_disposition: &Option<String>, |         content_disposition: &Option<String>, | ||||||
|         content_type: &Option<String>, |         content_type: &Option<String>, | ||||||
|         width: u32, |         width: u32, | ||||||
|  | @ -76,21 +83,28 @@ impl Media { | ||||||
|                 .unwrap_or_default(), |                 .unwrap_or_default(), | ||||||
|         ); |         ); | ||||||
| 
 | 
 | ||||||
|         self.mediaid_file.insert(key, file)?; |         let path = globals.get_media_file(&key); | ||||||
|  |         let mut f = File::create(path).await?; | ||||||
|  |         f.write_all(file).await?; | ||||||
|  | 
 | ||||||
|  |         self.mediaid_file.insert(&key, &[])?; | ||||||
| 
 | 
 | ||||||
|         Ok(()) |         Ok(()) | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     /// Downloads a file.
 |     /// Downloads a file.
 | ||||||
|     pub fn get(&self, mxc: &str) -> Result<Option<FileMeta>> { |     pub async fn get(&self, globals: &Globals, mxc: &str) -> Result<Option<FileMeta>> { | ||||||
|         let mut prefix = mxc.as_bytes().to_vec(); |         let mut prefix = mxc.as_bytes().to_vec(); | ||||||
|         prefix.push(0xff); |         prefix.push(0xff); | ||||||
|         prefix.extend_from_slice(&0_u32.to_be_bytes()); // Width = 0 if it's not a thumbnail
 |         prefix.extend_from_slice(&0_u32.to_be_bytes()); // Width = 0 if it's not a thumbnail
 | ||||||
|         prefix.extend_from_slice(&0_u32.to_be_bytes()); // Height = 0 if it's not a thumbnail
 |         prefix.extend_from_slice(&0_u32.to_be_bytes()); // Height = 0 if it's not a thumbnail
 | ||||||
|         prefix.push(0xff); |         prefix.push(0xff); | ||||||
| 
 | 
 | ||||||
|         if let Some(r) = self.mediaid_file.scan_prefix(&prefix).next() { |         let mut iter = self.mediaid_file.scan_prefix(prefix); | ||||||
|             let (key, file) = r?; |         if let Some((key, _)) = iter.next() { | ||||||
|  |             let path = globals.get_media_file(&key); | ||||||
|  |             let mut file = Vec::new(); | ||||||
|  |             File::open(path).await?.read_to_end(&mut file).await?; | ||||||
|             let mut parts = key.rsplit(|&b| b == 0xff); |             let mut parts = key.rsplit(|&b| b == 0xff); | ||||||
| 
 | 
 | ||||||
|             let content_type = parts |             let content_type = parts | ||||||
|  | @ -121,7 +135,7 @@ impl Media { | ||||||
|             Ok(Some(FileMeta { |             Ok(Some(FileMeta { | ||||||
|                 content_disposition, |                 content_disposition, | ||||||
|                 content_type, |                 content_type, | ||||||
|                 file: file.to_vec(), |                 file, | ||||||
|             })) |             })) | ||||||
|         } else { |         } else { | ||||||
|             Ok(None) |             Ok(None) | ||||||
|  | @ -151,7 +165,13 @@ impl Media { | ||||||
|     /// - Server creates the thumbnail and sends it to the user
 |     /// - Server creates the thumbnail and sends it to the user
 | ||||||
|     ///
 |     ///
 | ||||||
|     /// For width,height <= 96 the server uses another thumbnailing algorithm which crops the image afterwards.
 |     /// For width,height <= 96 the server uses another thumbnailing algorithm which crops the image afterwards.
 | ||||||
|     pub fn get_thumbnail(&self, mxc: String, width: u32, height: u32) -> Result<Option<FileMeta>> { |     pub async fn get_thumbnail( | ||||||
|  |         &self, | ||||||
|  |         mxc: String, | ||||||
|  |         globals: &Globals, | ||||||
|  |         width: u32, | ||||||
|  |         height: u32, | ||||||
|  |     ) -> Result<Option<FileMeta>> { | ||||||
|         let (width, height, crop) = self |         let (width, height, crop) = self | ||||||
|             .thumbnail_properties(width, height) |             .thumbnail_properties(width, height) | ||||||
|             .unwrap_or((0, 0, false)); // 0, 0 because that's the original file
 |             .unwrap_or((0, 0, false)); // 0, 0 because that's the original file
 | ||||||
|  | @ -169,9 +189,11 @@ impl Media { | ||||||
|         original_prefix.extend_from_slice(&0_u32.to_be_bytes()); // Height = 0 if it's not a thumbnail
 |         original_prefix.extend_from_slice(&0_u32.to_be_bytes()); // Height = 0 if it's not a thumbnail
 | ||||||
|         original_prefix.push(0xff); |         original_prefix.push(0xff); | ||||||
| 
 | 
 | ||||||
|         if let Some(r) = self.mediaid_file.scan_prefix(&thumbnail_prefix).next() { |         if let Some((key, _)) = self.mediaid_file.scan_prefix(thumbnail_prefix).next() { | ||||||
|             // Using saved thumbnail
 |             // Using saved thumbnail
 | ||||||
|             let (key, file) = r?; |             let path = globals.get_media_file(&key); | ||||||
|  |             let mut file = Vec::new(); | ||||||
|  |             File::open(path).await?.read_to_end(&mut file).await?; | ||||||
|             let mut parts = key.rsplit(|&b| b == 0xff); |             let mut parts = key.rsplit(|&b| b == 0xff); | ||||||
| 
 | 
 | ||||||
|             let content_type = parts |             let content_type = parts | ||||||
|  | @ -202,10 +224,12 @@ impl Media { | ||||||
|                 content_type, |                 content_type, | ||||||
|                 file: file.to_vec(), |                 file: file.to_vec(), | ||||||
|             })) |             })) | ||||||
|         } else if let Some(r) = self.mediaid_file.scan_prefix(&original_prefix).next() { |         } else if let Some((key, _)) = self.mediaid_file.scan_prefix(original_prefix).next() { | ||||||
|             // Generate a thumbnail
 |             // Generate a thumbnail
 | ||||||
|  |             let path = globals.get_media_file(&key); | ||||||
|  |             let mut file = Vec::new(); | ||||||
|  |             File::open(path).await?.read_to_end(&mut file).await?; | ||||||
| 
 | 
 | ||||||
|             let (key, file) = r?; |  | ||||||
|             let mut parts = key.rsplit(|&b| b == 0xff); |             let mut parts = key.rsplit(|&b| b == 0xff); | ||||||
| 
 | 
 | ||||||
|             let content_type = parts |             let content_type = parts | ||||||
|  | @ -302,7 +326,11 @@ impl Media { | ||||||
|                     widthheight, |                     widthheight, | ||||||
|                 ); |                 ); | ||||||
| 
 | 
 | ||||||
|                 self.mediaid_file.insert(thumbnail_key, &*thumbnail_bytes)?; |                 let path = globals.get_media_file(&thumbnail_key); | ||||||
|  |                 let mut f = File::create(path).await?; | ||||||
|  |                 f.write_all(&thumbnail_bytes).await?; | ||||||
|  | 
 | ||||||
|  |                 self.mediaid_file.insert(&thumbnail_key, &[])?; | ||||||
| 
 | 
 | ||||||
|                 Ok(Some(FileMeta { |                 Ok(Some(FileMeta { | ||||||
|                     content_disposition, |                     content_disposition, | ||||||
|  |  | ||||||
|  | @ -14,23 +14,17 @@ use ruma::{ | ||||||
|     push::{Action, PushConditionRoomCtx, PushFormat, Ruleset, Tweak}, |     push::{Action, PushConditionRoomCtx, PushFormat, Ruleset, Tweak}, | ||||||
|     uint, UInt, UserId, |     uint, UInt, UserId, | ||||||
| }; | }; | ||||||
| use sled::IVec; |  | ||||||
| 
 | 
 | ||||||
| use std::{convert::TryFrom, fmt::Debug, mem}; | use std::{convert::TryFrom, fmt::Debug, mem, sync::Arc}; | ||||||
|  | 
 | ||||||
|  | use super::abstraction::Tree; | ||||||
| 
 | 
 | ||||||
| #[derive(Debug, Clone)] |  | ||||||
| pub struct PushData { | pub struct PushData { | ||||||
|     /// UserId + pushkey -> Pusher
 |     /// UserId + pushkey -> Pusher
 | ||||||
|     pub(super) senderkey_pusher: sled::Tree, |     pub(super) senderkey_pusher: Arc<dyn Tree>, | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| impl PushData { | impl PushData { | ||||||
|     pub fn new(db: &sled::Db) -> Result<Self> { |  | ||||||
|         Ok(Self { |  | ||||||
|             senderkey_pusher: db.open_tree("senderkey_pusher")?, |  | ||||||
|         }) |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     pub fn set_pusher(&self, sender: &UserId, pusher: set_pusher::Pusher) -> Result<()> { |     pub fn set_pusher(&self, sender: &UserId, pusher: set_pusher::Pusher) -> Result<()> { | ||||||
|         let mut key = sender.as_bytes().to_vec(); |         let mut key = sender.as_bytes().to_vec(); | ||||||
|         key.push(0xff); |         key.push(0xff); | ||||||
|  | @ -40,14 +34,14 @@ impl PushData { | ||||||
|         if pusher.kind.is_none() { |         if pusher.kind.is_none() { | ||||||
|             return self |             return self | ||||||
|                 .senderkey_pusher |                 .senderkey_pusher | ||||||
|                 .remove(key) |                 .remove(&key) | ||||||
|                 .map(|_| ()) |                 .map(|_| ()) | ||||||
|                 .map_err(Into::into); |                 .map_err(Into::into); | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         self.senderkey_pusher.insert( |         self.senderkey_pusher.insert( | ||||||
|             key, |             &key, | ||||||
|             &*serde_json::to_string(&pusher).expect("Pusher is valid JSON string"), |             &serde_json::to_vec(&pusher).expect("Pusher is valid JSON value"), | ||||||
|         )?; |         )?; | ||||||
| 
 | 
 | ||||||
|         Ok(()) |         Ok(()) | ||||||
|  | @ -69,23 +63,21 @@ impl PushData { | ||||||
| 
 | 
 | ||||||
|         self.senderkey_pusher |         self.senderkey_pusher | ||||||
|             .scan_prefix(prefix) |             .scan_prefix(prefix) | ||||||
|             .values() |             .map(|(_, push)| { | ||||||
|             .map(|push| { |  | ||||||
|                 let push = push.map_err(|_| Error::bad_database("Invalid push bytes in db."))?; |  | ||||||
|                 Ok(serde_json::from_slice(&*push) |                 Ok(serde_json::from_slice(&*push) | ||||||
|                     .map_err(|_| Error::bad_database("Invalid Pusher in db."))?) |                     .map_err(|_| Error::bad_database("Invalid Pusher in db."))?) | ||||||
|             }) |             }) | ||||||
|             .collect() |             .collect() | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     pub fn get_pusher_senderkeys(&self, sender: &UserId) -> impl Iterator<Item = Result<IVec>> { |     pub fn get_pusher_senderkeys<'a>( | ||||||
|  |         &'a self, | ||||||
|  |         sender: &UserId, | ||||||
|  |     ) -> impl Iterator<Item = Box<[u8]>> + 'a { | ||||||
|         let mut prefix = sender.as_bytes().to_vec(); |         let mut prefix = sender.as_bytes().to_vec(); | ||||||
|         prefix.push(0xff); |         prefix.push(0xff); | ||||||
| 
 | 
 | ||||||
|         self.senderkey_pusher |         self.senderkey_pusher.scan_prefix(prefix).map(|(k, _)| k) | ||||||
|             .scan_prefix(prefix) |  | ||||||
|             .keys() |  | ||||||
|             .map(|r| Ok(r?)) |  | ||||||
|     } |     } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
										
											
												File diff suppressed because it is too large
												Load diff
											
										
									
								
							|  | @ -1,4 +1,4 @@ | ||||||
| use crate::{utils, Error, Result}; | use crate::{database::abstraction::Tree, utils, Error, Result}; | ||||||
| use ruma::{ | use ruma::{ | ||||||
|     events::{ |     events::{ | ||||||
|         presence::{PresenceEvent, PresenceEventContent}, |         presence::{PresenceEvent, PresenceEventContent}, | ||||||
|  | @ -13,17 +13,17 @@ use std::{ | ||||||
|     collections::{HashMap, HashSet}, |     collections::{HashMap, HashSet}, | ||||||
|     convert::{TryFrom, TryInto}, |     convert::{TryFrom, TryInto}, | ||||||
|     mem, |     mem, | ||||||
|  |     sync::Arc, | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| #[derive(Clone)] |  | ||||||
| pub struct RoomEdus { | pub struct RoomEdus { | ||||||
|     pub(in super::super) readreceiptid_readreceipt: sled::Tree, // ReadReceiptId = RoomId + Count + UserId
 |     pub(in super::super) readreceiptid_readreceipt: Arc<dyn Tree>, // ReadReceiptId = RoomId + Count + UserId
 | ||||||
|     pub(in super::super) roomuserid_privateread: sled::Tree, // RoomUserId = Room + User, PrivateRead = Count
 |     pub(in super::super) roomuserid_privateread: Arc<dyn Tree>, // RoomUserId = Room + User, PrivateRead = Count
 | ||||||
|     pub(in super::super) roomuserid_lastprivatereadupdate: sled::Tree, // LastPrivateReadUpdate = Count
 |     pub(in super::super) roomuserid_lastprivatereadupdate: Arc<dyn Tree>, // LastPrivateReadUpdate = Count
 | ||||||
|     pub(in super::super) typingid_userid: sled::Tree, // TypingId = RoomId + TimeoutTime + Count
 |     pub(in super::super) typingid_userid: Arc<dyn Tree>, // TypingId = RoomId + TimeoutTime + Count
 | ||||||
|     pub(in super::super) roomid_lasttypingupdate: sled::Tree, // LastRoomTypingUpdate = Count
 |     pub(in super::super) roomid_lasttypingupdate: Arc<dyn Tree>, // LastRoomTypingUpdate = Count
 | ||||||
|     pub(in super::super) presenceid_presence: sled::Tree, // PresenceId = RoomId + Count + UserId
 |     pub(in super::super) presenceid_presence: Arc<dyn Tree>, // PresenceId = RoomId + Count + UserId
 | ||||||
|     pub(in super::super) userid_lastpresenceupdate: sled::Tree, // LastPresenceUpdate = Count
 |     pub(in super::super) userid_lastpresenceupdate: Arc<dyn Tree>, // LastPresenceUpdate = Count
 | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| impl RoomEdus { | impl RoomEdus { | ||||||
|  | @ -38,15 +38,15 @@ impl RoomEdus { | ||||||
|         let mut prefix = room_id.as_bytes().to_vec(); |         let mut prefix = room_id.as_bytes().to_vec(); | ||||||
|         prefix.push(0xff); |         prefix.push(0xff); | ||||||
| 
 | 
 | ||||||
|  |         let mut last_possible_key = prefix.clone(); | ||||||
|  |         last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); | ||||||
|  | 
 | ||||||
|         // Remove old entry
 |         // Remove old entry
 | ||||||
|         if let Some(old) = self |         if let Some((old, _)) = self | ||||||
|             .readreceiptid_readreceipt |             .readreceiptid_readreceipt | ||||||
|             .scan_prefix(&prefix) |             .iter_from(&last_possible_key, true) | ||||||
|             .keys() |             .take_while(|(key, _)| key.starts_with(&prefix)) | ||||||
|             .rev() |             .find(|(key, _)| { | ||||||
|             .filter_map(|r| r.ok()) |  | ||||||
|             .take_while(|key| key.starts_with(&prefix)) |  | ||||||
|             .find(|key| { |  | ||||||
|                 key.rsplit(|&b| b == 0xff) |                 key.rsplit(|&b| b == 0xff) | ||||||
|                     .next() |                     .next() | ||||||
|                     .expect("rsplit always returns an element") |                     .expect("rsplit always returns an element") | ||||||
|  | @ -54,7 +54,7 @@ impl RoomEdus { | ||||||
|             }) |             }) | ||||||
|         { |         { | ||||||
|             // This is the old room_latest
 |             // This is the old room_latest
 | ||||||
|             self.readreceiptid_readreceipt.remove(old)?; |             self.readreceiptid_readreceipt.remove(&old)?; | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         let mut room_latest_id = prefix; |         let mut room_latest_id = prefix; | ||||||
|  | @ -63,8 +63,8 @@ impl RoomEdus { | ||||||
|         room_latest_id.extend_from_slice(&user_id.as_bytes()); |         room_latest_id.extend_from_slice(&user_id.as_bytes()); | ||||||
| 
 | 
 | ||||||
|         self.readreceiptid_readreceipt.insert( |         self.readreceiptid_readreceipt.insert( | ||||||
|             room_latest_id, |             &room_latest_id, | ||||||
|             &*serde_json::to_string(&event).expect("EduEvent::to_string always works"), |             &serde_json::to_vec(&event).expect("EduEvent::to_string always works"), | ||||||
|         )?; |         )?; | ||||||
| 
 | 
 | ||||||
|         Ok(()) |         Ok(()) | ||||||
|  | @ -72,13 +72,12 @@ impl RoomEdus { | ||||||
| 
 | 
 | ||||||
|     /// Returns an iterator over the most recent read_receipts in a room that happened after the event with id `since`.
 |     /// Returns an iterator over the most recent read_receipts in a room that happened after the event with id `since`.
 | ||||||
|     #[tracing::instrument(skip(self))] |     #[tracing::instrument(skip(self))] | ||||||
|     pub fn readreceipts_since( |     pub fn readreceipts_since<'a>( | ||||||
|         &self, |         &'a self, | ||||||
|         room_id: &RoomId, |         room_id: &RoomId, | ||||||
|         since: u64, |         since: u64, | ||||||
|     ) -> Result< |     ) -> impl Iterator<Item = Result<(UserId, u64, Raw<ruma::events::AnySyncEphemeralRoomEvent>)>> + 'a | ||||||
|         impl Iterator<Item = Result<(UserId, u64, Raw<ruma::events::AnySyncEphemeralRoomEvent>)>>, |     { | ||||||
|     > { |  | ||||||
|         let mut prefix = room_id.as_bytes().to_vec(); |         let mut prefix = room_id.as_bytes().to_vec(); | ||||||
|         prefix.push(0xff); |         prefix.push(0xff); | ||||||
|         let prefix2 = prefix.clone(); |         let prefix2 = prefix.clone(); | ||||||
|  | @ -86,10 +85,8 @@ impl RoomEdus { | ||||||
|         let mut first_possible_edu = prefix.clone(); |         let mut first_possible_edu = prefix.clone(); | ||||||
|         first_possible_edu.extend_from_slice(&(since + 1).to_be_bytes()); // +1 so we don't send the event at since
 |         first_possible_edu.extend_from_slice(&(since + 1).to_be_bytes()); // +1 so we don't send the event at since
 | ||||||
| 
 | 
 | ||||||
|         Ok(self |         self.readreceiptid_readreceipt | ||||||
|             .readreceiptid_readreceipt |             .iter_from(&first_possible_edu, false) | ||||||
|             .range(&*first_possible_edu..) |  | ||||||
|             .filter_map(|r| r.ok()) |  | ||||||
|             .take_while(move |(k, _)| k.starts_with(&prefix2)) |             .take_while(move |(k, _)| k.starts_with(&prefix2)) | ||||||
|             .map(move |(k, v)| { |             .map(move |(k, v)| { | ||||||
|                 let count = |                 let count = | ||||||
|  | @ -115,7 +112,7 @@ impl RoomEdus { | ||||||
|                         serde_json::value::to_raw_value(&json).expect("json is valid raw value"), |                         serde_json::value::to_raw_value(&json).expect("json is valid raw value"), | ||||||
|                     ), |                     ), | ||||||
|                 )) |                 )) | ||||||
|             })) |             }) | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     /// Sets a private read marker at `count`.
 |     /// Sets a private read marker at `count`.
 | ||||||
|  | @ -146,11 +143,13 @@ impl RoomEdus { | ||||||
|         key.push(0xff); |         key.push(0xff); | ||||||
|         key.extend_from_slice(&user_id.as_bytes()); |         key.extend_from_slice(&user_id.as_bytes()); | ||||||
| 
 | 
 | ||||||
|         self.roomuserid_privateread.get(key)?.map_or(Ok(None), |v| { |         self.roomuserid_privateread | ||||||
|             Ok(Some(utils::u64_from_bytes(&v).map_err(|_| { |             .get(&key)? | ||||||
|                 Error::bad_database("Invalid private read marker bytes") |             .map_or(Ok(None), |v| { | ||||||
|             })?)) |                 Ok(Some(utils::u64_from_bytes(&v).map_err(|_| { | ||||||
|         }) |                     Error::bad_database("Invalid private read marker bytes") | ||||||
|  |                 })?)) | ||||||
|  |             }) | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     /// Returns the count of the last typing update in this room.
 |     /// Returns the count of the last typing update in this room.
 | ||||||
|  | @ -215,11 +214,10 @@ impl RoomEdus { | ||||||
|         // Maybe there are multiple ones from calling roomtyping_add multiple times
 |         // Maybe there are multiple ones from calling roomtyping_add multiple times
 | ||||||
|         for outdated_edu in self |         for outdated_edu in self | ||||||
|             .typingid_userid |             .typingid_userid | ||||||
|             .scan_prefix(&prefix) |             .scan_prefix(prefix) | ||||||
|             .filter_map(|r| r.ok()) |             .filter(|(_, v)| &**v == user_id.as_bytes()) | ||||||
|             .filter(|(_, v)| v == user_id.as_bytes()) |  | ||||||
|         { |         { | ||||||
|             self.typingid_userid.remove(outdated_edu.0)?; |             self.typingid_userid.remove(&outdated_edu.0)?; | ||||||
|             found_outdated = true; |             found_outdated = true; | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|  | @ -247,10 +245,8 @@ impl RoomEdus { | ||||||
|         // Find all outdated edus before inserting a new one
 |         // Find all outdated edus before inserting a new one
 | ||||||
|         for outdated_edu in self |         for outdated_edu in self | ||||||
|             .typingid_userid |             .typingid_userid | ||||||
|             .scan_prefix(&prefix) |             .scan_prefix(prefix) | ||||||
|             .keys() |             .map(|(key, _)| { | ||||||
|             .map(|key| { |  | ||||||
|                 let key = key?; |  | ||||||
|                 Ok::<_, Error>(( |                 Ok::<_, Error>(( | ||||||
|                     key.clone(), |                     key.clone(), | ||||||
|                     utils::u64_from_bytes( |                     utils::u64_from_bytes( | ||||||
|  | @ -265,7 +261,7 @@ impl RoomEdus { | ||||||
|             .take_while(|&(_, timestamp)| timestamp < current_timestamp) |             .take_while(|&(_, timestamp)| timestamp < current_timestamp) | ||||||
|         { |         { | ||||||
|             // This is an outdated edu (time > timestamp)
 |             // This is an outdated edu (time > timestamp)
 | ||||||
|             self.typingid_userid.remove(outdated_edu.0)?; |             self.typingid_userid.remove(&outdated_edu.0)?; | ||||||
|             found_outdated = true; |             found_outdated = true; | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|  | @ -309,10 +305,9 @@ impl RoomEdus { | ||||||
|         for user_id in self |         for user_id in self | ||||||
|             .typingid_userid |             .typingid_userid | ||||||
|             .scan_prefix(prefix) |             .scan_prefix(prefix) | ||||||
|             .values() |             .map(|(_, user_id)| { | ||||||
|             .map(|user_id| { |  | ||||||
|                 Ok::<_, Error>( |                 Ok::<_, Error>( | ||||||
|                     UserId::try_from(utils::string_from_bytes(&user_id?).map_err(|_| { |                     UserId::try_from(utils::string_from_bytes(&user_id).map_err(|_| { | ||||||
|                         Error::bad_database("User ID in typingid_userid is invalid unicode.") |                         Error::bad_database("User ID in typingid_userid is invalid unicode.") | ||||||
|                     })?) |                     })?) | ||||||
|                     .map_err(|_| Error::bad_database("User ID in typingid_userid is invalid."))?, |                     .map_err(|_| Error::bad_database("User ID in typingid_userid is invalid."))?, | ||||||
|  | @ -351,12 +346,12 @@ impl RoomEdus { | ||||||
|         presence_id.extend_from_slice(&presence.sender.as_bytes()); |         presence_id.extend_from_slice(&presence.sender.as_bytes()); | ||||||
| 
 | 
 | ||||||
|         self.presenceid_presence.insert( |         self.presenceid_presence.insert( | ||||||
|             presence_id, |             &presence_id, | ||||||
|             &*serde_json::to_string(&presence).expect("PresenceEvent can be serialized"), |             &serde_json::to_vec(&presence).expect("PresenceEvent can be serialized"), | ||||||
|         )?; |         )?; | ||||||
| 
 | 
 | ||||||
|         self.userid_lastpresenceupdate.insert( |         self.userid_lastpresenceupdate.insert( | ||||||
|             &user_id.as_bytes(), |             user_id.as_bytes(), | ||||||
|             &utils::millis_since_unix_epoch().to_be_bytes(), |             &utils::millis_since_unix_epoch().to_be_bytes(), | ||||||
|         )?; |         )?; | ||||||
| 
 | 
 | ||||||
|  | @ -403,7 +398,7 @@ impl RoomEdus { | ||||||
|         presence_id.extend_from_slice(&user_id.as_bytes()); |         presence_id.extend_from_slice(&user_id.as_bytes()); | ||||||
| 
 | 
 | ||||||
|         self.presenceid_presence |         self.presenceid_presence | ||||||
|             .get(presence_id)? |             .get(&presence_id)? | ||||||
|             .map(|value| { |             .map(|value| { | ||||||
|                 let mut presence = serde_json::from_slice::<PresenceEvent>(&value) |                 let mut presence = serde_json::from_slice::<PresenceEvent>(&value) | ||||||
|                     .map_err(|_| Error::bad_database("Invalid presence event in db."))?; |                     .map_err(|_| Error::bad_database("Invalid presence event in db."))?; | ||||||
|  | @ -438,7 +433,6 @@ impl RoomEdus { | ||||||
|         for (user_id_bytes, last_timestamp) in self |         for (user_id_bytes, last_timestamp) in self | ||||||
|             .userid_lastpresenceupdate |             .userid_lastpresenceupdate | ||||||
|             .iter() |             .iter() | ||||||
|             .filter_map(|r| r.ok()) |  | ||||||
|             .filter_map(|(k, bytes)| { |             .filter_map(|(k, bytes)| { | ||||||
|                 Some(( |                 Some(( | ||||||
|                     k, |                     k, | ||||||
|  | @ -468,8 +462,8 @@ impl RoomEdus { | ||||||
|                 presence_id.extend_from_slice(&user_id_bytes); |                 presence_id.extend_from_slice(&user_id_bytes); | ||||||
| 
 | 
 | ||||||
|                 self.presenceid_presence.insert( |                 self.presenceid_presence.insert( | ||||||
|                     presence_id, |                     &presence_id, | ||||||
|                     &*serde_json::to_string(&PresenceEvent { |                     &serde_json::to_vec(&PresenceEvent { | ||||||
|                         content: PresenceEventContent { |                         content: PresenceEventContent { | ||||||
|                             avatar_url: None, |                             avatar_url: None, | ||||||
|                             currently_active: None, |                             currently_active: None, | ||||||
|  | @ -515,8 +509,7 @@ impl RoomEdus { | ||||||
| 
 | 
 | ||||||
|         for (key, value) in self |         for (key, value) in self | ||||||
|             .presenceid_presence |             .presenceid_presence | ||||||
|             .range(&*first_possible_edu..) |             .iter_from(&*first_possible_edu, false) | ||||||
|             .filter_map(|r| r.ok()) |  | ||||||
|             .take_while(|(key, _)| key.starts_with(&prefix)) |             .take_while(|(key, _)| key.starts_with(&prefix)) | ||||||
|         { |         { | ||||||
|             let user_id = UserId::try_from( |             let user_id = UserId::try_from( | ||||||
|  |  | ||||||
|  | @ -12,7 +12,10 @@ use crate::{ | ||||||
| use federation::transactions::send_transaction_message; | use federation::transactions::send_transaction_message; | ||||||
| use log::{error, warn}; | use log::{error, warn}; | ||||||
| use ring::digest; | use ring::digest; | ||||||
| use rocket::futures::stream::{FuturesUnordered, StreamExt}; | use rocket::futures::{ | ||||||
|  |     channel::mpsc, | ||||||
|  |     stream::{FuturesUnordered, StreamExt}, | ||||||
|  | }; | ||||||
| use ruma::{ | use ruma::{ | ||||||
|     api::{ |     api::{ | ||||||
|         appservice, |         appservice, | ||||||
|  | @ -27,9 +30,10 @@ use ruma::{ | ||||||
|     receipt::ReceiptType, |     receipt::ReceiptType, | ||||||
|     MilliSecondsSinceUnixEpoch, ServerName, UInt, UserId, |     MilliSecondsSinceUnixEpoch, ServerName, UInt, UserId, | ||||||
| }; | }; | ||||||
| use sled::IVec; |  | ||||||
| use tokio::{select, sync::Semaphore}; | use tokio::{select, sync::Semaphore}; | ||||||
| 
 | 
 | ||||||
|  | use super::abstraction::Tree; | ||||||
|  | 
 | ||||||
| #[derive(Clone, Debug, PartialEq, Eq, Hash)] | #[derive(Clone, Debug, PartialEq, Eq, Hash)] | ||||||
| pub enum OutgoingKind { | pub enum OutgoingKind { | ||||||
|     Appservice(Box<ServerName>), |     Appservice(Box<ServerName>), | ||||||
|  | @ -70,13 +74,13 @@ pub enum SendingEventType { | ||||||
|     Edu(Vec<u8>), |     Edu(Vec<u8>), | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| #[derive(Clone)] |  | ||||||
| pub struct Sending { | pub struct Sending { | ||||||
|     /// The state for a given state hash.
 |     /// The state for a given state hash.
 | ||||||
|     pub(super) servername_educount: sled::Tree, // EduCount: Count of last EDU sync
 |     pub(super) servername_educount: Arc<dyn Tree>, // EduCount: Count of last EDU sync
 | ||||||
|     pub(super) servernamepduids: sled::Tree, // ServernamePduId = (+ / $)SenderKey / ServerName / UserId + PduId
 |     pub(super) servernamepduids: Arc<dyn Tree>, // ServernamePduId = (+ / $)SenderKey / ServerName / UserId + PduId
 | ||||||
|     pub(super) servercurrentevents: sled::Tree, // ServerCurrentEvents = (+ / $)ServerName / UserId + PduId / (*)EduEvent
 |     pub(super) servercurrentevents: Arc<dyn Tree>, // ServerCurrentEvents = (+ / $)ServerName / UserId + PduId / (*)EduEvent
 | ||||||
|     pub(super) maximum_requests: Arc<Semaphore>, |     pub(super) maximum_requests: Arc<Semaphore>, | ||||||
|  |     pub sender: mpsc::UnboundedSender<Vec<u8>>, | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| enum TransactionStatus { | enum TransactionStatus { | ||||||
|  | @ -86,28 +90,23 @@ enum TransactionStatus { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| impl Sending { | impl Sending { | ||||||
|     pub fn start_handler(&self, db: &Database) { |     pub fn start_handler(&self, db: Arc<Database>, mut receiver: mpsc::UnboundedReceiver<Vec<u8>>) { | ||||||
|         let servernamepduids = self.servernamepduids.clone(); |  | ||||||
|         let servercurrentevents = self.servercurrentevents.clone(); |  | ||||||
| 
 |  | ||||||
|         let db = db.clone(); |  | ||||||
| 
 |  | ||||||
|         tokio::spawn(async move { |         tokio::spawn(async move { | ||||||
|             let mut futures = FuturesUnordered::new(); |             let mut futures = FuturesUnordered::new(); | ||||||
| 
 | 
 | ||||||
|             // Retry requests we could not finish yet
 |  | ||||||
|             let mut subscriber = servernamepduids.watch_prefix(b""); |  | ||||||
|             let mut current_transaction_status = HashMap::<Vec<u8>, TransactionStatus>::new(); |             let mut current_transaction_status = HashMap::<Vec<u8>, TransactionStatus>::new(); | ||||||
| 
 | 
 | ||||||
|  |             // Retry requests we could not finish yet
 | ||||||
|             let mut initial_transactions = HashMap::<OutgoingKind, Vec<SendingEventType>>::new(); |             let mut initial_transactions = HashMap::<OutgoingKind, Vec<SendingEventType>>::new(); | ||||||
|             for (key, outgoing_kind, event) in servercurrentevents |             for (key, outgoing_kind, event) in | ||||||
|                 .iter() |                 db.sending | ||||||
|                 .filter_map(|r| r.ok()) |                     .servercurrentevents | ||||||
|                 .filter_map(|(key, _)| { |                     .iter() | ||||||
|                     Self::parse_servercurrentevent(&key) |                     .filter_map(|(key, _)| { | ||||||
|                         .ok() |                         Self::parse_servercurrentevent(&key) | ||||||
|                         .map(|(k, e)| (key, k, e)) |                             .ok() | ||||||
|                 }) |                             .map(|(k, e)| (key, k, e)) | ||||||
|  |                     }) | ||||||
|             { |             { | ||||||
|                 let entry = initial_transactions |                 let entry = initial_transactions | ||||||
|                     .entry(outgoing_kind.clone()) |                     .entry(outgoing_kind.clone()) | ||||||
|  | @ -118,7 +117,7 @@ impl Sending { | ||||||
|                         "Dropping some current events: {:?} {:?} {:?}", |                         "Dropping some current events: {:?} {:?} {:?}", | ||||||
|                         key, outgoing_kind, event |                         key, outgoing_kind, event | ||||||
|                     ); |                     ); | ||||||
|                     servercurrentevents.remove(key).unwrap(); |                     db.sending.servercurrentevents.remove(&key).unwrap(); | ||||||
|                     continue; |                     continue; | ||||||
|                 } |                 } | ||||||
| 
 | 
 | ||||||
|  | @ -137,20 +136,16 @@ impl Sending { | ||||||
|                         match response { |                         match response { | ||||||
|                             Ok(outgoing_kind) => { |                             Ok(outgoing_kind) => { | ||||||
|                                 let prefix = outgoing_kind.get_prefix(); |                                 let prefix = outgoing_kind.get_prefix(); | ||||||
|                                 for key in servercurrentevents |                                 for (key, _) in db.sending.servercurrentevents | ||||||
|                                     .scan_prefix(&prefix) |                                     .scan_prefix(prefix.clone()) | ||||||
|                                     .keys() |  | ||||||
|                                     .filter_map(|r| r.ok()) |  | ||||||
|                                 { |                                 { | ||||||
|                                     servercurrentevents.remove(key).unwrap(); |                                     db.sending.servercurrentevents.remove(&key).unwrap(); | ||||||
|                                 } |                                 } | ||||||
| 
 | 
 | ||||||
|                                 // Find events that have been added since starting the last request
 |                                 // Find events that have been added since starting the last request
 | ||||||
|                                 let new_events = servernamepduids |                                 let new_events = db.sending.servernamepduids | ||||||
|                                     .scan_prefix(&prefix) |                                     .scan_prefix(prefix.clone()) | ||||||
|                                     .keys() |                                     .map(|(k, _)| { | ||||||
|                                     .filter_map(|r| r.ok()) |  | ||||||
|                                     .map(|k| { |  | ||||||
|                                         SendingEventType::Pdu(k[prefix.len()..].to_vec()) |                                         SendingEventType::Pdu(k[prefix.len()..].to_vec()) | ||||||
|                                     }) |                                     }) | ||||||
|                                     .take(30) |                                     .take(30) | ||||||
|  | @ -166,8 +161,8 @@ impl Sending { | ||||||
|                                             SendingEventType::Pdu(b) | |                                             SendingEventType::Pdu(b) | | ||||||
|                                             SendingEventType::Edu(b) => { |                                             SendingEventType::Edu(b) => { | ||||||
|                                                 current_key.extend_from_slice(&b); |                                                 current_key.extend_from_slice(&b); | ||||||
|                                                 servercurrentevents.insert(¤t_key, &[]).unwrap(); |                                                 db.sending.servercurrentevents.insert(¤t_key, &[]).unwrap(); | ||||||
|                                                 servernamepduids.remove(¤t_key).unwrap(); |                                                 db.sending.servernamepduids.remove(¤t_key).unwrap(); | ||||||
|                                              } |                                              } | ||||||
|                                         } |                                         } | ||||||
|                                     } |                                     } | ||||||
|  | @ -195,18 +190,15 @@ impl Sending { | ||||||
|                             } |                             } | ||||||
|                         }; |                         }; | ||||||
|                     }, |                     }, | ||||||
|                     Some(event) = &mut subscriber => { |                     Some(key) = receiver.next() => { | ||||||
|                         // New sled version:
 |                         if let Ok((outgoing_kind, event)) = Self::parse_servercurrentevent(&key) { | ||||||
|                         //for (_tree, key, value_opt) in &event {
 |                             if let Ok(Some(events)) = Self::select_events( | ||||||
|                         //    if value_opt.is_none() {
 |                                 &outgoing_kind, | ||||||
|                         //        continue;
 |                                 vec![(event, key)], | ||||||
|                         //    }
 |                                 &mut current_transaction_status, | ||||||
| 
 |                                 &db | ||||||
|                         if let sled::Event::Insert { key, .. } = event { |                             ) { | ||||||
|                             if let Ok((outgoing_kind, event)) = Self::parse_servercurrentevent(&key) { |                                 futures.push(Self::handle_events(outgoing_kind, events, &db)); | ||||||
|                                 if let Some(events) = Self::select_events(&outgoing_kind, vec![(event, key)], &mut current_transaction_status, &servercurrentevents, &servernamepduids, &db) { |  | ||||||
|                                     futures.push(Self::handle_events(outgoing_kind, events, &db)); |  | ||||||
|                                 } |  | ||||||
|                             } |                             } | ||||||
|                         } |                         } | ||||||
|                     } |                     } | ||||||
|  | @ -217,12 +209,10 @@ impl Sending { | ||||||
| 
 | 
 | ||||||
|     fn select_events( |     fn select_events( | ||||||
|         outgoing_kind: &OutgoingKind, |         outgoing_kind: &OutgoingKind, | ||||||
|         new_events: Vec<(SendingEventType, IVec)>, // Events we want to send: event and full key
 |         new_events: Vec<(SendingEventType, Vec<u8>)>, // Events we want to send: event and full key
 | ||||||
|         current_transaction_status: &mut HashMap<Vec<u8>, TransactionStatus>, |         current_transaction_status: &mut HashMap<Vec<u8>, TransactionStatus>, | ||||||
|         servercurrentevents: &sled::Tree, |  | ||||||
|         servernamepduids: &sled::Tree, |  | ||||||
|         db: &Database, |         db: &Database, | ||||||
|     ) -> Option<Vec<SendingEventType>> { |     ) -> Result<Option<Vec<SendingEventType>>> { | ||||||
|         let mut retry = false; |         let mut retry = false; | ||||||
|         let mut allow = true; |         let mut allow = true; | ||||||
| 
 | 
 | ||||||
|  | @ -252,29 +242,25 @@ impl Sending { | ||||||
|             .or_insert(TransactionStatus::Running); |             .or_insert(TransactionStatus::Running); | ||||||
| 
 | 
 | ||||||
|         if !allow { |         if !allow { | ||||||
|             return None; |             return Ok(None); | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         let mut events = Vec::new(); |         let mut events = Vec::new(); | ||||||
| 
 | 
 | ||||||
|         if retry { |         if retry { | ||||||
|             // We retry the previous transaction
 |             // We retry the previous transaction
 | ||||||
|             for key in servercurrentevents |             for (key, _) in db.sending.servercurrentevents.scan_prefix(prefix) { | ||||||
|                 .scan_prefix(&prefix) |  | ||||||
|                 .keys() |  | ||||||
|                 .filter_map(|r| r.ok()) |  | ||||||
|             { |  | ||||||
|                 if let Ok((_, e)) = Self::parse_servercurrentevent(&key) { |                 if let Ok((_, e)) = Self::parse_servercurrentevent(&key) { | ||||||
|                     events.push(e); |                     events.push(e); | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|         } else { |         } else { | ||||||
|             for (e, full_key) in new_events { |             for (e, full_key) in new_events { | ||||||
|                 servercurrentevents.insert(&full_key, &[]).unwrap(); |                 db.sending.servercurrentevents.insert(&full_key, &[])?; | ||||||
| 
 | 
 | ||||||
|                 // If it was a PDU we have to unqueue it
 |                 // If it was a PDU we have to unqueue it
 | ||||||
|                 // TODO: don't try to unqueue EDUs
 |                 // TODO: don't try to unqueue EDUs
 | ||||||
|                 servernamepduids.remove(&full_key).unwrap(); |                 db.sending.servernamepduids.remove(&full_key)?; | ||||||
| 
 | 
 | ||||||
|                 events.push(e); |                 events.push(e); | ||||||
|             } |             } | ||||||
|  | @ -284,13 +270,12 @@ impl Sending { | ||||||
|                     events.extend_from_slice(&select_edus); |                     events.extend_from_slice(&select_edus); | ||||||
|                     db.sending |                     db.sending | ||||||
|                         .servername_educount |                         .servername_educount | ||||||
|                         .insert(server_name.as_bytes(), &last_count.to_be_bytes()) |                         .insert(server_name.as_bytes(), &last_count.to_be_bytes())?; | ||||||
|                         .unwrap(); |  | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         Some(events) |         Ok(Some(events)) | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     pub fn select_edus(db: &Database, server: &ServerName) -> Result<(Vec<SendingEventType>, u64)> { |     pub fn select_edus(db: &Database, server: &ServerName) -> Result<(Vec<SendingEventType>, u64)> { | ||||||
|  | @ -307,7 +292,7 @@ impl Sending { | ||||||
|         let mut max_edu_count = since; |         let mut max_edu_count = since; | ||||||
|         'outer: for room_id in db.rooms.server_rooms(server) { |         'outer: for room_id in db.rooms.server_rooms(server) { | ||||||
|             let room_id = room_id?; |             let room_id = room_id?; | ||||||
|             for r in db.rooms.edus.readreceipts_since(&room_id, since)? { |             for r in db.rooms.edus.readreceipts_since(&room_id, since) { | ||||||
|                 let (user_id, count, read_receipt) = r?; |                 let (user_id, count, read_receipt) = r?; | ||||||
| 
 | 
 | ||||||
|                 if count > max_edu_count { |                 if count > max_edu_count { | ||||||
|  | @ -372,12 +357,13 @@ impl Sending { | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     #[tracing::instrument(skip(self))] |     #[tracing::instrument(skip(self))] | ||||||
|     pub fn send_push_pdu(&self, pdu_id: &[u8], senderkey: IVec) -> Result<()> { |     pub fn send_push_pdu(&self, pdu_id: &[u8], senderkey: Box<[u8]>) -> Result<()> { | ||||||
|         let mut key = b"$".to_vec(); |         let mut key = b"$".to_vec(); | ||||||
|         key.extend_from_slice(&senderkey); |         key.extend_from_slice(&senderkey); | ||||||
|         key.push(0xff); |         key.push(0xff); | ||||||
|         key.extend_from_slice(pdu_id); |         key.extend_from_slice(pdu_id); | ||||||
|         self.servernamepduids.insert(key, b"")?; |         self.servernamepduids.insert(&key, b"")?; | ||||||
|  |         self.sender.unbounded_send(key).unwrap(); | ||||||
| 
 | 
 | ||||||
|         Ok(()) |         Ok(()) | ||||||
|     } |     } | ||||||
|  | @ -387,7 +373,8 @@ impl Sending { | ||||||
|         let mut key = server.as_bytes().to_vec(); |         let mut key = server.as_bytes().to_vec(); | ||||||
|         key.push(0xff); |         key.push(0xff); | ||||||
|         key.extend_from_slice(pdu_id); |         key.extend_from_slice(pdu_id); | ||||||
|         self.servernamepduids.insert(key, b"")?; |         self.servernamepduids.insert(&key, b"")?; | ||||||
|  |         self.sender.unbounded_send(key).unwrap(); | ||||||
| 
 | 
 | ||||||
|         Ok(()) |         Ok(()) | ||||||
|     } |     } | ||||||
|  | @ -398,7 +385,8 @@ impl Sending { | ||||||
|         key.extend_from_slice(appservice_id.as_bytes()); |         key.extend_from_slice(appservice_id.as_bytes()); | ||||||
|         key.push(0xff); |         key.push(0xff); | ||||||
|         key.extend_from_slice(pdu_id); |         key.extend_from_slice(pdu_id); | ||||||
|         self.servernamepduids.insert(key, b"")?; |         self.servernamepduids.insert(&key, b"")?; | ||||||
|  |         self.sender.unbounded_send(key).unwrap(); | ||||||
| 
 | 
 | ||||||
|         Ok(()) |         Ok(()) | ||||||
|     } |     } | ||||||
|  | @ -641,7 +629,7 @@ impl Sending { | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     fn parse_servercurrentevent(key: &IVec) -> Result<(OutgoingKind, SendingEventType)> { |     fn parse_servercurrentevent(key: &[u8]) -> Result<(OutgoingKind, SendingEventType)> { | ||||||
|         // Appservices start with a plus
 |         // Appservices start with a plus
 | ||||||
|         Ok::<_, Error>(if key.starts_with(b"+") { |         Ok::<_, Error>(if key.starts_with(b"+") { | ||||||
|             let mut parts = key[1..].splitn(2, |&b| b == 0xff); |             let mut parts = key[1..].splitn(2, |&b| b == 0xff); | ||||||
|  |  | ||||||
|  | @ -1,10 +1,12 @@ | ||||||
|  | use std::sync::Arc; | ||||||
|  | 
 | ||||||
| use crate::Result; | use crate::Result; | ||||||
| use ruma::{DeviceId, UserId}; | use ruma::{DeviceId, UserId}; | ||||||
| use sled::IVec; |  | ||||||
| 
 | 
 | ||||||
| #[derive(Clone)] | use super::abstraction::Tree; | ||||||
|  | 
 | ||||||
| pub struct TransactionIds { | pub struct TransactionIds { | ||||||
|     pub(super) userdevicetxnid_response: sled::Tree, // Response can be empty (/sendToDevice) or the event id (/send)
 |     pub(super) userdevicetxnid_response: Arc<dyn Tree>, // Response can be empty (/sendToDevice) or the event id (/send)
 | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| impl TransactionIds { | impl TransactionIds { | ||||||
|  | @ -21,7 +23,7 @@ impl TransactionIds { | ||||||
|         key.push(0xff); |         key.push(0xff); | ||||||
|         key.extend_from_slice(txn_id.as_bytes()); |         key.extend_from_slice(txn_id.as_bytes()); | ||||||
| 
 | 
 | ||||||
|         self.userdevicetxnid_response.insert(key, data)?; |         self.userdevicetxnid_response.insert(&key, data)?; | ||||||
| 
 | 
 | ||||||
|         Ok(()) |         Ok(()) | ||||||
|     } |     } | ||||||
|  | @ -31,7 +33,7 @@ impl TransactionIds { | ||||||
|         user_id: &UserId, |         user_id: &UserId, | ||||||
|         device_id: Option<&DeviceId>, |         device_id: Option<&DeviceId>, | ||||||
|         txn_id: &str, |         txn_id: &str, | ||||||
|     ) -> Result<Option<IVec>> { |     ) -> Result<Option<Vec<u8>>> { | ||||||
|         let mut key = user_id.as_bytes().to_vec(); |         let mut key = user_id.as_bytes().to_vec(); | ||||||
|         key.push(0xff); |         key.push(0xff); | ||||||
|         key.extend_from_slice(device_id.map(|d| d.as_bytes()).unwrap_or_default()); |         key.extend_from_slice(device_id.map(|d| d.as_bytes()).unwrap_or_default()); | ||||||
|  | @ -39,6 +41,6 @@ impl TransactionIds { | ||||||
|         key.extend_from_slice(txn_id.as_bytes()); |         key.extend_from_slice(txn_id.as_bytes()); | ||||||
| 
 | 
 | ||||||
|         // If there's no entry, this is a new transaction
 |         // If there's no entry, this is a new transaction
 | ||||||
|         Ok(self.userdevicetxnid_response.get(key)?) |         Ok(self.userdevicetxnid_response.get(&key)?) | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -1,3 +1,5 @@ | ||||||
|  | use std::sync::Arc; | ||||||
|  | 
 | ||||||
| use crate::{client_server::SESSION_ID_LENGTH, utils, Error, Result}; | use crate::{client_server::SESSION_ID_LENGTH, utils, Error, Result}; | ||||||
| use ruma::{ | use ruma::{ | ||||||
|     api::client::{ |     api::client::{ | ||||||
|  | @ -8,10 +10,11 @@ use ruma::{ | ||||||
|     DeviceId, UserId, |     DeviceId, UserId, | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| #[derive(Clone)] | use super::abstraction::Tree; | ||||||
|  | 
 | ||||||
| pub struct Uiaa { | pub struct Uiaa { | ||||||
|     pub(super) userdevicesessionid_uiaainfo: sled::Tree, // User-interactive authentication
 |     pub(super) userdevicesessionid_uiaainfo: Arc<dyn Tree>, // User-interactive authentication
 | ||||||
|     pub(super) userdevicesessionid_uiaarequest: sled::Tree, // UiaaRequest = canonical json value
 |     pub(super) userdevicesessionid_uiaarequest: Arc<dyn Tree>, // UiaaRequest = canonical json value
 | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| impl Uiaa { | impl Uiaa { | ||||||
|  | @ -185,7 +188,7 @@ impl Uiaa { | ||||||
| 
 | 
 | ||||||
|         self.userdevicesessionid_uiaarequest.insert( |         self.userdevicesessionid_uiaarequest.insert( | ||||||
|             &userdevicesessionid, |             &userdevicesessionid, | ||||||
|             &*serde_json::to_string(request).expect("json value to string always works"), |             &serde_json::to_vec(request).expect("json value to vec always works"), | ||||||
|         )?; |         )?; | ||||||
| 
 | 
 | ||||||
|         Ok(()) |         Ok(()) | ||||||
|  | @ -233,7 +236,7 @@ impl Uiaa { | ||||||
|         if let Some(uiaainfo) = uiaainfo { |         if let Some(uiaainfo) = uiaainfo { | ||||||
|             self.userdevicesessionid_uiaainfo.insert( |             self.userdevicesessionid_uiaainfo.insert( | ||||||
|                 &userdevicesessionid, |                 &userdevicesessionid, | ||||||
|                 &*serde_json::to_string(&uiaainfo).expect("UiaaInfo::to_string always works"), |                 &serde_json::to_vec(&uiaainfo).expect("UiaaInfo::to_vec always works"), | ||||||
|             )?; |             )?; | ||||||
|         } else { |         } else { | ||||||
|             self.userdevicesessionid_uiaainfo |             self.userdevicesessionid_uiaainfo | ||||||
|  |  | ||||||
|  | @ -7,40 +7,41 @@ use ruma::{ | ||||||
|     serde::Raw, |     serde::Raw, | ||||||
|     DeviceId, DeviceKeyAlgorithm, DeviceKeyId, MilliSecondsSinceUnixEpoch, UInt, UserId, |     DeviceId, DeviceKeyAlgorithm, DeviceKeyId, MilliSecondsSinceUnixEpoch, UInt, UserId, | ||||||
| }; | }; | ||||||
| use std::{collections::BTreeMap, convert::TryFrom, mem}; | use std::{collections::BTreeMap, convert::TryFrom, mem, sync::Arc}; | ||||||
|  | 
 | ||||||
|  | use super::abstraction::Tree; | ||||||
| 
 | 
 | ||||||
| #[derive(Clone)] |  | ||||||
| pub struct Users { | pub struct Users { | ||||||
|     pub(super) userid_password: sled::Tree, |     pub(super) userid_password: Arc<dyn Tree>, | ||||||
|     pub(super) userid_displayname: sled::Tree, |     pub(super) userid_displayname: Arc<dyn Tree>, | ||||||
|     pub(super) userid_avatarurl: sled::Tree, |     pub(super) userid_avatarurl: Arc<dyn Tree>, | ||||||
|     pub(super) userdeviceid_token: sled::Tree, |     pub(super) userdeviceid_token: Arc<dyn Tree>, | ||||||
|     pub(super) userdeviceid_metadata: sled::Tree, // This is also used to check if a device exists
 |     pub(super) userdeviceid_metadata: Arc<dyn Tree>, // This is also used to check if a device exists
 | ||||||
|     pub(super) userid_devicelistversion: sled::Tree, // DevicelistVersion = u64
 |     pub(super) userid_devicelistversion: Arc<dyn Tree>, // DevicelistVersion = u64
 | ||||||
|     pub(super) token_userdeviceid: sled::Tree, |     pub(super) token_userdeviceid: Arc<dyn Tree>, | ||||||
| 
 | 
 | ||||||
|     pub(super) onetimekeyid_onetimekeys: sled::Tree, // OneTimeKeyId = UserId + DeviceKeyId
 |     pub(super) onetimekeyid_onetimekeys: Arc<dyn Tree>, // OneTimeKeyId = UserId + DeviceKeyId
 | ||||||
|     pub(super) userid_lastonetimekeyupdate: sled::Tree, // LastOneTimeKeyUpdate = Count
 |     pub(super) userid_lastonetimekeyupdate: Arc<dyn Tree>, // LastOneTimeKeyUpdate = Count
 | ||||||
|     pub(super) keychangeid_userid: sled::Tree,       // KeyChangeId = UserId/RoomId + Count
 |     pub(super) keychangeid_userid: Arc<dyn Tree>,       // KeyChangeId = UserId/RoomId + Count
 | ||||||
|     pub(super) keyid_key: sled::Tree,                // KeyId = UserId + KeyId (depends on key type)
 |     pub(super) keyid_key: Arc<dyn Tree>, // KeyId = UserId + KeyId (depends on key type)
 | ||||||
|     pub(super) userid_masterkeyid: sled::Tree, |     pub(super) userid_masterkeyid: Arc<dyn Tree>, | ||||||
|     pub(super) userid_selfsigningkeyid: sled::Tree, |     pub(super) userid_selfsigningkeyid: Arc<dyn Tree>, | ||||||
|     pub(super) userid_usersigningkeyid: sled::Tree, |     pub(super) userid_usersigningkeyid: Arc<dyn Tree>, | ||||||
| 
 | 
 | ||||||
|     pub(super) todeviceid_events: sled::Tree, // ToDeviceId = UserId + DeviceId + Count
 |     pub(super) todeviceid_events: Arc<dyn Tree>, // ToDeviceId = UserId + DeviceId + Count
 | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| impl Users { | impl Users { | ||||||
|     /// Check if a user has an account on this homeserver.
 |     /// Check if a user has an account on this homeserver.
 | ||||||
|     pub fn exists(&self, user_id: &UserId) -> Result<bool> { |     pub fn exists(&self, user_id: &UserId) -> Result<bool> { | ||||||
|         Ok(self.userid_password.contains_key(user_id.to_string())?) |         Ok(self.userid_password.get(user_id.as_bytes())?.is_some()) | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     /// Check if account is deactivated
 |     /// Check if account is deactivated
 | ||||||
|     pub fn is_deactivated(&self, user_id: &UserId) -> Result<bool> { |     pub fn is_deactivated(&self, user_id: &UserId) -> Result<bool> { | ||||||
|         Ok(self |         Ok(self | ||||||
|             .userid_password |             .userid_password | ||||||
|             .get(user_id.to_string())? |             .get(user_id.as_bytes())? | ||||||
|             .ok_or(Error::BadRequest( |             .ok_or(Error::BadRequest( | ||||||
|                 ErrorKind::InvalidParam, |                 ErrorKind::InvalidParam, | ||||||
|                 "User does not exist.", |                 "User does not exist.", | ||||||
|  | @ -55,14 +56,14 @@ impl Users { | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     /// Returns the number of users registered on this server.
 |     /// Returns the number of users registered on this server.
 | ||||||
|     pub fn count(&self) -> usize { |     pub fn count(&self) -> Result<usize> { | ||||||
|         self.userid_password.iter().count() |         Ok(self.userid_password.iter().count()) | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     /// Find out which user an access token belongs to.
 |     /// Find out which user an access token belongs to.
 | ||||||
|     pub fn find_from_token(&self, token: &str) -> Result<Option<(UserId, String)>> { |     pub fn find_from_token(&self, token: &str) -> Result<Option<(UserId, String)>> { | ||||||
|         self.token_userdeviceid |         self.token_userdeviceid | ||||||
|             .get(token)? |             .get(token.as_bytes())? | ||||||
|             .map_or(Ok(None), |bytes| { |             .map_or(Ok(None), |bytes| { | ||||||
|                 let mut parts = bytes.split(|&b| b == 0xff); |                 let mut parts = bytes.split(|&b| b == 0xff); | ||||||
|                 let user_bytes = parts.next().ok_or_else(|| { |                 let user_bytes = parts.next().ok_or_else(|| { | ||||||
|  | @ -87,10 +88,10 @@ impl Users { | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     /// Returns an iterator over all users on this homeserver.
 |     /// Returns an iterator over all users on this homeserver.
 | ||||||
|     pub fn iter(&self) -> impl Iterator<Item = Result<UserId>> { |     pub fn iter<'a>(&'a self) -> impl Iterator<Item = Result<UserId>> + 'a { | ||||||
|         self.userid_password.iter().keys().map(|bytes| { |         self.userid_password.iter().map(|(bytes, _)| { | ||||||
|             Ok( |             Ok( | ||||||
|                 UserId::try_from(utils::string_from_bytes(&bytes?).map_err(|_| { |                 UserId::try_from(utils::string_from_bytes(&bytes).map_err(|_| { | ||||||
|                     Error::bad_database("User ID in userid_password is invalid unicode.") |                     Error::bad_database("User ID in userid_password is invalid unicode.") | ||||||
|                 })?) |                 })?) | ||||||
|                 .map_err(|_| Error::bad_database("User ID in userid_password is invalid."))?, |                 .map_err(|_| Error::bad_database("User ID in userid_password is invalid."))?, | ||||||
|  | @ -101,7 +102,7 @@ impl Users { | ||||||
|     /// Returns the password hash for the given user.
 |     /// Returns the password hash for the given user.
 | ||||||
|     pub fn password_hash(&self, user_id: &UserId) -> Result<Option<String>> { |     pub fn password_hash(&self, user_id: &UserId) -> Result<Option<String>> { | ||||||
|         self.userid_password |         self.userid_password | ||||||
|             .get(user_id.to_string())? |             .get(user_id.as_bytes())? | ||||||
|             .map_or(Ok(None), |bytes| { |             .map_or(Ok(None), |bytes| { | ||||||
|                 Ok(Some(utils::string_from_bytes(&bytes).map_err(|_| { |                 Ok(Some(utils::string_from_bytes(&bytes).map_err(|_| { | ||||||
|                     Error::bad_database("Password hash in db is not valid string.") |                     Error::bad_database("Password hash in db is not valid string.") | ||||||
|  | @ -113,7 +114,8 @@ impl Users { | ||||||
|     pub fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { |     pub fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { | ||||||
|         if let Some(password) = password { |         if let Some(password) = password { | ||||||
|             if let Ok(hash) = utils::calculate_hash(&password) { |             if let Ok(hash) = utils::calculate_hash(&password) { | ||||||
|                 self.userid_password.insert(user_id.to_string(), &*hash)?; |                 self.userid_password | ||||||
|  |                     .insert(user_id.as_bytes(), hash.as_bytes())?; | ||||||
|                 Ok(()) |                 Ok(()) | ||||||
|             } else { |             } else { | ||||||
|                 Err(Error::BadRequest( |                 Err(Error::BadRequest( | ||||||
|  | @ -122,7 +124,7 @@ impl Users { | ||||||
|                 )) |                 )) | ||||||
|             } |             } | ||||||
|         } else { |         } else { | ||||||
|             self.userid_password.insert(user_id.to_string(), "")?; |             self.userid_password.insert(user_id.as_bytes(), b"")?; | ||||||
|             Ok(()) |             Ok(()) | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  | @ -130,7 +132,7 @@ impl Users { | ||||||
|     /// Returns the displayname of a user on this homeserver.
 |     /// Returns the displayname of a user on this homeserver.
 | ||||||
|     pub fn displayname(&self, user_id: &UserId) -> Result<Option<String>> { |     pub fn displayname(&self, user_id: &UserId) -> Result<Option<String>> { | ||||||
|         self.userid_displayname |         self.userid_displayname | ||||||
|             .get(user_id.to_string())? |             .get(user_id.as_bytes())? | ||||||
|             .map_or(Ok(None), |bytes| { |             .map_or(Ok(None), |bytes| { | ||||||
|                 Ok(Some(utils::string_from_bytes(&bytes).map_err(|_| { |                 Ok(Some(utils::string_from_bytes(&bytes).map_err(|_| { | ||||||
|                     Error::bad_database("Displayname in db is invalid.") |                     Error::bad_database("Displayname in db is invalid.") | ||||||
|  | @ -142,9 +144,9 @@ impl Users { | ||||||
|     pub fn set_displayname(&self, user_id: &UserId, displayname: Option<String>) -> Result<()> { |     pub fn set_displayname(&self, user_id: &UserId, displayname: Option<String>) -> Result<()> { | ||||||
|         if let Some(displayname) = displayname { |         if let Some(displayname) = displayname { | ||||||
|             self.userid_displayname |             self.userid_displayname | ||||||
|                 .insert(user_id.to_string(), &*displayname)?; |                 .insert(user_id.as_bytes(), displayname.as_bytes())?; | ||||||
|         } else { |         } else { | ||||||
|             self.userid_displayname.remove(user_id.to_string())?; |             self.userid_displayname.remove(user_id.as_bytes())?; | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         Ok(()) |         Ok(()) | ||||||
|  | @ -153,7 +155,7 @@ impl Users { | ||||||
|     /// Get a the avatar_url of a user.
 |     /// Get a the avatar_url of a user.
 | ||||||
|     pub fn avatar_url(&self, user_id: &UserId) -> Result<Option<MxcUri>> { |     pub fn avatar_url(&self, user_id: &UserId) -> Result<Option<MxcUri>> { | ||||||
|         self.userid_avatarurl |         self.userid_avatarurl | ||||||
|             .get(user_id.to_string())? |             .get(user_id.as_bytes())? | ||||||
|             .map(|bytes| { |             .map(|bytes| { | ||||||
|                 let s = utils::string_from_bytes(&bytes) |                 let s = utils::string_from_bytes(&bytes) | ||||||
|                     .map_err(|_| Error::bad_database("Avatar URL in db is invalid."))?; |                     .map_err(|_| Error::bad_database("Avatar URL in db is invalid."))?; | ||||||
|  | @ -166,9 +168,9 @@ impl Users { | ||||||
|     pub fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option<MxcUri>) -> Result<()> { |     pub fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option<MxcUri>) -> Result<()> { | ||||||
|         if let Some(avatar_url) = avatar_url { |         if let Some(avatar_url) = avatar_url { | ||||||
|             self.userid_avatarurl |             self.userid_avatarurl | ||||||
|                 .insert(user_id.to_string(), avatar_url.to_string().as_str())?; |                 .insert(user_id.as_bytes(), avatar_url.to_string().as_bytes())?; | ||||||
|         } else { |         } else { | ||||||
|             self.userid_avatarurl.remove(user_id.to_string())?; |             self.userid_avatarurl.remove(user_id.as_bytes())?; | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         Ok(()) |         Ok(()) | ||||||
|  | @ -190,19 +192,17 @@ impl Users { | ||||||
|         userdeviceid.extend_from_slice(device_id.as_bytes()); |         userdeviceid.extend_from_slice(device_id.as_bytes()); | ||||||
| 
 | 
 | ||||||
|         self.userid_devicelistversion |         self.userid_devicelistversion | ||||||
|             .update_and_fetch(&user_id.as_bytes(), utils::increment)? |             .increment(user_id.as_bytes())?; | ||||||
|             .expect("utils::increment will always put in a value"); |  | ||||||
| 
 | 
 | ||||||
|         self.userdeviceid_metadata.insert( |         self.userdeviceid_metadata.insert( | ||||||
|             userdeviceid, |             &userdeviceid, | ||||||
|             serde_json::to_string(&Device { |             &serde_json::to_vec(&Device { | ||||||
|                 device_id: device_id.into(), |                 device_id: device_id.into(), | ||||||
|                 display_name: initial_device_display_name, |                 display_name: initial_device_display_name, | ||||||
|                 last_seen_ip: None, // TODO
 |                 last_seen_ip: None, // TODO
 | ||||||
|                 last_seen_ts: Some(MilliSecondsSinceUnixEpoch::now()), |                 last_seen_ts: Some(MilliSecondsSinceUnixEpoch::now()), | ||||||
|             }) |             }) | ||||||
|             .expect("Device::to_string never fails.") |             .expect("Device::to_string never fails."), | ||||||
|             .as_bytes(), |  | ||||||
|         )?; |         )?; | ||||||
| 
 | 
 | ||||||
|         self.set_token(user_id, &device_id, token)?; |         self.set_token(user_id, &device_id, token)?; | ||||||
|  | @ -217,7 +217,8 @@ impl Users { | ||||||
|         userdeviceid.extend_from_slice(device_id.as_bytes()); |         userdeviceid.extend_from_slice(device_id.as_bytes()); | ||||||
| 
 | 
 | ||||||
|         // Remove tokens
 |         // Remove tokens
 | ||||||
|         if let Some(old_token) = self.userdeviceid_token.remove(&userdeviceid)? { |         if let Some(old_token) = self.userdeviceid_token.get(&userdeviceid)? { | ||||||
|  |             self.userdeviceid_token.remove(&userdeviceid)?; | ||||||
|             self.token_userdeviceid.remove(&old_token)?; |             self.token_userdeviceid.remove(&old_token)?; | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|  | @ -225,15 +226,14 @@ impl Users { | ||||||
|         let mut prefix = userdeviceid.clone(); |         let mut prefix = userdeviceid.clone(); | ||||||
|         prefix.push(0xff); |         prefix.push(0xff); | ||||||
| 
 | 
 | ||||||
|         for key in self.todeviceid_events.scan_prefix(&prefix).keys() { |         for (key, _) in self.todeviceid_events.scan_prefix(prefix) { | ||||||
|             self.todeviceid_events.remove(key?)?; |             self.todeviceid_events.remove(&key)?; | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         // TODO: Remove onetimekeys
 |         // TODO: Remove onetimekeys
 | ||||||
| 
 | 
 | ||||||
|         self.userid_devicelistversion |         self.userid_devicelistversion | ||||||
|             .update_and_fetch(&user_id.as_bytes(), utils::increment)? |             .increment(user_id.as_bytes())?; | ||||||
|             .expect("utils::increment will always put in a value"); |  | ||||||
| 
 | 
 | ||||||
|         self.userdeviceid_metadata.remove(&userdeviceid)?; |         self.userdeviceid_metadata.remove(&userdeviceid)?; | ||||||
| 
 | 
 | ||||||
|  | @ -241,16 +241,18 @@ impl Users { | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     /// Returns an iterator over all device ids of this user.
 |     /// Returns an iterator over all device ids of this user.
 | ||||||
|     pub fn all_device_ids(&self, user_id: &UserId) -> impl Iterator<Item = Result<Box<DeviceId>>> { |     pub fn all_device_ids<'a>( | ||||||
|  |         &'a self, | ||||||
|  |         user_id: &UserId, | ||||||
|  |     ) -> impl Iterator<Item = Result<Box<DeviceId>>> + 'a { | ||||||
|         let mut prefix = user_id.as_bytes().to_vec(); |         let mut prefix = user_id.as_bytes().to_vec(); | ||||||
|         prefix.push(0xff); |         prefix.push(0xff); | ||||||
|         // All devices have metadata
 |         // All devices have metadata
 | ||||||
|         self.userdeviceid_metadata |         self.userdeviceid_metadata | ||||||
|             .scan_prefix(prefix) |             .scan_prefix(prefix) | ||||||
|             .keys() |             .map(|(bytes, _)| { | ||||||
|             .map(|bytes| { |  | ||||||
|                 Ok(utils::string_from_bytes( |                 Ok(utils::string_from_bytes( | ||||||
|                     &*bytes? |                     &bytes | ||||||
|                         .rsplit(|&b| b == 0xff) |                         .rsplit(|&b| b == 0xff) | ||||||
|                         .next() |                         .next() | ||||||
|                         .ok_or_else(|| Error::bad_database("UserDevice ID in db is invalid."))?, |                         .ok_or_else(|| Error::bad_database("UserDevice ID in db is invalid."))?, | ||||||
|  | @ -271,13 +273,15 @@ impl Users { | ||||||
| 
 | 
 | ||||||
|         // Remove old token
 |         // Remove old token
 | ||||||
|         if let Some(old_token) = self.userdeviceid_token.get(&userdeviceid)? { |         if let Some(old_token) = self.userdeviceid_token.get(&userdeviceid)? { | ||||||
|             self.token_userdeviceid.remove(old_token)?; |             self.token_userdeviceid.remove(&old_token)?; | ||||||
|             // It will be removed from userdeviceid_token by the insert later
 |             // It will be removed from userdeviceid_token by the insert later
 | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         // Assign token to user device combination
 |         // Assign token to user device combination
 | ||||||
|         self.userdeviceid_token.insert(&userdeviceid, &*token)?; |         self.userdeviceid_token | ||||||
|         self.token_userdeviceid.insert(token, userdeviceid)?; |             .insert(&userdeviceid, token.as_bytes())?; | ||||||
|  |         self.token_userdeviceid | ||||||
|  |             .insert(token.as_bytes(), &userdeviceid)?; | ||||||
| 
 | 
 | ||||||
|         Ok(()) |         Ok(()) | ||||||
|     } |     } | ||||||
|  | @ -309,8 +313,7 @@ impl Users { | ||||||
| 
 | 
 | ||||||
|         self.onetimekeyid_onetimekeys.insert( |         self.onetimekeyid_onetimekeys.insert( | ||||||
|             &key, |             &key, | ||||||
|             &*serde_json::to_string(&one_time_key_value) |             &serde_json::to_vec(&one_time_key_value).expect("OneTimeKey::to_vec always works"), | ||||||
|                 .expect("OneTimeKey::to_string always works"), |  | ||||||
|         )?; |         )?; | ||||||
| 
 | 
 | ||||||
|         self.userid_lastonetimekeyupdate |         self.userid_lastonetimekeyupdate | ||||||
|  | @ -350,10 +353,9 @@ impl Users { | ||||||
|             .insert(&user_id.as_bytes(), &globals.next_count()?.to_be_bytes())?; |             .insert(&user_id.as_bytes(), &globals.next_count()?.to_be_bytes())?; | ||||||
| 
 | 
 | ||||||
|         self.onetimekeyid_onetimekeys |         self.onetimekeyid_onetimekeys | ||||||
|             .scan_prefix(&prefix) |             .scan_prefix(prefix) | ||||||
|             .next() |             .next() | ||||||
|             .map(|r| { |             .map(|(key, value)| { | ||||||
|                 let (key, value) = r?; |  | ||||||
|                 self.onetimekeyid_onetimekeys.remove(&key)?; |                 self.onetimekeyid_onetimekeys.remove(&key)?; | ||||||
| 
 | 
 | ||||||
|                 Ok(( |                 Ok(( | ||||||
|  | @ -383,21 +385,20 @@ impl Users { | ||||||
| 
 | 
 | ||||||
|         let mut counts = BTreeMap::new(); |         let mut counts = BTreeMap::new(); | ||||||
| 
 | 
 | ||||||
|         for algorithm in self |         for algorithm in | ||||||
|             .onetimekeyid_onetimekeys |             self.onetimekeyid_onetimekeys | ||||||
|             .scan_prefix(&userdeviceid) |                 .scan_prefix(userdeviceid) | ||||||
|             .keys() |                 .map(|(bytes, _)| { | ||||||
|             .map(|bytes| { |                     Ok::<_, Error>( | ||||||
|                 Ok::<_, Error>( |                         serde_json::from_slice::<DeviceKeyId>( | ||||||
|                     serde_json::from_slice::<DeviceKeyId>( |                             &*bytes.rsplit(|&b| b == 0xff).next().ok_or_else(|| { | ||||||
|                         &*bytes?.rsplit(|&b| b == 0xff).next().ok_or_else(|| { |                                 Error::bad_database("OneTimeKey ID in db is invalid.") | ||||||
|                             Error::bad_database("OneTimeKey ID in db is invalid.") |                             })?, | ||||||
|                         })?, |                         ) | ||||||
|  |                         .map_err(|_| Error::bad_database("DeviceKeyId in db is invalid."))? | ||||||
|  |                         .algorithm(), | ||||||
|                     ) |                     ) | ||||||
|                     .map_err(|_| Error::bad_database("DeviceKeyId in db is invalid."))? |                 }) | ||||||
|                     .algorithm(), |  | ||||||
|                 ) |  | ||||||
|             }) |  | ||||||
|         { |         { | ||||||
|             *counts.entry(algorithm?).or_default() += UInt::from(1_u32); |             *counts.entry(algorithm?).or_default() += UInt::from(1_u32); | ||||||
|         } |         } | ||||||
|  | @ -419,7 +420,7 @@ impl Users { | ||||||
| 
 | 
 | ||||||
|         self.keyid_key.insert( |         self.keyid_key.insert( | ||||||
|             &userdeviceid, |             &userdeviceid, | ||||||
|             &*serde_json::to_string(&device_keys).expect("DeviceKeys::to_string always works"), |             &serde_json::to_vec(&device_keys).expect("DeviceKeys::to_vec always works"), | ||||||
|         )?; |         )?; | ||||||
| 
 | 
 | ||||||
|         self.mark_device_key_update(user_id, rooms, globals)?; |         self.mark_device_key_update(user_id, rooms, globals)?; | ||||||
|  | @ -460,11 +461,11 @@ impl Users { | ||||||
| 
 | 
 | ||||||
|         self.keyid_key.insert( |         self.keyid_key.insert( | ||||||
|             &master_key_key, |             &master_key_key, | ||||||
|             &*serde_json::to_string(&master_key).expect("CrossSigningKey::to_string always works"), |             &serde_json::to_vec(&master_key).expect("CrossSigningKey::to_vec always works"), | ||||||
|         )?; |         )?; | ||||||
| 
 | 
 | ||||||
|         self.userid_masterkeyid |         self.userid_masterkeyid | ||||||
|             .insert(&*user_id.to_string(), master_key_key)?; |             .insert(user_id.as_bytes(), &master_key_key)?; | ||||||
| 
 | 
 | ||||||
|         // Self-signing key
 |         // Self-signing key
 | ||||||
|         if let Some(self_signing_key) = self_signing_key { |         if let Some(self_signing_key) = self_signing_key { | ||||||
|  | @ -486,12 +487,12 @@ impl Users { | ||||||
| 
 | 
 | ||||||
|             self.keyid_key.insert( |             self.keyid_key.insert( | ||||||
|                 &self_signing_key_key, |                 &self_signing_key_key, | ||||||
|                 &*serde_json::to_string(&self_signing_key) |                 &serde_json::to_vec(&self_signing_key) | ||||||
|                     .expect("CrossSigningKey::to_string always works"), |                     .expect("CrossSigningKey::to_vec always works"), | ||||||
|             )?; |             )?; | ||||||
| 
 | 
 | ||||||
|             self.userid_selfsigningkeyid |             self.userid_selfsigningkeyid | ||||||
|                 .insert(&*user_id.to_string(), self_signing_key_key)?; |                 .insert(user_id.as_bytes(), &self_signing_key_key)?; | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         // User-signing key
 |         // User-signing key
 | ||||||
|  | @ -514,12 +515,12 @@ impl Users { | ||||||
| 
 | 
 | ||||||
|             self.keyid_key.insert( |             self.keyid_key.insert( | ||||||
|                 &user_signing_key_key, |                 &user_signing_key_key, | ||||||
|                 &*serde_json::to_string(&user_signing_key) |                 &serde_json::to_vec(&user_signing_key) | ||||||
|                     .expect("CrossSigningKey::to_string always works"), |                     .expect("CrossSigningKey::to_vec always works"), | ||||||
|             )?; |             )?; | ||||||
| 
 | 
 | ||||||
|             self.userid_usersigningkeyid |             self.userid_usersigningkeyid | ||||||
|                 .insert(&*user_id.to_string(), user_signing_key_key)?; |                 .insert(user_id.as_bytes(), &user_signing_key_key)?; | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         self.mark_device_key_update(user_id, rooms, globals)?; |         self.mark_device_key_update(user_id, rooms, globals)?; | ||||||
|  | @ -561,8 +562,7 @@ impl Users { | ||||||
| 
 | 
 | ||||||
|         self.keyid_key.insert( |         self.keyid_key.insert( | ||||||
|             &key, |             &key, | ||||||
|             &*serde_json::to_string(&cross_signing_key) |             &serde_json::to_vec(&cross_signing_key).expect("CrossSigningKey::to_vec always works"), | ||||||
|                 .expect("CrossSigningKey::to_string always works"), |  | ||||||
|         )?; |         )?; | ||||||
| 
 | 
 | ||||||
|         // TODO: Should we notify about this change?
 |         // TODO: Should we notify about this change?
 | ||||||
|  | @ -572,24 +572,20 @@ impl Users { | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     #[tracing::instrument(skip(self))] |     #[tracing::instrument(skip(self))] | ||||||
|     pub fn keys_changed( |     pub fn keys_changed<'a>( | ||||||
|         &self, |         &'a self, | ||||||
|         user_or_room_id: &str, |         user_or_room_id: &str, | ||||||
|         from: u64, |         from: u64, | ||||||
|         to: Option<u64>, |         to: Option<u64>, | ||||||
|     ) -> impl Iterator<Item = Result<UserId>> { |     ) -> impl Iterator<Item = Result<UserId>> + 'a { | ||||||
|         let mut prefix = user_or_room_id.as_bytes().to_vec(); |         let mut prefix = user_or_room_id.as_bytes().to_vec(); | ||||||
|         prefix.push(0xff); |         prefix.push(0xff); | ||||||
| 
 | 
 | ||||||
|         let mut start = prefix.clone(); |         let mut start = prefix.clone(); | ||||||
|         start.extend_from_slice(&(from + 1).to_be_bytes()); |         start.extend_from_slice(&(from + 1).to_be_bytes()); | ||||||
| 
 | 
 | ||||||
|         let mut end = prefix.clone(); |  | ||||||
|         end.extend_from_slice(&to.unwrap_or(u64::MAX).to_be_bytes()); |  | ||||||
| 
 |  | ||||||
|         self.keychangeid_userid |         self.keychangeid_userid | ||||||
|             .range(start..end) |             .iter_from(&start, false) | ||||||
|             .filter_map(|r| r.ok()) |  | ||||||
|             .take_while(move |(k, _)| k.starts_with(&prefix)) |             .take_while(move |(k, _)| k.starts_with(&prefix)) | ||||||
|             .map(|(_, bytes)| { |             .map(|(_, bytes)| { | ||||||
|                 Ok( |                 Ok( | ||||||
|  | @ -625,13 +621,13 @@ impl Users { | ||||||
|             key.push(0xff); |             key.push(0xff); | ||||||
|             key.extend_from_slice(&count); |             key.extend_from_slice(&count); | ||||||
| 
 | 
 | ||||||
|             self.keychangeid_userid.insert(key, &*user_id.to_string())?; |             self.keychangeid_userid.insert(&key, user_id.as_bytes())?; | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         let mut key = user_id.as_bytes().to_vec(); |         let mut key = user_id.as_bytes().to_vec(); | ||||||
|         key.push(0xff); |         key.push(0xff); | ||||||
|         key.extend_from_slice(&count); |         key.extend_from_slice(&count); | ||||||
|         self.keychangeid_userid.insert(key, &*user_id.to_string())?; |         self.keychangeid_userid.insert(&key, user_id.as_bytes())?; | ||||||
| 
 | 
 | ||||||
|         Ok(()) |         Ok(()) | ||||||
|     } |     } | ||||||
|  | @ -645,7 +641,7 @@ impl Users { | ||||||
|         key.push(0xff); |         key.push(0xff); | ||||||
|         key.extend_from_slice(device_id.as_bytes()); |         key.extend_from_slice(device_id.as_bytes()); | ||||||
| 
 | 
 | ||||||
|         self.keyid_key.get(key)?.map_or(Ok(None), |bytes| { |         self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| { | ||||||
|             Ok(Some(serde_json::from_slice(&bytes).map_err(|_| { |             Ok(Some(serde_json::from_slice(&bytes).map_err(|_| { | ||||||
|                 Error::bad_database("DeviceKeys in db are invalid.") |                 Error::bad_database("DeviceKeys in db are invalid.") | ||||||
|             })?)) |             })?)) | ||||||
|  | @ -658,9 +654,9 @@ impl Users { | ||||||
|         allowed_signatures: F, |         allowed_signatures: F, | ||||||
|     ) -> Result<Option<CrossSigningKey>> { |     ) -> Result<Option<CrossSigningKey>> { | ||||||
|         self.userid_masterkeyid |         self.userid_masterkeyid | ||||||
|             .get(user_id.to_string())? |             .get(user_id.as_bytes())? | ||||||
|             .map_or(Ok(None), |key| { |             .map_or(Ok(None), |key| { | ||||||
|                 self.keyid_key.get(key)?.map_or(Ok(None), |bytes| { |                 self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| { | ||||||
|                     let mut cross_signing_key = serde_json::from_slice::<CrossSigningKey>(&bytes) |                     let mut cross_signing_key = serde_json::from_slice::<CrossSigningKey>(&bytes) | ||||||
|                         .map_err(|_| { |                         .map_err(|_| { | ||||||
|                         Error::bad_database("CrossSigningKey in db is invalid.") |                         Error::bad_database("CrossSigningKey in db is invalid.") | ||||||
|  | @ -685,9 +681,9 @@ impl Users { | ||||||
|         allowed_signatures: F, |         allowed_signatures: F, | ||||||
|     ) -> Result<Option<CrossSigningKey>> { |     ) -> Result<Option<CrossSigningKey>> { | ||||||
|         self.userid_selfsigningkeyid |         self.userid_selfsigningkeyid | ||||||
|             .get(user_id.to_string())? |             .get(user_id.as_bytes())? | ||||||
|             .map_or(Ok(None), |key| { |             .map_or(Ok(None), |key| { | ||||||
|                 self.keyid_key.get(key)?.map_or(Ok(None), |bytes| { |                 self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| { | ||||||
|                     let mut cross_signing_key = serde_json::from_slice::<CrossSigningKey>(&bytes) |                     let mut cross_signing_key = serde_json::from_slice::<CrossSigningKey>(&bytes) | ||||||
|                         .map_err(|_| { |                         .map_err(|_| { | ||||||
|                         Error::bad_database("CrossSigningKey in db is invalid.") |                         Error::bad_database("CrossSigningKey in db is invalid.") | ||||||
|  | @ -708,9 +704,9 @@ impl Users { | ||||||
| 
 | 
 | ||||||
|     pub fn get_user_signing_key(&self, user_id: &UserId) -> Result<Option<CrossSigningKey>> { |     pub fn get_user_signing_key(&self, user_id: &UserId) -> Result<Option<CrossSigningKey>> { | ||||||
|         self.userid_usersigningkeyid |         self.userid_usersigningkeyid | ||||||
|             .get(user_id.to_string())? |             .get(user_id.as_bytes())? | ||||||
|             .map_or(Ok(None), |key| { |             .map_or(Ok(None), |key| { | ||||||
|                 self.keyid_key.get(key)?.map_or(Ok(None), |bytes| { |                 self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| { | ||||||
|                     Ok(Some(serde_json::from_slice(&bytes).map_err(|_| { |                     Ok(Some(serde_json::from_slice(&bytes).map_err(|_| { | ||||||
|                         Error::bad_database("CrossSigningKey in db is invalid.") |                         Error::bad_database("CrossSigningKey in db is invalid.") | ||||||
|                     })?)) |                     })?)) | ||||||
|  | @ -740,7 +736,7 @@ impl Users { | ||||||
| 
 | 
 | ||||||
|         self.todeviceid_events.insert( |         self.todeviceid_events.insert( | ||||||
|             &key, |             &key, | ||||||
|             &*serde_json::to_string(&json).expect("Map::to_string always works"), |             &serde_json::to_vec(&json).expect("Map::to_vec always works"), | ||||||
|         )?; |         )?; | ||||||
| 
 | 
 | ||||||
|         Ok(()) |         Ok(()) | ||||||
|  | @ -759,9 +755,9 @@ impl Users { | ||||||
|         prefix.extend_from_slice(device_id.as_bytes()); |         prefix.extend_from_slice(device_id.as_bytes()); | ||||||
|         prefix.push(0xff); |         prefix.push(0xff); | ||||||
| 
 | 
 | ||||||
|         for value in self.todeviceid_events.scan_prefix(&prefix).values() { |         for (_, value) in self.todeviceid_events.scan_prefix(prefix) { | ||||||
|             events.push( |             events.push( | ||||||
|                 serde_json::from_slice(&*value?) |                 serde_json::from_slice(&value) | ||||||
|                     .map_err(|_| Error::bad_database("Event in todeviceid_events is invalid."))?, |                     .map_err(|_| Error::bad_database("Event in todeviceid_events is invalid."))?, | ||||||
|             ); |             ); | ||||||
|         } |         } | ||||||
|  | @ -786,10 +782,9 @@ impl Users { | ||||||
| 
 | 
 | ||||||
|         for (key, _) in self |         for (key, _) in self | ||||||
|             .todeviceid_events |             .todeviceid_events | ||||||
|             .range(&*prefix..=&*last) |             .iter_from(&last, true) | ||||||
|             .keys() |             .take_while(move |(k, _)| k.starts_with(&prefix)) | ||||||
|             .map(|key| { |             .map(|(key, _)| { | ||||||
|                 let key = key?; |  | ||||||
|                 Ok::<_, Error>(( |                 Ok::<_, Error>(( | ||||||
|                     key.clone(), |                     key.clone(), | ||||||
|                     utils::u64_from_bytes(&key[key.len() - mem::size_of::<u64>()..key.len()]) |                     utils::u64_from_bytes(&key[key.len() - mem::size_of::<u64>()..key.len()]) | ||||||
|  | @ -799,7 +794,7 @@ impl Users { | ||||||
|             .filter_map(|r| r.ok()) |             .filter_map(|r| r.ok()) | ||||||
|             .take_while(|&(_, count)| count <= until) |             .take_while(|&(_, count)| count <= until) | ||||||
|         { |         { | ||||||
|             self.todeviceid_events.remove(key)?; |             self.todeviceid_events.remove(&key)?; | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         Ok(()) |         Ok(()) | ||||||
|  | @ -819,14 +814,11 @@ impl Users { | ||||||
|         assert!(self.userdeviceid_metadata.get(&userdeviceid)?.is_some()); |         assert!(self.userdeviceid_metadata.get(&userdeviceid)?.is_some()); | ||||||
| 
 | 
 | ||||||
|         self.userid_devicelistversion |         self.userid_devicelistversion | ||||||
|             .update_and_fetch(&user_id.as_bytes(), utils::increment)? |             .increment(user_id.as_bytes())?; | ||||||
|             .expect("utils::increment will always put in a value"); |  | ||||||
| 
 | 
 | ||||||
|         self.userdeviceid_metadata.insert( |         self.userdeviceid_metadata.insert( | ||||||
|             userdeviceid, |             &userdeviceid, | ||||||
|             serde_json::to_string(device) |             &serde_json::to_vec(device).expect("Device::to_string always works"), | ||||||
|                 .expect("Device::to_string always works") |  | ||||||
|                 .as_bytes(), |  | ||||||
|         )?; |         )?; | ||||||
| 
 | 
 | ||||||
|         Ok(()) |         Ok(()) | ||||||
|  | @ -861,15 +853,17 @@ impl Users { | ||||||
|             }) |             }) | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     pub fn all_devices_metadata(&self, user_id: &UserId) -> impl Iterator<Item = Result<Device>> { |     pub fn all_devices_metadata<'a>( | ||||||
|  |         &'a self, | ||||||
|  |         user_id: &UserId, | ||||||
|  |     ) -> impl Iterator<Item = Result<Device>> + 'a { | ||||||
|         let mut key = user_id.as_bytes().to_vec(); |         let mut key = user_id.as_bytes().to_vec(); | ||||||
|         key.push(0xff); |         key.push(0xff); | ||||||
| 
 | 
 | ||||||
|         self.userdeviceid_metadata |         self.userdeviceid_metadata | ||||||
|             .scan_prefix(key) |             .scan_prefix(key) | ||||||
|             .values() |             .map(|(_, bytes)| { | ||||||
|             .map(|bytes| { |                 Ok(serde_json::from_slice::<Device>(&bytes).map_err(|_| { | ||||||
|                 Ok(serde_json::from_slice::<Device>(&bytes?).map_err(|_| { |  | ||||||
|                     Error::bad_database("Device in userdeviceid_metadata is invalid.") |                     Error::bad_database("Device in userdeviceid_metadata is invalid.") | ||||||
|                 })?) |                 })?) | ||||||
|             }) |             }) | ||||||
|  | @ -885,7 +879,7 @@ impl Users { | ||||||
|         // Set the password to "" to indicate a deactivated account. Hashes will never result in an
 |         // Set the password to "" to indicate a deactivated account. Hashes will never result in an
 | ||||||
|         // empty string, so the user will not be able to log in again. Systems like changing the
 |         // empty string, so the user will not be able to log in again. Systems like changing the
 | ||||||
|         // password without logging in should check if the account is deactivated.
 |         // password without logging in should check if the account is deactivated.
 | ||||||
|         self.userid_password.insert(user_id.to_string(), "")?; |         self.userid_password.insert(user_id.as_bytes(), &[])?; | ||||||
| 
 | 
 | ||||||
|         // TODO: Unhook 3PID
 |         // TODO: Unhook 3PID
 | ||||||
|         Ok(()) |         Ok(()) | ||||||
|  |  | ||||||
							
								
								
									
										14
									
								
								src/error.rs
									
									
									
									
									
								
							
							
						
						
									
										14
									
								
								src/error.rs
									
									
									
									
									
								
							|  | @ -23,11 +23,18 @@ pub type Result<T> = std::result::Result<T, Error>; | ||||||
| 
 | 
 | ||||||
| #[derive(Error, Debug)] | #[derive(Error, Debug)] | ||||||
| pub enum Error { | pub enum Error { | ||||||
|     #[error("There was a problem with the connection to the database.")] |     #[cfg(feature = "sled")] | ||||||
|  |     #[error("There was a problem with the connection to the sled database.")] | ||||||
|     SledError { |     SledError { | ||||||
|         #[from] |         #[from] | ||||||
|         source: sled::Error, |         source: sled::Error, | ||||||
|     }, |     }, | ||||||
|  |     #[cfg(feature = "rocksdb")] | ||||||
|  |     #[error("There was a problem with the connection to the rocksdb database: {source}")] | ||||||
|  |     RocksDbError { | ||||||
|  |         #[from] | ||||||
|  |         source: rocksdb::Error, | ||||||
|  |     }, | ||||||
|     #[error("Could not generate an image.")] |     #[error("Could not generate an image.")] | ||||||
|     ImageError { |     ImageError { | ||||||
|         #[from] |         #[from] | ||||||
|  | @ -40,6 +47,11 @@ pub enum Error { | ||||||
|     }, |     }, | ||||||
|     #[error("{0}")] |     #[error("{0}")] | ||||||
|     FederationError(Box<ServerName>, RumaError), |     FederationError(Box<ServerName>, RumaError), | ||||||
|  |     #[error("Could not do this io: {source}")] | ||||||
|  |     IoError { | ||||||
|  |         #[from] | ||||||
|  |         source: std::io::Error, | ||||||
|  |     }, | ||||||
|     #[error("{0}")] |     #[error("{0}")] | ||||||
|     BadServerResponse(&'static str), |     BadServerResponse(&'static str), | ||||||
|     #[error("{0}")] |     #[error("{0}")] | ||||||
|  |  | ||||||
|  | @ -12,6 +12,8 @@ mod pdu; | ||||||
| mod ruma_wrapper; | mod ruma_wrapper; | ||||||
| mod utils; | mod utils; | ||||||
| 
 | 
 | ||||||
|  | use std::sync::Arc; | ||||||
|  | 
 | ||||||
| use database::Config; | use database::Config; | ||||||
| pub use database::Database; | pub use database::Database; | ||||||
| pub use error::{Error, Result}; | pub use error::{Error, Result}; | ||||||
|  | @ -31,7 +33,7 @@ use rocket::{ | ||||||
| use tracing::span; | use tracing::span; | ||||||
| use tracing_subscriber::{prelude::*, Registry}; | use tracing_subscriber::{prelude::*, Registry}; | ||||||
| 
 | 
 | ||||||
| fn setup_rocket(config: Figment, data: Database) -> rocket::Rocket<rocket::Build> { | fn setup_rocket(config: Figment, data: Arc<Database>) -> rocket::Rocket<rocket::Build> { | ||||||
|     rocket::custom(config) |     rocket::custom(config) | ||||||
|         .manage(data) |         .manage(data) | ||||||
|         .mount( |         .mount( | ||||||
|  | @ -197,8 +199,6 @@ async fn main() { | ||||||
|         .await |         .await | ||||||
|         .expect("config is valid"); |         .expect("config is valid"); | ||||||
| 
 | 
 | ||||||
|     db.sending.start_handler(&db); |  | ||||||
| 
 |  | ||||||
|     if config.allow_jaeger { |     if config.allow_jaeger { | ||||||
|         let (tracer, _uninstall) = opentelemetry_jaeger::new_pipeline() |         let (tracer, _uninstall) = opentelemetry_jaeger::new_pipeline() | ||||||
|             .with_service_name("conduit") |             .with_service_name("conduit") | ||||||
|  |  | ||||||
|  | @ -2,13 +2,14 @@ use crate::Error; | ||||||
| use ruma::{ | use ruma::{ | ||||||
|     api::OutgoingResponse, |     api::OutgoingResponse, | ||||||
|     identifiers::{DeviceId, UserId}, |     identifiers::{DeviceId, UserId}, | ||||||
|     Outgoing, |     signatures::CanonicalJsonValue, | ||||||
|  |     Outgoing, ServerName, | ||||||
| }; | }; | ||||||
| use std::ops::Deref; | use std::ops::Deref; | ||||||
| 
 | 
 | ||||||
| #[cfg(feature = "conduit_bin")] | #[cfg(feature = "conduit_bin")] | ||||||
| use { | use { | ||||||
|     crate::server_server, |     crate::{server_server, Database}, | ||||||
|     log::{debug, warn}, |     log::{debug, warn}, | ||||||
|     rocket::{ |     rocket::{ | ||||||
|         data::{self, ByteUnit, Data, FromData}, |         data::{self, ByteUnit, Data, FromData}, | ||||||
|  | @ -18,14 +19,11 @@ use { | ||||||
|         tokio::io::AsyncReadExt, |         tokio::io::AsyncReadExt, | ||||||
|         Request, State, |         Request, State, | ||||||
|     }, |     }, | ||||||
|     ruma::{ |     ruma::api::{AuthScheme, IncomingRequest}, | ||||||
|         api::{AuthScheme, IncomingRequest}, |  | ||||||
|         signatures::CanonicalJsonValue, |  | ||||||
|         ServerName, |  | ||||||
|     }, |  | ||||||
|     std::collections::BTreeMap, |     std::collections::BTreeMap, | ||||||
|     std::convert::TryFrom, |     std::convert::TryFrom, | ||||||
|     std::io::Cursor, |     std::io::Cursor, | ||||||
|  |     std::sync::Arc, | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| /// This struct converts rocket requests into ruma structs by converting them into http requests
 | /// This struct converts rocket requests into ruma structs by converting them into http requests
 | ||||||
|  | @ -51,7 +49,7 @@ where | ||||||
|     async fn from_data(request: &'a Request<'_>, data: Data) -> data::Outcome<Self, Self::Error> { |     async fn from_data(request: &'a Request<'_>, data: Data) -> data::Outcome<Self, Self::Error> { | ||||||
|         let metadata = T::Incoming::METADATA; |         let metadata = T::Incoming::METADATA; | ||||||
|         let db = request |         let db = request | ||||||
|             .guard::<State<'_, crate::Database>>() |             .guard::<State<'_, Arc<Database>>>() | ||||||
|             .await |             .await | ||||||
|             .expect("database was loaded"); |             .expect("database was loaded"); | ||||||
| 
 | 
 | ||||||
|  | @ -75,6 +73,7 @@ where | ||||||
|         )) = db |         )) = db | ||||||
|             .appservice |             .appservice | ||||||
|             .iter_all() |             .iter_all() | ||||||
|  |             .unwrap() | ||||||
|             .filter_map(|r| r.ok()) |             .filter_map(|r| r.ok()) | ||||||
|             .find(|(_id, registration)| { |             .find(|(_id, registration)| { | ||||||
|                 registration |                 registration | ||||||
|  |  | ||||||
|  | @ -433,7 +433,7 @@ pub async fn request_well_known( | ||||||
| #[cfg_attr(feature = "conduit_bin", get("/_matrix/federation/v1/version"))] | #[cfg_attr(feature = "conduit_bin", get("/_matrix/federation/v1/version"))] | ||||||
| #[tracing::instrument(skip(db))] | #[tracing::instrument(skip(db))] | ||||||
| pub fn get_server_version_route( | pub fn get_server_version_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
| ) -> ConduitResult<get_server_version::v1::Response> { | ) -> ConduitResult<get_server_version::v1::Response> { | ||||||
|     if !db.globals.allow_federation() { |     if !db.globals.allow_federation() { | ||||||
|         return Err(Error::bad_config("Federation is disabled.")); |         return Err(Error::bad_config("Federation is disabled.")); | ||||||
|  | @ -451,7 +451,7 @@ pub fn get_server_version_route( | ||||||
| // Response type for this endpoint is Json because we need to calculate a signature for the response
 | // Response type for this endpoint is Json because we need to calculate a signature for the response
 | ||||||
| #[cfg_attr(feature = "conduit_bin", get("/_matrix/key/v2/server"))] | #[cfg_attr(feature = "conduit_bin", get("/_matrix/key/v2/server"))] | ||||||
| #[tracing::instrument(skip(db))] | #[tracing::instrument(skip(db))] | ||||||
| pub fn get_server_keys_route(db: State<'_, Database>) -> Json<String> { | pub fn get_server_keys_route(db: State<'_, Arc<Database>>) -> Json<String> { | ||||||
|     if !db.globals.allow_federation() { |     if !db.globals.allow_federation() { | ||||||
|         // TODO: Use proper types
 |         // TODO: Use proper types
 | ||||||
|         return Json("Federation is disabled.".to_owned()); |         return Json("Federation is disabled.".to_owned()); | ||||||
|  | @ -498,7 +498,7 @@ pub fn get_server_keys_route(db: State<'_, Database>) -> Json<String> { | ||||||
| 
 | 
 | ||||||
| #[cfg_attr(feature = "conduit_bin", get("/_matrix/key/v2/server/<_>"))] | #[cfg_attr(feature = "conduit_bin", get("/_matrix/key/v2/server/<_>"))] | ||||||
| #[tracing::instrument(skip(db))] | #[tracing::instrument(skip(db))] | ||||||
| pub fn get_server_keys_deprecated_route(db: State<'_, Database>) -> Json<String> { | pub fn get_server_keys_deprecated_route(db: State<'_, Arc<Database>>) -> Json<String> { | ||||||
|     get_server_keys_route(db) |     get_server_keys_route(db) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -508,7 +508,7 @@ pub fn get_server_keys_deprecated_route(db: State<'_, Database>) -> Json<String> | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn get_public_rooms_filtered_route( | pub async fn get_public_rooms_filtered_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<get_public_rooms_filtered::v1::Request<'_>>, |     body: Ruma<get_public_rooms_filtered::v1::Request<'_>>, | ||||||
| ) -> ConduitResult<get_public_rooms_filtered::v1::Response> { | ) -> ConduitResult<get_public_rooms_filtered::v1::Response> { | ||||||
|     if !db.globals.allow_federation() { |     if !db.globals.allow_federation() { | ||||||
|  | @ -556,7 +556,7 @@ pub async fn get_public_rooms_filtered_route( | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn get_public_rooms_route( | pub async fn get_public_rooms_route( | ||||||
|     db: State<'_, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<get_public_rooms::v1::Request<'_>>, |     body: Ruma<get_public_rooms::v1::Request<'_>>, | ||||||
| ) -> ConduitResult<get_public_rooms::v1::Response> { | ) -> ConduitResult<get_public_rooms::v1::Response> { | ||||||
|     if !db.globals.allow_federation() { |     if !db.globals.allow_federation() { | ||||||
|  | @ -603,8 +603,8 @@ pub async fn get_public_rooms_route( | ||||||
|     put("/_matrix/federation/v1/send/<_>", data = "<body>") |     put("/_matrix/federation/v1/send/<_>", data = "<body>") | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn send_transaction_message_route<'a>( | pub async fn send_transaction_message_route( | ||||||
|     db: State<'a, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<send_transaction_message::v1::Request<'_>>, |     body: Ruma<send_transaction_message::v1::Request<'_>>, | ||||||
| ) -> ConduitResult<send_transaction_message::v1::Response> { | ) -> ConduitResult<send_transaction_message::v1::Response> { | ||||||
|     if !db.globals.allow_federation() { |     if !db.globals.allow_federation() { | ||||||
|  | @ -1585,7 +1585,7 @@ pub(crate) async fn fetch_signing_keys( | ||||||
|         .await |         .await | ||||||
|     { |     { | ||||||
|         db.globals |         db.globals | ||||||
|             .add_signing_key(origin, &get_keys_response.server_key)?; |             .add_signing_key(origin, get_keys_response.server_key.clone())?; | ||||||
| 
 | 
 | ||||||
|         result.extend( |         result.extend( | ||||||
|             get_keys_response |             get_keys_response | ||||||
|  | @ -1628,7 +1628,7 @@ pub(crate) async fn fetch_signing_keys( | ||||||
|         { |         { | ||||||
|             trace!("Got signing keys: {:?}", keys); |             trace!("Got signing keys: {:?}", keys); | ||||||
|             for k in keys.server_keys { |             for k in keys.server_keys { | ||||||
|                 db.globals.add_signing_key(origin, &k)?; |                 db.globals.add_signing_key(origin, k.clone())?; | ||||||
|                 result.extend( |                 result.extend( | ||||||
|                     k.verify_keys |                     k.verify_keys | ||||||
|                         .into_iter() |                         .into_iter() | ||||||
|  | @ -1681,12 +1681,12 @@ pub(crate) fn append_incoming_pdu( | ||||||
|         pdu, |         pdu, | ||||||
|         pdu_json, |         pdu_json, | ||||||
|         count, |         count, | ||||||
|         pdu_id.clone().into(), |         &pdu_id, | ||||||
|         &new_room_leaves.into_iter().collect::<Vec<_>>(), |         &new_room_leaves.into_iter().collect::<Vec<_>>(), | ||||||
|         &db, |         &db, | ||||||
|     )?; |     )?; | ||||||
| 
 | 
 | ||||||
|     for appservice in db.appservice.iter_all().filter_map(|r| r.ok()) { |     for appservice in db.appservice.iter_all()?.filter_map(|r| r.ok()) { | ||||||
|         if let Some(namespaces) = appservice.1.get("namespaces") { |         if let Some(namespaces) = appservice.1.get("namespaces") { | ||||||
|             let users = namespaces |             let users = namespaces | ||||||
|                 .get("users") |                 .get("users") | ||||||
|  | @ -1758,8 +1758,8 @@ pub(crate) fn append_incoming_pdu( | ||||||
|     get("/_matrix/federation/v1/event/<_>", data = "<body>") |     get("/_matrix/federation/v1/event/<_>", data = "<body>") | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub fn get_event_route<'a>( | pub fn get_event_route( | ||||||
|     db: State<'a, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<get_event::v1::Request<'_>>, |     body: Ruma<get_event::v1::Request<'_>>, | ||||||
| ) -> ConduitResult<get_event::v1::Response> { | ) -> ConduitResult<get_event::v1::Response> { | ||||||
|     if !db.globals.allow_federation() { |     if !db.globals.allow_federation() { | ||||||
|  | @ -1783,8 +1783,8 @@ pub fn get_event_route<'a>( | ||||||
|     post("/_matrix/federation/v1/get_missing_events/<_>", data = "<body>") |     post("/_matrix/federation/v1/get_missing_events/<_>", data = "<body>") | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub fn get_missing_events_route<'a>( | pub fn get_missing_events_route( | ||||||
|     db: State<'a, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<get_missing_events::v1::Request<'_>>, |     body: Ruma<get_missing_events::v1::Request<'_>>, | ||||||
| ) -> ConduitResult<get_missing_events::v1::Response> { | ) -> ConduitResult<get_missing_events::v1::Response> { | ||||||
|     if !db.globals.allow_federation() { |     if !db.globals.allow_federation() { | ||||||
|  | @ -1832,8 +1832,8 @@ pub fn get_missing_events_route<'a>( | ||||||
|     get("/_matrix/federation/v1/state_ids/<_>", data = "<body>") |     get("/_matrix/federation/v1/state_ids/<_>", data = "<body>") | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub fn get_room_state_ids_route<'a>( | pub fn get_room_state_ids_route( | ||||||
|     db: State<'a, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<get_room_state_ids::v1::Request<'_>>, |     body: Ruma<get_room_state_ids::v1::Request<'_>>, | ||||||
| ) -> ConduitResult<get_room_state_ids::v1::Response> { | ) -> ConduitResult<get_room_state_ids::v1::Response> { | ||||||
|     if !db.globals.allow_federation() { |     if !db.globals.allow_federation() { | ||||||
|  | @ -1884,8 +1884,8 @@ pub fn get_room_state_ids_route<'a>( | ||||||
|     get("/_matrix/federation/v1/make_join/<_>/<_>", data = "<body>") |     get("/_matrix/federation/v1/make_join/<_>/<_>", data = "<body>") | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub fn create_join_event_template_route<'a>( | pub fn create_join_event_template_route( | ||||||
|     db: State<'a, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<create_join_event_template::v1::Request<'_>>, |     body: Ruma<create_join_event_template::v1::Request<'_>>, | ||||||
| ) -> ConduitResult<create_join_event_template::v1::Response> { | ) -> ConduitResult<create_join_event_template::v1::Response> { | ||||||
|     if !db.globals.allow_federation() { |     if !db.globals.allow_federation() { | ||||||
|  | @ -2055,8 +2055,8 @@ pub fn create_join_event_template_route<'a>( | ||||||
|     put("/_matrix/federation/v2/send_join/<_>/<_>", data = "<body>") |     put("/_matrix/federation/v2/send_join/<_>/<_>", data = "<body>") | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn create_join_event_route<'a>( | pub async fn create_join_event_route( | ||||||
|     db: State<'a, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<create_join_event::v2::Request<'_>>, |     body: Ruma<create_join_event::v2::Request<'_>>, | ||||||
| ) -> ConduitResult<create_join_event::v2::Response> { | ) -> ConduitResult<create_join_event::v2::Response> { | ||||||
|     if !db.globals.allow_federation() { |     if !db.globals.allow_federation() { | ||||||
|  | @ -2171,8 +2171,8 @@ pub async fn create_join_event_route<'a>( | ||||||
|     put("/_matrix/federation/v2/invite/<_>/<_>", data = "<body>") |     put("/_matrix/federation/v2/invite/<_>/<_>", data = "<body>") | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn create_invite_route<'a>( | pub async fn create_invite_route( | ||||||
|     db: State<'a, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<create_invite::v2::Request>, |     body: Ruma<create_invite::v2::Request>, | ||||||
| ) -> ConduitResult<create_invite::v2::Response> { | ) -> ConduitResult<create_invite::v2::Response> { | ||||||
|     if !db.globals.allow_federation() { |     if !db.globals.allow_federation() { | ||||||
|  | @ -2276,8 +2276,8 @@ pub async fn create_invite_route<'a>( | ||||||
|     get("/_matrix/federation/v1/user/devices/<_>", data = "<body>") |     get("/_matrix/federation/v1/user/devices/<_>", data = "<body>") | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub fn get_devices_route<'a>( | pub fn get_devices_route( | ||||||
|     db: State<'a, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<get_devices::v1::Request<'_>>, |     body: Ruma<get_devices::v1::Request<'_>>, | ||||||
| ) -> ConduitResult<get_devices::v1::Response> { | ) -> ConduitResult<get_devices::v1::Response> { | ||||||
|     if !db.globals.allow_federation() { |     if !db.globals.allow_federation() { | ||||||
|  | @ -2316,8 +2316,8 @@ pub fn get_devices_route<'a>( | ||||||
|     get("/_matrix/federation/v1/query/directory", data = "<body>") |     get("/_matrix/federation/v1/query/directory", data = "<body>") | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub fn get_room_information_route<'a>( | pub fn get_room_information_route( | ||||||
|     db: State<'a, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<get_room_information::v1::Request<'_>>, |     body: Ruma<get_room_information::v1::Request<'_>>, | ||||||
| ) -> ConduitResult<get_room_information::v1::Response> { | ) -> ConduitResult<get_room_information::v1::Response> { | ||||||
|     if !db.globals.allow_federation() { |     if !db.globals.allow_federation() { | ||||||
|  | @ -2344,8 +2344,8 @@ pub fn get_room_information_route<'a>( | ||||||
|     get("/_matrix/federation/v1/query/profile", data = "<body>") |     get("/_matrix/federation/v1/query/profile", data = "<body>") | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub fn get_profile_information_route<'a>( | pub fn get_profile_information_route( | ||||||
|     db: State<'a, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<get_profile_information::v1::Request<'_>>, |     body: Ruma<get_profile_information::v1::Request<'_>>, | ||||||
| ) -> ConduitResult<get_profile_information::v1::Response> { | ) -> ConduitResult<get_profile_information::v1::Response> { | ||||||
|     if !db.globals.allow_federation() { |     if !db.globals.allow_federation() { | ||||||
|  | @ -2378,8 +2378,8 @@ pub fn get_profile_information_route<'a>( | ||||||
|     post("/_matrix/federation/v1/user/keys/query", data = "<body>") |     post("/_matrix/federation/v1/user/keys/query", data = "<body>") | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub fn get_keys_route<'a>( | pub fn get_keys_route( | ||||||
|     db: State<'a, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<get_keys::v1::Request>, |     body: Ruma<get_keys::v1::Request>, | ||||||
| ) -> ConduitResult<get_keys::v1::Response> { | ) -> ConduitResult<get_keys::v1::Response> { | ||||||
|     if !db.globals.allow_federation() { |     if !db.globals.allow_federation() { | ||||||
|  | @ -2406,8 +2406,8 @@ pub fn get_keys_route<'a>( | ||||||
|     post("/_matrix/federation/v1/user/keys/claim", data = "<body>") |     post("/_matrix/federation/v1/user/keys/claim", data = "<body>") | ||||||
| )] | )] | ||||||
| #[tracing::instrument(skip(db, body))] | #[tracing::instrument(skip(db, body))] | ||||||
| pub async fn claim_keys_route<'a>( | pub async fn claim_keys_route( | ||||||
|     db: State<'a, Database>, |     db: State<'_, Arc<Database>>, | ||||||
|     body: Ruma<claim_keys::v1::Request>, |     body: Ruma<claim_keys::v1::Request>, | ||||||
| ) -> ConduitResult<claim_keys::v1::Response> { | ) -> ConduitResult<claim_keys::v1::Response> { | ||||||
|     if !db.globals.allow_federation() { |     if !db.globals.allow_federation() { | ||||||
|  |  | ||||||
							
								
								
									
										27
									
								
								src/utils.rs
									
									
									
									
									
								
							
							
						
						
									
										27
									
								
								src/utils.rs
									
									
									
									
									
								
							|  | @ -15,6 +15,15 @@ pub fn millis_since_unix_epoch() -> u64 { | ||||||
|         .as_millis() as u64 |         .as_millis() as u64 | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | #[cfg(feature = "rocksdb")] | ||||||
|  | pub fn increment_rocksdb( | ||||||
|  |     _new_key: &[u8], | ||||||
|  |     old: Option<&[u8]>, | ||||||
|  |     _operands: &mut rocksdb::MergeOperands, | ||||||
|  | ) -> Option<Vec<u8>> { | ||||||
|  |     increment(old) | ||||||
|  | } | ||||||
|  | 
 | ||||||
| pub fn increment(old: Option<&[u8]>) -> Option<Vec<u8>> { | pub fn increment(old: Option<&[u8]>) -> Option<Vec<u8>> { | ||||||
|     let number = match old.map(|bytes| bytes.try_into()) { |     let number = match old.map(|bytes| bytes.try_into()) { | ||||||
|         Some(Ok(bytes)) => { |         Some(Ok(bytes)) => { | ||||||
|  | @ -27,16 +36,14 @@ pub fn increment(old: Option<&[u8]>) -> Option<Vec<u8>> { | ||||||
|     Some(number.to_be_bytes().to_vec()) |     Some(number.to_be_bytes().to_vec()) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| pub fn generate_keypair(old: Option<&[u8]>) -> Option<Vec<u8>> { | pub fn generate_keypair() -> Vec<u8> { | ||||||
|     Some(old.map(|s| s.to_vec()).unwrap_or_else(|| { |     let mut value = random_string(8).as_bytes().to_vec(); | ||||||
|         let mut value = random_string(8).as_bytes().to_vec(); |     value.push(0xff); | ||||||
|         value.push(0xff); |     value.extend_from_slice( | ||||||
|         value.extend_from_slice( |         &ruma::signatures::Ed25519KeyPair::generate() | ||||||
|             &ruma::signatures::Ed25519KeyPair::generate() |             .expect("Ed25519KeyPair generation always works (?)"), | ||||||
|                 .expect("Ed25519KeyPair generation always works (?)"), |     ); | ||||||
|         ); |     value | ||||||
|         value |  | ||||||
|     })) |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /// Parses the bytes into an u64.
 | /// Parses the bytes into an u64.
 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in a new issue