fix: batch key fetching
This commit is contained in:
		
							parent
							
								
									c53d79e287
								
							
						
					
					
						commit
						4b39d7cb64
					
				
					 4 changed files with 147 additions and 146 deletions
				
			
		
							
								
								
									
										47
									
								
								Cargo.lock
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										47
									
								
								Cargo.lock
									
									
									
										generated
									
									
									
								
							|  | @ -2062,8 +2062,8 @@ dependencies = [ | |||
| 
 | ||||
| [[package]] | ||||
| name = "ruma" | ||||
| version = "0.3.0" | ||||
| source = "git+https://github.com/DevinR528/ruma?rev=c7860fcb89dbde636e2c83d0636934fb9924f40c#c7860fcb89dbde636e2c83d0636934fb9924f40c" | ||||
| version = "0.4.0" | ||||
| source = "git+https://github.com/timokoesters/ruma?rev=50c1db7e0a3a21fc794b0cce3b64285a4c750c71#50c1db7e0a3a21fc794b0cce3b64285a4c750c71" | ||||
| dependencies = [ | ||||
|  "assign", | ||||
|  "js_int", | ||||
|  | @ -2084,7 +2084,7 @@ dependencies = [ | |||
| [[package]] | ||||
| name = "ruma-api" | ||||
| version = "0.18.3" | ||||
| source = "git+https://github.com/DevinR528/ruma?rev=c7860fcb89dbde636e2c83d0636934fb9924f40c#c7860fcb89dbde636e2c83d0636934fb9924f40c" | ||||
| source = "git+https://github.com/timokoesters/ruma?rev=50c1db7e0a3a21fc794b0cce3b64285a4c750c71#50c1db7e0a3a21fc794b0cce3b64285a4c750c71" | ||||
| dependencies = [ | ||||
|  "bytes", | ||||
|  "http", | ||||
|  | @ -2100,7 +2100,7 @@ dependencies = [ | |||
| [[package]] | ||||
| name = "ruma-api-macros" | ||||
| version = "0.18.3" | ||||
| source = "git+https://github.com/DevinR528/ruma?rev=c7860fcb89dbde636e2c83d0636934fb9924f40c#c7860fcb89dbde636e2c83d0636934fb9924f40c" | ||||
| source = "git+https://github.com/timokoesters/ruma?rev=50c1db7e0a3a21fc794b0cce3b64285a4c750c71#50c1db7e0a3a21fc794b0cce3b64285a4c750c71" | ||||
| dependencies = [ | ||||
|  "proc-macro-crate", | ||||
|  "proc-macro2", | ||||
|  | @ -2111,7 +2111,7 @@ dependencies = [ | |||
| [[package]] | ||||
| name = "ruma-appservice-api" | ||||
| version = "0.4.0" | ||||
| source = "git+https://github.com/DevinR528/ruma?rev=c7860fcb89dbde636e2c83d0636934fb9924f40c#c7860fcb89dbde636e2c83d0636934fb9924f40c" | ||||
| source = "git+https://github.com/timokoesters/ruma?rev=50c1db7e0a3a21fc794b0cce3b64285a4c750c71#50c1db7e0a3a21fc794b0cce3b64285a4c750c71" | ||||
| dependencies = [ | ||||
|  "ruma-api", | ||||
|  "ruma-common", | ||||
|  | @ -2125,7 +2125,7 @@ dependencies = [ | |||
| [[package]] | ||||
| name = "ruma-client-api" | ||||
| version = "0.12.2" | ||||
| source = "git+https://github.com/DevinR528/ruma?rev=c7860fcb89dbde636e2c83d0636934fb9924f40c#c7860fcb89dbde636e2c83d0636934fb9924f40c" | ||||
| source = "git+https://github.com/timokoesters/ruma?rev=50c1db7e0a3a21fc794b0cce3b64285a4c750c71#50c1db7e0a3a21fc794b0cce3b64285a4c750c71" | ||||
| dependencies = [ | ||||
|  "assign", | ||||
|  "bytes", | ||||
|  | @ -2145,7 +2145,7 @@ dependencies = [ | |||
| [[package]] | ||||
| name = "ruma-common" | ||||
| version = "0.6.0" | ||||
| source = "git+https://github.com/DevinR528/ruma?rev=c7860fcb89dbde636e2c83d0636934fb9924f40c#c7860fcb89dbde636e2c83d0636934fb9924f40c" | ||||
| source = "git+https://github.com/timokoesters/ruma?rev=50c1db7e0a3a21fc794b0cce3b64285a4c750c71#50c1db7e0a3a21fc794b0cce3b64285a4c750c71" | ||||
| dependencies = [ | ||||
|  "indexmap", | ||||
|  "js_int", | ||||
|  | @ -2159,8 +2159,8 @@ dependencies = [ | |||
| 
 | ||||
| [[package]] | ||||
| name = "ruma-events" | ||||
| version = "0.24.4" | ||||
| source = "git+https://github.com/DevinR528/ruma?rev=c7860fcb89dbde636e2c83d0636934fb9924f40c#c7860fcb89dbde636e2c83d0636934fb9924f40c" | ||||
| version = "0.24.5" | ||||
| source = "git+https://github.com/timokoesters/ruma?rev=50c1db7e0a3a21fc794b0cce3b64285a4c750c71#50c1db7e0a3a21fc794b0cce3b64285a4c750c71" | ||||
| dependencies = [ | ||||
|  "indoc", | ||||
|  "js_int", | ||||
|  | @ -2175,8 +2175,8 @@ dependencies = [ | |||
| 
 | ||||
| [[package]] | ||||
| name = "ruma-events-macros" | ||||
| version = "0.24.4" | ||||
| source = "git+https://github.com/DevinR528/ruma?rev=c7860fcb89dbde636e2c83d0636934fb9924f40c#c7860fcb89dbde636e2c83d0636934fb9924f40c" | ||||
| version = "0.24.5" | ||||
| source = "git+https://github.com/timokoesters/ruma?rev=50c1db7e0a3a21fc794b0cce3b64285a4c750c71#50c1db7e0a3a21fc794b0cce3b64285a4c750c71" | ||||
| dependencies = [ | ||||
|  "proc-macro-crate", | ||||
|  "proc-macro2", | ||||
|  | @ -2186,8 +2186,8 @@ dependencies = [ | |||
| 
 | ||||
| [[package]] | ||||
| name = "ruma-federation-api" | ||||
| version = "0.3.0" | ||||
| source = "git+https://github.com/DevinR528/ruma?rev=c7860fcb89dbde636e2c83d0636934fb9924f40c#c7860fcb89dbde636e2c83d0636934fb9924f40c" | ||||
| version = "0.3.1" | ||||
| source = "git+https://github.com/timokoesters/ruma?rev=50c1db7e0a3a21fc794b0cce3b64285a4c750c71#50c1db7e0a3a21fc794b0cce3b64285a4c750c71" | ||||
| dependencies = [ | ||||
|  "js_int", | ||||
|  "ruma-api", | ||||
|  | @ -2202,7 +2202,7 @@ dependencies = [ | |||
| [[package]] | ||||
| name = "ruma-identifiers" | ||||
| version = "0.20.0" | ||||
| source = "git+https://github.com/DevinR528/ruma?rev=c7860fcb89dbde636e2c83d0636934fb9924f40c#c7860fcb89dbde636e2c83d0636934fb9924f40c" | ||||
| source = "git+https://github.com/timokoesters/ruma?rev=50c1db7e0a3a21fc794b0cce3b64285a4c750c71#50c1db7e0a3a21fc794b0cce3b64285a4c750c71" | ||||
| dependencies = [ | ||||
|  "paste", | ||||
|  "rand 0.8.4", | ||||
|  | @ -2216,7 +2216,7 @@ dependencies = [ | |||
| [[package]] | ||||
| name = "ruma-identifiers-macros" | ||||
| version = "0.20.0" | ||||
| source = "git+https://github.com/DevinR528/ruma?rev=c7860fcb89dbde636e2c83d0636934fb9924f40c#c7860fcb89dbde636e2c83d0636934fb9924f40c" | ||||
| source = "git+https://github.com/timokoesters/ruma?rev=50c1db7e0a3a21fc794b0cce3b64285a4c750c71#50c1db7e0a3a21fc794b0cce3b64285a4c750c71" | ||||
| dependencies = [ | ||||
|  "quote", | ||||
|  "ruma-identifiers-validation", | ||||
|  | @ -2226,7 +2226,7 @@ dependencies = [ | |||
| [[package]] | ||||
| name = "ruma-identifiers-validation" | ||||
| version = "0.5.0" | ||||
| source = "git+https://github.com/DevinR528/ruma?rev=c7860fcb89dbde636e2c83d0636934fb9924f40c#c7860fcb89dbde636e2c83d0636934fb9924f40c" | ||||
| source = "git+https://github.com/timokoesters/ruma?rev=50c1db7e0a3a21fc794b0cce3b64285a4c750c71#50c1db7e0a3a21fc794b0cce3b64285a4c750c71" | ||||
| dependencies = [ | ||||
|  "thiserror", | ||||
| ] | ||||
|  | @ -2234,7 +2234,7 @@ dependencies = [ | |||
| [[package]] | ||||
| name = "ruma-identity-service-api" | ||||
| version = "0.3.0" | ||||
| source = "git+https://github.com/DevinR528/ruma?rev=c7860fcb89dbde636e2c83d0636934fb9924f40c#c7860fcb89dbde636e2c83d0636934fb9924f40c" | ||||
| source = "git+https://github.com/timokoesters/ruma?rev=50c1db7e0a3a21fc794b0cce3b64285a4c750c71#50c1db7e0a3a21fc794b0cce3b64285a4c750c71" | ||||
| dependencies = [ | ||||
|  "js_int", | ||||
|  "ruma-api", | ||||
|  | @ -2247,7 +2247,7 @@ dependencies = [ | |||
| [[package]] | ||||
| name = "ruma-push-gateway-api" | ||||
| version = "0.3.0" | ||||
| source = "git+https://github.com/DevinR528/ruma?rev=c7860fcb89dbde636e2c83d0636934fb9924f40c#c7860fcb89dbde636e2c83d0636934fb9924f40c" | ||||
| source = "git+https://github.com/timokoesters/ruma?rev=50c1db7e0a3a21fc794b0cce3b64285a4c750c71#50c1db7e0a3a21fc794b0cce3b64285a4c750c71" | ||||
| dependencies = [ | ||||
|  "js_int", | ||||
|  "ruma-api", | ||||
|  | @ -2262,7 +2262,7 @@ dependencies = [ | |||
| [[package]] | ||||
| name = "ruma-serde" | ||||
| version = "0.5.0" | ||||
| source = "git+https://github.com/DevinR528/ruma?rev=c7860fcb89dbde636e2c83d0636934fb9924f40c#c7860fcb89dbde636e2c83d0636934fb9924f40c" | ||||
| source = "git+https://github.com/timokoesters/ruma?rev=50c1db7e0a3a21fc794b0cce3b64285a4c750c71#50c1db7e0a3a21fc794b0cce3b64285a4c750c71" | ||||
| dependencies = [ | ||||
|  "bytes", | ||||
|  "form_urlencoded", | ||||
|  | @ -2276,7 +2276,7 @@ dependencies = [ | |||
| [[package]] | ||||
| name = "ruma-serde-macros" | ||||
| version = "0.5.0" | ||||
| source = "git+https://github.com/DevinR528/ruma?rev=c7860fcb89dbde636e2c83d0636934fb9924f40c#c7860fcb89dbde636e2c83d0636934fb9924f40c" | ||||
| source = "git+https://github.com/timokoesters/ruma?rev=50c1db7e0a3a21fc794b0cce3b64285a4c750c71#50c1db7e0a3a21fc794b0cce3b64285a4c750c71" | ||||
| dependencies = [ | ||||
|  "proc-macro-crate", | ||||
|  "proc-macro2", | ||||
|  | @ -2287,7 +2287,7 @@ dependencies = [ | |||
| [[package]] | ||||
| name = "ruma-signatures" | ||||
| version = "0.9.0" | ||||
| source = "git+https://github.com/DevinR528/ruma?rev=c7860fcb89dbde636e2c83d0636934fb9924f40c#c7860fcb89dbde636e2c83d0636934fb9924f40c" | ||||
| source = "git+https://github.com/timokoesters/ruma?rev=50c1db7e0a3a21fc794b0cce3b64285a4c750c71#50c1db7e0a3a21fc794b0cce3b64285a4c750c71" | ||||
| dependencies = [ | ||||
|  "base64 0.13.0", | ||||
|  "ed25519-dalek", | ||||
|  | @ -2303,12 +2303,11 @@ dependencies = [ | |||
| 
 | ||||
| [[package]] | ||||
| name = "ruma-state-res" | ||||
| version = "0.3.0" | ||||
| source = "git+https://github.com/DevinR528/ruma?rev=c7860fcb89dbde636e2c83d0636934fb9924f40c#c7860fcb89dbde636e2c83d0636934fb9924f40c" | ||||
| version = "0.4.0" | ||||
| source = "git+https://github.com/timokoesters/ruma?rev=50c1db7e0a3a21fc794b0cce3b64285a4c750c71#50c1db7e0a3a21fc794b0cce3b64285a4c750c71" | ||||
| dependencies = [ | ||||
|  "itertools 0.10.1", | ||||
|  "js_int", | ||||
|  "maplit", | ||||
|  "ruma-common", | ||||
|  "ruma-events", | ||||
|  "ruma-identifiers", | ||||
|  |  | |||
|  | @ -20,7 +20,7 @@ rocket = { version = "0.5.0-rc.1", features = ["tls"] } # Used to handle request | |||
| # Used for matrix spec type definitions and helpers | ||||
| #ruma = { version = "0.4.0", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-pre-spec", "unstable-exhaustive-types"] } | ||||
| #ruma = { git = "https://github.com/ruma/ruma", rev = "f5ab038e22421ed338396ece977b6b2844772ced", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-pre-spec", "unstable-exhaustive-types"] } | ||||
| ruma = { git = "https://github.com/DevinR528/ruma", rev = "c7860fcb89dbde636e2c83d0636934fb9924f40c", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-pre-spec", "unstable-exhaustive-types"] } | ||||
| ruma = { git = "https://github.com/timokoesters/ruma", rev = "50c1db7e0a3a21fc794b0cce3b64285a4c750c71", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-pre-spec", "unstable-exhaustive-types"] } | ||||
| #ruma = { path = "../ruma/crates/ruma", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-pre-spec", "unstable-exhaustive-types"] } | ||||
| 
 | ||||
| # Used for long polling and federation sender, should be the same as rocket::tokio | ||||
|  |  | |||
|  | @ -422,7 +422,7 @@ impl RoomEdus { | |||
|     } | ||||
| 
 | ||||
|     /// Sets all users to offline who have been quiet for too long.
 | ||||
|     fn presence_maintain( | ||||
|     fn _presence_maintain( | ||||
|         &self, | ||||
|         rooms: &super::Rooms, | ||||
|         globals: &super::super::globals::Globals, | ||||
|  | @ -489,13 +489,13 @@ impl RoomEdus { | |||
|     } | ||||
| 
 | ||||
|     /// Returns an iterator over the most recent presence updates that happened after the event with id `since`.
 | ||||
|     #[tracing::instrument(skip(self, globals, rooms))] | ||||
|     #[tracing::instrument(skip(self, since, _rooms, _globals))] | ||||
|     pub fn presence_since( | ||||
|         &self, | ||||
|         room_id: &RoomId, | ||||
|         since: u64, | ||||
|         rooms: &super::Rooms, | ||||
|         globals: &super::super::globals::Globals, | ||||
|         _rooms: &super::Rooms, | ||||
|         _globals: &super::super::globals::Globals, | ||||
|     ) -> Result<HashMap<UserId, PresenceEvent>> { | ||||
|         //self.presence_maintain(rooms, globals)?;
 | ||||
| 
 | ||||
|  |  | |||
|  | @ -6,7 +6,10 @@ use crate::{ | |||
| use get_profile_information::v1::ProfileField; | ||||
| use http::header::{HeaderValue, AUTHORIZATION}; | ||||
| use regex::Regex; | ||||
| use rocket::{futures, response::content::Json}; | ||||
| use rocket::{ | ||||
|     futures::{prelude::*, stream::FuturesUnordered}, | ||||
|     response::content::Json, | ||||
| }; | ||||
| use ruma::{ | ||||
|     api::{ | ||||
|         client::error::{Error as RumaError, ErrorKind}, | ||||
|  | @ -61,7 +64,7 @@ use std::{ | |||
|     net::{IpAddr, SocketAddr}, | ||||
|     pin::Pin, | ||||
|     result::Result as StdResult, | ||||
|     sync::{Arc, RwLock}, | ||||
|     sync::{Arc, RwLock, RwLockWriteGuard}, | ||||
|     time::{Duration, Instant, SystemTime}, | ||||
| }; | ||||
| use tokio::sync::{MutexGuard, Semaphore}; | ||||
|  | @ -3281,101 +3284,96 @@ pub(crate) async fn fetch_required_signing_keys( | |||
| 
 | ||||
| // Gets a list of servers for which we don't have the signing key yet. We go over
 | ||||
| // the PDUs and either cache the key or add it to the list that needs to be retrieved.
 | ||||
| fn get_missing_servers_for_pdus( | ||||
|     pdus: &Vec<Raw<Pdu>>, | ||||
| fn get_server_keys_from_cache( | ||||
|     pdu: &Raw<Pdu>, | ||||
|     servers: &mut BTreeMap<Box<ServerName>, BTreeMap<ServerSigningKeyId, QueryCriteria>>, | ||||
|     room_version: &RoomVersionId, | ||||
|     pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, String>>>, | ||||
|     pub_key_map: &mut RwLockWriteGuard<'_, BTreeMap<String, BTreeMap<String, String>>>, | ||||
|     db: &Database, | ||||
| ) -> Result<()> { | ||||
|     let mut pkm = pub_key_map | ||||
|         .write() | ||||
|         .map_err(|_| Error::bad_database("RwLock is poisoned."))?; | ||||
|     for pdu in pdus { | ||||
|         let value = serde_json::from_str::<CanonicalJsonObject>(pdu.json().get()).map_err(|e| { | ||||
|             error!("Invalid PDU in server response: {:?}: {:?}", pdu, e); | ||||
|             Error::BadServerResponse("Invalid PDU in server response") | ||||
|     let value = serde_json::from_str::<CanonicalJsonObject>(pdu.json().get()).map_err(|e| { | ||||
|         error!("Invalid PDU in server response: {:?}: {:?}", pdu, e); | ||||
|         Error::BadServerResponse("Invalid PDU in server response") | ||||
|     })?; | ||||
| 
 | ||||
|     let event_id = EventId::try_from(&*format!( | ||||
|         "${}", | ||||
|         ruma::signatures::reference_hash(&value, &room_version) | ||||
|             .expect("ruma can calculate reference hashes") | ||||
|     )) | ||||
|     .expect("ruma's reference hashes are valid event ids"); | ||||
| 
 | ||||
|     if let Some((time, tries)) = db | ||||
|         .globals | ||||
|         .bad_event_ratelimiter | ||||
|         .read() | ||||
|         .unwrap() | ||||
|         .get(&event_id) | ||||
|     { | ||||
|         // Exponential backoff
 | ||||
|         let mut min_elapsed_duration = Duration::from_secs(30) * (*tries) * (*tries); | ||||
|         if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { | ||||
|             min_elapsed_duration = Duration::from_secs(60 * 60 * 24); | ||||
|         } | ||||
| 
 | ||||
|         if time.elapsed() < min_elapsed_duration { | ||||
|             debug!("Backing off from {}", event_id); | ||||
|             return Err(Error::BadServerResponse("bad event, still backing off")); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     let signatures = value | ||||
|         .get("signatures") | ||||
|         .ok_or(Error::BadServerResponse( | ||||
|             "No signatures in server response pdu.", | ||||
|         ))? | ||||
|         .as_object() | ||||
|         .ok_or(Error::BadServerResponse( | ||||
|             "Invalid signatures object in server response pdu.", | ||||
|         ))?; | ||||
| 
 | ||||
|     for (signature_server, signature) in signatures { | ||||
|         let signature_object = signature.as_object().ok_or(Error::BadServerResponse( | ||||
|             "Invalid signatures content object in server response pdu.", | ||||
|         ))?; | ||||
| 
 | ||||
|         let signature_ids = signature_object.keys().cloned().collect::<Vec<_>>(); | ||||
| 
 | ||||
|         let contains_all_ids = | ||||
|             |keys: &BTreeMap<String, String>| signature_ids.iter().all(|id| keys.contains_key(id)); | ||||
| 
 | ||||
|         let origin = &Box::<ServerName>::try_from(&**signature_server).map_err(|_| { | ||||
|             Error::BadServerResponse("Invalid servername in signatures of server response pdu.") | ||||
|         })?; | ||||
|         let event_id = EventId::try_from(&*format!( | ||||
|             "${}", | ||||
|             ruma::signatures::reference_hash(&value, &room_version) | ||||
|                 .expect("ruma can calculate reference hashes") | ||||
|         )) | ||||
|         .expect("ruma's reference hashes are valid event ids"); | ||||
| 
 | ||||
|         if let Some((time, tries)) = db | ||||
|         if servers.contains_key(origin) || pub_key_map.contains_key(origin.as_str()) { | ||||
|             continue; | ||||
|         } | ||||
| 
 | ||||
|         trace!("Loading signing keys for {}", origin); | ||||
| 
 | ||||
|         let result = db | ||||
|             .globals | ||||
|             .bad_event_ratelimiter | ||||
|             .read() | ||||
|             .unwrap() | ||||
|             .get(&event_id) | ||||
|         { | ||||
|             // Exponential backoff
 | ||||
|             let mut min_elapsed_duration = Duration::from_secs(30) * (*tries) * (*tries); | ||||
|             if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { | ||||
|                 min_elapsed_duration = Duration::from_secs(60 * 60 * 24); | ||||
|             } | ||||
|             .signing_keys_for(origin)? | ||||
|             .into_iter() | ||||
|             .map(|(k, v)| (k.to_string(), v.key)) | ||||
|             .collect::<BTreeMap<_, _>>(); | ||||
| 
 | ||||
|             if time.elapsed() < min_elapsed_duration { | ||||
|                 debug!("Backing off from {}", event_id); | ||||
|                 return Err(Error::BadServerResponse("bad event, still backing off")); | ||||
|             } | ||||
|         if !contains_all_ids(&result) { | ||||
|             trace!("Signing key not loaded for {}", origin); | ||||
|             servers.insert( | ||||
|                 origin.clone(), | ||||
|                 BTreeMap::<ServerSigningKeyId, QueryCriteria>::new(), | ||||
|             ); | ||||
|         } | ||||
| 
 | ||||
|         let signatures = value | ||||
|             .get("signatures") | ||||
|             .ok_or(Error::BadServerResponse( | ||||
|                 "No signatures in server response pdu.", | ||||
|             ))? | ||||
|             .as_object() | ||||
|             .ok_or(Error::BadServerResponse( | ||||
|                 "Invalid signatures object in server response pdu.", | ||||
|             ))?; | ||||
| 
 | ||||
|         for (signature_server, signature) in signatures { | ||||
|             let signature_object = signature.as_object().ok_or(Error::BadServerResponse( | ||||
|                 "Invalid signatures content object in server response pdu.", | ||||
|             ))?; | ||||
| 
 | ||||
|             let signature_ids = signature_object.keys().cloned().collect::<Vec<_>>(); | ||||
| 
 | ||||
|             let contains_all_ids = |keys: &BTreeMap<String, String>| { | ||||
|                 signature_ids.iter().all(|id| keys.contains_key(id)) | ||||
|             }; | ||||
| 
 | ||||
|             let origin = &Box::<ServerName>::try_from(&**signature_server).map_err(|_| { | ||||
|                 Error::BadServerResponse("Invalid servername in signatures of server response pdu.") | ||||
|             })?; | ||||
| 
 | ||||
|             if servers.contains_key(origin) { | ||||
|                 continue; | ||||
|             } | ||||
| 
 | ||||
|             trace!("Loading signing keys for {}", origin); | ||||
| 
 | ||||
|             let result = db | ||||
|                 .globals | ||||
|                 .signing_keys_for(origin)? | ||||
|                 .into_iter() | ||||
|                 .map(|(k, v)| (k.to_string(), v.key)) | ||||
|                 .collect::<BTreeMap<_, _>>(); | ||||
| 
 | ||||
|             if !contains_all_ids(&result) { | ||||
|                 trace!("Signing key not loaded for {}", origin); | ||||
|                 servers.insert( | ||||
|                     origin.clone(), | ||||
|                     BTreeMap::<ServerSigningKeyId, QueryCriteria>::new(), | ||||
|                 ); | ||||
|             } | ||||
| 
 | ||||
|             pkm.insert(origin.to_string(), result); | ||||
|         } | ||||
|         pub_key_map.insert(origin.to_string(), result); | ||||
|     } | ||||
| 
 | ||||
|     Ok(()) | ||||
| } | ||||
| 
 | ||||
| pub async fn fetch_join_signing_keys( | ||||
| pub(crate) async fn fetch_join_signing_keys( | ||||
|     event: &create_join_event::v2::Response, | ||||
|     room_version: &RoomVersionId, | ||||
|     pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, String>>>, | ||||
|  | @ -3384,32 +3382,26 @@ pub async fn fetch_join_signing_keys( | |||
|     let mut servers = | ||||
|         BTreeMap::<Box<ServerName>, BTreeMap<ServerSigningKeyId, QueryCriteria>>::new(); | ||||
| 
 | ||||
|     get_missing_servers_for_pdus( | ||||
|         &event.room_state.state, | ||||
|         &mut servers, | ||||
|         &room_version, | ||||
|         &pub_key_map, | ||||
|         &db, | ||||
|     )?; | ||||
|     get_missing_servers_for_pdus( | ||||
|         &event.room_state.auth_chain, | ||||
|         &mut servers, | ||||
|         &room_version, | ||||
|         &pub_key_map, | ||||
|         &db, | ||||
|     )?; | ||||
|     { | ||||
|         let mut pkm = pub_key_map | ||||
|             .write() | ||||
|             .map_err(|_| Error::bad_database("RwLock is poisoned."))?; | ||||
| 
 | ||||
|     if servers.is_empty() { | ||||
|         return Ok(()); | ||||
|         // Try to fetch keys, failure is okay
 | ||||
|         // Servers we couldn't find in the cache will be added to `servers`
 | ||||
|         for pdu in &event.room_state.state { | ||||
|             let _ = get_server_keys_from_cache(pdu, &mut servers, &room_version, &mut pkm, &db); | ||||
|         } | ||||
|         for pdu in &event.room_state.auth_chain { | ||||
|             let _ = get_server_keys_from_cache(pdu, &mut servers, &room_version, &mut pkm, &db); | ||||
|         } | ||||
| 
 | ||||
|         drop(pkm); | ||||
|     } | ||||
| 
 | ||||
|     for server in db.globals.trusted_servers() { | ||||
|         if db.globals.signing_keys_for(server)?.is_empty() { | ||||
|             servers.insert( | ||||
|                 server.clone(), | ||||
|                 BTreeMap::<ServerSigningKeyId, QueryCriteria>::new(), | ||||
|             ); | ||||
|         } | ||||
|     if servers.is_empty() { | ||||
|         // We had all keys locally
 | ||||
|         return Ok(()); | ||||
|     } | ||||
| 
 | ||||
|     for server in db.globals.trusted_servers() { | ||||
|  | @ -3434,7 +3426,7 @@ pub async fn fetch_join_signing_keys( | |||
|                 .write() | ||||
|                 .map_err(|_| Error::bad_database("RwLock is poisoned."))?; | ||||
|             for k in keys.server_keys { | ||||
|                 // TODO: Check signature
 | ||||
|                 // TODO: Check signature from trusted server?
 | ||||
|                 servers.remove(&k.server_name); | ||||
| 
 | ||||
|                 let result = db | ||||
|  | @ -3447,23 +3439,33 @@ pub async fn fetch_join_signing_keys( | |||
|                 pkm.insert(k.server_name.to_string(), result); | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         if servers.is_empty() { | ||||
|             return Ok(()); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     for result in futures::future::join_all(servers.iter().map(|(server, _)| { | ||||
|         db.sending | ||||
|             .send_federation_request(&db.globals, server, get_server_keys::v2::Request::new()) | ||||
|     })) | ||||
|     .await | ||||
|     { | ||||
|         if let Ok(get_keys_response) = result { | ||||
|             // TODO: We should probably not trust the server_name in the response.
 | ||||
|             let server = &get_keys_response.server_key.server_name; | ||||
|     let mut futures = servers | ||||
|         .into_iter() | ||||
|         .map(|(server, _)| async move { | ||||
|             ( | ||||
|                 db.sending | ||||
|                     .send_federation_request( | ||||
|                         &db.globals, | ||||
|                         &server, | ||||
|                         get_server_keys::v2::Request::new(), | ||||
|                     ) | ||||
|                     .await, | ||||
|                 server, | ||||
|             ) | ||||
|         }) | ||||
|         .collect::<FuturesUnordered<_>>(); | ||||
| 
 | ||||
|     while let Some(result) = futures.next().await { | ||||
|         if let (Ok(get_keys_response), origin) = result { | ||||
|             let result = db | ||||
|                 .globals | ||||
|                 .add_signing_key(server, get_keys_response.server_key.clone())? | ||||
|                 .add_signing_key(&origin, get_keys_response.server_key.clone())? | ||||
|                 .into_iter() | ||||
|                 .map(|(k, v)| (k.to_string(), v.key)) | ||||
|                 .collect::<BTreeMap<_, _>>(); | ||||
|  | @ -3471,7 +3473,7 @@ pub async fn fetch_join_signing_keys( | |||
|             pub_key_map | ||||
|                 .write() | ||||
|                 .map_err(|_| Error::bad_database("RwLock is poisoned."))? | ||||
|                 .insert(server.to_string(), result); | ||||
|                 .insert(origin.to_string(), result); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in a new issue