Merge branch 'speed' into 'master'
Speed See merge request famedly/conduit!168
This commit is contained in:
		
						commit
						00c9ad12bd
					
				
					 11 changed files with 488 additions and 308 deletions
				
			
		
							
								
								
									
										77
									
								
								Cargo.lock
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										77
									
								
								Cargo.lock
									
									
									
										generated
									
									
									
								
							|  | @ -324,9 +324,9 @@ checksum = "ea221b5284a47e40033bf9b66f35f984ec0ea2931eb03505246cd27a963f981b" | ||||||
| 
 | 
 | ||||||
| [[package]] | [[package]] | ||||||
| name = "cpufeatures" | name = "cpufeatures" | ||||||
| version = "0.1.5" | version = "0.2.1" | ||||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
| checksum = "66c99696f6c9dd7f35d486b9d04d7e6e202aa3e8c40d553f2fdf5e7e0c6a71ef" | checksum = "95059428f66df56b63431fdb4e1947ed2190586af5c5a8a8b71122bdf5a7f469" | ||||||
| dependencies = [ | dependencies = [ | ||||||
|  "libc", |  "libc", | ||||||
| ] | ] | ||||||
|  | @ -2061,9 +2061,8 @@ dependencies = [ | ||||||
| 
 | 
 | ||||||
| [[package]] | [[package]] | ||||||
| name = "ruma" | name = "ruma" | ||||||
| version = "0.4.0" | version = "0.3.0" | ||||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | source = "git+https://github.com/DevinR528/ruma?rev=c7860fcb89dbde636e2c83d0636934fb9924f40c#c7860fcb89dbde636e2c83d0636934fb9924f40c" | ||||||
| checksum = "668031e3108d6a2cfbe6eca271d8698f4593440e71a44afdadcf67ce3cb93c1f" |  | ||||||
| dependencies = [ | dependencies = [ | ||||||
|  "assign", |  "assign", | ||||||
|  "js_int", |  "js_int", | ||||||
|  | @ -2084,8 +2083,7 @@ dependencies = [ | ||||||
| [[package]] | [[package]] | ||||||
| name = "ruma-api" | name = "ruma-api" | ||||||
| version = "0.18.3" | version = "0.18.3" | ||||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | source = "git+https://github.com/DevinR528/ruma?rev=c7860fcb89dbde636e2c83d0636934fb9924f40c#c7860fcb89dbde636e2c83d0636934fb9924f40c" | ||||||
| checksum = "f5f1843792b6749ec1ece62595cf99ad30bf9589c96bb237515235e71da396ea" |  | ||||||
| dependencies = [ | dependencies = [ | ||||||
|  "bytes", |  "bytes", | ||||||
|  "http", |  "http", | ||||||
|  | @ -2101,8 +2099,7 @@ dependencies = [ | ||||||
| [[package]] | [[package]] | ||||||
| name = "ruma-api-macros" | name = "ruma-api-macros" | ||||||
| version = "0.18.3" | version = "0.18.3" | ||||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | source = "git+https://github.com/DevinR528/ruma?rev=c7860fcb89dbde636e2c83d0636934fb9924f40c#c7860fcb89dbde636e2c83d0636934fb9924f40c" | ||||||
| checksum = "7b18abda5cca94178d08b622bca042e1cbb5eb7d4ebf3a2a81590a3bb3c57008" |  | ||||||
| dependencies = [ | dependencies = [ | ||||||
|  "proc-macro-crate", |  "proc-macro-crate", | ||||||
|  "proc-macro2", |  "proc-macro2", | ||||||
|  | @ -2113,8 +2110,7 @@ dependencies = [ | ||||||
| [[package]] | [[package]] | ||||||
| name = "ruma-appservice-api" | name = "ruma-appservice-api" | ||||||
| version = "0.4.0" | version = "0.4.0" | ||||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | source = "git+https://github.com/DevinR528/ruma?rev=c7860fcb89dbde636e2c83d0636934fb9924f40c#c7860fcb89dbde636e2c83d0636934fb9924f40c" | ||||||
| checksum = "49369332a5f299e832e19661f92d49e08c345c3c6c4ab16e09cb31c5ff6da878" |  | ||||||
| dependencies = [ | dependencies = [ | ||||||
|  "ruma-api", |  "ruma-api", | ||||||
|  "ruma-common", |  "ruma-common", | ||||||
|  | @ -2128,8 +2124,7 @@ dependencies = [ | ||||||
| [[package]] | [[package]] | ||||||
| name = "ruma-client-api" | name = "ruma-client-api" | ||||||
| version = "0.12.2" | version = "0.12.2" | ||||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | source = "git+https://github.com/DevinR528/ruma?rev=c7860fcb89dbde636e2c83d0636934fb9924f40c#c7860fcb89dbde636e2c83d0636934fb9924f40c" | ||||||
| checksum = "9568a222c12cf6220e751484ab78feec28071f85965113a5bb802936a2920ff0" |  | ||||||
| dependencies = [ | dependencies = [ | ||||||
|  "assign", |  "assign", | ||||||
|  "bytes", |  "bytes", | ||||||
|  | @ -2149,8 +2144,7 @@ dependencies = [ | ||||||
| [[package]] | [[package]] | ||||||
| name = "ruma-common" | name = "ruma-common" | ||||||
| version = "0.6.0" | version = "0.6.0" | ||||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | source = "git+https://github.com/DevinR528/ruma?rev=c7860fcb89dbde636e2c83d0636934fb9924f40c#c7860fcb89dbde636e2c83d0636934fb9924f40c" | ||||||
| checksum = "41d5b7605f58dc0d9cf1848cc7f1af2bae4e4bcd1d2b7a87bbb9864c8a785b91" |  | ||||||
| dependencies = [ | dependencies = [ | ||||||
|  "indexmap", |  "indexmap", | ||||||
|  "js_int", |  "js_int", | ||||||
|  | @ -2164,9 +2158,8 @@ dependencies = [ | ||||||
| 
 | 
 | ||||||
| [[package]] | [[package]] | ||||||
| name = "ruma-events" | name = "ruma-events" | ||||||
| version = "0.24.5" | version = "0.24.4" | ||||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | source = "git+https://github.com/DevinR528/ruma?rev=c7860fcb89dbde636e2c83d0636934fb9924f40c#c7860fcb89dbde636e2c83d0636934fb9924f40c" | ||||||
| checksum = "87801e1207cfebdee02e7997ebf181a1c9837260b78c1b8ce96b896a2bcb3763" |  | ||||||
| dependencies = [ | dependencies = [ | ||||||
|  "indoc", |  "indoc", | ||||||
|  "js_int", |  "js_int", | ||||||
|  | @ -2181,9 +2174,8 @@ dependencies = [ | ||||||
| 
 | 
 | ||||||
| [[package]] | [[package]] | ||||||
| name = "ruma-events-macros" | name = "ruma-events-macros" | ||||||
| version = "0.24.5" | version = "0.24.4" | ||||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | source = "git+https://github.com/DevinR528/ruma?rev=c7860fcb89dbde636e2c83d0636934fb9924f40c#c7860fcb89dbde636e2c83d0636934fb9924f40c" | ||||||
| checksum = "5da4498845347de88adf1b7da4578e2ca7355ad4ce47b0976f6594bacf958660" |  | ||||||
| dependencies = [ | dependencies = [ | ||||||
|  "proc-macro-crate", |  "proc-macro-crate", | ||||||
|  "proc-macro2", |  "proc-macro2", | ||||||
|  | @ -2194,8 +2186,7 @@ dependencies = [ | ||||||
| [[package]] | [[package]] | ||||||
| name = "ruma-federation-api" | name = "ruma-federation-api" | ||||||
| version = "0.3.0" | version = "0.3.0" | ||||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | source = "git+https://github.com/DevinR528/ruma?rev=c7860fcb89dbde636e2c83d0636934fb9924f40c#c7860fcb89dbde636e2c83d0636934fb9924f40c" | ||||||
| checksum = "c61c9adbe1a29c301ae627604406d60102c89fc833b110cd35bbf29ae205ea6c" |  | ||||||
| dependencies = [ | dependencies = [ | ||||||
|  "js_int", |  "js_int", | ||||||
|  "ruma-api", |  "ruma-api", | ||||||
|  | @ -2210,8 +2201,7 @@ dependencies = [ | ||||||
| [[package]] | [[package]] | ||||||
| name = "ruma-identifiers" | name = "ruma-identifiers" | ||||||
| version = "0.20.0" | version = "0.20.0" | ||||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | source = "git+https://github.com/DevinR528/ruma?rev=c7860fcb89dbde636e2c83d0636934fb9924f40c#c7860fcb89dbde636e2c83d0636934fb9924f40c" | ||||||
| checksum = "cb417d091e8dd5a633e4e5998231a156049d7fcc221045cfdc0642eb72067732" |  | ||||||
| dependencies = [ | dependencies = [ | ||||||
|  "paste", |  "paste", | ||||||
|  "rand 0.8.4", |  "rand 0.8.4", | ||||||
|  | @ -2225,8 +2215,7 @@ dependencies = [ | ||||||
| [[package]] | [[package]] | ||||||
| name = "ruma-identifiers-macros" | name = "ruma-identifiers-macros" | ||||||
| version = "0.20.0" | version = "0.20.0" | ||||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | source = "git+https://github.com/DevinR528/ruma?rev=c7860fcb89dbde636e2c83d0636934fb9924f40c#c7860fcb89dbde636e2c83d0636934fb9924f40c" | ||||||
| checksum = "c708edad7f605638f26c951cbad7501fbf28ab01009e5ca65ea5a2db74a882b1" |  | ||||||
| dependencies = [ | dependencies = [ | ||||||
|  "quote", |  "quote", | ||||||
|  "ruma-identifiers-validation", |  "ruma-identifiers-validation", | ||||||
|  | @ -2236,14 +2225,15 @@ dependencies = [ | ||||||
| [[package]] | [[package]] | ||||||
| name = "ruma-identifiers-validation" | name = "ruma-identifiers-validation" | ||||||
| version = "0.5.0" | version = "0.5.0" | ||||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | source = "git+https://github.com/DevinR528/ruma?rev=c7860fcb89dbde636e2c83d0636934fb9924f40c#c7860fcb89dbde636e2c83d0636934fb9924f40c" | ||||||
| checksum = "42285e7fb5d5f2d5268e45bb683e36d5c6fd9fc1e11a4559ba3c3521f3bbb2cb" | dependencies = [ | ||||||
|  |  "thiserror", | ||||||
|  | ] | ||||||
| 
 | 
 | ||||||
| [[package]] | [[package]] | ||||||
| name = "ruma-identity-service-api" | name = "ruma-identity-service-api" | ||||||
| version = "0.3.0" | version = "0.3.0" | ||||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | source = "git+https://github.com/DevinR528/ruma?rev=c7860fcb89dbde636e2c83d0636934fb9924f40c#c7860fcb89dbde636e2c83d0636934fb9924f40c" | ||||||
| checksum = "e76e66e24f2d5a31511fbf6c79e79f67a7a6a98ebf48d72381b7d5bb6c09f035" |  | ||||||
| dependencies = [ | dependencies = [ | ||||||
|  "js_int", |  "js_int", | ||||||
|  "ruma-api", |  "ruma-api", | ||||||
|  | @ -2256,8 +2246,7 @@ dependencies = [ | ||||||
| [[package]] | [[package]] | ||||||
| name = "ruma-push-gateway-api" | name = "ruma-push-gateway-api" | ||||||
| version = "0.3.0" | version = "0.3.0" | ||||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | source = "git+https://github.com/DevinR528/ruma?rev=c7860fcb89dbde636e2c83d0636934fb9924f40c#c7860fcb89dbde636e2c83d0636934fb9924f40c" | ||||||
| checksum = "5ef5b29da7065efc5b1e1a8f61add7543c9ab4ecce5ee0dd1c1c5ecec83fbeec" |  | ||||||
| dependencies = [ | dependencies = [ | ||||||
|  "js_int", |  "js_int", | ||||||
|  "ruma-api", |  "ruma-api", | ||||||
|  | @ -2272,8 +2261,7 @@ dependencies = [ | ||||||
| [[package]] | [[package]] | ||||||
| name = "ruma-serde" | name = "ruma-serde" | ||||||
| version = "0.5.0" | version = "0.5.0" | ||||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | source = "git+https://github.com/DevinR528/ruma?rev=c7860fcb89dbde636e2c83d0636934fb9924f40c#c7860fcb89dbde636e2c83d0636934fb9924f40c" | ||||||
| checksum = "8b2b22aae842e7ecda695e42b7b39d4558959d9d9a27acc2a16acf4f4f7f00c3" |  | ||||||
| dependencies = [ | dependencies = [ | ||||||
|  "bytes", |  "bytes", | ||||||
|  "form_urlencoded", |  "form_urlencoded", | ||||||
|  | @ -2287,8 +2275,7 @@ dependencies = [ | ||||||
| [[package]] | [[package]] | ||||||
| name = "ruma-serde-macros" | name = "ruma-serde-macros" | ||||||
| version = "0.5.0" | version = "0.5.0" | ||||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | source = "git+https://github.com/DevinR528/ruma?rev=c7860fcb89dbde636e2c83d0636934fb9924f40c#c7860fcb89dbde636e2c83d0636934fb9924f40c" | ||||||
| checksum = "243e9bef188b08f94c79bc2f8fd1eb307a9e636b2b8e4571acf8c7be16381d28" |  | ||||||
| dependencies = [ | dependencies = [ | ||||||
|  "proc-macro-crate", |  "proc-macro-crate", | ||||||
|  "proc-macro2", |  "proc-macro2", | ||||||
|  | @ -2299,8 +2286,7 @@ dependencies = [ | ||||||
| [[package]] | [[package]] | ||||||
| name = "ruma-signatures" | name = "ruma-signatures" | ||||||
| version = "0.9.0" | version = "0.9.0" | ||||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | source = "git+https://github.com/DevinR528/ruma?rev=c7860fcb89dbde636e2c83d0636934fb9924f40c#c7860fcb89dbde636e2c83d0636934fb9924f40c" | ||||||
| checksum = "4a4f64027165b59500162d10d435b1253898bf3ad4f5002cb0d56913fe7f76d7" |  | ||||||
| dependencies = [ | dependencies = [ | ||||||
|  "base64 0.13.0", |  "base64 0.13.0", | ||||||
|  "ed25519-dalek", |  "ed25519-dalek", | ||||||
|  | @ -2316,9 +2302,8 @@ dependencies = [ | ||||||
| 
 | 
 | ||||||
| [[package]] | [[package]] | ||||||
| name = "ruma-state-res" | name = "ruma-state-res" | ||||||
| version = "0.4.0" | version = "0.3.0" | ||||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | source = "git+https://github.com/DevinR528/ruma?rev=c7860fcb89dbde636e2c83d0636934fb9924f40c#c7860fcb89dbde636e2c83d0636934fb9924f40c" | ||||||
| checksum = "796427aaa2d266238c5c1b1a6ca4640a4d282ec2cb2e844c69a8f3a262d3db15" |  | ||||||
| dependencies = [ | dependencies = [ | ||||||
|  "itertools 0.10.1", |  "itertools 0.10.1", | ||||||
|  "js_int", |  "js_int", | ||||||
|  | @ -2522,9 +2507,9 @@ dependencies = [ | ||||||
| 
 | 
 | ||||||
| [[package]] | [[package]] | ||||||
| name = "serde_yaml" | name = "serde_yaml" | ||||||
| version = "0.8.19" | version = "0.8.20" | ||||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
| checksum = "6375dbd828ed6964c3748e4ef6d18e7a175d408ffe184bca01698d0c73f915a9" | checksum = "ad104641f3c958dab30eb3010e834c2622d1f3f4c530fef1dee20ad9485f3c09" | ||||||
| dependencies = [ | dependencies = [ | ||||||
|  "dtoa", |  "dtoa", | ||||||
|  "indexmap", |  "indexmap", | ||||||
|  | @ -2540,9 +2525,9 @@ checksum = "2579985fda508104f7587689507983eadd6a6e84dd35d6d115361f530916fa0d" | ||||||
| 
 | 
 | ||||||
| [[package]] | [[package]] | ||||||
| name = "sha2" | name = "sha2" | ||||||
| version = "0.9.5" | version = "0.9.6" | ||||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
| checksum = "b362ae5752fd2137731f9fa25fd4d9058af34666ca1966fb969119cc35719f12" | checksum = "9204c41a1597a8c5af23c82d1c921cb01ec0a4c59e07a9c7306062829a3903f3" | ||||||
| dependencies = [ | dependencies = [ | ||||||
|  "block-buffer", |  "block-buffer", | ||||||
|  "cfg-if 1.0.0", |  "cfg-if 1.0.0", | ||||||
|  |  | ||||||
|  | @ -18,9 +18,9 @@ edition = "2018" | ||||||
| rocket = { version = "0.5.0-rc.1", features = ["tls"] } # Used to handle requests | rocket = { version = "0.5.0-rc.1", features = ["tls"] } # Used to handle requests | ||||||
| 
 | 
 | ||||||
| # Used for matrix spec type definitions and helpers | # Used for matrix spec type definitions and helpers | ||||||
| ruma = { version = "0.4.0", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-pre-spec", "unstable-exhaustive-types"] } | #ruma = { version = "0.4.0", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-pre-spec", "unstable-exhaustive-types"] } | ||||||
| #ruma = { git = "https://github.com/ruma/ruma", rev = "f5ab038e22421ed338396ece977b6b2844772ced", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-pre-spec", "unstable-exhaustive-types"] } | #ruma = { git = "https://github.com/ruma/ruma", rev = "f5ab038e22421ed338396ece977b6b2844772ced", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-pre-spec", "unstable-exhaustive-types"] } | ||||||
| #ruma = { git = "https://github.com/timokoesters/ruma", rev = "2215049b60a1c3358f5a52215adf1e7bb88619a1", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-pre-spec", "unstable-exhaustive-types"] } | ruma = { git = "https://github.com/DevinR528/ruma", rev = "c7860fcb89dbde636e2c83d0636934fb9924f40c", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-pre-spec", "unstable-exhaustive-types"] } | ||||||
| #ruma = { path = "../ruma/crates/ruma", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-pre-spec", "unstable-exhaustive-types"] } | #ruma = { path = "../ruma/crates/ruma", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-pre-spec", "unstable-exhaustive-types"] } | ||||||
| 
 | 
 | ||||||
| # Used for long polling and federation sender, should be the same as rocket::tokio | # Used for long polling and federation sender, should be the same as rocket::tokio | ||||||
|  |  | ||||||
|  | @ -46,7 +46,11 @@ where | ||||||
|     *reqwest_request.timeout_mut() = Some(Duration::from_secs(30)); |     *reqwest_request.timeout_mut() = Some(Duration::from_secs(30)); | ||||||
| 
 | 
 | ||||||
|     let url = reqwest_request.url().clone(); |     let url = reqwest_request.url().clone(); | ||||||
|     let mut response = globals.reqwest_client().execute(reqwest_request).await?; |     let mut response = globals | ||||||
|  |         .reqwest_client()? | ||||||
|  |         .build()? | ||||||
|  |         .execute(reqwest_request) | ||||||
|  |         .await?; | ||||||
| 
 | 
 | ||||||
|     // reqwest::Response -> http::Response conversion
 |     // reqwest::Response -> http::Response conversion
 | ||||||
|     let status = response.status(); |     let status = response.status(); | ||||||
|  |  | ||||||
|  | @ -1,5 +1,6 @@ | ||||||
| use super::SESSION_ID_LENGTH; | use super::SESSION_ID_LENGTH; | ||||||
| use crate::{database::DatabaseGuard, utils, ConduitResult, Database, Error, Result, Ruma}; | use crate::{database::DatabaseGuard, utils, ConduitResult, Database, Error, Result, Ruma}; | ||||||
|  | use rocket::futures::{prelude::*, stream::FuturesUnordered}; | ||||||
| use ruma::{ | use ruma::{ | ||||||
|     api::{ |     api::{ | ||||||
|         client::{ |         client::{ | ||||||
|  | @ -18,7 +19,7 @@ use ruma::{ | ||||||
|     DeviceId, DeviceKeyAlgorithm, UserId, |     DeviceId, DeviceKeyAlgorithm, UserId, | ||||||
| }; | }; | ||||||
| use serde_json::json; | use serde_json::json; | ||||||
| use std::collections::{BTreeMap, HashSet}; | use std::collections::{BTreeMap, HashMap, HashSet}; | ||||||
| 
 | 
 | ||||||
| #[cfg(feature = "conduit_bin")] | #[cfg(feature = "conduit_bin")] | ||||||
| use rocket::{get, post}; | use rocket::{get, post}; | ||||||
|  | @ -294,7 +295,7 @@ pub async fn get_keys_helper<F: Fn(&UserId) -> bool>( | ||||||
|     let mut user_signing_keys = BTreeMap::new(); |     let mut user_signing_keys = BTreeMap::new(); | ||||||
|     let mut device_keys = BTreeMap::new(); |     let mut device_keys = BTreeMap::new(); | ||||||
| 
 | 
 | ||||||
|     let mut get_over_federation = BTreeMap::new(); |     let mut get_over_federation = HashMap::new(); | ||||||
| 
 | 
 | ||||||
|     for (user_id, device_ids) in device_keys_input { |     for (user_id, device_ids) in device_keys_input { | ||||||
|         if user_id.server_name() != db.globals.server_name() { |         if user_id.server_name() != db.globals.server_name() { | ||||||
|  | @ -364,13 +365,16 @@ pub async fn get_keys_helper<F: Fn(&UserId) -> bool>( | ||||||
| 
 | 
 | ||||||
|     let mut failures = BTreeMap::new(); |     let mut failures = BTreeMap::new(); | ||||||
| 
 | 
 | ||||||
|     for (server, vec) in get_over_federation { |     let mut futures = get_over_federation | ||||||
|  |         .into_iter() | ||||||
|  |         .map(|(server, vec)| async move { | ||||||
|             let mut device_keys_input_fed = BTreeMap::new(); |             let mut device_keys_input_fed = BTreeMap::new(); | ||||||
|             for (user_id, keys) in vec { |             for (user_id, keys) in vec { | ||||||
|                 device_keys_input_fed.insert(user_id.clone(), keys.clone()); |                 device_keys_input_fed.insert(user_id.clone(), keys.clone()); | ||||||
|             } |             } | ||||||
|         match db |             ( | ||||||
|             .sending |                 server, | ||||||
|  |                 db.sending | ||||||
|                     .send_federation_request( |                     .send_federation_request( | ||||||
|                         &db.globals, |                         &db.globals, | ||||||
|                         server, |                         server, | ||||||
|  | @ -378,8 +382,13 @@ pub async fn get_keys_helper<F: Fn(&UserId) -> bool>( | ||||||
|                             device_keys: device_keys_input_fed, |                             device_keys: device_keys_input_fed, | ||||||
|                         }, |                         }, | ||||||
|                     ) |                     ) | ||||||
|             .await |                     .await, | ||||||
|         { |             ) | ||||||
|  |         }) | ||||||
|  |         .collect::<FuturesUnordered<_>>(); | ||||||
|  | 
 | ||||||
|  |     while let Some((server, response)) = futures.next().await { | ||||||
|  |         match response { | ||||||
|             Ok(response) => { |             Ok(response) => { | ||||||
|                 master_keys.extend(response.master_keys); |                 master_keys.extend(response.master_keys); | ||||||
|                 self_signing_keys.extend(response.self_signing_keys); |                 self_signing_keys.extend(response.self_signing_keys); | ||||||
|  | @ -430,13 +439,15 @@ pub async fn claim_keys_helper( | ||||||
|         one_time_keys.insert(user_id.clone(), container); |         one_time_keys.insert(user_id.clone(), container); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |     let mut failures = BTreeMap::new(); | ||||||
|  | 
 | ||||||
|     for (server, vec) in get_over_federation { |     for (server, vec) in get_over_federation { | ||||||
|         let mut one_time_keys_input_fed = BTreeMap::new(); |         let mut one_time_keys_input_fed = BTreeMap::new(); | ||||||
|         for (user_id, keys) in vec { |         for (user_id, keys) in vec { | ||||||
|             one_time_keys_input_fed.insert(user_id.clone(), keys.clone()); |             one_time_keys_input_fed.insert(user_id.clone(), keys.clone()); | ||||||
|         } |         } | ||||||
|         // Ignore failures
 |         // Ignore failures
 | ||||||
|         let keys = db |         if let Ok(keys) = db | ||||||
|             .sending |             .sending | ||||||
|             .send_federation_request( |             .send_federation_request( | ||||||
|                 &db.globals, |                 &db.globals, | ||||||
|  | @ -445,13 +456,16 @@ pub async fn claim_keys_helper( | ||||||
|                     one_time_keys: one_time_keys_input_fed, |                     one_time_keys: one_time_keys_input_fed, | ||||||
|                 }, |                 }, | ||||||
|             ) |             ) | ||||||
|             .await?; |             .await | ||||||
| 
 |         { | ||||||
|             one_time_keys.extend(keys.one_time_keys); |             one_time_keys.extend(keys.one_time_keys); | ||||||
|  |         } else { | ||||||
|  |             failures.insert(server.to_string(), json!({})); | ||||||
|  |         } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     Ok(claim_keys::Response { |     Ok(claim_keys::Response { | ||||||
|         failures: BTreeMap::new(), |         failures, | ||||||
|         one_time_keys, |         one_time_keys, | ||||||
|     }) |     }) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -256,8 +256,8 @@ async fn sync_helper( | ||||||
| 
 | 
 | ||||||
|         // Calculates joined_member_count, invited_member_count and heroes
 |         // Calculates joined_member_count, invited_member_count and heroes
 | ||||||
|         let calculate_counts = || { |         let calculate_counts = || { | ||||||
|             let joined_member_count = db.rooms.room_members(&room_id).count(); |             let joined_member_count = db.rooms.room_joined_count(&room_id)?.unwrap_or(0); | ||||||
|             let invited_member_count = db.rooms.room_members_invited(&room_id).count(); |             let invited_member_count = db.rooms.room_invited_count(&room_id)?.unwrap_or(0); | ||||||
| 
 | 
 | ||||||
|             // Recalculate heroes (first 5 members)
 |             // Recalculate heroes (first 5 members)
 | ||||||
|             let mut heroes = Vec::new(); |             let mut heroes = Vec::new(); | ||||||
|  | @ -407,60 +407,35 @@ async fn sync_helper( | ||||||
|                 }); |                 }); | ||||||
| 
 | 
 | ||||||
|             if encrypted_room { |             if encrypted_room { | ||||||
|                 for (user_id, current_member) in db |                 for state_event in &state_events { | ||||||
|                     .rooms |                     if state_event.kind != EventType::RoomMember { | ||||||
|                     .room_members(&room_id) |                         continue; | ||||||
|                     .filter_map(|r| r.ok()) |                     } | ||||||
|                     .filter_map(|user_id| { | 
 | ||||||
|                         db.rooms |                     if let Some(state_key) = &state_event.state_key { | ||||||
|                             .state_get( |                         let user_id = UserId::try_from(state_key.clone()) | ||||||
|                                 current_shortstatehash, |                             .map_err(|_| Error::bad_database("Invalid UserId in member PDU."))?; | ||||||
|                                 &EventType::RoomMember, | 
 | ||||||
|                                 user_id.as_str(), |                         if user_id == sender_user { | ||||||
|                             ) |                             continue; | ||||||
|                             .ok() |                         } | ||||||
|                             .flatten() | 
 | ||||||
|                             .map(|current_member| (user_id, current_member)) |                         let new_membership = serde_json::from_value::< | ||||||
|                     }) |  | ||||||
|                 { |  | ||||||
|                     let current_membership = serde_json::from_value::< |  | ||||||
|                             Raw<ruma::events::room::member::MemberEventContent>, |                             Raw<ruma::events::room::member::MemberEventContent>, | ||||||
|                     >(current_member.content.clone()) |                         >(state_event.content.clone()) | ||||||
|                         .expect("Raw::from_value always works") |                         .expect("Raw::from_value always works") | ||||||
|                         .deserialize() |                         .deserialize() | ||||||
|                         .map_err(|_| Error::bad_database("Invalid PDU in database."))? |                         .map_err(|_| Error::bad_database("Invalid PDU in database."))? | ||||||
|                         .membership; |                         .membership; | ||||||
| 
 | 
 | ||||||
|                     let since_membership = db |                         match new_membership { | ||||||
|                         .rooms |                             MembershipState::Join => { | ||||||
|                         .state_get( |  | ||||||
|                             since_shortstatehash, |  | ||||||
|                             &EventType::RoomMember, |  | ||||||
|                             user_id.as_str(), |  | ||||||
|                         )? |  | ||||||
|                         .and_then(|since_member| { |  | ||||||
|                             serde_json::from_value::< |  | ||||||
|                                 Raw<ruma::events::room::member::MemberEventContent>, |  | ||||||
|                             >(since_member.content.clone()) |  | ||||||
|                             .expect("Raw::from_value always works") |  | ||||||
|                             .deserialize() |  | ||||||
|                             .map_err(|_| Error::bad_database("Invalid PDU in database.")) |  | ||||||
|                             .ok() |  | ||||||
|                         }) |  | ||||||
|                         .map_or(MembershipState::Leave, |member| member.membership); |  | ||||||
| 
 |  | ||||||
|                     let user_id = UserId::try_from(user_id.clone()) |  | ||||||
|                         .map_err(|_| Error::bad_database("Invalid UserId in member PDU."))?; |  | ||||||
| 
 |  | ||||||
|                     match (since_membership, current_membership) { |  | ||||||
|                         (MembershipState::Leave, MembershipState::Join) => { |  | ||||||
|                                 // A new user joined an encrypted room
 |                                 // A new user joined an encrypted room
 | ||||||
|                                 if !share_encrypted_room(&db, &sender_user, &user_id, &room_id)? { |                                 if !share_encrypted_room(&db, &sender_user, &user_id, &room_id)? { | ||||||
|                                     device_list_updates.insert(user_id); |                                     device_list_updates.insert(user_id); | ||||||
|                                 } |                                 } | ||||||
|                             } |                             } | ||||||
|                         // TODO: Remove, this should never happen here, right?
 |                             MembershipState::Leave => { | ||||||
|                         (MembershipState::Join, MembershipState::Leave) => { |  | ||||||
|                                 // Write down users that have left encrypted rooms we are in
 |                                 // Write down users that have left encrypted rooms we are in
 | ||||||
|                                 left_encrypted_users.insert(user_id); |                                 left_encrypted_users.insert(user_id); | ||||||
|                             } |                             } | ||||||
|  | @ -468,6 +443,7 @@ async fn sync_helper( | ||||||
|                         } |                         } | ||||||
|                     } |                     } | ||||||
|                 } |                 } | ||||||
|  |             } | ||||||
| 
 | 
 | ||||||
|             if joined_since_last_sync && encrypted_room || new_encrypted_room { |             if joined_since_last_sync && encrypted_room || new_encrypted_room { | ||||||
|                 // If the user is in a new encrypted room, give them all joined users
 |                 // If the user is in a new encrypted room, give them all joined users
 | ||||||
|  |  | ||||||
|  | @ -252,6 +252,7 @@ impl Database { | ||||||
|                 userroomid_joined: builder.open_tree("userroomid_joined")?, |                 userroomid_joined: builder.open_tree("userroomid_joined")?, | ||||||
|                 roomuserid_joined: builder.open_tree("roomuserid_joined")?, |                 roomuserid_joined: builder.open_tree("roomuserid_joined")?, | ||||||
|                 roomid_joinedcount: builder.open_tree("roomid_joinedcount")?, |                 roomid_joinedcount: builder.open_tree("roomid_joinedcount")?, | ||||||
|  |                 roomid_invitedcount: builder.open_tree("roomid_invitedcount")?, | ||||||
|                 roomuseroncejoinedids: builder.open_tree("roomuseroncejoinedids")?, |                 roomuseroncejoinedids: builder.open_tree("roomuseroncejoinedids")?, | ||||||
|                 userroomid_invitestate: builder.open_tree("userroomid_invitestate")?, |                 userroomid_invitestate: builder.open_tree("userroomid_invitestate")?, | ||||||
|                 roomuserid_invitecount: builder.open_tree("roomuserid_invitecount")?, |                 roomuserid_invitecount: builder.open_tree("roomuserid_invitecount")?, | ||||||
|  | @ -277,6 +278,8 @@ impl Database { | ||||||
|                 statehash_shortstatehash: builder.open_tree("statehash_shortstatehash")?, |                 statehash_shortstatehash: builder.open_tree("statehash_shortstatehash")?, | ||||||
| 
 | 
 | ||||||
|                 eventid_outlierpdu: builder.open_tree("eventid_outlierpdu")?, |                 eventid_outlierpdu: builder.open_tree("eventid_outlierpdu")?, | ||||||
|  |                 softfailedeventids: builder.open_tree("softfailedeventids")?, | ||||||
|  | 
 | ||||||
|                 referencedevents: builder.open_tree("referencedevents")?, |                 referencedevents: builder.open_tree("referencedevents")?, | ||||||
|                 pdu_cache: Mutex::new(LruCache::new(100_000)), |                 pdu_cache: Mutex::new(LruCache::new(100_000)), | ||||||
|                 auth_chain_cache: Mutex::new(LruCache::new(1_000_000)), |                 auth_chain_cache: Mutex::new(LruCache::new(1_000_000)), | ||||||
|  | @ -285,6 +288,7 @@ impl Database { | ||||||
|                 shortstatekey_cache: Mutex::new(LruCache::new(1_000_000)), |                 shortstatekey_cache: Mutex::new(LruCache::new(1_000_000)), | ||||||
|                 statekeyshort_cache: Mutex::new(LruCache::new(1_000_000)), |                 statekeyshort_cache: Mutex::new(LruCache::new(1_000_000)), | ||||||
|                 stateinfo_cache: Mutex::new(LruCache::new(1000)), |                 stateinfo_cache: Mutex::new(LruCache::new(1000)), | ||||||
|  |                 our_real_users_cache: RwLock::new(HashMap::new()), | ||||||
|             }, |             }, | ||||||
|             account_data: account_data::AccountData { |             account_data: account_data::AccountData { | ||||||
|                 roomuserdataid_accountdata: builder.open_tree("roomuserdataid_accountdata")?, |                 roomuserdataid_accountdata: builder.open_tree("roomuserdataid_accountdata")?, | ||||||
|  | @ -442,7 +446,7 @@ impl Database { | ||||||
|                     let room_id = |                     let room_id = | ||||||
|                         RoomId::try_from(utils::string_from_bytes(&roomid).unwrap()).unwrap(); |                         RoomId::try_from(utils::string_from_bytes(&roomid).unwrap()).unwrap(); | ||||||
| 
 | 
 | ||||||
|                     db.rooms.update_joined_count(&room_id)?; |                     db.rooms.update_joined_count(&room_id, &db)?; | ||||||
|                 } |                 } | ||||||
| 
 | 
 | ||||||
|                 db.globals.bump_database_version(6)?; |                 db.globals.bump_database_version(6)?; | ||||||
|  |  | ||||||
|  | @ -93,9 +93,8 @@ impl Engine { | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     pub fn flush_wal(self: &Arc<Self>) -> Result<()> { |     pub fn flush_wal(self: &Arc<Self>) -> Result<()> { | ||||||
|         // We use autocheckpoints
 |         self.write_lock() | ||||||
|         //self.write_lock()
 |             .pragma_update(Some(Main), "wal_checkpoint", &"TRUNCATE")?; | ||||||
|         //.pragma_update(Some(Main), "wal_checkpoint", &"TRUNCATE")?;
 |  | ||||||
|         Ok(()) |         Ok(()) | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -1,4 +1,4 @@ | ||||||
| use crate::{database::Config, utils, ConduitResult, Error, Result}; | use crate::{database::Config, server_server::FedDest, utils, ConduitResult, Error, Result}; | ||||||
| use ruma::{ | use ruma::{ | ||||||
|     api::{ |     api::{ | ||||||
|         client::r0::sync::sync_events, |         client::r0::sync::sync_events, | ||||||
|  | @ -6,25 +6,25 @@ use ruma::{ | ||||||
|     }, |     }, | ||||||
|     DeviceId, EventId, MilliSecondsSinceUnixEpoch, RoomId, ServerName, ServerSigningKeyId, UserId, |     DeviceId, EventId, MilliSecondsSinceUnixEpoch, RoomId, ServerName, ServerSigningKeyId, UserId, | ||||||
| }; | }; | ||||||
| use rustls::{ServerCertVerifier, WebPKIVerifier}; |  | ||||||
| use std::{ | use std::{ | ||||||
|     collections::{BTreeMap, HashMap}, |     collections::{BTreeMap, HashMap}, | ||||||
|     fs, |     fs, | ||||||
|     future::Future, |     future::Future, | ||||||
|  |     net::IpAddr, | ||||||
|     path::PathBuf, |     path::PathBuf, | ||||||
|     sync::{Arc, Mutex, RwLock}, |     sync::{Arc, Mutex, RwLock}, | ||||||
|     time::{Duration, Instant}, |     time::{Duration, Instant}, | ||||||
| }; | }; | ||||||
| use tokio::sync::{broadcast, watch::Receiver, Mutex as TokioMutex, Semaphore}; | use tokio::sync::{broadcast, watch::Receiver, Mutex as TokioMutex, Semaphore}; | ||||||
| use tracing::{error, info}; | use tracing::error; | ||||||
| use trust_dns_resolver::TokioAsyncResolver; | use trust_dns_resolver::TokioAsyncResolver; | ||||||
| 
 | 
 | ||||||
| use super::abstraction::Tree; | use super::abstraction::Tree; | ||||||
| 
 | 
 | ||||||
| pub const COUNTER: &[u8] = b"c"; | pub const COUNTER: &[u8] = b"c"; | ||||||
| 
 | 
 | ||||||
| type WellKnownMap = HashMap<Box<ServerName>, (String, String)>; | type WellKnownMap = HashMap<Box<ServerName>, (FedDest, String)>; | ||||||
| type TlsNameMap = HashMap<String, webpki::DNSName>; | type TlsNameMap = HashMap<String, (Vec<IpAddr>, u16)>; | ||||||
| type RateLimitState = (Instant, u32); // Time if last failed try, number of failed tries
 | type RateLimitState = (Instant, u32); // Time if last failed try, number of failed tries
 | ||||||
| type SyncHandle = ( | type SyncHandle = ( | ||||||
|     Option<String>,                                         // since
 |     Option<String>,                                         // since
 | ||||||
|  | @ -37,7 +37,6 @@ pub struct Globals { | ||||||
|     pub(super) globals: Arc<dyn Tree>, |     pub(super) globals: Arc<dyn Tree>, | ||||||
|     config: Config, |     config: Config, | ||||||
|     keypair: Arc<ruma::signatures::Ed25519KeyPair>, |     keypair: Arc<ruma::signatures::Ed25519KeyPair>, | ||||||
|     reqwest_client: reqwest::Client, |  | ||||||
|     dns_resolver: TokioAsyncResolver, |     dns_resolver: TokioAsyncResolver, | ||||||
|     jwt_decoding_key: Option<jsonwebtoken::DecodingKey<'static>>, |     jwt_decoding_key: Option<jsonwebtoken::DecodingKey<'static>>, | ||||||
|     pub(super) server_signingkeys: Arc<dyn Tree>, |     pub(super) server_signingkeys: Arc<dyn Tree>, | ||||||
|  | @ -51,40 +50,6 @@ pub struct Globals { | ||||||
|     pub rotate: RotationHandler, |     pub rotate: RotationHandler, | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| struct MatrixServerVerifier { |  | ||||||
|     inner: WebPKIVerifier, |  | ||||||
|     tls_name_override: Arc<RwLock<TlsNameMap>>, |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| impl ServerCertVerifier for MatrixServerVerifier { |  | ||||||
|     #[tracing::instrument(skip(self, roots, presented_certs, dns_name, ocsp_response))] |  | ||||||
|     fn verify_server_cert( |  | ||||||
|         &self, |  | ||||||
|         roots: &rustls::RootCertStore, |  | ||||||
|         presented_certs: &[rustls::Certificate], |  | ||||||
|         dns_name: webpki::DNSNameRef<'_>, |  | ||||||
|         ocsp_response: &[u8], |  | ||||||
|     ) -> std::result::Result<rustls::ServerCertVerified, rustls::TLSError> { |  | ||||||
|         if let Some(override_name) = self.tls_name_override.read().unwrap().get(dns_name.into()) { |  | ||||||
|             let result = self.inner.verify_server_cert( |  | ||||||
|                 roots, |  | ||||||
|                 presented_certs, |  | ||||||
|                 override_name.as_ref(), |  | ||||||
|                 ocsp_response, |  | ||||||
|             ); |  | ||||||
|             if result.is_ok() { |  | ||||||
|                 return result; |  | ||||||
|             } |  | ||||||
|             info!( |  | ||||||
|                 "Server {:?} is non-compliant, retrying TLS verification with original name", |  | ||||||
|                 dns_name |  | ||||||
|             ); |  | ||||||
|         } |  | ||||||
|         self.inner |  | ||||||
|             .verify_server_cert(roots, presented_certs, dns_name, ocsp_response) |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| /// Handles "rotation" of long-polling requests. "Rotation" in this context is similar to "rotation" of log files and the like.
 | /// Handles "rotation" of long-polling requests. "Rotation" in this context is similar to "rotation" of log files and the like.
 | ||||||
| ///
 | ///
 | ||||||
| /// This is utilized to have sync workers return early and release read locks on the database.
 | /// This is utilized to have sync workers return early and release read locks on the database.
 | ||||||
|  | @ -162,24 +127,6 @@ impl Globals { | ||||||
|         }; |         }; | ||||||
| 
 | 
 | ||||||
|         let tls_name_override = Arc::new(RwLock::new(TlsNameMap::new())); |         let tls_name_override = Arc::new(RwLock::new(TlsNameMap::new())); | ||||||
|         let verifier = Arc::new(MatrixServerVerifier { |  | ||||||
|             inner: WebPKIVerifier::new(), |  | ||||||
|             tls_name_override: tls_name_override.clone(), |  | ||||||
|         }); |  | ||||||
|         let mut tlsconfig = rustls::ClientConfig::new(); |  | ||||||
|         tlsconfig.dangerous().set_certificate_verifier(verifier); |  | ||||||
|         tlsconfig.root_store = |  | ||||||
|             rustls_native_certs::load_native_certs().expect("Error loading system certificates"); |  | ||||||
| 
 |  | ||||||
|         let mut reqwest_client_builder = reqwest::Client::builder() |  | ||||||
|             .connect_timeout(Duration::from_secs(30)) |  | ||||||
|             .timeout(Duration::from_secs(60 * 3)) |  | ||||||
|             .pool_max_idle_per_host(1) |  | ||||||
|             .use_preconfigured_tls(tlsconfig); |  | ||||||
|         if let Some(proxy) = config.proxy.to_proxy()? { |  | ||||||
|             reqwest_client_builder = reqwest_client_builder.proxy(proxy); |  | ||||||
|         } |  | ||||||
|         let reqwest_client = reqwest_client_builder.build().unwrap(); |  | ||||||
| 
 | 
 | ||||||
|         let jwt_decoding_key = config |         let jwt_decoding_key = config | ||||||
|             .jwt_secret |             .jwt_secret | ||||||
|  | @ -190,7 +137,6 @@ impl Globals { | ||||||
|             globals, |             globals, | ||||||
|             config, |             config, | ||||||
|             keypair: Arc::new(keypair), |             keypair: Arc::new(keypair), | ||||||
|             reqwest_client, |  | ||||||
|             dns_resolver: TokioAsyncResolver::tokio_from_system_conf().map_err(|_| { |             dns_resolver: TokioAsyncResolver::tokio_from_system_conf().map_err(|_| { | ||||||
|                 Error::bad_config("Failed to set up trust dns resolver with system config.") |                 Error::bad_config("Failed to set up trust dns resolver with system config.") | ||||||
|             })?, |             })?, | ||||||
|  | @ -219,8 +165,16 @@ impl Globals { | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     /// Returns a reqwest client which can be used to send requests.
 |     /// Returns a reqwest client which can be used to send requests.
 | ||||||
|     pub fn reqwest_client(&self) -> &reqwest::Client { |     pub fn reqwest_client(&self) -> Result<reqwest::ClientBuilder> { | ||||||
|         &self.reqwest_client |         let mut reqwest_client_builder = reqwest::Client::builder() | ||||||
|  |             .connect_timeout(Duration::from_secs(30)) | ||||||
|  |             .timeout(Duration::from_secs(60 * 3)) | ||||||
|  |             .pool_max_idle_per_host(1); | ||||||
|  |         if let Some(proxy) = self.config.proxy.to_proxy()? { | ||||||
|  |             reqwest_client_builder = reqwest_client_builder.proxy(proxy); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         Ok(reqwest_client_builder) | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     #[tracing::instrument(skip(self))] |     #[tracing::instrument(skip(self))] | ||||||
|  |  | ||||||
|  | @ -113,7 +113,11 @@ where | ||||||
|     //*reqwest_request.timeout_mut() = Some(Duration::from_secs(5));
 |     //*reqwest_request.timeout_mut() = Some(Duration::from_secs(5));
 | ||||||
| 
 | 
 | ||||||
|     let url = reqwest_request.url().clone(); |     let url = reqwest_request.url().clone(); | ||||||
|     let response = globals.reqwest_client().execute(reqwest_request).await; |     let response = globals | ||||||
|  |         .reqwest_client()? | ||||||
|  |         .build()? | ||||||
|  |         .execute(reqwest_request) | ||||||
|  |         .await; | ||||||
| 
 | 
 | ||||||
|     match response { |     match response { | ||||||
|         Ok(mut response) => { |         Ok(mut response) => { | ||||||
|  |  | ||||||
|  | @ -3,7 +3,7 @@ mod edus; | ||||||
| pub use edus::RoomEdus; | pub use edus::RoomEdus; | ||||||
| use member::MembershipState; | use member::MembershipState; | ||||||
| 
 | 
 | ||||||
| use crate::{pdu::PduBuilder, utils, Database, Error, PduEvent, Result}; | use crate::{pdu::PduBuilder, server_server, utils, Database, Error, PduEvent, Result}; | ||||||
| use lru_cache::LruCache; | use lru_cache::LruCache; | ||||||
| use regex::Regex; | use regex::Regex; | ||||||
| use ring::digest; | use ring::digest; | ||||||
|  | @ -26,7 +26,8 @@ use std::{ | ||||||
|     collections::{BTreeMap, HashMap, HashSet}, |     collections::{BTreeMap, HashMap, HashSet}, | ||||||
|     convert::{TryFrom, TryInto}, |     convert::{TryFrom, TryInto}, | ||||||
|     mem::size_of, |     mem::size_of, | ||||||
|     sync::{Arc, Mutex}, |     sync::{Arc, Mutex, RwLock}, | ||||||
|  |     time::Instant, | ||||||
| }; | }; | ||||||
| use tokio::sync::MutexGuard; | use tokio::sync::MutexGuard; | ||||||
| use tracing::{error, warn}; | use tracing::{error, warn}; | ||||||
|  | @ -58,6 +59,7 @@ pub struct Rooms { | ||||||
|     pub(super) userroomid_joined: Arc<dyn Tree>, |     pub(super) userroomid_joined: Arc<dyn Tree>, | ||||||
|     pub(super) roomuserid_joined: Arc<dyn Tree>, |     pub(super) roomuserid_joined: Arc<dyn Tree>, | ||||||
|     pub(super) roomid_joinedcount: Arc<dyn Tree>, |     pub(super) roomid_joinedcount: Arc<dyn Tree>, | ||||||
|  |     pub(super) roomid_invitedcount: Arc<dyn Tree>, | ||||||
|     pub(super) roomuseroncejoinedids: Arc<dyn Tree>, |     pub(super) roomuseroncejoinedids: Arc<dyn Tree>, | ||||||
|     pub(super) userroomid_invitestate: Arc<dyn Tree>, // InviteState = Vec<Raw<Pdu>>
 |     pub(super) userroomid_invitestate: Arc<dyn Tree>, // InviteState = Vec<Raw<Pdu>>
 | ||||||
|     pub(super) roomuserid_invitecount: Arc<dyn Tree>, // InviteCount = Count
 |     pub(super) roomuserid_invitecount: Arc<dyn Tree>, // InviteCount = Count
 | ||||||
|  | @ -89,16 +91,18 @@ pub struct Rooms { | ||||||
|     /// RoomId + EventId -> outlier PDU.
 |     /// RoomId + EventId -> outlier PDU.
 | ||||||
|     /// Any pdu that has passed the steps 1-8 in the incoming event /federation/send/txn.
 |     /// Any pdu that has passed the steps 1-8 in the incoming event /federation/send/txn.
 | ||||||
|     pub(super) eventid_outlierpdu: Arc<dyn Tree>, |     pub(super) eventid_outlierpdu: Arc<dyn Tree>, | ||||||
|  |     pub(super) softfailedeventids: Arc<dyn Tree>, | ||||||
| 
 | 
 | ||||||
|     /// RoomId + EventId -> Parent PDU EventId.
 |     /// RoomId + EventId -> Parent PDU EventId.
 | ||||||
|     pub(super) referencedevents: Arc<dyn Tree>, |     pub(super) referencedevents: Arc<dyn Tree>, | ||||||
| 
 | 
 | ||||||
|     pub(super) pdu_cache: Mutex<LruCache<EventId, Arc<PduEvent>>>, |     pub(super) pdu_cache: Mutex<LruCache<EventId, Arc<PduEvent>>>, | ||||||
|  |     pub(super) shorteventid_cache: Mutex<LruCache<u64, Arc<EventId>>>, | ||||||
|     pub(super) auth_chain_cache: Mutex<LruCache<Vec<u64>, Arc<HashSet<u64>>>>, |     pub(super) auth_chain_cache: Mutex<LruCache<Vec<u64>, Arc<HashSet<u64>>>>, | ||||||
|     pub(super) shorteventid_cache: Mutex<LruCache<u64, EventId>>, |  | ||||||
|     pub(super) eventidshort_cache: Mutex<LruCache<EventId, u64>>, |     pub(super) eventidshort_cache: Mutex<LruCache<EventId, u64>>, | ||||||
|     pub(super) statekeyshort_cache: Mutex<LruCache<(EventType, String), u64>>, |     pub(super) statekeyshort_cache: Mutex<LruCache<(EventType, String), u64>>, | ||||||
|     pub(super) shortstatekey_cache: Mutex<LruCache<u64, (EventType, String)>>, |     pub(super) shortstatekey_cache: Mutex<LruCache<u64, (EventType, String)>>, | ||||||
|  |     pub(super) our_real_users_cache: RwLock<HashMap<RoomId, Arc<HashSet<UserId>>>>, | ||||||
|     pub(super) stateinfo_cache: Mutex< |     pub(super) stateinfo_cache: Mutex< | ||||||
|         LruCache< |         LruCache< | ||||||
|             u64, |             u64, | ||||||
|  | @ -116,7 +120,7 @@ impl Rooms { | ||||||
|     /// Builds a StateMap by iterating over all keys that start
 |     /// Builds a StateMap by iterating over all keys that start
 | ||||||
|     /// with state_hash, this gives the full state for the given state_hash.
 |     /// with state_hash, this gives the full state for the given state_hash.
 | ||||||
|     #[tracing::instrument(skip(self))] |     #[tracing::instrument(skip(self))] | ||||||
|     pub fn state_full_ids(&self, shortstatehash: u64) -> Result<BTreeMap<u64, EventId>> { |     pub fn state_full_ids(&self, shortstatehash: u64) -> Result<BTreeMap<u64, Arc<EventId>>> { | ||||||
|         let full_state = self |         let full_state = self | ||||||
|             .load_shortstatehash_info(shortstatehash)? |             .load_shortstatehash_info(shortstatehash)? | ||||||
|             .pop() |             .pop() | ||||||
|  | @ -167,7 +171,7 @@ impl Rooms { | ||||||
|         shortstatehash: u64, |         shortstatehash: u64, | ||||||
|         event_type: &EventType, |         event_type: &EventType, | ||||||
|         state_key: &str, |         state_key: &str, | ||||||
|     ) -> Result<Option<EventId>> { |     ) -> Result<Option<Arc<EventId>>> { | ||||||
|         let shortstatekey = match self.get_shortstatekey(event_type, state_key)? { |         let shortstatekey = match self.get_shortstatekey(event_type, state_key)? { | ||||||
|             Some(s) => s, |             Some(s) => s, | ||||||
|             None => return Ok(None), |             None => return Ok(None), | ||||||
|  | @ -424,7 +428,7 @@ impl Rooms { | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         self.update_joined_count(room_id)?; |         self.update_joined_count(room_id, &db)?; | ||||||
| 
 | 
 | ||||||
|         self.roomid_shortstatehash |         self.roomid_shortstatehash | ||||||
|             .insert(room_id.as_bytes(), &new_shortstatehash.to_be_bytes())?; |             .insert(room_id.as_bytes(), &new_shortstatehash.to_be_bytes())?; | ||||||
|  | @ -523,7 +527,7 @@ impl Rooms { | ||||||
|     pub fn parse_compressed_state_event( |     pub fn parse_compressed_state_event( | ||||||
|         &self, |         &self, | ||||||
|         compressed_event: CompressedStateEvent, |         compressed_event: CompressedStateEvent, | ||||||
|     ) -> Result<(u64, EventId)> { |     ) -> Result<(u64, Arc<EventId>)> { | ||||||
|         Ok(( |         Ok(( | ||||||
|             utils::u64_from_bytes(&compressed_event[0..size_of::<u64>()]) |             utils::u64_from_bytes(&compressed_event[0..size_of::<u64>()]) | ||||||
|                 .expect("bytes have right length"), |                 .expect("bytes have right length"), | ||||||
|  | @ -839,14 +843,14 @@ impl Rooms { | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     #[tracing::instrument(skip(self))] |     #[tracing::instrument(skip(self))] | ||||||
|     pub fn get_eventid_from_short(&self, shorteventid: u64) -> Result<EventId> { |     pub fn get_eventid_from_short(&self, shorteventid: u64) -> Result<Arc<EventId>> { | ||||||
|         if let Some(id) = self |         if let Some(id) = self | ||||||
|             .shorteventid_cache |             .shorteventid_cache | ||||||
|             .lock() |             .lock() | ||||||
|             .unwrap() |             .unwrap() | ||||||
|             .get_mut(&shorteventid) |             .get_mut(&shorteventid) | ||||||
|         { |         { | ||||||
|             return Ok(id.clone()); |             return Ok(Arc::clone(id)); | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         let bytes = self |         let bytes = self | ||||||
|  | @ -854,15 +858,17 @@ impl Rooms { | ||||||
|             .get(&shorteventid.to_be_bytes())? |             .get(&shorteventid.to_be_bytes())? | ||||||
|             .ok_or_else(|| Error::bad_database("Shorteventid does not exist"))?; |             .ok_or_else(|| Error::bad_database("Shorteventid does not exist"))?; | ||||||
| 
 | 
 | ||||||
|         let event_id = EventId::try_from(utils::string_from_bytes(&bytes).map_err(|_| { |         let event_id = Arc::new( | ||||||
|  |             EventId::try_from(utils::string_from_bytes(&bytes).map_err(|_| { | ||||||
|                 Error::bad_database("EventID in shorteventid_eventid is invalid unicode.") |                 Error::bad_database("EventID in shorteventid_eventid is invalid unicode.") | ||||||
|             })?) |             })?) | ||||||
|         .map_err(|_| Error::bad_database("EventId in shorteventid_eventid is invalid."))?; |             .map_err(|_| Error::bad_database("EventId in shorteventid_eventid is invalid."))?, | ||||||
|  |         ); | ||||||
| 
 | 
 | ||||||
|         self.shorteventid_cache |         self.shorteventid_cache | ||||||
|             .lock() |             .lock() | ||||||
|             .unwrap() |             .unwrap() | ||||||
|             .insert(shorteventid, event_id.clone()); |             .insert(shorteventid, Arc::clone(&event_id)); | ||||||
| 
 | 
 | ||||||
|         Ok(event_id) |         Ok(event_id) | ||||||
|     } |     } | ||||||
|  | @ -929,7 +935,7 @@ impl Rooms { | ||||||
|         room_id: &RoomId, |         room_id: &RoomId, | ||||||
|         event_type: &EventType, |         event_type: &EventType, | ||||||
|         state_key: &str, |         state_key: &str, | ||||||
|     ) -> Result<Option<EventId>> { |     ) -> Result<Option<Arc<EventId>>> { | ||||||
|         if let Some(current_shortstatehash) = self.current_shortstatehash(room_id)? { |         if let Some(current_shortstatehash) = self.current_shortstatehash(room_id)? { | ||||||
|             self.state_get_id(current_shortstatehash, event_type, state_key) |             self.state_get_id(current_shortstatehash, event_type, state_key) | ||||||
|         } else { |         } else { | ||||||
|  | @ -1226,9 +1232,19 @@ impl Rooms { | ||||||
|         self.eventid_outlierpdu.insert( |         self.eventid_outlierpdu.insert( | ||||||
|             &event_id.as_bytes(), |             &event_id.as_bytes(), | ||||||
|             &serde_json::to_vec(&pdu).expect("CanonicalJsonObject is valid"), |             &serde_json::to_vec(&pdu).expect("CanonicalJsonObject is valid"), | ||||||
|         )?; |         ) | ||||||
|  |     } | ||||||
| 
 | 
 | ||||||
|         Ok(()) |     #[tracing::instrument(skip(self))] | ||||||
|  |     pub fn mark_event_soft_failed(&self, event_id: &EventId) -> Result<()> { | ||||||
|  |         self.softfailedeventids.insert(&event_id.as_bytes(), &[]) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     #[tracing::instrument(skip(self))] | ||||||
|  |     pub fn is_event_soft_failed(&self, event_id: &EventId) -> Result<bool> { | ||||||
|  |         self.softfailedeventids | ||||||
|  |             .get(&event_id.as_bytes()) | ||||||
|  |             .map(|o| o.is_some()) | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     /// Creates a new persisted data unit and adds it to a room.
 |     /// Creates a new persisted data unit and adds it to a room.
 | ||||||
|  | @ -1331,15 +1347,9 @@ impl Rooms { | ||||||
|         let mut notifies = Vec::new(); |         let mut notifies = Vec::new(); | ||||||
|         let mut highlights = Vec::new(); |         let mut highlights = Vec::new(); | ||||||
| 
 | 
 | ||||||
|         for user in db |         for user in self.get_our_real_users(&pdu.room_id, db)?.iter() { | ||||||
|             .rooms |  | ||||||
|             .room_members(&pdu.room_id) |  | ||||||
|             .filter_map(|r| r.ok()) |  | ||||||
|             .filter(|user_id| user_id.server_name() == db.globals.server_name()) |  | ||||||
|             .filter(|user_id| !db.users.is_deactivated(user_id).unwrap_or(true)) |  | ||||||
|         { |  | ||||||
|             // Don't notify the user of their own events
 |             // Don't notify the user of their own events
 | ||||||
|             if user == pdu.sender { |             if user == &pdu.sender { | ||||||
|                 continue; |                 continue; | ||||||
|             } |             } | ||||||
| 
 | 
 | ||||||
|  | @ -1515,6 +1525,85 @@ impl Rooms { | ||||||
|                                 "list_appservices" => { |                                 "list_appservices" => { | ||||||
|                                     db.admin.send(AdminCommand::ListAppservices); |                                     db.admin.send(AdminCommand::ListAppservices); | ||||||
|                                 } |                                 } | ||||||
|  |                                 "get_auth_chain" => { | ||||||
|  |                                     if args.len() == 1 { | ||||||
|  |                                         if let Ok(event_id) = EventId::try_from(args[0]) { | ||||||
|  |                                             let start = Instant::now(); | ||||||
|  |                                             let count = server_server::get_auth_chain( | ||||||
|  |                                                 vec![Arc::new(event_id)], | ||||||
|  |                                                 db, | ||||||
|  |                                             )? | ||||||
|  |                                             .count(); | ||||||
|  |                                             let elapsed = start.elapsed(); | ||||||
|  |                                             db.admin.send(AdminCommand::SendMessage( | ||||||
|  |                                                 message::MessageEventContent::text_plain(format!( | ||||||
|  |                                                     "Loaded auth chain with length {} in {:?}", | ||||||
|  |                                                     count, elapsed | ||||||
|  |                                                 )), | ||||||
|  |                                             )); | ||||||
|  |                                         } | ||||||
|  |                                     } | ||||||
|  |                                 } | ||||||
|  |                                 "parse_pdu" => { | ||||||
|  |                                     if body.len() > 2 | ||||||
|  |                                         && body[0].trim() == "```" | ||||||
|  |                                         && body.last().unwrap().trim() == "```" | ||||||
|  |                                     { | ||||||
|  |                                         let string = body[1..body.len() - 1].join("\n"); | ||||||
|  |                                         match serde_json::from_str(&string) { | ||||||
|  |                                             Ok(value) => { | ||||||
|  |                                                 let event_id = EventId::try_from(&*format!( | ||||||
|  |                                                     "${}", | ||||||
|  |                                                     // Anything higher than version3 behaves the same
 | ||||||
|  |                                                     ruma::signatures::reference_hash( | ||||||
|  |                                                         &value, | ||||||
|  |                                                         &RoomVersionId::Version6 | ||||||
|  |                                                     ) | ||||||
|  |                                                     .expect("ruma can calculate reference hashes") | ||||||
|  |                                                 )) | ||||||
|  |                                                 .expect( | ||||||
|  |                                                     "ruma's reference hashes are valid event ids", | ||||||
|  |                                                 ); | ||||||
|  | 
 | ||||||
|  |                                                 match serde_json::from_value::<PduEvent>( | ||||||
|  |                                                     serde_json::to_value(value) | ||||||
|  |                                                         .expect("value is json"), | ||||||
|  |                                                 ) { | ||||||
|  |                                                     Ok(pdu) => { | ||||||
|  |                                                         db.admin.send(AdminCommand::SendMessage( | ||||||
|  |                                                             message::MessageEventContent::text_plain( | ||||||
|  |                                                                 format!("EventId: {:?}\n{:#?}", event_id, pdu), | ||||||
|  |                                                             ), | ||||||
|  |                                                         )); | ||||||
|  |                                                     } | ||||||
|  |                                                     Err(e) => { | ||||||
|  |                                                         db.admin.send(AdminCommand::SendMessage( | ||||||
|  |                                                             message::MessageEventContent::text_plain( | ||||||
|  |                                                                 format!("EventId: {:?}\nCould not parse event: {}", event_id, e), | ||||||
|  |                                                             ), | ||||||
|  |                                                         )); | ||||||
|  |                                                     } | ||||||
|  |                                                 } | ||||||
|  |                                             } | ||||||
|  |                                             Err(e) => { | ||||||
|  |                                                 db.admin.send(AdminCommand::SendMessage( | ||||||
|  |                                                     message::MessageEventContent::text_plain( | ||||||
|  |                                                         format!( | ||||||
|  |                                                             "Invalid json in command body: {}", | ||||||
|  |                                                             e | ||||||
|  |                                                         ), | ||||||
|  |                                                     ), | ||||||
|  |                                                 )); | ||||||
|  |                                             } | ||||||
|  |                                         } | ||||||
|  |                                     } else { | ||||||
|  |                                         db.admin.send(AdminCommand::SendMessage( | ||||||
|  |                                             message::MessageEventContent::text_plain( | ||||||
|  |                                                 "Expected code block in command body.", | ||||||
|  |                                             ), | ||||||
|  |                                         )); | ||||||
|  |                                     } | ||||||
|  |                                 } | ||||||
|                                 "get_pdu" => { |                                 "get_pdu" => { | ||||||
|                                     if args.len() == 1 { |                                     if args.len() == 1 { | ||||||
|                                         if let Ok(event_id) = EventId::try_from(args[0]) { |                                         if let Ok(event_id) = EventId::try_from(args[0]) { | ||||||
|  | @ -2421,29 +2510,45 @@ impl Rooms { | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         if update_joined_count { |         if update_joined_count { | ||||||
|             self.update_joined_count(room_id)?; |             self.update_joined_count(room_id, db)?; | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         Ok(()) |         Ok(()) | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     #[tracing::instrument(skip(self))] |     #[tracing::instrument(skip(self, room_id, db))] | ||||||
|     pub fn update_joined_count(&self, room_id: &RoomId) -> Result<()> { |     pub fn update_joined_count(&self, room_id: &RoomId, db: &Database) -> Result<()> { | ||||||
|         let mut joinedcount = 0_u64; |         let mut joinedcount = 0_u64; | ||||||
|  |         let mut invitedcount = 0_u64; | ||||||
|         let mut joined_servers = HashSet::new(); |         let mut joined_servers = HashSet::new(); | ||||||
|  |         let mut real_users = HashSet::new(); | ||||||
| 
 | 
 | ||||||
|         for joined in self.room_members(&room_id).filter_map(|r| r.ok()) { |         for joined in self.room_members(&room_id).filter_map(|r| r.ok()) { | ||||||
|             joined_servers.insert(joined.server_name().to_owned()); |             joined_servers.insert(joined.server_name().to_owned()); | ||||||
|  |             if joined.server_name() == db.globals.server_name() | ||||||
|  |                 && !db.users.is_deactivated(&joined).unwrap_or(true) | ||||||
|  |             { | ||||||
|  |                 real_users.insert(joined); | ||||||
|  |             } | ||||||
|             joinedcount += 1; |             joinedcount += 1; | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         for invited in self.room_members_invited(&room_id).filter_map(|r| r.ok()) { |         for invited in self.room_members_invited(&room_id).filter_map(|r| r.ok()) { | ||||||
|             joined_servers.insert(invited.server_name().to_owned()); |             joined_servers.insert(invited.server_name().to_owned()); | ||||||
|  |             invitedcount += 1; | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         self.roomid_joinedcount |         self.roomid_joinedcount | ||||||
|             .insert(room_id.as_bytes(), &joinedcount.to_be_bytes())?; |             .insert(room_id.as_bytes(), &joinedcount.to_be_bytes())?; | ||||||
| 
 | 
 | ||||||
|  |         self.roomid_invitedcount | ||||||
|  |             .insert(room_id.as_bytes(), &invitedcount.to_be_bytes())?; | ||||||
|  | 
 | ||||||
|  |         self.our_real_users_cache | ||||||
|  |             .write() | ||||||
|  |             .unwrap() | ||||||
|  |             .insert(room_id.clone(), Arc::new(real_users)); | ||||||
|  | 
 | ||||||
|         for old_joined_server in self.room_servers(room_id).filter_map(|r| r.ok()) { |         for old_joined_server in self.room_servers(room_id).filter_map(|r| r.ok()) { | ||||||
|             if !joined_servers.remove(&old_joined_server) { |             if !joined_servers.remove(&old_joined_server) { | ||||||
|                 // Server not in room anymore
 |                 // Server not in room anymore
 | ||||||
|  | @ -2477,6 +2582,32 @@ impl Rooms { | ||||||
|         Ok(()) |         Ok(()) | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |     #[tracing::instrument(skip(self, room_id, db))] | ||||||
|  |     pub fn get_our_real_users( | ||||||
|  |         &self, | ||||||
|  |         room_id: &RoomId, | ||||||
|  |         db: &Database, | ||||||
|  |     ) -> Result<Arc<HashSet<UserId>>> { | ||||||
|  |         let maybe = self | ||||||
|  |             .our_real_users_cache | ||||||
|  |             .read() | ||||||
|  |             .unwrap() | ||||||
|  |             .get(room_id) | ||||||
|  |             .cloned(); | ||||||
|  |         if let Some(users) = maybe { | ||||||
|  |             Ok(users) | ||||||
|  |         } else { | ||||||
|  |             self.update_joined_count(room_id, &db)?; | ||||||
|  |             Ok(Arc::clone( | ||||||
|  |                 self.our_real_users_cache | ||||||
|  |                     .read() | ||||||
|  |                     .unwrap() | ||||||
|  |                     .get(room_id) | ||||||
|  |                     .unwrap(), | ||||||
|  |             )) | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     #[tracing::instrument(skip(self, db))] |     #[tracing::instrument(skip(self, db))] | ||||||
|     pub async fn leave_room( |     pub async fn leave_room( | ||||||
|         &self, |         &self, | ||||||
|  | @ -2955,6 +3086,18 @@ impl Rooms { | ||||||
|             .transpose()?) |             .transpose()?) | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |     #[tracing::instrument(skip(self))] | ||||||
|  |     pub fn room_invited_count(&self, room_id: &RoomId) -> Result<Option<u64>> { | ||||||
|  |         Ok(self | ||||||
|  |             .roomid_invitedcount | ||||||
|  |             .get(room_id.as_bytes())? | ||||||
|  |             .map(|b| { | ||||||
|  |                 utils::u64_from_bytes(&b) | ||||||
|  |                     .map_err(|_| Error::bad_database("Invalid joinedcount in db.")) | ||||||
|  |             }) | ||||||
|  |             .transpose()?) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     /// Returns an iterator over all User IDs who ever joined a room.
 |     /// Returns an iterator over all User IDs who ever joined a room.
 | ||||||
|     #[tracing::instrument(skip(self))] |     #[tracing::instrument(skip(self))] | ||||||
|     pub fn room_useroncejoined<'a>( |     pub fn room_useroncejoined<'a>( | ||||||
|  |  | ||||||
|  | @ -4,7 +4,7 @@ use crate::{ | ||||||
|     utils, ConduitResult, Database, Error, PduEvent, Result, Ruma, |     utils, ConduitResult, Database, Error, PduEvent, Result, Ruma, | ||||||
| }; | }; | ||||||
| use get_profile_information::v1::ProfileField; | use get_profile_information::v1::ProfileField; | ||||||
| use http::header::{HeaderValue, AUTHORIZATION, HOST}; | use http::header::{HeaderValue, AUTHORIZATION}; | ||||||
| use regex::Regex; | use regex::Regex; | ||||||
| use rocket::response::content::Json; | use rocket::response::content::Json; | ||||||
| use ruma::{ | use ruma::{ | ||||||
|  | @ -83,7 +83,7 @@ use rocket::{get, post, put}; | ||||||
| /// FedDest::Named("198.51.100.5".to_owned(), "".to_owned());
 | /// FedDest::Named("198.51.100.5".to_owned(), "".to_owned());
 | ||||||
| /// ```
 | /// ```
 | ||||||
| #[derive(Clone, Debug, PartialEq)] | #[derive(Clone, Debug, PartialEq)] | ||||||
| enum FedDest { | pub enum FedDest { | ||||||
|     Literal(SocketAddr), |     Literal(SocketAddr), | ||||||
|     Named(String, String), |     Named(String, String), | ||||||
| } | } | ||||||
|  | @ -109,6 +109,13 @@ impl FedDest { | ||||||
|             Self::Named(host, _) => host.clone(), |             Self::Named(host, _) => host.clone(), | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  | 
 | ||||||
|  |     fn port(&self) -> Option<u16> { | ||||||
|  |         match &self { | ||||||
|  |             Self::Literal(addr) => Some(addr.port()), | ||||||
|  |             Self::Named(_, port) => port[1..].parse().ok(), | ||||||
|  |         } | ||||||
|  |     } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| #[tracing::instrument(skip(globals, request))] | #[tracing::instrument(skip(globals, request))] | ||||||
|  | @ -124,41 +131,34 @@ where | ||||||
|         return Err(Error::bad_config("Federation is disabled.")); |         return Err(Error::bad_config("Federation is disabled.")); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     let maybe_result = globals |     let mut write_destination_to_cache = false; | ||||||
|  | 
 | ||||||
|  |     let cached_result = globals | ||||||
|         .actual_destination_cache |         .actual_destination_cache | ||||||
|         .read() |         .read() | ||||||
|         .unwrap() |         .unwrap() | ||||||
|         .get(destination) |         .get(destination) | ||||||
|         .cloned(); |         .cloned(); | ||||||
| 
 | 
 | ||||||
|     let (actual_destination, host) = if let Some(result) = maybe_result { |     let (actual_destination, host) = if let Some(result) = cached_result { | ||||||
|         result |         result | ||||||
|     } else { |     } else { | ||||||
|  |         write_destination_to_cache = true; | ||||||
|  | 
 | ||||||
|         let result = find_actual_destination(globals, &destination).await; |         let result = find_actual_destination(globals, &destination).await; | ||||||
|         let (actual_destination, host) = result.clone(); | 
 | ||||||
|         let result_string = (result.0.into_https_string(), result.1.into_uri_string()); |         (result.0, result.1.clone().into_uri_string()) | ||||||
|         globals |  | ||||||
|             .actual_destination_cache |  | ||||||
|             .write() |  | ||||||
|             .unwrap() |  | ||||||
|             .insert(Box::<ServerName>::from(destination), result_string.clone()); |  | ||||||
|         let dest_hostname = actual_destination.hostname(); |  | ||||||
|         let host_hostname = host.hostname(); |  | ||||||
|         if dest_hostname != host_hostname { |  | ||||||
|             globals.tls_name_override.write().unwrap().insert( |  | ||||||
|                 dest_hostname, |  | ||||||
|                 webpki::DNSNameRef::try_from_ascii_str(&host_hostname) |  | ||||||
|                     .unwrap() |  | ||||||
|                     .to_owned(), |  | ||||||
|             ); |  | ||||||
|         } |  | ||||||
|         result_string |  | ||||||
|     }; |     }; | ||||||
| 
 | 
 | ||||||
|  |     let actual_destination_str = actual_destination.clone().into_https_string(); | ||||||
|  | 
 | ||||||
|     let mut http_request = request |     let mut http_request = request | ||||||
|         .try_into_http_request::<Vec<u8>>(&actual_destination, SendAccessToken::IfRequired("")) |         .try_into_http_request::<Vec<u8>>(&actual_destination_str, SendAccessToken::IfRequired("")) | ||||||
|         .map_err(|e| { |         .map_err(|e| { | ||||||
|             warn!("Failed to find destination {}: {}", actual_destination, e); |             warn!( | ||||||
|  |                 "Failed to find destination {}: {}", | ||||||
|  |                 actual_destination_str, e | ||||||
|  |             ); | ||||||
|             Error::BadServerResponse("Invalid destination") |             Error::BadServerResponse("Invalid destination") | ||||||
|         })?; |         })?; | ||||||
| 
 | 
 | ||||||
|  | @ -224,15 +224,26 @@ where | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     http_request |  | ||||||
|         .headers_mut() |  | ||||||
|         .insert(HOST, HeaderValue::from_str(&host).unwrap()); |  | ||||||
| 
 |  | ||||||
|     let reqwest_request = reqwest::Request::try_from(http_request) |     let reqwest_request = reqwest::Request::try_from(http_request) | ||||||
|         .expect("all http requests are valid reqwest requests"); |         .expect("all http requests are valid reqwest requests"); | ||||||
| 
 | 
 | ||||||
|     let url = reqwest_request.url().clone(); |     let url = reqwest_request.url().clone(); | ||||||
|     let response = globals.reqwest_client().execute(reqwest_request).await; | 
 | ||||||
|  |     let mut client = globals.reqwest_client()?; | ||||||
|  |     if let Some((override_name, port)) = globals | ||||||
|  |         .tls_name_override | ||||||
|  |         .read() | ||||||
|  |         .unwrap() | ||||||
|  |         .get(&actual_destination.hostname()) | ||||||
|  |     { | ||||||
|  |         client = client.resolve( | ||||||
|  |             &actual_destination.hostname(), | ||||||
|  |             SocketAddr::new(override_name[0], *port), | ||||||
|  |         ); | ||||||
|  |         // port will be ignored
 | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     let response = client.build()?.execute(reqwest_request).await; | ||||||
| 
 | 
 | ||||||
|     match response { |     match response { | ||||||
|         Ok(mut response) => { |         Ok(mut response) => { | ||||||
|  | @ -271,6 +282,13 @@ where | ||||||
| 
 | 
 | ||||||
|             if status == 200 { |             if status == 200 { | ||||||
|                 let response = T::IncomingResponse::try_from_http_response(http_response); |                 let response = T::IncomingResponse::try_from_http_response(http_response); | ||||||
|  |                 if response.is_ok() && write_destination_to_cache { | ||||||
|  |                     globals.actual_destination_cache.write().unwrap().insert( | ||||||
|  |                         Box::<ServerName>::from(destination), | ||||||
|  |                         (actual_destination, host), | ||||||
|  |                     ); | ||||||
|  |                 } | ||||||
|  | 
 | ||||||
|                 response.map_err(|e| { |                 response.map_err(|e| { | ||||||
|                     warn!( |                     warn!( | ||||||
|                         "Invalid 200 response from {} on: {} {}", |                         "Invalid 200 response from {} on: {} {}", | ||||||
|  | @ -339,20 +357,49 @@ async fn find_actual_destination( | ||||||
|                 match request_well_known(globals, &destination.as_str()).await { |                 match request_well_known(globals, &destination.as_str()).await { | ||||||
|                     // 3: A .well-known file is available
 |                     // 3: A .well-known file is available
 | ||||||
|                     Some(delegated_hostname) => { |                     Some(delegated_hostname) => { | ||||||
|                         hostname = delegated_hostname.clone(); |                         hostname = add_port_to_hostname(&delegated_hostname).into_uri_string(); | ||||||
|                         match get_ip_with_port(&delegated_hostname) { |                         match get_ip_with_port(&delegated_hostname) { | ||||||
|                             Some(host_and_port) => host_and_port, // 3.1: IP literal in .well-known file
 |                             Some(host_and_port) => host_and_port, // 3.1: IP literal in .well-known file
 | ||||||
|                             None => { |                             None => { | ||||||
|                                 if let Some(pos) = destination_str.find(':') { |                                 if let Some(pos) = delegated_hostname.find(':') { | ||||||
|                                     // 3.2: Hostname with port in .well-known file
 |                                     // 3.2: Hostname with port in .well-known file
 | ||||||
|                                     let (host, port) = destination_str.split_at(pos); |                                     let (host, port) = delegated_hostname.split_at(pos); | ||||||
|                                     FedDest::Named(host.to_string(), port.to_string()) |                                     FedDest::Named(host.to_string(), port.to_string()) | ||||||
|                                 } else { |                                 } else { | ||||||
|                                     match query_srv_record(globals, &delegated_hostname).await { |                                     // Delegated hostname has no port in this branch
 | ||||||
|  |                                     if let Some(hostname_override) = | ||||||
|  |                                         query_srv_record(globals, &delegated_hostname).await | ||||||
|  |                                     { | ||||||
|                                         // 3.3: SRV lookup successful
 |                                         // 3.3: SRV lookup successful
 | ||||||
|                                         Some(hostname) => hostname, |                                         let force_port = hostname_override.port(); | ||||||
|  | 
 | ||||||
|  |                                         if let Ok(override_ip) = globals | ||||||
|  |                                             .dns_resolver() | ||||||
|  |                                             .lookup_ip(hostname_override.hostname()) | ||||||
|  |                                             .await | ||||||
|  |                                         { | ||||||
|  |                                             globals.tls_name_override.write().unwrap().insert( | ||||||
|  |                                                 delegated_hostname.clone(), | ||||||
|  |                                                 ( | ||||||
|  |                                                     override_ip.iter().collect(), | ||||||
|  |                                                     force_port.unwrap_or(8448), | ||||||
|  |                                                 ), | ||||||
|  |                                             ); | ||||||
|  |                                         } else { | ||||||
|  |                                             warn!("Using SRV record, but could not resolve to IP"); | ||||||
|  |                                         } | ||||||
|  | 
 | ||||||
|  |                                         if let Some(port) = force_port { | ||||||
|  |                                             FedDest::Named( | ||||||
|  |                                                 delegated_hostname, | ||||||
|  |                                                 format!(":{}", port.to_string()), | ||||||
|  |                                             ) | ||||||
|  |                                         } else { | ||||||
|  |                                             add_port_to_hostname(&delegated_hostname) | ||||||
|  |                                         } | ||||||
|  |                                     } else { | ||||||
|                                         // 3.4: No SRV records, just use the hostname from .well-known
 |                                         // 3.4: No SRV records, just use the hostname from .well-known
 | ||||||
|                                         None => add_port_to_hostname(&delegated_hostname), |                                         add_port_to_hostname(&delegated_hostname) | ||||||
|                                     } |                                     } | ||||||
|                                 } |                                 } | ||||||
|                             } |                             } | ||||||
|  | @ -362,7 +409,31 @@ async fn find_actual_destination( | ||||||
|                     None => { |                     None => { | ||||||
|                         match query_srv_record(globals, &destination_str).await { |                         match query_srv_record(globals, &destination_str).await { | ||||||
|                             // 4: SRV record found
 |                             // 4: SRV record found
 | ||||||
|                             Some(hostname) => hostname, |                             Some(hostname_override) => { | ||||||
|  |                                 let force_port = hostname_override.port(); | ||||||
|  | 
 | ||||||
|  |                                 if let Ok(override_ip) = globals | ||||||
|  |                                     .dns_resolver() | ||||||
|  |                                     .lookup_ip(hostname_override.hostname()) | ||||||
|  |                                     .await | ||||||
|  |                                 { | ||||||
|  |                                     globals.tls_name_override.write().unwrap().insert( | ||||||
|  |                                         hostname.clone(), | ||||||
|  |                                         (override_ip.iter().collect(), force_port.unwrap_or(8448)), | ||||||
|  |                                     ); | ||||||
|  |                                 } else { | ||||||
|  |                                     warn!("Using SRV record, but could not resolve to IP"); | ||||||
|  |                                 } | ||||||
|  | 
 | ||||||
|  |                                 if let Some(port) = force_port { | ||||||
|  |                                     FedDest::Named( | ||||||
|  |                                         hostname.clone(), | ||||||
|  |                                         format!(":{}", port.to_string()), | ||||||
|  |                                     ) | ||||||
|  |                                 } else { | ||||||
|  |                                     add_port_to_hostname(&hostname) | ||||||
|  |                                 } | ||||||
|  |                             } | ||||||
|                             // 5: No SRV record found
 |                             // 5: No SRV record found
 | ||||||
|                             None => add_port_to_hostname(&destination_str), |                             None => add_port_to_hostname(&destination_str), | ||||||
|                         } |                         } | ||||||
|  | @ -377,12 +448,12 @@ async fn find_actual_destination( | ||||||
|     let hostname = if let Ok(addr) = hostname.parse::<SocketAddr>() { |     let hostname = if let Ok(addr) = hostname.parse::<SocketAddr>() { | ||||||
|         FedDest::Literal(addr) |         FedDest::Literal(addr) | ||||||
|     } else if let Ok(addr) = hostname.parse::<IpAddr>() { |     } else if let Ok(addr) = hostname.parse::<IpAddr>() { | ||||||
|         FedDest::Named(addr.to_string(), "".to_string()) |         FedDest::Named(addr.to_string(), ":8448".to_string()) | ||||||
|     } else if let Some(pos) = hostname.find(':') { |     } else if let Some(pos) = hostname.find(':') { | ||||||
|         let (host, port) = hostname.split_at(pos); |         let (host, port) = hostname.split_at(pos); | ||||||
|         FedDest::Named(host.to_string(), port.to_string()) |         FedDest::Named(host.to_string(), port.to_string()) | ||||||
|     } else { |     } else { | ||||||
|         FedDest::Named(hostname, "".to_string()) |         FedDest::Named(hostname, ":8448".to_string()) | ||||||
|     }; |     }; | ||||||
|     (actual_destination, hostname) |     (actual_destination, hostname) | ||||||
| } | } | ||||||
|  | @ -423,6 +494,9 @@ pub async fn request_well_known( | ||||||
|     let body: serde_json::Value = serde_json::from_str( |     let body: serde_json::Value = serde_json::from_str( | ||||||
|         &globals |         &globals | ||||||
|             .reqwest_client() |             .reqwest_client() | ||||||
|  |             .ok()? | ||||||
|  |             .build() | ||||||
|  |             .ok()? | ||||||
|             .get(&format!( |             .get(&format!( | ||||||
|                 "https://{}/.well-known/matrix/server", |                 "https://{}/.well-known/matrix/server", | ||||||
|                 destination |                 destination | ||||||
|  | @ -893,7 +967,12 @@ pub async fn handle_incoming_pdu<'a>( | ||||||
|     // 9. Fetch any missing prev events doing all checks listed here starting at 1. These are timeline events
 |     // 9. Fetch any missing prev events doing all checks listed here starting at 1. These are timeline events
 | ||||||
|     let mut graph = HashMap::new(); |     let mut graph = HashMap::new(); | ||||||
|     let mut eventid_info = HashMap::new(); |     let mut eventid_info = HashMap::new(); | ||||||
|     let mut todo_outlier_stack = incoming_pdu.prev_events.clone(); |     let mut todo_outlier_stack = incoming_pdu | ||||||
|  |         .prev_events | ||||||
|  |         .iter() | ||||||
|  |         .cloned() | ||||||
|  |         .map(Arc::new) | ||||||
|  |         .collect::<Vec<_>>(); | ||||||
| 
 | 
 | ||||||
|     let mut amount = 0; |     let mut amount = 0; | ||||||
| 
 | 
 | ||||||
|  | @ -929,13 +1008,13 @@ pub async fn handle_incoming_pdu<'a>( | ||||||
|                     amount += 1; |                     amount += 1; | ||||||
|                     for prev_prev in &pdu.prev_events { |                     for prev_prev in &pdu.prev_events { | ||||||
|                         if !graph.contains_key(prev_prev) { |                         if !graph.contains_key(prev_prev) { | ||||||
|                             todo_outlier_stack.push(dbg!(prev_prev.clone())); |                             todo_outlier_stack.push(dbg!(Arc::new(prev_prev.clone()))); | ||||||
|                         } |                         } | ||||||
|                     } |                     } | ||||||
| 
 | 
 | ||||||
|                     graph.insert( |                     graph.insert( | ||||||
|                         prev_event_id.clone(), |                         prev_event_id.clone(), | ||||||
|                         pdu.prev_events.iter().cloned().collect(), |                         pdu.prev_events.iter().cloned().map(Arc::new).collect(), | ||||||
|                     ); |                     ); | ||||||
|                     eventid_info.insert(prev_event_id.clone(), (pdu, json)); |                     eventid_info.insert(prev_event_id.clone(), (pdu, json)); | ||||||
|                 } else { |                 } else { | ||||||
|  | @ -964,9 +1043,9 @@ pub async fn handle_incoming_pdu<'a>( | ||||||
|                 MilliSecondsSinceUnixEpoch( |                 MilliSecondsSinceUnixEpoch( | ||||||
|                     eventid_info |                     eventid_info | ||||||
|                         .get(event_id) |                         .get(event_id) | ||||||
|                         .map_or_else(|| uint!(0), |info| info.0.origin_server_ts.clone()), |                         .map_or_else(|| uint!(0), |info| info.0.origin_server_ts), | ||||||
|                 ), |                 ), | ||||||
|                 ruma::event_id!("$notimportant"), |                 Arc::new(ruma::event_id!("$notimportant")), | ||||||
|             )) |             )) | ||||||
|         }) |         }) | ||||||
|         .map_err(|_| "Error sorting prev events".to_owned())?; |         .map_err(|_| "Error sorting prev events".to_owned())?; | ||||||
|  | @ -1084,7 +1163,12 @@ fn handle_outlier_pdu<'a>( | ||||||
|         fetch_and_handle_outliers( |         fetch_and_handle_outliers( | ||||||
|             db, |             db, | ||||||
|             origin, |             origin, | ||||||
|             &incoming_pdu.auth_events, |             &incoming_pdu | ||||||
|  |                 .auth_events | ||||||
|  |                 .iter() | ||||||
|  |                 .cloned() | ||||||
|  |                 .map(Arc::new) | ||||||
|  |                 .collect::<Vec<_>>(), | ||||||
|             &create_event, |             &create_event, | ||||||
|             &room_id, |             &room_id, | ||||||
|             pub_key_map, |             pub_key_map, | ||||||
|  | @ -1100,13 +1184,13 @@ fn handle_outlier_pdu<'a>( | ||||||
|         // Build map of auth events
 |         // Build map of auth events
 | ||||||
|         let mut auth_events = HashMap::new(); |         let mut auth_events = HashMap::new(); | ||||||
|         for id in &incoming_pdu.auth_events { |         for id in &incoming_pdu.auth_events { | ||||||
|             let auth_event = db |             let auth_event = match db.rooms.get_pdu(id).map_err(|e| e.to_string())? { | ||||||
|                 .rooms |                 Some(e) => e, | ||||||
|                 .get_pdu(id) |                 None => { | ||||||
|                 .map_err(|e| e.to_string())? |                     warn!("Could not find auth event {}", id); | ||||||
|                 .ok_or_else(|| { |                     continue; | ||||||
|                     "Auth event not found, event failed recursive auth checks.".to_string() |                 } | ||||||
|                 })?; |             }; | ||||||
| 
 | 
 | ||||||
|             match auth_events.entry(( |             match auth_events.entry(( | ||||||
|                 auth_event.kind.clone(), |                 auth_event.kind.clone(), | ||||||
|  | @ -1153,7 +1237,7 @@ fn handle_outlier_pdu<'a>( | ||||||
|         if !state_res::event_auth::auth_check( |         if !state_res::event_auth::auth_check( | ||||||
|             &room_version, |             &room_version, | ||||||
|             &incoming_pdu, |             &incoming_pdu, | ||||||
|             previous_create.clone(), |             previous_create, | ||||||
|             None, // TODO: third party invite
 |             None, // TODO: third party invite
 | ||||||
|             |k, s| auth_events.get(&(k.clone(), s.to_owned())).map(Arc::clone), |             |k, s| auth_events.get(&(k.clone(), s.to_owned())).map(Arc::clone), | ||||||
|         ) |         ) | ||||||
|  | @ -1187,6 +1271,15 @@ async fn upgrade_outlier_to_timeline_pdu( | ||||||
|     if let Ok(Some(pduid)) = db.rooms.get_pdu_id(&incoming_pdu.event_id) { |     if let Ok(Some(pduid)) = db.rooms.get_pdu_id(&incoming_pdu.event_id) { | ||||||
|         return Ok(Some(pduid)); |         return Ok(Some(pduid)); | ||||||
|     } |     } | ||||||
|  | 
 | ||||||
|  |     if db | ||||||
|  |         .rooms | ||||||
|  |         .is_event_soft_failed(&incoming_pdu.event_id) | ||||||
|  |         .map_err(|_| "Failed to ask db for soft fail".to_owned())? | ||||||
|  |     { | ||||||
|  |         return Err("Event has been soft failed".into()); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     // 10. Fetch missing state and auth chain events by calling /state_ids at backwards extremities
 |     // 10. Fetch missing state and auth chain events by calling /state_ids at backwards extremities
 | ||||||
|     //     doing all the checks in this list starting at 1. These are not timeline events.
 |     //     doing all the checks in this list starting at 1. These are not timeline events.
 | ||||||
| 
 | 
 | ||||||
|  | @ -1219,7 +1312,7 @@ async fn upgrade_outlier_to_timeline_pdu( | ||||||
|                     .get_or_create_shortstatekey(&prev_pdu.kind, state_key, &db.globals) |                     .get_or_create_shortstatekey(&prev_pdu.kind, state_key, &db.globals) | ||||||
|                     .map_err(|_| "Failed to create shortstatekey.".to_owned())?; |                     .map_err(|_| "Failed to create shortstatekey.".to_owned())?; | ||||||
| 
 | 
 | ||||||
|                 state.insert(shortstatekey, prev_event.clone()); |                 state.insert(shortstatekey, Arc::new(prev_event.clone())); | ||||||
|                 // Now it's the state after the pdu
 |                 // Now it's the state after the pdu
 | ||||||
|             } |             } | ||||||
| 
 | 
 | ||||||
|  | @ -1249,7 +1342,11 @@ async fn upgrade_outlier_to_timeline_pdu( | ||||||
|                 let state_vec = fetch_and_handle_outliers( |                 let state_vec = fetch_and_handle_outliers( | ||||||
|                     &db, |                     &db, | ||||||
|                     origin, |                     origin, | ||||||
|                     &res.pdu_ids, |                     &res.pdu_ids | ||||||
|  |                         .iter() | ||||||
|  |                         .cloned() | ||||||
|  |                         .map(Arc::new) | ||||||
|  |                         .collect::<Vec<_>>(), | ||||||
|                     &create_event, |                     &create_event, | ||||||
|                     &room_id, |                     &room_id, | ||||||
|                     pub_key_map, |                     pub_key_map, | ||||||
|  | @ -1270,7 +1367,7 @@ async fn upgrade_outlier_to_timeline_pdu( | ||||||
| 
 | 
 | ||||||
|                     match state.entry(shortstatekey) { |                     match state.entry(shortstatekey) { | ||||||
|                         btree_map::Entry::Vacant(v) => { |                         btree_map::Entry::Vacant(v) => { | ||||||
|                             v.insert(pdu.event_id.clone()); |                             v.insert(Arc::new(pdu.event_id.clone())); | ||||||
|                         } |                         } | ||||||
|                         btree_map::Entry::Occupied(_) => return Err( |                         btree_map::Entry::Occupied(_) => return Err( | ||||||
|                             "State event's type and state_key combination exists multiple times." |                             "State event's type and state_key combination exists multiple times." | ||||||
|  | @ -1286,7 +1383,9 @@ async fn upgrade_outlier_to_timeline_pdu( | ||||||
|                     .map_err(|_| "Failed to talk to db.")? |                     .map_err(|_| "Failed to talk to db.")? | ||||||
|                     .expect("Room exists"); |                     .expect("Room exists"); | ||||||
| 
 | 
 | ||||||
|                 if state.get(&create_shortstatekey) != Some(&create_event.event_id) { |                 if state.get(&create_shortstatekey).map(|id| id.as_ref()) | ||||||
|  |                     != Some(&create_event.event_id) | ||||||
|  |                 { | ||||||
|                     return Err("Incoming event refers to wrong create event.".to_owned()); |                     return Err("Incoming event refers to wrong create event.".to_owned()); | ||||||
|                 } |                 } | ||||||
| 
 | 
 | ||||||
|  | @ -1451,7 +1550,7 @@ async fn upgrade_outlier_to_timeline_pdu( | ||||||
|                     .rooms |                     .rooms | ||||||
|                     .get_or_create_shortstatekey(&leaf_pdu.kind, state_key, &db.globals) |                     .get_or_create_shortstatekey(&leaf_pdu.kind, state_key, &db.globals) | ||||||
|                     .map_err(|_| "Failed to create shortstatekey.".to_owned())?; |                     .map_err(|_| "Failed to create shortstatekey.".to_owned())?; | ||||||
|                 leaf_state.insert(shortstatekey, leaf_pdu.event_id.clone()); |                 leaf_state.insert(shortstatekey, Arc::new(leaf_pdu.event_id.clone())); | ||||||
|                 // Now it's the state after the pdu
 |                 // Now it's the state after the pdu
 | ||||||
|             } |             } | ||||||
| 
 | 
 | ||||||
|  | @ -1466,9 +1565,9 @@ async fn upgrade_outlier_to_timeline_pdu( | ||||||
|                 .get_or_create_shortstatekey(&incoming_pdu.kind, state_key, &db.globals) |                 .get_or_create_shortstatekey(&incoming_pdu.kind, state_key, &db.globals) | ||||||
|                 .map_err(|_| "Failed to create shortstatekey.".to_owned())?; |                 .map_err(|_| "Failed to create shortstatekey.".to_owned())?; | ||||||
| 
 | 
 | ||||||
|             state_after.insert(shortstatekey, incoming_pdu.event_id.clone()); |             state_after.insert(shortstatekey, Arc::new(incoming_pdu.event_id.clone())); | ||||||
|         } |         } | ||||||
|         fork_states.push(state_after.clone()); |         fork_states.push(state_after); | ||||||
| 
 | 
 | ||||||
|         let mut update_state = false; |         let mut update_state = false; | ||||||
|         // 14. Use state resolution to find new room state
 |         // 14. Use state resolution to find new room state
 | ||||||
|  | @ -1593,6 +1692,9 @@ async fn upgrade_outlier_to_timeline_pdu( | ||||||
|     if soft_fail { |     if soft_fail { | ||||||
|         // Soft fail, we keep the event as an outlier but don't add it to the timeline
 |         // Soft fail, we keep the event as an outlier but don't add it to the timeline
 | ||||||
|         warn!("Event was soft failed: {:?}", incoming_pdu); |         warn!("Event was soft failed: {:?}", incoming_pdu); | ||||||
|  |         db.rooms | ||||||
|  |             .mark_event_soft_failed(&incoming_pdu.event_id) | ||||||
|  |             .map_err(|_| "Failed to set soft failed flag".to_owned())?; | ||||||
|         return Err("Event has been soft failed".into()); |         return Err("Event has been soft failed".into()); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  | @ -1614,7 +1716,7 @@ async fn upgrade_outlier_to_timeline_pdu( | ||||||
| pub(crate) fn fetch_and_handle_outliers<'a>( | pub(crate) fn fetch_and_handle_outliers<'a>( | ||||||
|     db: &'a Database, |     db: &'a Database, | ||||||
|     origin: &'a ServerName, |     origin: &'a ServerName, | ||||||
|     events: &'a [EventId], |     events: &'a [Arc<EventId>], | ||||||
|     create_event: &'a PduEvent, |     create_event: &'a PduEvent, | ||||||
|     room_id: &'a RoomId, |     room_id: &'a RoomId, | ||||||
|     pub_key_map: &'a RwLock<BTreeMap<String, BTreeMap<String, String>>>, |     pub_key_map: &'a RwLock<BTreeMap<String, BTreeMap<String, String>>>, | ||||||
|  | @ -1665,20 +1767,25 @@ pub(crate) fn fetch_and_handle_outliers<'a>( | ||||||
|                     { |                     { | ||||||
|                         Ok(res) => { |                         Ok(res) => { | ||||||
|                             warn!("Got {} over federation", id); |                             warn!("Got {} over federation", id); | ||||||
|                             let (event_id, value) = |                             let (calculated_event_id, value) = | ||||||
|                                 match crate::pdu::gen_event_id_canonical_json(&res.pdu) { |                                 match crate::pdu::gen_event_id_canonical_json(&res.pdu) { | ||||||
|                                     Ok(t) => t, |                                     Ok(t) => t, | ||||||
|                                     Err(_) => { |                                     Err(_) => { | ||||||
|                                         back_off(id.clone()); |                                         back_off((**id).clone()); | ||||||
|                                         continue; |                                         continue; | ||||||
|                                     } |                                     } | ||||||
|                                 }; |                                 }; | ||||||
| 
 | 
 | ||||||
|  |                             if calculated_event_id != **id { | ||||||
|  |                                 warn!("Server didn't return event id we requested: requested: {}, we got {}. Event: {:?}", | ||||||
|  |                                     id, calculated_event_id, &res.pdu); | ||||||
|  |                             } | ||||||
|  | 
 | ||||||
|                             // This will also fetch the auth chain
 |                             // This will also fetch the auth chain
 | ||||||
|                             match handle_outlier_pdu( |                             match handle_outlier_pdu( | ||||||
|                                 origin, |                                 origin, | ||||||
|                                 create_event, |                                 create_event, | ||||||
|                                 &event_id, |                                 &id, | ||||||
|                                 &room_id, |                                 &room_id, | ||||||
|                                 value.clone(), |                                 value.clone(), | ||||||
|                                 db, |                                 db, | ||||||
|  | @ -1689,14 +1796,14 @@ pub(crate) fn fetch_and_handle_outliers<'a>( | ||||||
|                                 Ok((pdu, json)) => (pdu, Some(json)), |                                 Ok((pdu, json)) => (pdu, Some(json)), | ||||||
|                                 Err(e) => { |                                 Err(e) => { | ||||||
|                                     warn!("Authentication of event {} failed: {:?}", id, e); |                                     warn!("Authentication of event {} failed: {:?}", id, e); | ||||||
|                                     back_off(id.clone()); |                                     back_off((**id).clone()); | ||||||
|                                     continue; |                                     continue; | ||||||
|                                 } |                                 } | ||||||
|                             } |                             } | ||||||
|                         } |                         } | ||||||
|                         Err(_) => { |                         Err(_) => { | ||||||
|                             warn!("Failed to fetch event: {}", id); |                             warn!("Failed to fetch event: {}", id); | ||||||
|                             back_off(id.clone()); |                             back_off((**id).clone()); | ||||||
|                             continue; |                             continue; | ||||||
|                         } |                         } | ||||||
|                     } |                     } | ||||||
|  | @ -1971,24 +2078,18 @@ fn append_incoming_pdu( | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| #[tracing::instrument(skip(starting_events, db))] | #[tracing::instrument(skip(starting_events, db))] | ||||||
| fn get_auth_chain( | pub fn get_auth_chain( | ||||||
|     starting_events: Vec<EventId>, |     starting_events: Vec<Arc<EventId>>, | ||||||
|     db: &Database, |     db: &Database, | ||||||
| ) -> Result<impl Iterator<Item = EventId> + '_> { | ) -> Result<impl Iterator<Item = Arc<EventId>> + '_> { | ||||||
|     const NUM_BUCKETS: usize = 50; |     const NUM_BUCKETS: usize = 50; | ||||||
| 
 | 
 | ||||||
|     let mut buckets = vec![BTreeSet::new(); NUM_BUCKETS]; |     let mut buckets = vec![BTreeSet::new(); NUM_BUCKETS]; | ||||||
| 
 | 
 | ||||||
|     for id in starting_events { |     for id in starting_events { | ||||||
|         if let Some(pdu) = db.rooms.get_pdu(&id)? { |         let short = db.rooms.get_or_create_shorteventid(&id, &db.globals)?; | ||||||
|             for auth_event in &pdu.auth_events { |  | ||||||
|                 let short = db |  | ||||||
|                     .rooms |  | ||||||
|                     .get_or_create_shorteventid(&auth_event, &db.globals)?; |  | ||||||
|         let bucket_id = (short % NUM_BUCKETS as u64) as usize; |         let bucket_id = (short % NUM_BUCKETS as u64) as usize; | ||||||
|                 buckets[bucket_id].insert((short, auth_event.clone())); |         buckets[bucket_id].insert((short, id.clone())); | ||||||
|             } |  | ||||||
|         } |  | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     let mut full_auth_chain = HashSet::new(); |     let mut full_auth_chain = HashSet::new(); | ||||||
|  | @ -2000,10 +2101,6 @@ fn get_auth_chain( | ||||||
|             continue; |             continue; | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         // The code below will only get the auth chains, not the events in the chunk. So let's add
 |  | ||||||
|         // them first
 |  | ||||||
|         full_auth_chain.extend(chunk.iter().map(|(id, _)| id)); |  | ||||||
| 
 |  | ||||||
|         let chunk_key = chunk |         let chunk_key = chunk | ||||||
|             .iter() |             .iter() | ||||||
|             .map(|(short, _)| short) |             .map(|(short, _)| short) | ||||||
|  | @ -2178,12 +2275,12 @@ pub fn get_event_authorization_route( | ||||||
|         return Err(Error::bad_config("Federation is disabled.")); |         return Err(Error::bad_config("Federation is disabled.")); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     let auth_chain_ids = get_auth_chain(vec![body.event_id.clone()], &db)?; |     let auth_chain_ids = get_auth_chain(vec![Arc::new(body.event_id.clone())], &db)?; | ||||||
| 
 | 
 | ||||||
|     Ok(get_event_authorization::v1::Response { |     Ok(get_event_authorization::v1::Response { | ||||||
|         auth_chain: auth_chain_ids |         auth_chain: auth_chain_ids | ||||||
|             .filter_map(|id| Some(db.rooms.get_pdu_json(&id).ok()??)) |             .filter_map(|id| db.rooms.get_pdu_json(&id).ok()?) | ||||||
|             .map(|event| PduEvent::convert_to_outgoing_federation_event(event)) |             .map(PduEvent::convert_to_outgoing_federation_event) | ||||||
|             .collect(), |             .collect(), | ||||||
|     } |     } | ||||||
|     .into()) |     .into()) | ||||||
|  | @ -2221,7 +2318,7 @@ pub fn get_room_state_route( | ||||||
|         }) |         }) | ||||||
|         .collect(); |         .collect(); | ||||||
| 
 | 
 | ||||||
|     let auth_chain_ids = get_auth_chain(vec![body.event_id.clone()], &db)?; |     let auth_chain_ids = get_auth_chain(vec![Arc::new(body.event_id.clone())], &db)?; | ||||||
| 
 | 
 | ||||||
|     Ok(get_room_state::v1::Response { |     Ok(get_room_state::v1::Response { | ||||||
|         auth_chain: auth_chain_ids |         auth_chain: auth_chain_ids | ||||||
|  | @ -2262,13 +2359,13 @@ pub fn get_room_state_ids_route( | ||||||
|         .rooms |         .rooms | ||||||
|         .state_full_ids(shortstatehash)? |         .state_full_ids(shortstatehash)? | ||||||
|         .into_iter() |         .into_iter() | ||||||
|         .map(|(_, id)| id) |         .map(|(_, id)| (*id).clone()) | ||||||
|         .collect(); |         .collect(); | ||||||
| 
 | 
 | ||||||
|     let auth_chain_ids = get_auth_chain(vec![body.event_id.clone()], &db)?; |     let auth_chain_ids = get_auth_chain(vec![Arc::new(body.event_id.clone())], &db)?; | ||||||
| 
 | 
 | ||||||
|     Ok(get_room_state_ids::v1::Response { |     Ok(get_room_state_ids::v1::Response { | ||||||
|         auth_chain_ids: auth_chain_ids.collect(), |         auth_chain_ids: auth_chain_ids.map(|id| (*id).clone()).collect(), | ||||||
|         pdu_ids, |         pdu_ids, | ||||||
|     } |     } | ||||||
|     .into()) |     .into()) | ||||||
|  |  | ||||||
		Loading…
	
		Reference in a new issue