Get required keys in batch when joining a room
We now ask the trusted server for all keys in 1 request, instead of asking each server individual for it's own keys.
This commit is contained in:
		
							parent
							
								
									9c3f1a9272
								
							
						
					
					
						commit
						a87519fb71
					
				
					 2 changed files with 221 additions and 21 deletions
				
			
		|  | @ -5,7 +5,6 @@ use crate::{ | |||
|     server_server, utils, ConduitResult, Database, Error, Result, Ruma, | ||||
| }; | ||||
| use member::{MemberEventContent, MembershipState}; | ||||
| use rocket::futures; | ||||
| use ruma::{ | ||||
|     api::{ | ||||
|         client::{ | ||||
|  | @ -667,14 +666,19 @@ async fn join_room_by_id_helper( | |||
|         let mut state = HashMap::new(); | ||||
|         let pub_key_map = RwLock::new(BTreeMap::new()); | ||||
| 
 | ||||
|         for result in futures::future::join_all( | ||||
|             send_join_response | ||||
|         server_server::fetch_join_signing_keys( | ||||
|             &send_join_response, | ||||
|             &room_version, | ||||
|             &pub_key_map, | ||||
|             &db, | ||||
|         ) | ||||
|         .await?; | ||||
| 
 | ||||
|         for result in send_join_response | ||||
|             .room_state | ||||
|             .state | ||||
|             .iter() | ||||
|                 .map(|pdu| validate_and_add_event_id(pdu, &room_version, &pub_key_map, &db)), | ||||
|         ) | ||||
|         .await | ||||
|             .map(|pdu| validate_and_add_event_id(pdu, &room_version, &pub_key_map, &db)) | ||||
|         { | ||||
|             let (event_id, value) = match result { | ||||
|                 Ok(t) => t, | ||||
|  | @ -723,14 +727,11 @@ async fn join_room_by_id_helper( | |||
|             &db, | ||||
|         )?; | ||||
| 
 | ||||
|         for result in futures::future::join_all( | ||||
|             send_join_response | ||||
|         for result in send_join_response | ||||
|             .room_state | ||||
|             .auth_chain | ||||
|             .iter() | ||||
|                 .map(|pdu| validate_and_add_event_id(pdu, &room_version, &pub_key_map, &db)), | ||||
|         ) | ||||
|         .await | ||||
|             .map(|pdu| validate_and_add_event_id(pdu, &room_version, &pub_key_map, &db)) | ||||
|         { | ||||
|             let (event_id, value) = match result { | ||||
|                 Ok(t) => t, | ||||
|  | @ -787,7 +788,7 @@ async fn join_room_by_id_helper( | |||
|     Ok(join_room_by_id::Response::new(room_id.clone()).into()) | ||||
| } | ||||
| 
 | ||||
| async fn validate_and_add_event_id( | ||||
| fn validate_and_add_event_id( | ||||
|     pdu: &Raw<Pdu>, | ||||
|     room_version: &RoomVersionId, | ||||
|     pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, String>>>, | ||||
|  | @ -830,7 +831,6 @@ async fn validate_and_add_event_id( | |||
|         } | ||||
|     } | ||||
| 
 | ||||
|     server_server::fetch_required_signing_keys(&value, pub_key_map, db).await?; | ||||
|     if let Err(e) = ruma::signatures::verify_event( | ||||
|         &*pub_key_map | ||||
|             .read() | ||||
|  |  | |||
|  | @ -6,7 +6,7 @@ use crate::{ | |||
| use get_profile_information::v1::ProfileField; | ||||
| use http::header::{HeaderValue, AUTHORIZATION}; | ||||
| use regex::Regex; | ||||
| use rocket::response::content::Json; | ||||
| use rocket::{futures, response::content::Json}; | ||||
| use ruma::{ | ||||
|     api::{ | ||||
|         client::error::{Error as RumaError, ErrorKind}, | ||||
|  | @ -15,8 +15,9 @@ use ruma::{ | |||
|             device::get_devices::{self, v1::UserDevice}, | ||||
|             directory::{get_public_rooms, get_public_rooms_filtered}, | ||||
|             discovery::{ | ||||
|                 get_remote_server_keys, get_server_keys, get_server_version, ServerSigningKeys, | ||||
|                 VerifyKey, | ||||
|                 get_remote_server_keys, get_remote_server_keys_batch, | ||||
|                 get_remote_server_keys_batch::v2::QueryCriteria, get_server_keys, | ||||
|                 get_server_version, ServerSigningKeys, VerifyKey, | ||||
|             }, | ||||
|             event::{get_event, get_missing_events, get_room_state, get_room_state_ids}, | ||||
|             keys::{claim_keys, get_keys}, | ||||
|  | @ -35,6 +36,7 @@ use ruma::{ | |||
|     }, | ||||
|     directory::{IncomingFilter, IncomingRoomNetwork}, | ||||
|     events::{ | ||||
|         pdu::Pdu, | ||||
|         receipt::{ReceiptEvent, ReceiptEventContent}, | ||||
|         room::{ | ||||
|             create::CreateEventContent, | ||||
|  | @ -3277,6 +3279,204 @@ pub(crate) async fn fetch_required_signing_keys( | |||
|     Ok(()) | ||||
| } | ||||
| 
 | ||||
| pub fn get_missing_signing_keys_for_pdus( | ||||
|     pdus: &Vec<Raw<Pdu>>, | ||||
|     servers: &mut BTreeMap<Box<ServerName>, BTreeMap<ServerSigningKeyId, QueryCriteria>>, | ||||
|     room_version: &RoomVersionId, | ||||
|     pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, String>>>, | ||||
|     db: &Database, | ||||
| ) -> Result<()> { | ||||
|     for pdu in pdus { | ||||
|         let value = serde_json::from_str::<CanonicalJsonObject>(pdu.json().get()).map_err(|e| { | ||||
|             error!("Invalid PDU in server response: {:?}: {:?}", pdu, e); | ||||
|             Error::BadServerResponse("Invalid PDU in server response") | ||||
|         })?; | ||||
|         let event_id = EventId::try_from(&*format!( | ||||
|             "${}", | ||||
|             ruma::signatures::reference_hash(&value, &room_version) | ||||
|                 .expect("ruma can calculate reference hashes") | ||||
|         )) | ||||
|         .expect("ruma's reference hashes are valid event ids"); | ||||
| 
 | ||||
|         if let Some((time, tries)) = db | ||||
|             .globals | ||||
|             .bad_event_ratelimiter | ||||
|             .read() | ||||
|             .unwrap() | ||||
|             .get(&event_id) | ||||
|         { | ||||
|             // Exponential backoff
 | ||||
|             let mut min_elapsed_duration = Duration::from_secs(30) * (*tries) * (*tries); | ||||
|             if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { | ||||
|                 min_elapsed_duration = Duration::from_secs(60 * 60 * 24); | ||||
|             } | ||||
| 
 | ||||
|             if time.elapsed() < min_elapsed_duration { | ||||
|                 debug!("Backing off from {}", event_id); | ||||
|                 return Err(Error::BadServerResponse("bad event, still backing off")); | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         let signatures = value | ||||
|             .get("signatures") | ||||
|             .ok_or(Error::BadServerResponse( | ||||
|                 "No signatures in server response pdu.", | ||||
|             ))? | ||||
|             .as_object() | ||||
|             .ok_or(Error::BadServerResponse( | ||||
|                 "Invalid signatures object in server response pdu.", | ||||
|             ))?; | ||||
| 
 | ||||
|         for (signature_server, signature) in signatures { | ||||
|             let signature_object = signature.as_object().ok_or(Error::BadServerResponse( | ||||
|                 "Invalid signatures content object in server response pdu.", | ||||
|             ))?; | ||||
| 
 | ||||
|             let signature_ids = signature_object.keys().cloned().collect::<Vec<_>>(); | ||||
| 
 | ||||
|             let contains_all_ids = |keys: &BTreeMap<String, String>| { | ||||
|                 signature_ids.iter().all(|id| keys.contains_key(id)) | ||||
|             }; | ||||
| 
 | ||||
|             let origin = &Box::<ServerName>::try_from(&**signature_server).map_err(|_| { | ||||
|                 Error::BadServerResponse("Invalid servername in signatures of server response pdu.") | ||||
|             })?; | ||||
| 
 | ||||
|             trace!("Loading signing keys for {}", origin); | ||||
| 
 | ||||
|             let result = db | ||||
|                 .globals | ||||
|                 .signing_keys_for(origin)? | ||||
|                 .into_iter() | ||||
|                 .map(|(k, v)| (k.to_string(), v.key)) | ||||
|                 .collect::<BTreeMap<_, _>>(); | ||||
| 
 | ||||
|             if !contains_all_ids(&result) { | ||||
|                 trace!("Signing key not loaded for {}", origin); | ||||
|                 servers.insert( | ||||
|                     origin.clone(), | ||||
|                     BTreeMap::<ServerSigningKeyId, QueryCriteria>::new(), | ||||
|                 ); | ||||
|             } | ||||
| 
 | ||||
|             pub_key_map | ||||
|                 .write() | ||||
|                 .map_err(|_| Error::bad_database("RwLock is poisoned."))? | ||||
|                 .insert(origin.to_string(), result); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     Ok(()) | ||||
| } | ||||
| 
 | ||||
| pub async fn fetch_join_signing_keys( | ||||
|     event: &create_join_event::v2::Response, | ||||
|     room_version: &RoomVersionId, | ||||
|     pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, String>>>, | ||||
|     db: &Database, | ||||
| ) -> Result<()> { | ||||
|     let mut servers = | ||||
|         BTreeMap::<Box<ServerName>, BTreeMap<ServerSigningKeyId, QueryCriteria>>::new(); | ||||
| 
 | ||||
|     get_missing_signing_keys_for_pdus( | ||||
|         &event.room_state.state, | ||||
|         &mut servers, | ||||
|         &room_version, | ||||
|         &pub_key_map, | ||||
|         &db, | ||||
|     )?; | ||||
|     get_missing_signing_keys_for_pdus( | ||||
|         &event.room_state.auth_chain, | ||||
|         &mut servers, | ||||
|         &room_version, | ||||
|         &pub_key_map, | ||||
|         &db, | ||||
|     )?; | ||||
| 
 | ||||
|     if servers.is_empty() { | ||||
|         return Ok(()); | ||||
|     } | ||||
| 
 | ||||
|     for server in db.globals.trusted_servers() { | ||||
|         if db.globals.signing_keys_for(server)?.is_empty() { | ||||
|             servers.insert( | ||||
|                 server.clone(), | ||||
|                 BTreeMap::<ServerSigningKeyId, QueryCriteria>::new(), | ||||
|             ); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     for server in db.globals.trusted_servers() { | ||||
|         trace!("Asking batch signing keys from trusted server {}", server); | ||||
|         if let Ok(keys) = db | ||||
|             .sending | ||||
|             .send_federation_request( | ||||
|                 &db.globals, | ||||
|                 server, | ||||
|                 get_remote_server_keys_batch::v2::Request { | ||||
|                     server_keys: servers.clone(), | ||||
|                     minimum_valid_until_ts: MilliSecondsSinceUnixEpoch::from_system_time( | ||||
|                         SystemTime::now() + Duration::from_secs(60), | ||||
|                     ) | ||||
|                     .expect("time is valid"), | ||||
|                 }, | ||||
|             ) | ||||
|             .await | ||||
|         { | ||||
|             trace!("Got signing keys: {:?}", keys); | ||||
|             for k in keys.server_keys { | ||||
|                 // TODO: Check signature
 | ||||
|                 servers.remove(&k.server_name); | ||||
| 
 | ||||
|                 db.globals.add_signing_key(&k.server_name, k.clone())?; | ||||
| 
 | ||||
|                 let result = db | ||||
|                     .globals | ||||
|                     .signing_keys_for(&k.server_name)? | ||||
|                     .into_iter() | ||||
|                     .map(|(k, v)| (k.to_string(), v.key)) | ||||
|                     .collect::<BTreeMap<_, _>>(); | ||||
| 
 | ||||
|                 pub_key_map | ||||
|                     .write() | ||||
|                     .map_err(|_| Error::bad_database("RwLock is poisoned."))? | ||||
|                     .insert(k.server_name.to_string(), result); | ||||
|             } | ||||
|         } | ||||
|         if servers.is_empty() { | ||||
|             return Ok(()); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     for result in futures::future::join_all(servers.iter().map(|(server, _)| { | ||||
|         db.sending | ||||
|             .send_federation_request(&db.globals, server, get_server_keys::v2::Request::new()) | ||||
|     })) | ||||
|     .await | ||||
|     { | ||||
|         if let Ok(get_keys_response) = result { | ||||
|             // TODO: We should probably not trust the server_name in the response.
 | ||||
|             let server = &get_keys_response.server_key.server_name; | ||||
|             db.globals | ||||
|                 .add_signing_key(server, get_keys_response.server_key.clone())?; | ||||
| 
 | ||||
|             let result = db | ||||
|                 .globals | ||||
|                 .signing_keys_for(server)? | ||||
|                 .into_iter() | ||||
|                 .map(|(k, v)| (k.to_string(), v.key)) | ||||
|                 .collect::<BTreeMap<_, _>>(); | ||||
| 
 | ||||
|             pub_key_map | ||||
|                 .write() | ||||
|                 .map_err(|_| Error::bad_database("RwLock is poisoned."))? | ||||
|                 .insert(server.to_string(), result); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     Ok(()) | ||||
| } | ||||
| 
 | ||||
| #[cfg(test)] | ||||
| mod tests { | ||||
|     use super::{add_port_to_hostname, get_ip_with_port, FedDest}; | ||||
|  |  | |||
		Loading…
	
		Reference in a new issue