Merge branch 'speed' into 'master'
Speed See merge request famedly/conduit!168
This commit is contained in:
		
						commit
						00c9ad12bd
					
				
					 11 changed files with 488 additions and 308 deletions
				
			
		
							
								
								
									
										77
									
								
								Cargo.lock
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										77
									
								
								Cargo.lock
									
									
									
										generated
									
									
									
								
							|  | @ -324,9 +324,9 @@ checksum = "ea221b5284a47e40033bf9b66f35f984ec0ea2931eb03505246cd27a963f981b" | |||
| 
 | ||||
| [[package]] | ||||
| name = "cpufeatures" | ||||
| version = "0.1.5" | ||||
| version = "0.2.1" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "66c99696f6c9dd7f35d486b9d04d7e6e202aa3e8c40d553f2fdf5e7e0c6a71ef" | ||||
| checksum = "95059428f66df56b63431fdb4e1947ed2190586af5c5a8a8b71122bdf5a7f469" | ||||
| dependencies = [ | ||||
|  "libc", | ||||
| ] | ||||
|  | @ -2061,9 +2061,8 @@ dependencies = [ | |||
| 
 | ||||
| [[package]] | ||||
| name = "ruma" | ||||
| version = "0.4.0" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "668031e3108d6a2cfbe6eca271d8698f4593440e71a44afdadcf67ce3cb93c1f" | ||||
| version = "0.3.0" | ||||
| source = "git+https://github.com/DevinR528/ruma?rev=c7860fcb89dbde636e2c83d0636934fb9924f40c#c7860fcb89dbde636e2c83d0636934fb9924f40c" | ||||
| dependencies = [ | ||||
|  "assign", | ||||
|  "js_int", | ||||
|  | @ -2084,8 +2083,7 @@ dependencies = [ | |||
| [[package]] | ||||
| name = "ruma-api" | ||||
| version = "0.18.3" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "f5f1843792b6749ec1ece62595cf99ad30bf9589c96bb237515235e71da396ea" | ||||
| source = "git+https://github.com/DevinR528/ruma?rev=c7860fcb89dbde636e2c83d0636934fb9924f40c#c7860fcb89dbde636e2c83d0636934fb9924f40c" | ||||
| dependencies = [ | ||||
|  "bytes", | ||||
|  "http", | ||||
|  | @ -2101,8 +2099,7 @@ dependencies = [ | |||
| [[package]] | ||||
| name = "ruma-api-macros" | ||||
| version = "0.18.3" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "7b18abda5cca94178d08b622bca042e1cbb5eb7d4ebf3a2a81590a3bb3c57008" | ||||
| source = "git+https://github.com/DevinR528/ruma?rev=c7860fcb89dbde636e2c83d0636934fb9924f40c#c7860fcb89dbde636e2c83d0636934fb9924f40c" | ||||
| dependencies = [ | ||||
|  "proc-macro-crate", | ||||
|  "proc-macro2", | ||||
|  | @ -2113,8 +2110,7 @@ dependencies = [ | |||
| [[package]] | ||||
| name = "ruma-appservice-api" | ||||
| version = "0.4.0" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "49369332a5f299e832e19661f92d49e08c345c3c6c4ab16e09cb31c5ff6da878" | ||||
| source = "git+https://github.com/DevinR528/ruma?rev=c7860fcb89dbde636e2c83d0636934fb9924f40c#c7860fcb89dbde636e2c83d0636934fb9924f40c" | ||||
| dependencies = [ | ||||
|  "ruma-api", | ||||
|  "ruma-common", | ||||
|  | @ -2128,8 +2124,7 @@ dependencies = [ | |||
| [[package]] | ||||
| name = "ruma-client-api" | ||||
| version = "0.12.2" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "9568a222c12cf6220e751484ab78feec28071f85965113a5bb802936a2920ff0" | ||||
| source = "git+https://github.com/DevinR528/ruma?rev=c7860fcb89dbde636e2c83d0636934fb9924f40c#c7860fcb89dbde636e2c83d0636934fb9924f40c" | ||||
| dependencies = [ | ||||
|  "assign", | ||||
|  "bytes", | ||||
|  | @ -2149,8 +2144,7 @@ dependencies = [ | |||
| [[package]] | ||||
| name = "ruma-common" | ||||
| version = "0.6.0" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "41d5b7605f58dc0d9cf1848cc7f1af2bae4e4bcd1d2b7a87bbb9864c8a785b91" | ||||
| source = "git+https://github.com/DevinR528/ruma?rev=c7860fcb89dbde636e2c83d0636934fb9924f40c#c7860fcb89dbde636e2c83d0636934fb9924f40c" | ||||
| dependencies = [ | ||||
|  "indexmap", | ||||
|  "js_int", | ||||
|  | @ -2164,9 +2158,8 @@ dependencies = [ | |||
| 
 | ||||
| [[package]] | ||||
| name = "ruma-events" | ||||
| version = "0.24.5" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "87801e1207cfebdee02e7997ebf181a1c9837260b78c1b8ce96b896a2bcb3763" | ||||
| version = "0.24.4" | ||||
| source = "git+https://github.com/DevinR528/ruma?rev=c7860fcb89dbde636e2c83d0636934fb9924f40c#c7860fcb89dbde636e2c83d0636934fb9924f40c" | ||||
| dependencies = [ | ||||
|  "indoc", | ||||
|  "js_int", | ||||
|  | @ -2181,9 +2174,8 @@ dependencies = [ | |||
| 
 | ||||
| [[package]] | ||||
| name = "ruma-events-macros" | ||||
| version = "0.24.5" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "5da4498845347de88adf1b7da4578e2ca7355ad4ce47b0976f6594bacf958660" | ||||
| version = "0.24.4" | ||||
| source = "git+https://github.com/DevinR528/ruma?rev=c7860fcb89dbde636e2c83d0636934fb9924f40c#c7860fcb89dbde636e2c83d0636934fb9924f40c" | ||||
| dependencies = [ | ||||
|  "proc-macro-crate", | ||||
|  "proc-macro2", | ||||
|  | @ -2194,8 +2186,7 @@ dependencies = [ | |||
| [[package]] | ||||
| name = "ruma-federation-api" | ||||
| version = "0.3.0" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "c61c9adbe1a29c301ae627604406d60102c89fc833b110cd35bbf29ae205ea6c" | ||||
| source = "git+https://github.com/DevinR528/ruma?rev=c7860fcb89dbde636e2c83d0636934fb9924f40c#c7860fcb89dbde636e2c83d0636934fb9924f40c" | ||||
| dependencies = [ | ||||
|  "js_int", | ||||
|  "ruma-api", | ||||
|  | @ -2210,8 +2201,7 @@ dependencies = [ | |||
| [[package]] | ||||
| name = "ruma-identifiers" | ||||
| version = "0.20.0" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "cb417d091e8dd5a633e4e5998231a156049d7fcc221045cfdc0642eb72067732" | ||||
| source = "git+https://github.com/DevinR528/ruma?rev=c7860fcb89dbde636e2c83d0636934fb9924f40c#c7860fcb89dbde636e2c83d0636934fb9924f40c" | ||||
| dependencies = [ | ||||
|  "paste", | ||||
|  "rand 0.8.4", | ||||
|  | @ -2225,8 +2215,7 @@ dependencies = [ | |||
| [[package]] | ||||
| name = "ruma-identifiers-macros" | ||||
| version = "0.20.0" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "c708edad7f605638f26c951cbad7501fbf28ab01009e5ca65ea5a2db74a882b1" | ||||
| source = "git+https://github.com/DevinR528/ruma?rev=c7860fcb89dbde636e2c83d0636934fb9924f40c#c7860fcb89dbde636e2c83d0636934fb9924f40c" | ||||
| dependencies = [ | ||||
|  "quote", | ||||
|  "ruma-identifiers-validation", | ||||
|  | @ -2236,14 +2225,15 @@ dependencies = [ | |||
| [[package]] | ||||
| name = "ruma-identifiers-validation" | ||||
| version = "0.5.0" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "42285e7fb5d5f2d5268e45bb683e36d5c6fd9fc1e11a4559ba3c3521f3bbb2cb" | ||||
| source = "git+https://github.com/DevinR528/ruma?rev=c7860fcb89dbde636e2c83d0636934fb9924f40c#c7860fcb89dbde636e2c83d0636934fb9924f40c" | ||||
| dependencies = [ | ||||
|  "thiserror", | ||||
| ] | ||||
| 
 | ||||
| [[package]] | ||||
| name = "ruma-identity-service-api" | ||||
| version = "0.3.0" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "e76e66e24f2d5a31511fbf6c79e79f67a7a6a98ebf48d72381b7d5bb6c09f035" | ||||
| source = "git+https://github.com/DevinR528/ruma?rev=c7860fcb89dbde636e2c83d0636934fb9924f40c#c7860fcb89dbde636e2c83d0636934fb9924f40c" | ||||
| dependencies = [ | ||||
|  "js_int", | ||||
|  "ruma-api", | ||||
|  | @ -2256,8 +2246,7 @@ dependencies = [ | |||
| [[package]] | ||||
| name = "ruma-push-gateway-api" | ||||
| version = "0.3.0" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "5ef5b29da7065efc5b1e1a8f61add7543c9ab4ecce5ee0dd1c1c5ecec83fbeec" | ||||
| source = "git+https://github.com/DevinR528/ruma?rev=c7860fcb89dbde636e2c83d0636934fb9924f40c#c7860fcb89dbde636e2c83d0636934fb9924f40c" | ||||
| dependencies = [ | ||||
|  "js_int", | ||||
|  "ruma-api", | ||||
|  | @ -2272,8 +2261,7 @@ dependencies = [ | |||
| [[package]] | ||||
| name = "ruma-serde" | ||||
| version = "0.5.0" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "8b2b22aae842e7ecda695e42b7b39d4558959d9d9a27acc2a16acf4f4f7f00c3" | ||||
| source = "git+https://github.com/DevinR528/ruma?rev=c7860fcb89dbde636e2c83d0636934fb9924f40c#c7860fcb89dbde636e2c83d0636934fb9924f40c" | ||||
| dependencies = [ | ||||
|  "bytes", | ||||
|  "form_urlencoded", | ||||
|  | @ -2287,8 +2275,7 @@ dependencies = [ | |||
| [[package]] | ||||
| name = "ruma-serde-macros" | ||||
| version = "0.5.0" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "243e9bef188b08f94c79bc2f8fd1eb307a9e636b2b8e4571acf8c7be16381d28" | ||||
| source = "git+https://github.com/DevinR528/ruma?rev=c7860fcb89dbde636e2c83d0636934fb9924f40c#c7860fcb89dbde636e2c83d0636934fb9924f40c" | ||||
| dependencies = [ | ||||
|  "proc-macro-crate", | ||||
|  "proc-macro2", | ||||
|  | @ -2299,8 +2286,7 @@ dependencies = [ | |||
| [[package]] | ||||
| name = "ruma-signatures" | ||||
| version = "0.9.0" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "4a4f64027165b59500162d10d435b1253898bf3ad4f5002cb0d56913fe7f76d7" | ||||
| source = "git+https://github.com/DevinR528/ruma?rev=c7860fcb89dbde636e2c83d0636934fb9924f40c#c7860fcb89dbde636e2c83d0636934fb9924f40c" | ||||
| dependencies = [ | ||||
|  "base64 0.13.0", | ||||
|  "ed25519-dalek", | ||||
|  | @ -2316,9 +2302,8 @@ dependencies = [ | |||
| 
 | ||||
| [[package]] | ||||
| name = "ruma-state-res" | ||||
| version = "0.4.0" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "796427aaa2d266238c5c1b1a6ca4640a4d282ec2cb2e844c69a8f3a262d3db15" | ||||
| version = "0.3.0" | ||||
| source = "git+https://github.com/DevinR528/ruma?rev=c7860fcb89dbde636e2c83d0636934fb9924f40c#c7860fcb89dbde636e2c83d0636934fb9924f40c" | ||||
| dependencies = [ | ||||
|  "itertools 0.10.1", | ||||
|  "js_int", | ||||
|  | @ -2522,9 +2507,9 @@ dependencies = [ | |||
| 
 | ||||
| [[package]] | ||||
| name = "serde_yaml" | ||||
| version = "0.8.19" | ||||
| version = "0.8.20" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "6375dbd828ed6964c3748e4ef6d18e7a175d408ffe184bca01698d0c73f915a9" | ||||
| checksum = "ad104641f3c958dab30eb3010e834c2622d1f3f4c530fef1dee20ad9485f3c09" | ||||
| dependencies = [ | ||||
|  "dtoa", | ||||
|  "indexmap", | ||||
|  | @ -2540,9 +2525,9 @@ checksum = "2579985fda508104f7587689507983eadd6a6e84dd35d6d115361f530916fa0d" | |||
| 
 | ||||
| [[package]] | ||||
| name = "sha2" | ||||
| version = "0.9.5" | ||||
| version = "0.9.6" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "b362ae5752fd2137731f9fa25fd4d9058af34666ca1966fb969119cc35719f12" | ||||
| checksum = "9204c41a1597a8c5af23c82d1c921cb01ec0a4c59e07a9c7306062829a3903f3" | ||||
| dependencies = [ | ||||
|  "block-buffer", | ||||
|  "cfg-if 1.0.0", | ||||
|  |  | |||
|  | @ -18,9 +18,9 @@ edition = "2018" | |||
| rocket = { version = "0.5.0-rc.1", features = ["tls"] } # Used to handle requests | ||||
| 
 | ||||
| # Used for matrix spec type definitions and helpers | ||||
| ruma = { version = "0.4.0", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-pre-spec", "unstable-exhaustive-types"] } | ||||
| #ruma = { version = "0.4.0", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-pre-spec", "unstable-exhaustive-types"] } | ||||
| #ruma = { git = "https://github.com/ruma/ruma", rev = "f5ab038e22421ed338396ece977b6b2844772ced", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-pre-spec", "unstable-exhaustive-types"] } | ||||
| #ruma = { git = "https://github.com/timokoesters/ruma", rev = "2215049b60a1c3358f5a52215adf1e7bb88619a1", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-pre-spec", "unstable-exhaustive-types"] } | ||||
| ruma = { git = "https://github.com/DevinR528/ruma", rev = "c7860fcb89dbde636e2c83d0636934fb9924f40c", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-pre-spec", "unstable-exhaustive-types"] } | ||||
| #ruma = { path = "../ruma/crates/ruma", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-pre-spec", "unstable-exhaustive-types"] } | ||||
| 
 | ||||
| # Used for long polling and federation sender, should be the same as rocket::tokio | ||||
|  |  | |||
|  | @ -46,7 +46,11 @@ where | |||
|     *reqwest_request.timeout_mut() = Some(Duration::from_secs(30)); | ||||
| 
 | ||||
|     let url = reqwest_request.url().clone(); | ||||
|     let mut response = globals.reqwest_client().execute(reqwest_request).await?; | ||||
|     let mut response = globals | ||||
|         .reqwest_client()? | ||||
|         .build()? | ||||
|         .execute(reqwest_request) | ||||
|         .await?; | ||||
| 
 | ||||
|     // reqwest::Response -> http::Response conversion
 | ||||
|     let status = response.status(); | ||||
|  |  | |||
|  | @ -1,5 +1,6 @@ | |||
| use super::SESSION_ID_LENGTH; | ||||
| use crate::{database::DatabaseGuard, utils, ConduitResult, Database, Error, Result, Ruma}; | ||||
| use rocket::futures::{prelude::*, stream::FuturesUnordered}; | ||||
| use ruma::{ | ||||
|     api::{ | ||||
|         client::{ | ||||
|  | @ -18,7 +19,7 @@ use ruma::{ | |||
|     DeviceId, DeviceKeyAlgorithm, UserId, | ||||
| }; | ||||
| use serde_json::json; | ||||
| use std::collections::{BTreeMap, HashSet}; | ||||
| use std::collections::{BTreeMap, HashMap, HashSet}; | ||||
| 
 | ||||
| #[cfg(feature = "conduit_bin")] | ||||
| use rocket::{get, post}; | ||||
|  | @ -294,7 +295,7 @@ pub async fn get_keys_helper<F: Fn(&UserId) -> bool>( | |||
|     let mut user_signing_keys = BTreeMap::new(); | ||||
|     let mut device_keys = BTreeMap::new(); | ||||
| 
 | ||||
|     let mut get_over_federation = BTreeMap::new(); | ||||
|     let mut get_over_federation = HashMap::new(); | ||||
| 
 | ||||
|     for (user_id, device_ids) in device_keys_input { | ||||
|         if user_id.server_name() != db.globals.server_name() { | ||||
|  | @ -364,22 +365,30 @@ pub async fn get_keys_helper<F: Fn(&UserId) -> bool>( | |||
| 
 | ||||
|     let mut failures = BTreeMap::new(); | ||||
| 
 | ||||
|     for (server, vec) in get_over_federation { | ||||
|         let mut device_keys_input_fed = BTreeMap::new(); | ||||
|         for (user_id, keys) in vec { | ||||
|             device_keys_input_fed.insert(user_id.clone(), keys.clone()); | ||||
|         } | ||||
|         match db | ||||
|             .sending | ||||
|             .send_federation_request( | ||||
|                 &db.globals, | ||||
|     let mut futures = get_over_federation | ||||
|         .into_iter() | ||||
|         .map(|(server, vec)| async move { | ||||
|             let mut device_keys_input_fed = BTreeMap::new(); | ||||
|             for (user_id, keys) in vec { | ||||
|                 device_keys_input_fed.insert(user_id.clone(), keys.clone()); | ||||
|             } | ||||
|             ( | ||||
|                 server, | ||||
|                 federation::keys::get_keys::v1::Request { | ||||
|                     device_keys: device_keys_input_fed, | ||||
|                 }, | ||||
|                 db.sending | ||||
|                     .send_federation_request( | ||||
|                         &db.globals, | ||||
|                         server, | ||||
|                         federation::keys::get_keys::v1::Request { | ||||
|                             device_keys: device_keys_input_fed, | ||||
|                         }, | ||||
|                     ) | ||||
|                     .await, | ||||
|             ) | ||||
|             .await | ||||
|         { | ||||
|         }) | ||||
|         .collect::<FuturesUnordered<_>>(); | ||||
| 
 | ||||
|     while let Some((server, response)) = futures.next().await { | ||||
|         match response { | ||||
|             Ok(response) => { | ||||
|                 master_keys.extend(response.master_keys); | ||||
|                 self_signing_keys.extend(response.self_signing_keys); | ||||
|  | @ -430,13 +439,15 @@ pub async fn claim_keys_helper( | |||
|         one_time_keys.insert(user_id.clone(), container); | ||||
|     } | ||||
| 
 | ||||
|     let mut failures = BTreeMap::new(); | ||||
| 
 | ||||
|     for (server, vec) in get_over_federation { | ||||
|         let mut one_time_keys_input_fed = BTreeMap::new(); | ||||
|         for (user_id, keys) in vec { | ||||
|             one_time_keys_input_fed.insert(user_id.clone(), keys.clone()); | ||||
|         } | ||||
|         // Ignore failures
 | ||||
|         let keys = db | ||||
|         if let Ok(keys) = db | ||||
|             .sending | ||||
|             .send_federation_request( | ||||
|                 &db.globals, | ||||
|  | @ -445,13 +456,16 @@ pub async fn claim_keys_helper( | |||
|                     one_time_keys: one_time_keys_input_fed, | ||||
|                 }, | ||||
|             ) | ||||
|             .await?; | ||||
| 
 | ||||
|         one_time_keys.extend(keys.one_time_keys); | ||||
|             .await | ||||
|         { | ||||
|             one_time_keys.extend(keys.one_time_keys); | ||||
|         } else { | ||||
|             failures.insert(server.to_string(), json!({})); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     Ok(claim_keys::Response { | ||||
|         failures: BTreeMap::new(), | ||||
|         failures, | ||||
|         one_time_keys, | ||||
|     }) | ||||
| } | ||||
|  |  | |||
|  | @ -256,8 +256,8 @@ async fn sync_helper( | |||
| 
 | ||||
|         // Calculates joined_member_count, invited_member_count and heroes
 | ||||
|         let calculate_counts = || { | ||||
|             let joined_member_count = db.rooms.room_members(&room_id).count(); | ||||
|             let invited_member_count = db.rooms.room_members_invited(&room_id).count(); | ||||
|             let joined_member_count = db.rooms.room_joined_count(&room_id)?.unwrap_or(0); | ||||
|             let invited_member_count = db.rooms.room_invited_count(&room_id)?.unwrap_or(0); | ||||
| 
 | ||||
|             // Recalculate heroes (first 5 members)
 | ||||
|             let mut heroes = Vec::new(); | ||||
|  | @ -407,64 +407,40 @@ async fn sync_helper( | |||
|                 }); | ||||
| 
 | ||||
|             if encrypted_room { | ||||
|                 for (user_id, current_member) in db | ||||
|                     .rooms | ||||
|                     .room_members(&room_id) | ||||
|                     .filter_map(|r| r.ok()) | ||||
|                     .filter_map(|user_id| { | ||||
|                         db.rooms | ||||
|                             .state_get( | ||||
|                                 current_shortstatehash, | ||||
|                                 &EventType::RoomMember, | ||||
|                                 user_id.as_str(), | ||||
|                             ) | ||||
|                             .ok() | ||||
|                             .flatten() | ||||
|                             .map(|current_member| (user_id, current_member)) | ||||
|                     }) | ||||
|                 { | ||||
|                     let current_membership = serde_json::from_value::< | ||||
|                         Raw<ruma::events::room::member::MemberEventContent>, | ||||
|                     >(current_member.content.clone()) | ||||
|                     .expect("Raw::from_value always works") | ||||
|                     .deserialize() | ||||
|                     .map_err(|_| Error::bad_database("Invalid PDU in database."))? | ||||
|                     .membership; | ||||
|                 for state_event in &state_events { | ||||
|                     if state_event.kind != EventType::RoomMember { | ||||
|                         continue; | ||||
|                     } | ||||
| 
 | ||||
|                     let since_membership = db | ||||
|                         .rooms | ||||
|                         .state_get( | ||||
|                             since_shortstatehash, | ||||
|                             &EventType::RoomMember, | ||||
|                             user_id.as_str(), | ||||
|                         )? | ||||
|                         .and_then(|since_member| { | ||||
|                             serde_json::from_value::< | ||||
|                                 Raw<ruma::events::room::member::MemberEventContent>, | ||||
|                             >(since_member.content.clone()) | ||||
|                             .expect("Raw::from_value always works") | ||||
|                             .deserialize() | ||||
|                             .map_err(|_| Error::bad_database("Invalid PDU in database.")) | ||||
|                             .ok() | ||||
|                         }) | ||||
|                         .map_or(MembershipState::Leave, |member| member.membership); | ||||
|                     if let Some(state_key) = &state_event.state_key { | ||||
|                         let user_id = UserId::try_from(state_key.clone()) | ||||
|                             .map_err(|_| Error::bad_database("Invalid UserId in member PDU."))?; | ||||
| 
 | ||||
|                     let user_id = UserId::try_from(user_id.clone()) | ||||
|                         .map_err(|_| Error::bad_database("Invalid UserId in member PDU."))?; | ||||
|                         if user_id == sender_user { | ||||
|                             continue; | ||||
|                         } | ||||
| 
 | ||||
|                     match (since_membership, current_membership) { | ||||
|                         (MembershipState::Leave, MembershipState::Join) => { | ||||
|                             // A new user joined an encrypted room
 | ||||
|                             if !share_encrypted_room(&db, &sender_user, &user_id, &room_id)? { | ||||
|                                 device_list_updates.insert(user_id); | ||||
|                         let new_membership = serde_json::from_value::< | ||||
|                             Raw<ruma::events::room::member::MemberEventContent>, | ||||
|                         >(state_event.content.clone()) | ||||
|                         .expect("Raw::from_value always works") | ||||
|                         .deserialize() | ||||
|                         .map_err(|_| Error::bad_database("Invalid PDU in database."))? | ||||
|                         .membership; | ||||
| 
 | ||||
|                         match new_membership { | ||||
|                             MembershipState::Join => { | ||||
|                                 // A new user joined an encrypted room
 | ||||
|                                 if !share_encrypted_room(&db, &sender_user, &user_id, &room_id)? { | ||||
|                                     device_list_updates.insert(user_id); | ||||
|                                 } | ||||
|                             } | ||||
|                             MembershipState::Leave => { | ||||
|                                 // Write down users that have left encrypted rooms we are in
 | ||||
|                                 left_encrypted_users.insert(user_id); | ||||
|                             } | ||||
|                             _ => {} | ||||
|                         } | ||||
|                         // TODO: Remove, this should never happen here, right?
 | ||||
|                         (MembershipState::Join, MembershipState::Leave) => { | ||||
|                             // Write down users that have left encrypted rooms we are in
 | ||||
|                             left_encrypted_users.insert(user_id); | ||||
|                         } | ||||
|                         _ => {} | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|  |  | |||
|  | @ -252,6 +252,7 @@ impl Database { | |||
|                 userroomid_joined: builder.open_tree("userroomid_joined")?, | ||||
|                 roomuserid_joined: builder.open_tree("roomuserid_joined")?, | ||||
|                 roomid_joinedcount: builder.open_tree("roomid_joinedcount")?, | ||||
|                 roomid_invitedcount: builder.open_tree("roomid_invitedcount")?, | ||||
|                 roomuseroncejoinedids: builder.open_tree("roomuseroncejoinedids")?, | ||||
|                 userroomid_invitestate: builder.open_tree("userroomid_invitestate")?, | ||||
|                 roomuserid_invitecount: builder.open_tree("roomuserid_invitecount")?, | ||||
|  | @ -277,6 +278,8 @@ impl Database { | |||
|                 statehash_shortstatehash: builder.open_tree("statehash_shortstatehash")?, | ||||
| 
 | ||||
|                 eventid_outlierpdu: builder.open_tree("eventid_outlierpdu")?, | ||||
|                 softfailedeventids: builder.open_tree("softfailedeventids")?, | ||||
| 
 | ||||
|                 referencedevents: builder.open_tree("referencedevents")?, | ||||
|                 pdu_cache: Mutex::new(LruCache::new(100_000)), | ||||
|                 auth_chain_cache: Mutex::new(LruCache::new(1_000_000)), | ||||
|  | @ -285,6 +288,7 @@ impl Database { | |||
|                 shortstatekey_cache: Mutex::new(LruCache::new(1_000_000)), | ||||
|                 statekeyshort_cache: Mutex::new(LruCache::new(1_000_000)), | ||||
|                 stateinfo_cache: Mutex::new(LruCache::new(1000)), | ||||
|                 our_real_users_cache: RwLock::new(HashMap::new()), | ||||
|             }, | ||||
|             account_data: account_data::AccountData { | ||||
|                 roomuserdataid_accountdata: builder.open_tree("roomuserdataid_accountdata")?, | ||||
|  | @ -442,7 +446,7 @@ impl Database { | |||
|                     let room_id = | ||||
|                         RoomId::try_from(utils::string_from_bytes(&roomid).unwrap()).unwrap(); | ||||
| 
 | ||||
|                     db.rooms.update_joined_count(&room_id)?; | ||||
|                     db.rooms.update_joined_count(&room_id, &db)?; | ||||
|                 } | ||||
| 
 | ||||
|                 db.globals.bump_database_version(6)?; | ||||
|  |  | |||
|  | @ -93,9 +93,8 @@ impl Engine { | |||
|     } | ||||
| 
 | ||||
|     pub fn flush_wal(self: &Arc<Self>) -> Result<()> { | ||||
|         // We use autocheckpoints
 | ||||
|         //self.write_lock()
 | ||||
|         //.pragma_update(Some(Main), "wal_checkpoint", &"TRUNCATE")?;
 | ||||
|         self.write_lock() | ||||
|             .pragma_update(Some(Main), "wal_checkpoint", &"TRUNCATE")?; | ||||
|         Ok(()) | ||||
|     } | ||||
| } | ||||
|  |  | |||
|  | @ -1,4 +1,4 @@ | |||
| use crate::{database::Config, utils, ConduitResult, Error, Result}; | ||||
| use crate::{database::Config, server_server::FedDest, utils, ConduitResult, Error, Result}; | ||||
| use ruma::{ | ||||
|     api::{ | ||||
|         client::r0::sync::sync_events, | ||||
|  | @ -6,25 +6,25 @@ use ruma::{ | |||
|     }, | ||||
|     DeviceId, EventId, MilliSecondsSinceUnixEpoch, RoomId, ServerName, ServerSigningKeyId, UserId, | ||||
| }; | ||||
| use rustls::{ServerCertVerifier, WebPKIVerifier}; | ||||
| use std::{ | ||||
|     collections::{BTreeMap, HashMap}, | ||||
|     fs, | ||||
|     future::Future, | ||||
|     net::IpAddr, | ||||
|     path::PathBuf, | ||||
|     sync::{Arc, Mutex, RwLock}, | ||||
|     time::{Duration, Instant}, | ||||
| }; | ||||
| use tokio::sync::{broadcast, watch::Receiver, Mutex as TokioMutex, Semaphore}; | ||||
| use tracing::{error, info}; | ||||
| use tracing::error; | ||||
| use trust_dns_resolver::TokioAsyncResolver; | ||||
| 
 | ||||
| use super::abstraction::Tree; | ||||
| 
 | ||||
| pub const COUNTER: &[u8] = b"c"; | ||||
| 
 | ||||
| type WellKnownMap = HashMap<Box<ServerName>, (String, String)>; | ||||
| type TlsNameMap = HashMap<String, webpki::DNSName>; | ||||
| type WellKnownMap = HashMap<Box<ServerName>, (FedDest, String)>; | ||||
| type TlsNameMap = HashMap<String, (Vec<IpAddr>, u16)>; | ||||
| type RateLimitState = (Instant, u32); // Time if last failed try, number of failed tries
 | ||||
| type SyncHandle = ( | ||||
|     Option<String>,                                         // since
 | ||||
|  | @ -37,7 +37,6 @@ pub struct Globals { | |||
|     pub(super) globals: Arc<dyn Tree>, | ||||
|     config: Config, | ||||
|     keypair: Arc<ruma::signatures::Ed25519KeyPair>, | ||||
|     reqwest_client: reqwest::Client, | ||||
|     dns_resolver: TokioAsyncResolver, | ||||
|     jwt_decoding_key: Option<jsonwebtoken::DecodingKey<'static>>, | ||||
|     pub(super) server_signingkeys: Arc<dyn Tree>, | ||||
|  | @ -51,40 +50,6 @@ pub struct Globals { | |||
|     pub rotate: RotationHandler, | ||||
| } | ||||
| 
 | ||||
| struct MatrixServerVerifier { | ||||
|     inner: WebPKIVerifier, | ||||
|     tls_name_override: Arc<RwLock<TlsNameMap>>, | ||||
| } | ||||
| 
 | ||||
| impl ServerCertVerifier for MatrixServerVerifier { | ||||
|     #[tracing::instrument(skip(self, roots, presented_certs, dns_name, ocsp_response))] | ||||
|     fn verify_server_cert( | ||||
|         &self, | ||||
|         roots: &rustls::RootCertStore, | ||||
|         presented_certs: &[rustls::Certificate], | ||||
|         dns_name: webpki::DNSNameRef<'_>, | ||||
|         ocsp_response: &[u8], | ||||
|     ) -> std::result::Result<rustls::ServerCertVerified, rustls::TLSError> { | ||||
|         if let Some(override_name) = self.tls_name_override.read().unwrap().get(dns_name.into()) { | ||||
|             let result = self.inner.verify_server_cert( | ||||
|                 roots, | ||||
|                 presented_certs, | ||||
|                 override_name.as_ref(), | ||||
|                 ocsp_response, | ||||
|             ); | ||||
|             if result.is_ok() { | ||||
|                 return result; | ||||
|             } | ||||
|             info!( | ||||
|                 "Server {:?} is non-compliant, retrying TLS verification with original name", | ||||
|                 dns_name | ||||
|             ); | ||||
|         } | ||||
|         self.inner | ||||
|             .verify_server_cert(roots, presented_certs, dns_name, ocsp_response) | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| /// Handles "rotation" of long-polling requests. "Rotation" in this context is similar to "rotation" of log files and the like.
 | ||||
| ///
 | ||||
| /// This is utilized to have sync workers return early and release read locks on the database.
 | ||||
|  | @ -162,24 +127,6 @@ impl Globals { | |||
|         }; | ||||
| 
 | ||||
|         let tls_name_override = Arc::new(RwLock::new(TlsNameMap::new())); | ||||
|         let verifier = Arc::new(MatrixServerVerifier { | ||||
|             inner: WebPKIVerifier::new(), | ||||
|             tls_name_override: tls_name_override.clone(), | ||||
|         }); | ||||
|         let mut tlsconfig = rustls::ClientConfig::new(); | ||||
|         tlsconfig.dangerous().set_certificate_verifier(verifier); | ||||
|         tlsconfig.root_store = | ||||
|             rustls_native_certs::load_native_certs().expect("Error loading system certificates"); | ||||
| 
 | ||||
|         let mut reqwest_client_builder = reqwest::Client::builder() | ||||
|             .connect_timeout(Duration::from_secs(30)) | ||||
|             .timeout(Duration::from_secs(60 * 3)) | ||||
|             .pool_max_idle_per_host(1) | ||||
|             .use_preconfigured_tls(tlsconfig); | ||||
|         if let Some(proxy) = config.proxy.to_proxy()? { | ||||
|             reqwest_client_builder = reqwest_client_builder.proxy(proxy); | ||||
|         } | ||||
|         let reqwest_client = reqwest_client_builder.build().unwrap(); | ||||
| 
 | ||||
|         let jwt_decoding_key = config | ||||
|             .jwt_secret | ||||
|  | @ -190,7 +137,6 @@ impl Globals { | |||
|             globals, | ||||
|             config, | ||||
|             keypair: Arc::new(keypair), | ||||
|             reqwest_client, | ||||
|             dns_resolver: TokioAsyncResolver::tokio_from_system_conf().map_err(|_| { | ||||
|                 Error::bad_config("Failed to set up trust dns resolver with system config.") | ||||
|             })?, | ||||
|  | @ -219,8 +165,16 @@ impl Globals { | |||
|     } | ||||
| 
 | ||||
|     /// Returns a reqwest client which can be used to send requests.
 | ||||
|     pub fn reqwest_client(&self) -> &reqwest::Client { | ||||
|         &self.reqwest_client | ||||
|     pub fn reqwest_client(&self) -> Result<reqwest::ClientBuilder> { | ||||
|         let mut reqwest_client_builder = reqwest::Client::builder() | ||||
|             .connect_timeout(Duration::from_secs(30)) | ||||
|             .timeout(Duration::from_secs(60 * 3)) | ||||
|             .pool_max_idle_per_host(1); | ||||
|         if let Some(proxy) = self.config.proxy.to_proxy()? { | ||||
|             reqwest_client_builder = reqwest_client_builder.proxy(proxy); | ||||
|         } | ||||
| 
 | ||||
|         Ok(reqwest_client_builder) | ||||
|     } | ||||
| 
 | ||||
|     #[tracing::instrument(skip(self))] | ||||
|  |  | |||
|  | @ -113,7 +113,11 @@ where | |||
|     //*reqwest_request.timeout_mut() = Some(Duration::from_secs(5));
 | ||||
| 
 | ||||
|     let url = reqwest_request.url().clone(); | ||||
|     let response = globals.reqwest_client().execute(reqwest_request).await; | ||||
|     let response = globals | ||||
|         .reqwest_client()? | ||||
|         .build()? | ||||
|         .execute(reqwest_request) | ||||
|         .await; | ||||
| 
 | ||||
|     match response { | ||||
|         Ok(mut response) => { | ||||
|  |  | |||
|  | @ -3,7 +3,7 @@ mod edus; | |||
| pub use edus::RoomEdus; | ||||
| use member::MembershipState; | ||||
| 
 | ||||
| use crate::{pdu::PduBuilder, utils, Database, Error, PduEvent, Result}; | ||||
| use crate::{pdu::PduBuilder, server_server, utils, Database, Error, PduEvent, Result}; | ||||
| use lru_cache::LruCache; | ||||
| use regex::Regex; | ||||
| use ring::digest; | ||||
|  | @ -26,7 +26,8 @@ use std::{ | |||
|     collections::{BTreeMap, HashMap, HashSet}, | ||||
|     convert::{TryFrom, TryInto}, | ||||
|     mem::size_of, | ||||
|     sync::{Arc, Mutex}, | ||||
|     sync::{Arc, Mutex, RwLock}, | ||||
|     time::Instant, | ||||
| }; | ||||
| use tokio::sync::MutexGuard; | ||||
| use tracing::{error, warn}; | ||||
|  | @ -58,6 +59,7 @@ pub struct Rooms { | |||
|     pub(super) userroomid_joined: Arc<dyn Tree>, | ||||
|     pub(super) roomuserid_joined: Arc<dyn Tree>, | ||||
|     pub(super) roomid_joinedcount: Arc<dyn Tree>, | ||||
|     pub(super) roomid_invitedcount: Arc<dyn Tree>, | ||||
|     pub(super) roomuseroncejoinedids: Arc<dyn Tree>, | ||||
|     pub(super) userroomid_invitestate: Arc<dyn Tree>, // InviteState = Vec<Raw<Pdu>>
 | ||||
|     pub(super) roomuserid_invitecount: Arc<dyn Tree>, // InviteCount = Count
 | ||||
|  | @ -89,16 +91,18 @@ pub struct Rooms { | |||
|     /// RoomId + EventId -> outlier PDU.
 | ||||
|     /// Any pdu that has passed the steps 1-8 in the incoming event /federation/send/txn.
 | ||||
|     pub(super) eventid_outlierpdu: Arc<dyn Tree>, | ||||
|     pub(super) softfailedeventids: Arc<dyn Tree>, | ||||
| 
 | ||||
|     /// RoomId + EventId -> Parent PDU EventId.
 | ||||
|     pub(super) referencedevents: Arc<dyn Tree>, | ||||
| 
 | ||||
|     pub(super) pdu_cache: Mutex<LruCache<EventId, Arc<PduEvent>>>, | ||||
|     pub(super) shorteventid_cache: Mutex<LruCache<u64, Arc<EventId>>>, | ||||
|     pub(super) auth_chain_cache: Mutex<LruCache<Vec<u64>, Arc<HashSet<u64>>>>, | ||||
|     pub(super) shorteventid_cache: Mutex<LruCache<u64, EventId>>, | ||||
|     pub(super) eventidshort_cache: Mutex<LruCache<EventId, u64>>, | ||||
|     pub(super) statekeyshort_cache: Mutex<LruCache<(EventType, String), u64>>, | ||||
|     pub(super) shortstatekey_cache: Mutex<LruCache<u64, (EventType, String)>>, | ||||
|     pub(super) our_real_users_cache: RwLock<HashMap<RoomId, Arc<HashSet<UserId>>>>, | ||||
|     pub(super) stateinfo_cache: Mutex< | ||||
|         LruCache< | ||||
|             u64, | ||||
|  | @ -116,7 +120,7 @@ impl Rooms { | |||
|     /// Builds a StateMap by iterating over all keys that start
 | ||||
|     /// with state_hash, this gives the full state for the given state_hash.
 | ||||
|     #[tracing::instrument(skip(self))] | ||||
|     pub fn state_full_ids(&self, shortstatehash: u64) -> Result<BTreeMap<u64, EventId>> { | ||||
|     pub fn state_full_ids(&self, shortstatehash: u64) -> Result<BTreeMap<u64, Arc<EventId>>> { | ||||
|         let full_state = self | ||||
|             .load_shortstatehash_info(shortstatehash)? | ||||
|             .pop() | ||||
|  | @ -167,7 +171,7 @@ impl Rooms { | |||
|         shortstatehash: u64, | ||||
|         event_type: &EventType, | ||||
|         state_key: &str, | ||||
|     ) -> Result<Option<EventId>> { | ||||
|     ) -> Result<Option<Arc<EventId>>> { | ||||
|         let shortstatekey = match self.get_shortstatekey(event_type, state_key)? { | ||||
|             Some(s) => s, | ||||
|             None => return Ok(None), | ||||
|  | @ -424,7 +428,7 @@ impl Rooms { | |||
|             } | ||||
|         } | ||||
| 
 | ||||
|         self.update_joined_count(room_id)?; | ||||
|         self.update_joined_count(room_id, &db)?; | ||||
| 
 | ||||
|         self.roomid_shortstatehash | ||||
|             .insert(room_id.as_bytes(), &new_shortstatehash.to_be_bytes())?; | ||||
|  | @ -523,7 +527,7 @@ impl Rooms { | |||
|     pub fn parse_compressed_state_event( | ||||
|         &self, | ||||
|         compressed_event: CompressedStateEvent, | ||||
|     ) -> Result<(u64, EventId)> { | ||||
|     ) -> Result<(u64, Arc<EventId>)> { | ||||
|         Ok(( | ||||
|             utils::u64_from_bytes(&compressed_event[0..size_of::<u64>()]) | ||||
|                 .expect("bytes have right length"), | ||||
|  | @ -839,14 +843,14 @@ impl Rooms { | |||
|     } | ||||
| 
 | ||||
|     #[tracing::instrument(skip(self))] | ||||
|     pub fn get_eventid_from_short(&self, shorteventid: u64) -> Result<EventId> { | ||||
|     pub fn get_eventid_from_short(&self, shorteventid: u64) -> Result<Arc<EventId>> { | ||||
|         if let Some(id) = self | ||||
|             .shorteventid_cache | ||||
|             .lock() | ||||
|             .unwrap() | ||||
|             .get_mut(&shorteventid) | ||||
|         { | ||||
|             return Ok(id.clone()); | ||||
|             return Ok(Arc::clone(id)); | ||||
|         } | ||||
| 
 | ||||
|         let bytes = self | ||||
|  | @ -854,15 +858,17 @@ impl Rooms { | |||
|             .get(&shorteventid.to_be_bytes())? | ||||
|             .ok_or_else(|| Error::bad_database("Shorteventid does not exist"))?; | ||||
| 
 | ||||
|         let event_id = EventId::try_from(utils::string_from_bytes(&bytes).map_err(|_| { | ||||
|             Error::bad_database("EventID in shorteventid_eventid is invalid unicode.") | ||||
|         })?) | ||||
|         .map_err(|_| Error::bad_database("EventId in shorteventid_eventid is invalid."))?; | ||||
|         let event_id = Arc::new( | ||||
|             EventId::try_from(utils::string_from_bytes(&bytes).map_err(|_| { | ||||
|                 Error::bad_database("EventID in shorteventid_eventid is invalid unicode.") | ||||
|             })?) | ||||
|             .map_err(|_| Error::bad_database("EventId in shorteventid_eventid is invalid."))?, | ||||
|         ); | ||||
| 
 | ||||
|         self.shorteventid_cache | ||||
|             .lock() | ||||
|             .unwrap() | ||||
|             .insert(shorteventid, event_id.clone()); | ||||
|             .insert(shorteventid, Arc::clone(&event_id)); | ||||
| 
 | ||||
|         Ok(event_id) | ||||
|     } | ||||
|  | @ -929,7 +935,7 @@ impl Rooms { | |||
|         room_id: &RoomId, | ||||
|         event_type: &EventType, | ||||
|         state_key: &str, | ||||
|     ) -> Result<Option<EventId>> { | ||||
|     ) -> Result<Option<Arc<EventId>>> { | ||||
|         if let Some(current_shortstatehash) = self.current_shortstatehash(room_id)? { | ||||
|             self.state_get_id(current_shortstatehash, event_type, state_key) | ||||
|         } else { | ||||
|  | @ -1226,9 +1232,19 @@ impl Rooms { | |||
|         self.eventid_outlierpdu.insert( | ||||
|             &event_id.as_bytes(), | ||||
|             &serde_json::to_vec(&pdu).expect("CanonicalJsonObject is valid"), | ||||
|         )?; | ||||
|         ) | ||||
|     } | ||||
| 
 | ||||
|         Ok(()) | ||||
|     #[tracing::instrument(skip(self))] | ||||
|     pub fn mark_event_soft_failed(&self, event_id: &EventId) -> Result<()> { | ||||
|         self.softfailedeventids.insert(&event_id.as_bytes(), &[]) | ||||
|     } | ||||
| 
 | ||||
|     #[tracing::instrument(skip(self))] | ||||
|     pub fn is_event_soft_failed(&self, event_id: &EventId) -> Result<bool> { | ||||
|         self.softfailedeventids | ||||
|             .get(&event_id.as_bytes()) | ||||
|             .map(|o| o.is_some()) | ||||
|     } | ||||
| 
 | ||||
|     /// Creates a new persisted data unit and adds it to a room.
 | ||||
|  | @ -1331,15 +1347,9 @@ impl Rooms { | |||
|         let mut notifies = Vec::new(); | ||||
|         let mut highlights = Vec::new(); | ||||
| 
 | ||||
|         for user in db | ||||
|             .rooms | ||||
|             .room_members(&pdu.room_id) | ||||
|             .filter_map(|r| r.ok()) | ||||
|             .filter(|user_id| user_id.server_name() == db.globals.server_name()) | ||||
|             .filter(|user_id| !db.users.is_deactivated(user_id).unwrap_or(true)) | ||||
|         { | ||||
|         for user in self.get_our_real_users(&pdu.room_id, db)?.iter() { | ||||
|             // Don't notify the user of their own events
 | ||||
|             if user == pdu.sender { | ||||
|             if user == &pdu.sender { | ||||
|                 continue; | ||||
|             } | ||||
| 
 | ||||
|  | @ -1515,6 +1525,85 @@ impl Rooms { | |||
|                                 "list_appservices" => { | ||||
|                                     db.admin.send(AdminCommand::ListAppservices); | ||||
|                                 } | ||||
|                                 "get_auth_chain" => { | ||||
|                                     if args.len() == 1 { | ||||
|                                         if let Ok(event_id) = EventId::try_from(args[0]) { | ||||
|                                             let start = Instant::now(); | ||||
|                                             let count = server_server::get_auth_chain( | ||||
|                                                 vec![Arc::new(event_id)], | ||||
|                                                 db, | ||||
|                                             )? | ||||
|                                             .count(); | ||||
|                                             let elapsed = start.elapsed(); | ||||
|                                             db.admin.send(AdminCommand::SendMessage( | ||||
|                                                 message::MessageEventContent::text_plain(format!( | ||||
|                                                     "Loaded auth chain with length {} in {:?}", | ||||
|                                                     count, elapsed | ||||
|                                                 )), | ||||
|                                             )); | ||||
|                                         } | ||||
|                                     } | ||||
|                                 } | ||||
|                                 "parse_pdu" => { | ||||
|                                     if body.len() > 2 | ||||
|                                         && body[0].trim() == "```" | ||||
|                                         && body.last().unwrap().trim() == "```" | ||||
|                                     { | ||||
|                                         let string = body[1..body.len() - 1].join("\n"); | ||||
|                                         match serde_json::from_str(&string) { | ||||
|                                             Ok(value) => { | ||||
|                                                 let event_id = EventId::try_from(&*format!( | ||||
|                                                     "${}", | ||||
|                                                     // Anything higher than version3 behaves the same
 | ||||
|                                                     ruma::signatures::reference_hash( | ||||
|                                                         &value, | ||||
|                                                         &RoomVersionId::Version6 | ||||
|                                                     ) | ||||
|                                                     .expect("ruma can calculate reference hashes") | ||||
|                                                 )) | ||||
|                                                 .expect( | ||||
|                                                     "ruma's reference hashes are valid event ids", | ||||
|                                                 ); | ||||
| 
 | ||||
|                                                 match serde_json::from_value::<PduEvent>( | ||||
|                                                     serde_json::to_value(value) | ||||
|                                                         .expect("value is json"), | ||||
|                                                 ) { | ||||
|                                                     Ok(pdu) => { | ||||
|                                                         db.admin.send(AdminCommand::SendMessage( | ||||
|                                                             message::MessageEventContent::text_plain( | ||||
|                                                                 format!("EventId: {:?}\n{:#?}", event_id, pdu), | ||||
|                                                             ), | ||||
|                                                         )); | ||||
|                                                     } | ||||
|                                                     Err(e) => { | ||||
|                                                         db.admin.send(AdminCommand::SendMessage( | ||||
|                                                             message::MessageEventContent::text_plain( | ||||
|                                                                 format!("EventId: {:?}\nCould not parse event: {}", event_id, e), | ||||
|                                                             ), | ||||
|                                                         )); | ||||
|                                                     } | ||||
|                                                 } | ||||
|                                             } | ||||
|                                             Err(e) => { | ||||
|                                                 db.admin.send(AdminCommand::SendMessage( | ||||
|                                                     message::MessageEventContent::text_plain( | ||||
|                                                         format!( | ||||
|                                                             "Invalid json in command body: {}", | ||||
|                                                             e | ||||
|                                                         ), | ||||
|                                                     ), | ||||
|                                                 )); | ||||
|                                             } | ||||
|                                         } | ||||
|                                     } else { | ||||
|                                         db.admin.send(AdminCommand::SendMessage( | ||||
|                                             message::MessageEventContent::text_plain( | ||||
|                                                 "Expected code block in command body.", | ||||
|                                             ), | ||||
|                                         )); | ||||
|                                     } | ||||
|                                 } | ||||
|                                 "get_pdu" => { | ||||
|                                     if args.len() == 1 { | ||||
|                                         if let Ok(event_id) = EventId::try_from(args[0]) { | ||||
|  | @ -1536,7 +1625,7 @@ impl Rooms { | |||
|                                                             if outlier { | ||||
|                                                                 "PDU is outlier" | ||||
|                                                             } else { "PDU was accepted"}, json_text), | ||||
|                                                             format!("<p>{}</p>\n<pre><code class=\"language-json\">{}\n</code></pre>\n", 
 | ||||
|                                                             format!("<p>{}</p>\n<pre><code class=\"language-json\">{}\n</code></pre>\n", | ||||
|                                                             if outlier { | ||||
|                                                                 "PDU is outlier" | ||||
|                                                             } else { "PDU was accepted"}, RawStr::new(&json_text).html_escape()) | ||||
|  | @ -2421,29 +2510,45 @@ impl Rooms { | |||
|         } | ||||
| 
 | ||||
|         if update_joined_count { | ||||
|             self.update_joined_count(room_id)?; | ||||
|             self.update_joined_count(room_id, db)?; | ||||
|         } | ||||
| 
 | ||||
|         Ok(()) | ||||
|     } | ||||
| 
 | ||||
|     #[tracing::instrument(skip(self))] | ||||
|     pub fn update_joined_count(&self, room_id: &RoomId) -> Result<()> { | ||||
|     #[tracing::instrument(skip(self, room_id, db))] | ||||
|     pub fn update_joined_count(&self, room_id: &RoomId, db: &Database) -> Result<()> { | ||||
|         let mut joinedcount = 0_u64; | ||||
|         let mut invitedcount = 0_u64; | ||||
|         let mut joined_servers = HashSet::new(); | ||||
|         let mut real_users = HashSet::new(); | ||||
| 
 | ||||
|         for joined in self.room_members(&room_id).filter_map(|r| r.ok()) { | ||||
|             joined_servers.insert(joined.server_name().to_owned()); | ||||
|             if joined.server_name() == db.globals.server_name() | ||||
|                 && !db.users.is_deactivated(&joined).unwrap_or(true) | ||||
|             { | ||||
|                 real_users.insert(joined); | ||||
|             } | ||||
|             joinedcount += 1; | ||||
|         } | ||||
| 
 | ||||
|         for invited in self.room_members_invited(&room_id).filter_map(|r| r.ok()) { | ||||
|             joined_servers.insert(invited.server_name().to_owned()); | ||||
|             invitedcount += 1; | ||||
|         } | ||||
| 
 | ||||
|         self.roomid_joinedcount | ||||
|             .insert(room_id.as_bytes(), &joinedcount.to_be_bytes())?; | ||||
| 
 | ||||
|         self.roomid_invitedcount | ||||
|             .insert(room_id.as_bytes(), &invitedcount.to_be_bytes())?; | ||||
| 
 | ||||
|         self.our_real_users_cache | ||||
|             .write() | ||||
|             .unwrap() | ||||
|             .insert(room_id.clone(), Arc::new(real_users)); | ||||
| 
 | ||||
|         for old_joined_server in self.room_servers(room_id).filter_map(|r| r.ok()) { | ||||
|             if !joined_servers.remove(&old_joined_server) { | ||||
|                 // Server not in room anymore
 | ||||
|  | @ -2477,6 +2582,32 @@ impl Rooms { | |||
|         Ok(()) | ||||
|     } | ||||
| 
 | ||||
|     #[tracing::instrument(skip(self, room_id, db))] | ||||
|     pub fn get_our_real_users( | ||||
|         &self, | ||||
|         room_id: &RoomId, | ||||
|         db: &Database, | ||||
|     ) -> Result<Arc<HashSet<UserId>>> { | ||||
|         let maybe = self | ||||
|             .our_real_users_cache | ||||
|             .read() | ||||
|             .unwrap() | ||||
|             .get(room_id) | ||||
|             .cloned(); | ||||
|         if let Some(users) = maybe { | ||||
|             Ok(users) | ||||
|         } else { | ||||
|             self.update_joined_count(room_id, &db)?; | ||||
|             Ok(Arc::clone( | ||||
|                 self.our_real_users_cache | ||||
|                     .read() | ||||
|                     .unwrap() | ||||
|                     .get(room_id) | ||||
|                     .unwrap(), | ||||
|             )) | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     #[tracing::instrument(skip(self, db))] | ||||
|     pub async fn leave_room( | ||||
|         &self, | ||||
|  | @ -2955,6 +3086,18 @@ impl Rooms { | |||
|             .transpose()?) | ||||
|     } | ||||
| 
 | ||||
|     #[tracing::instrument(skip(self))] | ||||
|     pub fn room_invited_count(&self, room_id: &RoomId) -> Result<Option<u64>> { | ||||
|         Ok(self | ||||
|             .roomid_invitedcount | ||||
|             .get(room_id.as_bytes())? | ||||
|             .map(|b| { | ||||
|                 utils::u64_from_bytes(&b) | ||||
|                     .map_err(|_| Error::bad_database("Invalid joinedcount in db.")) | ||||
|             }) | ||||
|             .transpose()?) | ||||
|     } | ||||
| 
 | ||||
|     /// Returns an iterator over all User IDs who ever joined a room.
 | ||||
|     #[tracing::instrument(skip(self))] | ||||
|     pub fn room_useroncejoined<'a>( | ||||
|  |  | |||
|  | @ -4,7 +4,7 @@ use crate::{ | |||
|     utils, ConduitResult, Database, Error, PduEvent, Result, Ruma, | ||||
| }; | ||||
| use get_profile_information::v1::ProfileField; | ||||
| use http::header::{HeaderValue, AUTHORIZATION, HOST}; | ||||
| use http::header::{HeaderValue, AUTHORIZATION}; | ||||
| use regex::Regex; | ||||
| use rocket::response::content::Json; | ||||
| use ruma::{ | ||||
|  | @ -83,7 +83,7 @@ use rocket::{get, post, put}; | |||
| /// FedDest::Named("198.51.100.5".to_owned(), "".to_owned());
 | ||||
| /// ```
 | ||||
| #[derive(Clone, Debug, PartialEq)] | ||||
| enum FedDest { | ||||
| pub enum FedDest { | ||||
|     Literal(SocketAddr), | ||||
|     Named(String, String), | ||||
| } | ||||
|  | @ -109,6 +109,13 @@ impl FedDest { | |||
|             Self::Named(host, _) => host.clone(), | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     fn port(&self) -> Option<u16> { | ||||
|         match &self { | ||||
|             Self::Literal(addr) => Some(addr.port()), | ||||
|             Self::Named(_, port) => port[1..].parse().ok(), | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| #[tracing::instrument(skip(globals, request))] | ||||
|  | @ -124,41 +131,34 @@ where | |||
|         return Err(Error::bad_config("Federation is disabled.")); | ||||
|     } | ||||
| 
 | ||||
|     let maybe_result = globals | ||||
|     let mut write_destination_to_cache = false; | ||||
| 
 | ||||
|     let cached_result = globals | ||||
|         .actual_destination_cache | ||||
|         .read() | ||||
|         .unwrap() | ||||
|         .get(destination) | ||||
|         .cloned(); | ||||
| 
 | ||||
|     let (actual_destination, host) = if let Some(result) = maybe_result { | ||||
|     let (actual_destination, host) = if let Some(result) = cached_result { | ||||
|         result | ||||
|     } else { | ||||
|         write_destination_to_cache = true; | ||||
| 
 | ||||
|         let result = find_actual_destination(globals, &destination).await; | ||||
|         let (actual_destination, host) = result.clone(); | ||||
|         let result_string = (result.0.into_https_string(), result.1.into_uri_string()); | ||||
|         globals | ||||
|             .actual_destination_cache | ||||
|             .write() | ||||
|             .unwrap() | ||||
|             .insert(Box::<ServerName>::from(destination), result_string.clone()); | ||||
|         let dest_hostname = actual_destination.hostname(); | ||||
|         let host_hostname = host.hostname(); | ||||
|         if dest_hostname != host_hostname { | ||||
|             globals.tls_name_override.write().unwrap().insert( | ||||
|                 dest_hostname, | ||||
|                 webpki::DNSNameRef::try_from_ascii_str(&host_hostname) | ||||
|                     .unwrap() | ||||
|                     .to_owned(), | ||||
|             ); | ||||
|         } | ||||
|         result_string | ||||
| 
 | ||||
|         (result.0, result.1.clone().into_uri_string()) | ||||
|     }; | ||||
| 
 | ||||
|     let actual_destination_str = actual_destination.clone().into_https_string(); | ||||
| 
 | ||||
|     let mut http_request = request | ||||
|         .try_into_http_request::<Vec<u8>>(&actual_destination, SendAccessToken::IfRequired("")) | ||||
|         .try_into_http_request::<Vec<u8>>(&actual_destination_str, SendAccessToken::IfRequired("")) | ||||
|         .map_err(|e| { | ||||
|             warn!("Failed to find destination {}: {}", actual_destination, e); | ||||
|             warn!( | ||||
|                 "Failed to find destination {}: {}", | ||||
|                 actual_destination_str, e | ||||
|             ); | ||||
|             Error::BadServerResponse("Invalid destination") | ||||
|         })?; | ||||
| 
 | ||||
|  | @ -224,15 +224,26 @@ where | |||
|         } | ||||
|     } | ||||
| 
 | ||||
|     http_request | ||||
|         .headers_mut() | ||||
|         .insert(HOST, HeaderValue::from_str(&host).unwrap()); | ||||
| 
 | ||||
|     let reqwest_request = reqwest::Request::try_from(http_request) | ||||
|         .expect("all http requests are valid reqwest requests"); | ||||
| 
 | ||||
|     let url = reqwest_request.url().clone(); | ||||
|     let response = globals.reqwest_client().execute(reqwest_request).await; | ||||
| 
 | ||||
|     let mut client = globals.reqwest_client()?; | ||||
|     if let Some((override_name, port)) = globals | ||||
|         .tls_name_override | ||||
|         .read() | ||||
|         .unwrap() | ||||
|         .get(&actual_destination.hostname()) | ||||
|     { | ||||
|         client = client.resolve( | ||||
|             &actual_destination.hostname(), | ||||
|             SocketAddr::new(override_name[0], *port), | ||||
|         ); | ||||
|         // port will be ignored
 | ||||
|     } | ||||
| 
 | ||||
|     let response = client.build()?.execute(reqwest_request).await; | ||||
| 
 | ||||
|     match response { | ||||
|         Ok(mut response) => { | ||||
|  | @ -271,6 +282,13 @@ where | |||
| 
 | ||||
|             if status == 200 { | ||||
|                 let response = T::IncomingResponse::try_from_http_response(http_response); | ||||
|                 if response.is_ok() && write_destination_to_cache { | ||||
|                     globals.actual_destination_cache.write().unwrap().insert( | ||||
|                         Box::<ServerName>::from(destination), | ||||
|                         (actual_destination, host), | ||||
|                     ); | ||||
|                 } | ||||
| 
 | ||||
|                 response.map_err(|e| { | ||||
|                     warn!( | ||||
|                         "Invalid 200 response from {} on: {} {}", | ||||
|  | @ -339,20 +357,49 @@ async fn find_actual_destination( | |||
|                 match request_well_known(globals, &destination.as_str()).await { | ||||
|                     // 3: A .well-known file is available
 | ||||
|                     Some(delegated_hostname) => { | ||||
|                         hostname = delegated_hostname.clone(); | ||||
|                         hostname = add_port_to_hostname(&delegated_hostname).into_uri_string(); | ||||
|                         match get_ip_with_port(&delegated_hostname) { | ||||
|                             Some(host_and_port) => host_and_port, // 3.1: IP literal in .well-known file
 | ||||
|                             None => { | ||||
|                                 if let Some(pos) = destination_str.find(':') { | ||||
|                                 if let Some(pos) = delegated_hostname.find(':') { | ||||
|                                     // 3.2: Hostname with port in .well-known file
 | ||||
|                                     let (host, port) = destination_str.split_at(pos); | ||||
|                                     let (host, port) = delegated_hostname.split_at(pos); | ||||
|                                     FedDest::Named(host.to_string(), port.to_string()) | ||||
|                                 } else { | ||||
|                                     match query_srv_record(globals, &delegated_hostname).await { | ||||
|                                     // Delegated hostname has no port in this branch
 | ||||
|                                     if let Some(hostname_override) = | ||||
|                                         query_srv_record(globals, &delegated_hostname).await | ||||
|                                     { | ||||
|                                         // 3.3: SRV lookup successful
 | ||||
|                                         Some(hostname) => hostname, | ||||
|                                         let force_port = hostname_override.port(); | ||||
| 
 | ||||
|                                         if let Ok(override_ip) = globals | ||||
|                                             .dns_resolver() | ||||
|                                             .lookup_ip(hostname_override.hostname()) | ||||
|                                             .await | ||||
|                                         { | ||||
|                                             globals.tls_name_override.write().unwrap().insert( | ||||
|                                                 delegated_hostname.clone(), | ||||
|                                                 ( | ||||
|                                                     override_ip.iter().collect(), | ||||
|                                                     force_port.unwrap_or(8448), | ||||
|                                                 ), | ||||
|                                             ); | ||||
|                                         } else { | ||||
|                                             warn!("Using SRV record, but could not resolve to IP"); | ||||
|                                         } | ||||
| 
 | ||||
|                                         if let Some(port) = force_port { | ||||
|                                             FedDest::Named( | ||||
|                                                 delegated_hostname, | ||||
|                                                 format!(":{}", port.to_string()), | ||||
|                                             ) | ||||
|                                         } else { | ||||
|                                             add_port_to_hostname(&delegated_hostname) | ||||
|                                         } | ||||
|                                     } else { | ||||
|                                         // 3.4: No SRV records, just use the hostname from .well-known
 | ||||
|                                         None => add_port_to_hostname(&delegated_hostname), | ||||
|                                         add_port_to_hostname(&delegated_hostname) | ||||
|                                     } | ||||
|                                 } | ||||
|                             } | ||||
|  | @ -362,7 +409,31 @@ async fn find_actual_destination( | |||
|                     None => { | ||||
|                         match query_srv_record(globals, &destination_str).await { | ||||
|                             // 4: SRV record found
 | ||||
|                             Some(hostname) => hostname, | ||||
|                             Some(hostname_override) => { | ||||
|                                 let force_port = hostname_override.port(); | ||||
| 
 | ||||
|                                 if let Ok(override_ip) = globals | ||||
|                                     .dns_resolver() | ||||
|                                     .lookup_ip(hostname_override.hostname()) | ||||
|                                     .await | ||||
|                                 { | ||||
|                                     globals.tls_name_override.write().unwrap().insert( | ||||
|                                         hostname.clone(), | ||||
|                                         (override_ip.iter().collect(), force_port.unwrap_or(8448)), | ||||
|                                     ); | ||||
|                                 } else { | ||||
|                                     warn!("Using SRV record, but could not resolve to IP"); | ||||
|                                 } | ||||
| 
 | ||||
|                                 if let Some(port) = force_port { | ||||
|                                     FedDest::Named( | ||||
|                                         hostname.clone(), | ||||
|                                         format!(":{}", port.to_string()), | ||||
|                                     ) | ||||
|                                 } else { | ||||
|                                     add_port_to_hostname(&hostname) | ||||
|                                 } | ||||
|                             } | ||||
|                             // 5: No SRV record found
 | ||||
|                             None => add_port_to_hostname(&destination_str), | ||||
|                         } | ||||
|  | @ -377,12 +448,12 @@ async fn find_actual_destination( | |||
|     let hostname = if let Ok(addr) = hostname.parse::<SocketAddr>() { | ||||
|         FedDest::Literal(addr) | ||||
|     } else if let Ok(addr) = hostname.parse::<IpAddr>() { | ||||
|         FedDest::Named(addr.to_string(), "".to_string()) | ||||
|         FedDest::Named(addr.to_string(), ":8448".to_string()) | ||||
|     } else if let Some(pos) = hostname.find(':') { | ||||
|         let (host, port) = hostname.split_at(pos); | ||||
|         FedDest::Named(host.to_string(), port.to_string()) | ||||
|     } else { | ||||
|         FedDest::Named(hostname, "".to_string()) | ||||
|         FedDest::Named(hostname, ":8448".to_string()) | ||||
|     }; | ||||
|     (actual_destination, hostname) | ||||
| } | ||||
|  | @ -423,6 +494,9 @@ pub async fn request_well_known( | |||
|     let body: serde_json::Value = serde_json::from_str( | ||||
|         &globals | ||||
|             .reqwest_client() | ||||
|             .ok()? | ||||
|             .build() | ||||
|             .ok()? | ||||
|             .get(&format!( | ||||
|                 "https://{}/.well-known/matrix/server", | ||||
|                 destination | ||||
|  | @ -893,7 +967,12 @@ pub async fn handle_incoming_pdu<'a>( | |||
|     // 9. Fetch any missing prev events doing all checks listed here starting at 1. These are timeline events
 | ||||
|     let mut graph = HashMap::new(); | ||||
|     let mut eventid_info = HashMap::new(); | ||||
|     let mut todo_outlier_stack = incoming_pdu.prev_events.clone(); | ||||
|     let mut todo_outlier_stack = incoming_pdu | ||||
|         .prev_events | ||||
|         .iter() | ||||
|         .cloned() | ||||
|         .map(Arc::new) | ||||
|         .collect::<Vec<_>>(); | ||||
| 
 | ||||
|     let mut amount = 0; | ||||
| 
 | ||||
|  | @ -929,13 +1008,13 @@ pub async fn handle_incoming_pdu<'a>( | |||
|                     amount += 1; | ||||
|                     for prev_prev in &pdu.prev_events { | ||||
|                         if !graph.contains_key(prev_prev) { | ||||
|                             todo_outlier_stack.push(dbg!(prev_prev.clone())); | ||||
|                             todo_outlier_stack.push(dbg!(Arc::new(prev_prev.clone()))); | ||||
|                         } | ||||
|                     } | ||||
| 
 | ||||
|                     graph.insert( | ||||
|                         prev_event_id.clone(), | ||||
|                         pdu.prev_events.iter().cloned().collect(), | ||||
|                         pdu.prev_events.iter().cloned().map(Arc::new).collect(), | ||||
|                     ); | ||||
|                     eventid_info.insert(prev_event_id.clone(), (pdu, json)); | ||||
|                 } else { | ||||
|  | @ -964,9 +1043,9 @@ pub async fn handle_incoming_pdu<'a>( | |||
|                 MilliSecondsSinceUnixEpoch( | ||||
|                     eventid_info | ||||
|                         .get(event_id) | ||||
|                         .map_or_else(|| uint!(0), |info| info.0.origin_server_ts.clone()), | ||||
|                         .map_or_else(|| uint!(0), |info| info.0.origin_server_ts), | ||||
|                 ), | ||||
|                 ruma::event_id!("$notimportant"), | ||||
|                 Arc::new(ruma::event_id!("$notimportant")), | ||||
|             )) | ||||
|         }) | ||||
|         .map_err(|_| "Error sorting prev events".to_owned())?; | ||||
|  | @ -1084,7 +1163,12 @@ fn handle_outlier_pdu<'a>( | |||
|         fetch_and_handle_outliers( | ||||
|             db, | ||||
|             origin, | ||||
|             &incoming_pdu.auth_events, | ||||
|             &incoming_pdu | ||||
|                 .auth_events | ||||
|                 .iter() | ||||
|                 .cloned() | ||||
|                 .map(Arc::new) | ||||
|                 .collect::<Vec<_>>(), | ||||
|             &create_event, | ||||
|             &room_id, | ||||
|             pub_key_map, | ||||
|  | @ -1100,13 +1184,13 @@ fn handle_outlier_pdu<'a>( | |||
|         // Build map of auth events
 | ||||
|         let mut auth_events = HashMap::new(); | ||||
|         for id in &incoming_pdu.auth_events { | ||||
|             let auth_event = db | ||||
|                 .rooms | ||||
|                 .get_pdu(id) | ||||
|                 .map_err(|e| e.to_string())? | ||||
|                 .ok_or_else(|| { | ||||
|                     "Auth event not found, event failed recursive auth checks.".to_string() | ||||
|                 })?; | ||||
|             let auth_event = match db.rooms.get_pdu(id).map_err(|e| e.to_string())? { | ||||
|                 Some(e) => e, | ||||
|                 None => { | ||||
|                     warn!("Could not find auth event {}", id); | ||||
|                     continue; | ||||
|                 } | ||||
|             }; | ||||
| 
 | ||||
|             match auth_events.entry(( | ||||
|                 auth_event.kind.clone(), | ||||
|  | @ -1153,7 +1237,7 @@ fn handle_outlier_pdu<'a>( | |||
|         if !state_res::event_auth::auth_check( | ||||
|             &room_version, | ||||
|             &incoming_pdu, | ||||
|             previous_create.clone(), | ||||
|             previous_create, | ||||
|             None, // TODO: third party invite
 | ||||
|             |k, s| auth_events.get(&(k.clone(), s.to_owned())).map(Arc::clone), | ||||
|         ) | ||||
|  | @ -1187,6 +1271,15 @@ async fn upgrade_outlier_to_timeline_pdu( | |||
|     if let Ok(Some(pduid)) = db.rooms.get_pdu_id(&incoming_pdu.event_id) { | ||||
|         return Ok(Some(pduid)); | ||||
|     } | ||||
| 
 | ||||
|     if db | ||||
|         .rooms | ||||
|         .is_event_soft_failed(&incoming_pdu.event_id) | ||||
|         .map_err(|_| "Failed to ask db for soft fail".to_owned())? | ||||
|     { | ||||
|         return Err("Event has been soft failed".into()); | ||||
|     } | ||||
| 
 | ||||
|     // 10. Fetch missing state and auth chain events by calling /state_ids at backwards extremities
 | ||||
|     //     doing all the checks in this list starting at 1. These are not timeline events.
 | ||||
| 
 | ||||
|  | @ -1219,7 +1312,7 @@ async fn upgrade_outlier_to_timeline_pdu( | |||
|                     .get_or_create_shortstatekey(&prev_pdu.kind, state_key, &db.globals) | ||||
|                     .map_err(|_| "Failed to create shortstatekey.".to_owned())?; | ||||
| 
 | ||||
|                 state.insert(shortstatekey, prev_event.clone()); | ||||
|                 state.insert(shortstatekey, Arc::new(prev_event.clone())); | ||||
|                 // Now it's the state after the pdu
 | ||||
|             } | ||||
| 
 | ||||
|  | @ -1249,7 +1342,11 @@ async fn upgrade_outlier_to_timeline_pdu( | |||
|                 let state_vec = fetch_and_handle_outliers( | ||||
|                     &db, | ||||
|                     origin, | ||||
|                     &res.pdu_ids, | ||||
|                     &res.pdu_ids | ||||
|                         .iter() | ||||
|                         .cloned() | ||||
|                         .map(Arc::new) | ||||
|                         .collect::<Vec<_>>(), | ||||
|                     &create_event, | ||||
|                     &room_id, | ||||
|                     pub_key_map, | ||||
|  | @ -1270,7 +1367,7 @@ async fn upgrade_outlier_to_timeline_pdu( | |||
| 
 | ||||
|                     match state.entry(shortstatekey) { | ||||
|                         btree_map::Entry::Vacant(v) => { | ||||
|                             v.insert(pdu.event_id.clone()); | ||||
|                             v.insert(Arc::new(pdu.event_id.clone())); | ||||
|                         } | ||||
|                         btree_map::Entry::Occupied(_) => return Err( | ||||
|                             "State event's type and state_key combination exists multiple times." | ||||
|  | @ -1286,7 +1383,9 @@ async fn upgrade_outlier_to_timeline_pdu( | |||
|                     .map_err(|_| "Failed to talk to db.")? | ||||
|                     .expect("Room exists"); | ||||
| 
 | ||||
|                 if state.get(&create_shortstatekey) != Some(&create_event.event_id) { | ||||
|                 if state.get(&create_shortstatekey).map(|id| id.as_ref()) | ||||
|                     != Some(&create_event.event_id) | ||||
|                 { | ||||
|                     return Err("Incoming event refers to wrong create event.".to_owned()); | ||||
|                 } | ||||
| 
 | ||||
|  | @ -1451,7 +1550,7 @@ async fn upgrade_outlier_to_timeline_pdu( | |||
|                     .rooms | ||||
|                     .get_or_create_shortstatekey(&leaf_pdu.kind, state_key, &db.globals) | ||||
|                     .map_err(|_| "Failed to create shortstatekey.".to_owned())?; | ||||
|                 leaf_state.insert(shortstatekey, leaf_pdu.event_id.clone()); | ||||
|                 leaf_state.insert(shortstatekey, Arc::new(leaf_pdu.event_id.clone())); | ||||
|                 // Now it's the state after the pdu
 | ||||
|             } | ||||
| 
 | ||||
|  | @ -1466,9 +1565,9 @@ async fn upgrade_outlier_to_timeline_pdu( | |||
|                 .get_or_create_shortstatekey(&incoming_pdu.kind, state_key, &db.globals) | ||||
|                 .map_err(|_| "Failed to create shortstatekey.".to_owned())?; | ||||
| 
 | ||||
|             state_after.insert(shortstatekey, incoming_pdu.event_id.clone()); | ||||
|             state_after.insert(shortstatekey, Arc::new(incoming_pdu.event_id.clone())); | ||||
|         } | ||||
|         fork_states.push(state_after.clone()); | ||||
|         fork_states.push(state_after); | ||||
| 
 | ||||
|         let mut update_state = false; | ||||
|         // 14. Use state resolution to find new room state
 | ||||
|  | @ -1593,6 +1692,9 @@ async fn upgrade_outlier_to_timeline_pdu( | |||
|     if soft_fail { | ||||
|         // Soft fail, we keep the event as an outlier but don't add it to the timeline
 | ||||
|         warn!("Event was soft failed: {:?}", incoming_pdu); | ||||
|         db.rooms | ||||
|             .mark_event_soft_failed(&incoming_pdu.event_id) | ||||
|             .map_err(|_| "Failed to set soft failed flag".to_owned())?; | ||||
|         return Err("Event has been soft failed".into()); | ||||
|     } | ||||
| 
 | ||||
|  | @ -1614,7 +1716,7 @@ async fn upgrade_outlier_to_timeline_pdu( | |||
| pub(crate) fn fetch_and_handle_outliers<'a>( | ||||
|     db: &'a Database, | ||||
|     origin: &'a ServerName, | ||||
|     events: &'a [EventId], | ||||
|     events: &'a [Arc<EventId>], | ||||
|     create_event: &'a PduEvent, | ||||
|     room_id: &'a RoomId, | ||||
|     pub_key_map: &'a RwLock<BTreeMap<String, BTreeMap<String, String>>>, | ||||
|  | @ -1665,20 +1767,25 @@ pub(crate) fn fetch_and_handle_outliers<'a>( | |||
|                     { | ||||
|                         Ok(res) => { | ||||
|                             warn!("Got {} over federation", id); | ||||
|                             let (event_id, value) = | ||||
|                             let (calculated_event_id, value) = | ||||
|                                 match crate::pdu::gen_event_id_canonical_json(&res.pdu) { | ||||
|                                     Ok(t) => t, | ||||
|                                     Err(_) => { | ||||
|                                         back_off(id.clone()); | ||||
|                                         back_off((**id).clone()); | ||||
|                                         continue; | ||||
|                                     } | ||||
|                                 }; | ||||
| 
 | ||||
|                             if calculated_event_id != **id { | ||||
|                                 warn!("Server didn't return event id we requested: requested: {}, we got {}. Event: {:?}", | ||||
|                                     id, calculated_event_id, &res.pdu); | ||||
|                             } | ||||
| 
 | ||||
|                             // This will also fetch the auth chain
 | ||||
|                             match handle_outlier_pdu( | ||||
|                                 origin, | ||||
|                                 create_event, | ||||
|                                 &event_id, | ||||
|                                 &id, | ||||
|                                 &room_id, | ||||
|                                 value.clone(), | ||||
|                                 db, | ||||
|  | @ -1689,14 +1796,14 @@ pub(crate) fn fetch_and_handle_outliers<'a>( | |||
|                                 Ok((pdu, json)) => (pdu, Some(json)), | ||||
|                                 Err(e) => { | ||||
|                                     warn!("Authentication of event {} failed: {:?}", id, e); | ||||
|                                     back_off(id.clone()); | ||||
|                                     back_off((**id).clone()); | ||||
|                                     continue; | ||||
|                                 } | ||||
|                             } | ||||
|                         } | ||||
|                         Err(_) => { | ||||
|                             warn!("Failed to fetch event: {}", id); | ||||
|                             back_off(id.clone()); | ||||
|                             back_off((**id).clone()); | ||||
|                             continue; | ||||
|                         } | ||||
|                     } | ||||
|  | @ -1971,24 +2078,18 @@ fn append_incoming_pdu( | |||
| } | ||||
| 
 | ||||
| #[tracing::instrument(skip(starting_events, db))] | ||||
| fn get_auth_chain( | ||||
|     starting_events: Vec<EventId>, | ||||
| pub fn get_auth_chain( | ||||
|     starting_events: Vec<Arc<EventId>>, | ||||
|     db: &Database, | ||||
| ) -> Result<impl Iterator<Item = EventId> + '_> { | ||||
| ) -> Result<impl Iterator<Item = Arc<EventId>> + '_> { | ||||
|     const NUM_BUCKETS: usize = 50; | ||||
| 
 | ||||
|     let mut buckets = vec![BTreeSet::new(); NUM_BUCKETS]; | ||||
| 
 | ||||
|     for id in starting_events { | ||||
|         if let Some(pdu) = db.rooms.get_pdu(&id)? { | ||||
|             for auth_event in &pdu.auth_events { | ||||
|                 let short = db | ||||
|                     .rooms | ||||
|                     .get_or_create_shorteventid(&auth_event, &db.globals)?; | ||||
|                 let bucket_id = (short % NUM_BUCKETS as u64) as usize; | ||||
|                 buckets[bucket_id].insert((short, auth_event.clone())); | ||||
|             } | ||||
|         } | ||||
|         let short = db.rooms.get_or_create_shorteventid(&id, &db.globals)?; | ||||
|         let bucket_id = (short % NUM_BUCKETS as u64) as usize; | ||||
|         buckets[bucket_id].insert((short, id.clone())); | ||||
|     } | ||||
| 
 | ||||
|     let mut full_auth_chain = HashSet::new(); | ||||
|  | @ -2000,10 +2101,6 @@ fn get_auth_chain( | |||
|             continue; | ||||
|         } | ||||
| 
 | ||||
|         // The code below will only get the auth chains, not the events in the chunk. So let's add
 | ||||
|         // them first
 | ||||
|         full_auth_chain.extend(chunk.iter().map(|(id, _)| id)); | ||||
| 
 | ||||
|         let chunk_key = chunk | ||||
|             .iter() | ||||
|             .map(|(short, _)| short) | ||||
|  | @ -2178,12 +2275,12 @@ pub fn get_event_authorization_route( | |||
|         return Err(Error::bad_config("Federation is disabled.")); | ||||
|     } | ||||
| 
 | ||||
|     let auth_chain_ids = get_auth_chain(vec![body.event_id.clone()], &db)?; | ||||
|     let auth_chain_ids = get_auth_chain(vec![Arc::new(body.event_id.clone())], &db)?; | ||||
| 
 | ||||
|     Ok(get_event_authorization::v1::Response { | ||||
|         auth_chain: auth_chain_ids | ||||
|             .filter_map(|id| Some(db.rooms.get_pdu_json(&id).ok()??)) | ||||
|             .map(|event| PduEvent::convert_to_outgoing_federation_event(event)) | ||||
|             .filter_map(|id| db.rooms.get_pdu_json(&id).ok()?) | ||||
|             .map(PduEvent::convert_to_outgoing_federation_event) | ||||
|             .collect(), | ||||
|     } | ||||
|     .into()) | ||||
|  | @ -2221,7 +2318,7 @@ pub fn get_room_state_route( | |||
|         }) | ||||
|         .collect(); | ||||
| 
 | ||||
|     let auth_chain_ids = get_auth_chain(vec![body.event_id.clone()], &db)?; | ||||
|     let auth_chain_ids = get_auth_chain(vec![Arc::new(body.event_id.clone())], &db)?; | ||||
| 
 | ||||
|     Ok(get_room_state::v1::Response { | ||||
|         auth_chain: auth_chain_ids | ||||
|  | @ -2262,13 +2359,13 @@ pub fn get_room_state_ids_route( | |||
|         .rooms | ||||
|         .state_full_ids(shortstatehash)? | ||||
|         .into_iter() | ||||
|         .map(|(_, id)| id) | ||||
|         .map(|(_, id)| (*id).clone()) | ||||
|         .collect(); | ||||
| 
 | ||||
|     let auth_chain_ids = get_auth_chain(vec![body.event_id.clone()], &db)?; | ||||
|     let auth_chain_ids = get_auth_chain(vec![Arc::new(body.event_id.clone())], &db)?; | ||||
| 
 | ||||
|     Ok(get_room_state_ids::v1::Response { | ||||
|         auth_chain_ids: auth_chain_ids.collect(), | ||||
|         auth_chain_ids: auth_chain_ids.map(|id| (*id).clone()).collect(), | ||||
|         pdu_ids, | ||||
|     } | ||||
|     .into()) | ||||
|  |  | |||
		Loading…
	
		Reference in a new issue