Get required keys in batch when joining a room
We now ask the trusted server for all keys in 1 request, instead of asking each server individual for it's own keys.
This commit is contained in:
		
							parent
							
								
									9c3f1a9272
								
							
						
					
					
						commit
						a87519fb71
					
				
					 2 changed files with 221 additions and 21 deletions
				
			
		|  | @ -5,7 +5,6 @@ use crate::{ | ||||||
|     server_server, utils, ConduitResult, Database, Error, Result, Ruma, |     server_server, utils, ConduitResult, Database, Error, Result, Ruma, | ||||||
| }; | }; | ||||||
| use member::{MemberEventContent, MembershipState}; | use member::{MemberEventContent, MembershipState}; | ||||||
| use rocket::futures; |  | ||||||
| use ruma::{ | use ruma::{ | ||||||
|     api::{ |     api::{ | ||||||
|         client::{ |         client::{ | ||||||
|  | @ -667,14 +666,19 @@ async fn join_room_by_id_helper( | ||||||
|         let mut state = HashMap::new(); |         let mut state = HashMap::new(); | ||||||
|         let pub_key_map = RwLock::new(BTreeMap::new()); |         let pub_key_map = RwLock::new(BTreeMap::new()); | ||||||
| 
 | 
 | ||||||
|         for result in futures::future::join_all( |         server_server::fetch_join_signing_keys( | ||||||
|             send_join_response |             &send_join_response, | ||||||
|  |             &room_version, | ||||||
|  |             &pub_key_map, | ||||||
|  |             &db, | ||||||
|  |         ) | ||||||
|  |         .await?; | ||||||
|  | 
 | ||||||
|  |         for result in send_join_response | ||||||
|             .room_state |             .room_state | ||||||
|             .state |             .state | ||||||
|             .iter() |             .iter() | ||||||
|                 .map(|pdu| validate_and_add_event_id(pdu, &room_version, &pub_key_map, &db)), |             .map(|pdu| validate_and_add_event_id(pdu, &room_version, &pub_key_map, &db)) | ||||||
|         ) |  | ||||||
|         .await |  | ||||||
|         { |         { | ||||||
|             let (event_id, value) = match result { |             let (event_id, value) = match result { | ||||||
|                 Ok(t) => t, |                 Ok(t) => t, | ||||||
|  | @ -723,14 +727,11 @@ async fn join_room_by_id_helper( | ||||||
|             &db, |             &db, | ||||||
|         )?; |         )?; | ||||||
| 
 | 
 | ||||||
|         for result in futures::future::join_all( |         for result in send_join_response | ||||||
|             send_join_response |  | ||||||
|             .room_state |             .room_state | ||||||
|             .auth_chain |             .auth_chain | ||||||
|             .iter() |             .iter() | ||||||
|                 .map(|pdu| validate_and_add_event_id(pdu, &room_version, &pub_key_map, &db)), |             .map(|pdu| validate_and_add_event_id(pdu, &room_version, &pub_key_map, &db)) | ||||||
|         ) |  | ||||||
|         .await |  | ||||||
|         { |         { | ||||||
|             let (event_id, value) = match result { |             let (event_id, value) = match result { | ||||||
|                 Ok(t) => t, |                 Ok(t) => t, | ||||||
|  | @ -787,7 +788,7 @@ async fn join_room_by_id_helper( | ||||||
|     Ok(join_room_by_id::Response::new(room_id.clone()).into()) |     Ok(join_room_by_id::Response::new(room_id.clone()).into()) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| async fn validate_and_add_event_id( | fn validate_and_add_event_id( | ||||||
|     pdu: &Raw<Pdu>, |     pdu: &Raw<Pdu>, | ||||||
|     room_version: &RoomVersionId, |     room_version: &RoomVersionId, | ||||||
|     pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, String>>>, |     pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, String>>>, | ||||||
|  | @ -830,7 +831,6 @@ async fn validate_and_add_event_id( | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     server_server::fetch_required_signing_keys(&value, pub_key_map, db).await?; |  | ||||||
|     if let Err(e) = ruma::signatures::verify_event( |     if let Err(e) = ruma::signatures::verify_event( | ||||||
|         &*pub_key_map |         &*pub_key_map | ||||||
|             .read() |             .read() | ||||||
|  |  | ||||||
|  | @ -6,7 +6,7 @@ use crate::{ | ||||||
| use get_profile_information::v1::ProfileField; | use get_profile_information::v1::ProfileField; | ||||||
| use http::header::{HeaderValue, AUTHORIZATION}; | use http::header::{HeaderValue, AUTHORIZATION}; | ||||||
| use regex::Regex; | use regex::Regex; | ||||||
| use rocket::response::content::Json; | use rocket::{futures, response::content::Json}; | ||||||
| use ruma::{ | use ruma::{ | ||||||
|     api::{ |     api::{ | ||||||
|         client::error::{Error as RumaError, ErrorKind}, |         client::error::{Error as RumaError, ErrorKind}, | ||||||
|  | @ -15,8 +15,9 @@ use ruma::{ | ||||||
|             device::get_devices::{self, v1::UserDevice}, |             device::get_devices::{self, v1::UserDevice}, | ||||||
|             directory::{get_public_rooms, get_public_rooms_filtered}, |             directory::{get_public_rooms, get_public_rooms_filtered}, | ||||||
|             discovery::{ |             discovery::{ | ||||||
|                 get_remote_server_keys, get_server_keys, get_server_version, ServerSigningKeys, |                 get_remote_server_keys, get_remote_server_keys_batch, | ||||||
|                 VerifyKey, |                 get_remote_server_keys_batch::v2::QueryCriteria, get_server_keys, | ||||||
|  |                 get_server_version, ServerSigningKeys, VerifyKey, | ||||||
|             }, |             }, | ||||||
|             event::{get_event, get_missing_events, get_room_state, get_room_state_ids}, |             event::{get_event, get_missing_events, get_room_state, get_room_state_ids}, | ||||||
|             keys::{claim_keys, get_keys}, |             keys::{claim_keys, get_keys}, | ||||||
|  | @ -35,6 +36,7 @@ use ruma::{ | ||||||
|     }, |     }, | ||||||
|     directory::{IncomingFilter, IncomingRoomNetwork}, |     directory::{IncomingFilter, IncomingRoomNetwork}, | ||||||
|     events::{ |     events::{ | ||||||
|  |         pdu::Pdu, | ||||||
|         receipt::{ReceiptEvent, ReceiptEventContent}, |         receipt::{ReceiptEvent, ReceiptEventContent}, | ||||||
|         room::{ |         room::{ | ||||||
|             create::CreateEventContent, |             create::CreateEventContent, | ||||||
|  | @ -3277,6 +3279,204 @@ pub(crate) async fn fetch_required_signing_keys( | ||||||
|     Ok(()) |     Ok(()) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | pub fn get_missing_signing_keys_for_pdus( | ||||||
|  |     pdus: &Vec<Raw<Pdu>>, | ||||||
|  |     servers: &mut BTreeMap<Box<ServerName>, BTreeMap<ServerSigningKeyId, QueryCriteria>>, | ||||||
|  |     room_version: &RoomVersionId, | ||||||
|  |     pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, String>>>, | ||||||
|  |     db: &Database, | ||||||
|  | ) -> Result<()> { | ||||||
|  |     for pdu in pdus { | ||||||
|  |         let value = serde_json::from_str::<CanonicalJsonObject>(pdu.json().get()).map_err(|e| { | ||||||
|  |             error!("Invalid PDU in server response: {:?}: {:?}", pdu, e); | ||||||
|  |             Error::BadServerResponse("Invalid PDU in server response") | ||||||
|  |         })?; | ||||||
|  |         let event_id = EventId::try_from(&*format!( | ||||||
|  |             "${}", | ||||||
|  |             ruma::signatures::reference_hash(&value, &room_version) | ||||||
|  |                 .expect("ruma can calculate reference hashes") | ||||||
|  |         )) | ||||||
|  |         .expect("ruma's reference hashes are valid event ids"); | ||||||
|  | 
 | ||||||
|  |         if let Some((time, tries)) = db | ||||||
|  |             .globals | ||||||
|  |             .bad_event_ratelimiter | ||||||
|  |             .read() | ||||||
|  |             .unwrap() | ||||||
|  |             .get(&event_id) | ||||||
|  |         { | ||||||
|  |             // Exponential backoff
 | ||||||
|  |             let mut min_elapsed_duration = Duration::from_secs(30) * (*tries) * (*tries); | ||||||
|  |             if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { | ||||||
|  |                 min_elapsed_duration = Duration::from_secs(60 * 60 * 24); | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             if time.elapsed() < min_elapsed_duration { | ||||||
|  |                 debug!("Backing off from {}", event_id); | ||||||
|  |                 return Err(Error::BadServerResponse("bad event, still backing off")); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         let signatures = value | ||||||
|  |             .get("signatures") | ||||||
|  |             .ok_or(Error::BadServerResponse( | ||||||
|  |                 "No signatures in server response pdu.", | ||||||
|  |             ))? | ||||||
|  |             .as_object() | ||||||
|  |             .ok_or(Error::BadServerResponse( | ||||||
|  |                 "Invalid signatures object in server response pdu.", | ||||||
|  |             ))?; | ||||||
|  | 
 | ||||||
|  |         for (signature_server, signature) in signatures { | ||||||
|  |             let signature_object = signature.as_object().ok_or(Error::BadServerResponse( | ||||||
|  |                 "Invalid signatures content object in server response pdu.", | ||||||
|  |             ))?; | ||||||
|  | 
 | ||||||
|  |             let signature_ids = signature_object.keys().cloned().collect::<Vec<_>>(); | ||||||
|  | 
 | ||||||
|  |             let contains_all_ids = |keys: &BTreeMap<String, String>| { | ||||||
|  |                 signature_ids.iter().all(|id| keys.contains_key(id)) | ||||||
|  |             }; | ||||||
|  | 
 | ||||||
|  |             let origin = &Box::<ServerName>::try_from(&**signature_server).map_err(|_| { | ||||||
|  |                 Error::BadServerResponse("Invalid servername in signatures of server response pdu.") | ||||||
|  |             })?; | ||||||
|  | 
 | ||||||
|  |             trace!("Loading signing keys for {}", origin); | ||||||
|  | 
 | ||||||
|  |             let result = db | ||||||
|  |                 .globals | ||||||
|  |                 .signing_keys_for(origin)? | ||||||
|  |                 .into_iter() | ||||||
|  |                 .map(|(k, v)| (k.to_string(), v.key)) | ||||||
|  |                 .collect::<BTreeMap<_, _>>(); | ||||||
|  | 
 | ||||||
|  |             if !contains_all_ids(&result) { | ||||||
|  |                 trace!("Signing key not loaded for {}", origin); | ||||||
|  |                 servers.insert( | ||||||
|  |                     origin.clone(), | ||||||
|  |                     BTreeMap::<ServerSigningKeyId, QueryCriteria>::new(), | ||||||
|  |                 ); | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             pub_key_map | ||||||
|  |                 .write() | ||||||
|  |                 .map_err(|_| Error::bad_database("RwLock is poisoned."))? | ||||||
|  |                 .insert(origin.to_string(), result); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     Ok(()) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | pub async fn fetch_join_signing_keys( | ||||||
|  |     event: &create_join_event::v2::Response, | ||||||
|  |     room_version: &RoomVersionId, | ||||||
|  |     pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, String>>>, | ||||||
|  |     db: &Database, | ||||||
|  | ) -> Result<()> { | ||||||
|  |     let mut servers = | ||||||
|  |         BTreeMap::<Box<ServerName>, BTreeMap<ServerSigningKeyId, QueryCriteria>>::new(); | ||||||
|  | 
 | ||||||
|  |     get_missing_signing_keys_for_pdus( | ||||||
|  |         &event.room_state.state, | ||||||
|  |         &mut servers, | ||||||
|  |         &room_version, | ||||||
|  |         &pub_key_map, | ||||||
|  |         &db, | ||||||
|  |     )?; | ||||||
|  |     get_missing_signing_keys_for_pdus( | ||||||
|  |         &event.room_state.auth_chain, | ||||||
|  |         &mut servers, | ||||||
|  |         &room_version, | ||||||
|  |         &pub_key_map, | ||||||
|  |         &db, | ||||||
|  |     )?; | ||||||
|  | 
 | ||||||
|  |     if servers.is_empty() { | ||||||
|  |         return Ok(()); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     for server in db.globals.trusted_servers() { | ||||||
|  |         if db.globals.signing_keys_for(server)?.is_empty() { | ||||||
|  |             servers.insert( | ||||||
|  |                 server.clone(), | ||||||
|  |                 BTreeMap::<ServerSigningKeyId, QueryCriteria>::new(), | ||||||
|  |             ); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     for server in db.globals.trusted_servers() { | ||||||
|  |         trace!("Asking batch signing keys from trusted server {}", server); | ||||||
|  |         if let Ok(keys) = db | ||||||
|  |             .sending | ||||||
|  |             .send_federation_request( | ||||||
|  |                 &db.globals, | ||||||
|  |                 server, | ||||||
|  |                 get_remote_server_keys_batch::v2::Request { | ||||||
|  |                     server_keys: servers.clone(), | ||||||
|  |                     minimum_valid_until_ts: MilliSecondsSinceUnixEpoch::from_system_time( | ||||||
|  |                         SystemTime::now() + Duration::from_secs(60), | ||||||
|  |                     ) | ||||||
|  |                     .expect("time is valid"), | ||||||
|  |                 }, | ||||||
|  |             ) | ||||||
|  |             .await | ||||||
|  |         { | ||||||
|  |             trace!("Got signing keys: {:?}", keys); | ||||||
|  |             for k in keys.server_keys { | ||||||
|  |                 // TODO: Check signature
 | ||||||
|  |                 servers.remove(&k.server_name); | ||||||
|  | 
 | ||||||
|  |                 db.globals.add_signing_key(&k.server_name, k.clone())?; | ||||||
|  | 
 | ||||||
|  |                 let result = db | ||||||
|  |                     .globals | ||||||
|  |                     .signing_keys_for(&k.server_name)? | ||||||
|  |                     .into_iter() | ||||||
|  |                     .map(|(k, v)| (k.to_string(), v.key)) | ||||||
|  |                     .collect::<BTreeMap<_, _>>(); | ||||||
|  | 
 | ||||||
|  |                 pub_key_map | ||||||
|  |                     .write() | ||||||
|  |                     .map_err(|_| Error::bad_database("RwLock is poisoned."))? | ||||||
|  |                     .insert(k.server_name.to_string(), result); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |         if servers.is_empty() { | ||||||
|  |             return Ok(()); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     for result in futures::future::join_all(servers.iter().map(|(server, _)| { | ||||||
|  |         db.sending | ||||||
|  |             .send_federation_request(&db.globals, server, get_server_keys::v2::Request::new()) | ||||||
|  |     })) | ||||||
|  |     .await | ||||||
|  |     { | ||||||
|  |         if let Ok(get_keys_response) = result { | ||||||
|  |             // TODO: We should probably not trust the server_name in the response.
 | ||||||
|  |             let server = &get_keys_response.server_key.server_name; | ||||||
|  |             db.globals | ||||||
|  |                 .add_signing_key(server, get_keys_response.server_key.clone())?; | ||||||
|  | 
 | ||||||
|  |             let result = db | ||||||
|  |                 .globals | ||||||
|  |                 .signing_keys_for(server)? | ||||||
|  |                 .into_iter() | ||||||
|  |                 .map(|(k, v)| (k.to_string(), v.key)) | ||||||
|  |                 .collect::<BTreeMap<_, _>>(); | ||||||
|  | 
 | ||||||
|  |             pub_key_map | ||||||
|  |                 .write() | ||||||
|  |                 .map_err(|_| Error::bad_database("RwLock is poisoned."))? | ||||||
|  |                 .insert(server.to_string(), result); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     Ok(()) | ||||||
|  | } | ||||||
|  | 
 | ||||||
| #[cfg(test)] | #[cfg(test)] | ||||||
| mod tests { | mod tests { | ||||||
|     use super::{add_port_to_hostname, get_ip_with_port, FedDest}; |     use super::{add_port_to_hostname, get_ip_with_port, FedDest}; | ||||||
|  |  | ||||||
		Loading…
	
		Reference in a new issue