improvement: efficient /sync, mutex for federation transactions
This commit is contained in:
		
							parent
							
								
									1c25492a7e
								
							
						
					
					
						commit
						e15e6d4405
					
				
					 5 changed files with 359 additions and 263 deletions
				
			
		|  | @ -902,8 +902,25 @@ pub async fn invite_helper( | |||
|         ) | ||||
|         .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Origin field is invalid."))?; | ||||
| 
 | ||||
|         let pdu_id = | ||||
|             server_server::handle_incoming_pdu(&origin, &event_id, value, true, &db, &pub_key_map) | ||||
|         let mutex = Arc::clone( | ||||
|             db.globals | ||||
|                 .roomid_mutex | ||||
|                 .write() | ||||
|                 .unwrap() | ||||
|                 .entry(room_id.clone()) | ||||
|                 .or_default(), | ||||
|         ); | ||||
|         let mutex_lock = mutex.lock().await; | ||||
| 
 | ||||
|         let pdu_id = server_server::handle_incoming_pdu( | ||||
|             &origin, | ||||
|             &event_id, | ||||
|             &room_id, | ||||
|             value, | ||||
|             true, | ||||
|             &db, | ||||
|             &pub_key_map, | ||||
|         ) | ||||
|         .await | ||||
|         .map_err(|_| { | ||||
|             Error::BadRequest( | ||||
|  | @ -915,6 +932,7 @@ pub async fn invite_helper( | |||
|             ErrorKind::InvalidParam, | ||||
|             "Could not accept incoming PDU as timeline event.", | ||||
|         ))?; | ||||
|         drop(mutex_lock); | ||||
| 
 | ||||
|         for server in db | ||||
|             .rooms | ||||
|  |  | |||
|  | @ -227,13 +227,16 @@ async fn sync_helper( | |||
| 
 | ||||
|         // Database queries:
 | ||||
| 
 | ||||
|         let current_shortstatehash = db.rooms.current_shortstatehash(&room_id)?; | ||||
|         let current_shortstatehash = db | ||||
|             .rooms | ||||
|             .current_shortstatehash(&room_id)? | ||||
|             .expect("All rooms have state"); | ||||
| 
 | ||||
|         // These type is Option<Option<_>>. The outer Option is None when there is no event between
 | ||||
|         // since and the current room state, meaning there should be no updates.
 | ||||
|         // The inner Option is None when there is an event, but there is no state hash associated
 | ||||
|         // with it. This can happen for the RoomCreate event, so all updates should arrive.
 | ||||
|         let first_pdu_before_since = db.rooms.pdus_until(&sender_user, &room_id, since).next(); | ||||
|         let first_pdu_before_since = db | ||||
|             .rooms | ||||
|             .pdus_until(&sender_user, &room_id, since) | ||||
|             .next() | ||||
|             .transpose()?; | ||||
| 
 | ||||
|         let pdus_after_since = db | ||||
|             .rooms | ||||
|  | @ -241,152 +244,18 @@ async fn sync_helper( | |||
|             .next() | ||||
|             .is_some(); | ||||
| 
 | ||||
|         let since_shortstatehash = first_pdu_before_since.as_ref().map(|pdu| { | ||||
|             db.rooms | ||||
|                 .pdu_shortstatehash(&pdu.as_ref().ok()?.1.event_id) | ||||
|                 .ok()? | ||||
|         }); | ||||
| 
 | ||||
|         let ( | ||||
|             heroes, | ||||
|             joined_member_count, | ||||
|             invited_member_count, | ||||
|             joined_since_last_sync, | ||||
|             state_events, | ||||
|         ) = if pdus_after_since && Some(current_shortstatehash) != since_shortstatehash { | ||||
|             let current_state = db.rooms.room_state_full(&room_id)?; | ||||
|             let current_members = current_state | ||||
|                 .iter() | ||||
|                 .filter(|(key, _)| key.0 == EventType::RoomMember) | ||||
|                 .map(|(key, value)| (&key.1, value)) // Only keep state key
 | ||||
|                 .collect::<Vec<_>>(); | ||||
|             let encrypted_room = current_state | ||||
|                 .get(&(EventType::RoomEncryption, "".to_owned())) | ||||
|                 .is_some(); | ||||
|             let since_state = since_shortstatehash | ||||
|         let since_shortstatehash = first_pdu_before_since | ||||
|             .as_ref() | ||||
|                 .map(|since_shortstatehash| { | ||||
|                     since_shortstatehash | ||||
|                         .map(|since_shortstatehash| db.rooms.state_full(since_shortstatehash)) | ||||
|             .map(|pdu| { | ||||
|                 db.rooms | ||||
|                     .pdu_shortstatehash(&pdu.1.event_id) | ||||
|                     .transpose() | ||||
|                     .expect("all pdus have state") | ||||
|             }) | ||||
|             .transpose()?; | ||||
| 
 | ||||
|             let since_encryption = since_state.as_ref().map(|state| { | ||||
|                 state | ||||
|                     .as_ref() | ||||
|                     .map(|state| state.get(&(EventType::RoomEncryption, "".to_owned()))) | ||||
|             }); | ||||
| 
 | ||||
|             // Calculations:
 | ||||
|             let new_encrypted_room = | ||||
|                 encrypted_room && since_encryption.map_or(true, |encryption| encryption.is_none()); | ||||
| 
 | ||||
|             let send_member_count = since_state.as_ref().map_or(true, |since_state| { | ||||
|                 since_state.as_ref().map_or(true, |since_state| { | ||||
|                     current_members.len() | ||||
|                         != since_state | ||||
|                             .iter() | ||||
|                             .filter(|(key, _)| key.0 == EventType::RoomMember) | ||||
|                             .count() | ||||
|                 }) | ||||
|             }); | ||||
| 
 | ||||
|             let since_sender_member = since_state.as_ref().map(|since_state| { | ||||
|                 since_state.as_ref().and_then(|state| { | ||||
|                     state | ||||
|                         .get(&(EventType::RoomMember, sender_user.as_str().to_owned())) | ||||
|                         .and_then(|pdu| { | ||||
|                             serde_json::from_value::< | ||||
|                                 Raw<ruma::events::room::member::MemberEventContent>, | ||||
|                             >(pdu.content.clone()) | ||||
|                             .expect("Raw::from_value always works") | ||||
|                             .deserialize() | ||||
|                             .map_err(|_| Error::bad_database("Invalid PDU in database.")) | ||||
|                             .ok() | ||||
|                         }) | ||||
|                 }) | ||||
|             }); | ||||
| 
 | ||||
|             if encrypted_room { | ||||
|                 for (user_id, current_member) in current_members { | ||||
|                     let current_membership = serde_json::from_value::< | ||||
|                         Raw<ruma::events::room::member::MemberEventContent>, | ||||
|                     >(current_member.content.clone()) | ||||
|                     .expect("Raw::from_value always works") | ||||
|                     .deserialize() | ||||
|                     .map_err(|_| Error::bad_database("Invalid PDU in database."))? | ||||
|                     .membership; | ||||
| 
 | ||||
|                     let since_membership = | ||||
|                         since_state | ||||
|                             .as_ref() | ||||
|                             .map_or(MembershipState::Leave, |since_state| { | ||||
|                                 since_state | ||||
|                                     .as_ref() | ||||
|                                     .and_then(|since_state| { | ||||
|                                         since_state | ||||
|                                             .get(&(EventType::RoomMember, user_id.clone())) | ||||
|                                             .and_then(|since_member| { | ||||
|                                                 serde_json::from_value::< | ||||
|                                                 Raw<ruma::events::room::member::MemberEventContent>, | ||||
|                                             >( | ||||
|                                                 since_member.content.clone() | ||||
|                                             ) | ||||
|                                             .expect("Raw::from_value always works") | ||||
|                                             .deserialize() | ||||
|                                             .map_err(|_| { | ||||
|                                                 Error::bad_database("Invalid PDU in database.") | ||||
|                                             }) | ||||
|                                             .ok() | ||||
|                                             }) | ||||
|                                     }) | ||||
|                                     .map_or(MembershipState::Leave, |member| member.membership) | ||||
|                             }); | ||||
| 
 | ||||
|                     let user_id = UserId::try_from(user_id.clone()) | ||||
|                         .map_err(|_| Error::bad_database("Invalid UserId in member PDU."))?; | ||||
| 
 | ||||
|                     match (since_membership, current_membership) { | ||||
|                         (MembershipState::Leave, MembershipState::Join) => { | ||||
|                             // A new user joined an encrypted room
 | ||||
|                             if !share_encrypted_room(&db, &sender_user, &user_id, &room_id)? { | ||||
|                                 device_list_updates.insert(user_id); | ||||
|                             } | ||||
|                         } | ||||
|                         // TODO: Remove, this should never happen here, right?
 | ||||
|                         (MembershipState::Join, MembershipState::Leave) => { | ||||
|                             // Write down users that have left encrypted rooms we are in
 | ||||
|                             left_encrypted_users.insert(user_id); | ||||
|                         } | ||||
|                         _ => {} | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
| 
 | ||||
|             let joined_since_last_sync = since_sender_member.map_or(true, |member| { | ||||
|                 member.map_or(true, |member| member.membership != MembershipState::Join) | ||||
|             }); | ||||
| 
 | ||||
|             if joined_since_last_sync && encrypted_room || new_encrypted_room { | ||||
|                 // If the user is in a new encrypted room, give them all joined users
 | ||||
|                 device_list_updates.extend( | ||||
|                     db.rooms | ||||
|                         .room_members(&room_id) | ||||
|                         .flatten() | ||||
|                         .filter(|user_id| { | ||||
|                             // Don't send key updates from the sender to the sender
 | ||||
|                             &sender_user != user_id | ||||
|                         }) | ||||
|                         .filter(|user_id| { | ||||
|                             // Only send keys if the sender doesn't share an encrypted room with the target already
 | ||||
|                             !share_encrypted_room(&db, &sender_user, user_id, &room_id) | ||||
|                                 .unwrap_or(false) | ||||
|                         }), | ||||
|                 ); | ||||
|             } | ||||
| 
 | ||||
|             let (joined_member_count, invited_member_count, heroes) = if send_member_count { | ||||
|         // Calculates joined_member_count, invited_member_count and heroes
 | ||||
|         let calculate_counts = || { | ||||
|             let joined_member_count = db.rooms.room_members(&room_id).count(); | ||||
|             let invited_member_count = db.rooms.room_members_invited(&room_id).count(); | ||||
| 
 | ||||
|  | @ -406,13 +275,10 @@ async fn sync_helper( | |||
|                         let content = serde_json::from_value::< | ||||
|                             ruma::events::room::member::MemberEventContent, | ||||
|                         >(pdu.content.clone()) | ||||
|                             .map_err(|_| { | ||||
|                                 Error::bad_database("Invalid member event in database.") | ||||
|                             })?; | ||||
|                         .map_err(|_| Error::bad_database("Invalid member event in database."))?; | ||||
| 
 | ||||
|                         if let Some(state_key) = &pdu.state_key { | ||||
|                                 let user_id = | ||||
|                                     UserId::try_from(state_key.clone()).map_err(|_| { | ||||
|                             let user_id = UserId::try_from(state_key.clone()).map_err(|_| { | ||||
|                                 Error::bad_database("Invalid UserId in member PDU.") | ||||
|                             })?; | ||||
| 
 | ||||
|  | @ -449,36 +315,183 @@ async fn sync_helper( | |||
|                 Some(invited_member_count), | ||||
|                 heroes, | ||||
|             ) | ||||
|             } else { | ||||
|                 (None, None, Vec::new()) | ||||
|         }; | ||||
| 
 | ||||
|             let state_events = if joined_since_last_sync { | ||||
|                 current_state | ||||
|         let ( | ||||
|             heroes, | ||||
|             joined_member_count, | ||||
|             invited_member_count, | ||||
|             joined_since_last_sync, | ||||
|             state_events, | ||||
|         ) = if since_shortstatehash.is_none() { | ||||
|             // Probably since = 0, we will do an initial sync
 | ||||
|             let (joined_member_count, invited_member_count, heroes) = calculate_counts(); | ||||
| 
 | ||||
|             let current_state_ids = db.rooms.state_full_ids(current_shortstatehash)?; | ||||
|             let state_events = current_state_ids | ||||
|                 .iter() | ||||
|                     .map(|(_, pdu)| pdu.to_sync_state_event()) | ||||
|                     .collect() | ||||
|                 .map(|id| db.rooms.get_pdu(id)) | ||||
|                 .filter_map(|r| r.ok().flatten()) | ||||
|                 .collect::<Vec<_>>(); | ||||
| 
 | ||||
|             ( | ||||
|                 heroes, | ||||
|                 joined_member_count, | ||||
|                 invited_member_count, | ||||
|                 true, | ||||
|                 state_events, | ||||
|             ) | ||||
|         } else if !pdus_after_since || since_shortstatehash == Some(current_shortstatehash) { | ||||
|             // No state changes
 | ||||
|             (Vec::new(), None, None, false, Vec::new()) | ||||
|         } else { | ||||
|                 match since_state { | ||||
|                     None => Vec::new(), | ||||
|                     Some(Some(since_state)) => current_state | ||||
|             // Incremental /sync
 | ||||
|             let since_shortstatehash = since_shortstatehash.unwrap(); | ||||
| 
 | ||||
|             let since_sender_member = db | ||||
|                 .rooms | ||||
|                 .state_get( | ||||
|                     since_shortstatehash, | ||||
|                     &EventType::RoomMember, | ||||
|                     sender_user.as_str(), | ||||
|                 )? | ||||
|                 .and_then(|pdu| { | ||||
|                     serde_json::from_value::<Raw<ruma::events::room::member::MemberEventContent>>( | ||||
|                         pdu.content.clone(), | ||||
|                     ) | ||||
|                     .expect("Raw::from_value always works") | ||||
|                     .deserialize() | ||||
|                     .map_err(|_| Error::bad_database("Invalid PDU in database.")) | ||||
|                     .ok() | ||||
|                 }); | ||||
| 
 | ||||
|             let joined_since_last_sync = since_sender_member | ||||
|                 .map_or(true, |member| member.membership != MembershipState::Join); | ||||
| 
 | ||||
|             let current_state_ids = db.rooms.state_full_ids(current_shortstatehash)?; | ||||
| 
 | ||||
|             let since_state_ids = db.rooms.state_full_ids(since_shortstatehash)?; | ||||
| 
 | ||||
|             let state_events = if joined_since_last_sync { | ||||
|                 current_state_ids | ||||
|                     .iter() | ||||
|                         .filter(|(key, value)| { | ||||
|                             since_state.get(key).map(|e| &e.event_id) != Some(&value.event_id) | ||||
|                         }) | ||||
|                         .filter(|(_, value)| { | ||||
|                             !timeline_pdus.iter().any(|(_, timeline_pdu)| { | ||||
|                                 timeline_pdu.kind == value.kind | ||||
|                                     && timeline_pdu.state_key == value.state_key | ||||
|                             }) | ||||
|                         }) | ||||
|                         .map(|(_, pdu)| pdu.to_sync_state_event()) | ||||
|                         .collect(), | ||||
|                     Some(None) => current_state | ||||
|                     .map(|id| db.rooms.get_pdu(id)) | ||||
|                     .filter_map(|r| r.ok().flatten()) | ||||
|                     .collect::<Vec<_>>() | ||||
|             } else { | ||||
|                 current_state_ids | ||||
|                     .difference(&since_state_ids) | ||||
|                     .filter(|id| { | ||||
|                         !timeline_pdus | ||||
|                             .iter() | ||||
|                         .map(|(_, pdu)| pdu.to_sync_state_event()) | ||||
|                         .collect(), | ||||
|                             .any(|(_, timeline_pdu)| timeline_pdu.event_id == **id) | ||||
|                     }) | ||||
|                     .map(|id| db.rooms.get_pdu(id)) | ||||
|                     .filter_map(|r| r.ok().flatten()) | ||||
|                     .collect() | ||||
|             }; | ||||
| 
 | ||||
|             let encrypted_room = db | ||||
|                 .rooms | ||||
|                 .state_get(current_shortstatehash, &EventType::RoomEncryption, "")? | ||||
|                 .is_some(); | ||||
| 
 | ||||
|             let since_encryption = | ||||
|                 db.rooms | ||||
|                     .state_get(since_shortstatehash, &EventType::RoomEncryption, "")?; | ||||
| 
 | ||||
|             // Calculations:
 | ||||
|             let new_encrypted_room = encrypted_room && since_encryption.is_none(); | ||||
| 
 | ||||
|             let send_member_count = state_events | ||||
|                 .iter() | ||||
|                 .any(|event| event.kind == EventType::RoomMember); | ||||
| 
 | ||||
|             if encrypted_room { | ||||
|                 for (user_id, current_member) in db | ||||
|                     .rooms | ||||
|                     .room_members(&room_id) | ||||
|                     .filter_map(|r| r.ok()) | ||||
|                     .filter_map(|user_id| { | ||||
|                         db.rooms | ||||
|                             .state_get( | ||||
|                                 current_shortstatehash, | ||||
|                                 &EventType::RoomMember, | ||||
|                                 user_id.as_str(), | ||||
|                             ) | ||||
|                             .ok() | ||||
|                             .flatten() | ||||
|                             .map(|current_member| (user_id, current_member)) | ||||
|                     }) | ||||
|                 { | ||||
|                     let current_membership = serde_json::from_value::< | ||||
|                         Raw<ruma::events::room::member::MemberEventContent>, | ||||
|                     >(current_member.content.clone()) | ||||
|                     .expect("Raw::from_value always works") | ||||
|                     .deserialize() | ||||
|                     .map_err(|_| Error::bad_database("Invalid PDU in database."))? | ||||
|                     .membership; | ||||
| 
 | ||||
|                     let since_membership = db | ||||
|                         .rooms | ||||
|                         .state_get( | ||||
|                             since_shortstatehash, | ||||
|                             &EventType::RoomMember, | ||||
|                             user_id.as_str(), | ||||
|                         )? | ||||
|                         .and_then(|since_member| { | ||||
|                             serde_json::from_value::< | ||||
|                                 Raw<ruma::events::room::member::MemberEventContent>, | ||||
|                             >(since_member.content.clone()) | ||||
|                             .expect("Raw::from_value always works") | ||||
|                             .deserialize() | ||||
|                             .map_err(|_| Error::bad_database("Invalid PDU in database.")) | ||||
|                             .ok() | ||||
|                         }) | ||||
|                         .map_or(MembershipState::Leave, |member| member.membership); | ||||
| 
 | ||||
|                     let user_id = UserId::try_from(user_id.clone()) | ||||
|                         .map_err(|_| Error::bad_database("Invalid UserId in member PDU."))?; | ||||
| 
 | ||||
|                     match (since_membership, current_membership) { | ||||
|                         (MembershipState::Leave, MembershipState::Join) => { | ||||
|                             // A new user joined an encrypted room
 | ||||
|                             if !share_encrypted_room(&db, &sender_user, &user_id, &room_id)? { | ||||
|                                 device_list_updates.insert(user_id); | ||||
|                             } | ||||
|                         } | ||||
|                         // TODO: Remove, this should never happen here, right?
 | ||||
|                         (MembershipState::Join, MembershipState::Leave) => { | ||||
|                             // Write down users that have left encrypted rooms we are in
 | ||||
|                             left_encrypted_users.insert(user_id); | ||||
|                         } | ||||
|                         _ => {} | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
| 
 | ||||
|             if joined_since_last_sync && encrypted_room || new_encrypted_room { | ||||
|                 // If the user is in a new encrypted room, give them all joined users
 | ||||
|                 device_list_updates.extend( | ||||
|                     db.rooms | ||||
|                         .room_members(&room_id) | ||||
|                         .flatten() | ||||
|                         .filter(|user_id| { | ||||
|                             // Don't send key updates from the sender to the sender
 | ||||
|                             &sender_user != user_id | ||||
|                         }) | ||||
|                         .filter(|user_id| { | ||||
|                             // Only send keys if the sender doesn't share an encrypted room with the target already
 | ||||
|                             !share_encrypted_room(&db, &sender_user, user_id, &room_id) | ||||
|                                 .unwrap_or(false) | ||||
|                         }), | ||||
|                 ); | ||||
|             } | ||||
| 
 | ||||
|             let (joined_member_count, invited_member_count, heroes) = if send_member_count { | ||||
|                 calculate_counts() | ||||
|             } else { | ||||
|                 (None, None, Vec::new()) | ||||
|             }; | ||||
| 
 | ||||
|             ( | ||||
|  | @ -488,8 +501,6 @@ async fn sync_helper( | |||
|                 joined_since_last_sync, | ||||
|                 state_events, | ||||
|             ) | ||||
|         } else { | ||||
|             (Vec::new(), None, None, false, Vec::new()) | ||||
|         }; | ||||
| 
 | ||||
|         // Look for device list updates in this room
 | ||||
|  | @ -580,7 +591,10 @@ async fn sync_helper( | |||
|                 events: room_events, | ||||
|             }, | ||||
|             state: sync_events::State { | ||||
|                 events: state_events, | ||||
|                 events: state_events | ||||
|                     .iter() | ||||
|                     .map(|pdu| pdu.to_sync_state_event()) | ||||
|                     .collect(), | ||||
|             }, | ||||
|             ephemeral: sync_events::Ephemeral { events: edus }, | ||||
|         }; | ||||
|  |  | |||
|  | @ -5,7 +5,7 @@ use ruma::{ | |||
|         client::r0::sync::sync_events, | ||||
|         federation::discovery::{ServerSigningKeys, VerifyKey}, | ||||
|     }, | ||||
|     DeviceId, EventId, MilliSecondsSinceUnixEpoch, ServerName, ServerSigningKeyId, UserId, | ||||
|     DeviceId, EventId, MilliSecondsSinceUnixEpoch, RoomId, ServerName, ServerSigningKeyId, UserId, | ||||
| }; | ||||
| use rustls::{ServerCertVerifier, WebPKIVerifier}; | ||||
| use std::{ | ||||
|  | @ -16,7 +16,7 @@ use std::{ | |||
|     sync::{Arc, RwLock}, | ||||
|     time::{Duration, Instant}, | ||||
| }; | ||||
| use tokio::sync::{broadcast, watch::Receiver, Semaphore}; | ||||
| use tokio::sync::{broadcast, watch::Receiver, Mutex, Semaphore}; | ||||
| use trust_dns_resolver::TokioAsyncResolver; | ||||
| 
 | ||||
| use super::abstraction::Tree; | ||||
|  | @ -45,6 +45,7 @@ pub struct Globals { | |||
|     pub bad_signature_ratelimiter: Arc<RwLock<BTreeMap<Vec<String>, RateLimitState>>>, | ||||
|     pub servername_ratelimiter: Arc<RwLock<BTreeMap<Box<ServerName>, Arc<Semaphore>>>>, | ||||
|     pub sync_receivers: RwLock<BTreeMap<(UserId, Box<DeviceId>), SyncHandle>>, | ||||
|     pub roomid_mutex: RwLock<BTreeMap<RoomId, Arc<Mutex<()>>>>, | ||||
|     pub rotate: RotationHandler, | ||||
| } | ||||
| 
 | ||||
|  | @ -197,6 +198,7 @@ impl Globals { | |||
|             bad_event_ratelimiter: Arc::new(RwLock::new(BTreeMap::new())), | ||||
|             bad_signature_ratelimiter: Arc::new(RwLock::new(BTreeMap::new())), | ||||
|             servername_ratelimiter: Arc::new(RwLock::new(BTreeMap::new())), | ||||
|             roomid_mutex: RwLock::new(BTreeMap::new()), | ||||
|             sync_receivers: RwLock::new(BTreeMap::new()), | ||||
|             rotate: RotationHandler::new(), | ||||
|         }; | ||||
|  |  | |||
|  | @ -21,7 +21,7 @@ use ruma::{ | |||
|     uint, EventId, RoomAliasId, RoomId, RoomVersionId, ServerName, UserId, | ||||
| }; | ||||
| use std::{ | ||||
|     collections::{BTreeMap, HashMap, HashSet}, | ||||
|     collections::{BTreeMap, BTreeSet, HashMap, HashSet}, | ||||
|     convert::{TryFrom, TryInto}, | ||||
|     mem, | ||||
|     sync::{Arc, RwLock}, | ||||
|  | @ -89,7 +89,7 @@ pub struct Rooms { | |||
| impl Rooms { | ||||
|     /// Builds a StateMap by iterating over all keys that start
 | ||||
|     /// with state_hash, this gives the full state for the given state_hash.
 | ||||
|     pub fn state_full_ids(&self, shortstatehash: u64) -> Result<Vec<EventId>> { | ||||
|     pub fn state_full_ids(&self, shortstatehash: u64) -> Result<BTreeSet<EventId>> { | ||||
|         Ok(self | ||||
|             .stateid_shorteventid | ||||
|             .scan_prefix(shortstatehash.to_be_bytes().to_vec()) | ||||
|  | @ -1215,6 +1215,7 @@ impl Rooms { | |||
|             state_key, | ||||
|             redacts, | ||||
|         } = pdu_builder; | ||||
| 
 | ||||
|         // TODO: Make sure this isn't called twice in parallel
 | ||||
|         let prev_events = self | ||||
|             .get_pdu_leaves(&room_id)? | ||||
|  |  | |||
|  | @ -625,13 +625,44 @@ pub async fn send_transaction_message_route( | |||
|             } | ||||
|         }; | ||||
| 
 | ||||
|         // 0. Check the server is in the room
 | ||||
|         let room_id = match value | ||||
|             .get("room_id") | ||||
|             .and_then(|id| RoomId::try_from(id.as_str()?).ok()) | ||||
|         { | ||||
|             Some(id) => id, | ||||
|             None => { | ||||
|                 // Event is invalid
 | ||||
|                 resolved_map.insert(event_id, Err("Event needs a valid RoomId.".to_string())); | ||||
|                 continue; | ||||
|             } | ||||
|         }; | ||||
| 
 | ||||
|         let mutex = Arc::clone( | ||||
|             db.globals | ||||
|                 .roomid_mutex | ||||
|                 .write() | ||||
|                 .unwrap() | ||||
|                 .entry(room_id.clone()) | ||||
|                 .or_default(), | ||||
|         ); | ||||
|         let mutex_lock = mutex.lock().await; | ||||
|         let start_time = Instant::now(); | ||||
|         resolved_map.insert( | ||||
|             event_id.clone(), | ||||
|             handle_incoming_pdu(&body.origin, &event_id, value, true, &db, &pub_key_map) | ||||
|             handle_incoming_pdu( | ||||
|                 &body.origin, | ||||
|                 &event_id, | ||||
|                 &room_id, | ||||
|                 value, | ||||
|                 true, | ||||
|                 &db, | ||||
|                 &pub_key_map, | ||||
|             ) | ||||
|             .await | ||||
|             .map(|_| ()), | ||||
|         ); | ||||
|         drop(mutex_lock); | ||||
| 
 | ||||
|         let elapsed = start_time.elapsed(); | ||||
|         if elapsed > Duration::from_secs(1) { | ||||
|  | @ -784,8 +815,8 @@ pub async fn send_transaction_message_route( | |||
| type AsyncRecursiveResult<'a, T, E> = Pin<Box<dyn Future<Output = StdResult<T, E>> + 'a + Send>>; | ||||
| 
 | ||||
| /// When receiving an event one needs to:
 | ||||
| /// 0. Skip the PDU if we already know about it
 | ||||
| /// 1. Check the server is in the room
 | ||||
| /// 0. Check the server is in the room
 | ||||
| /// 1. Skip the PDU if we already know about it
 | ||||
| /// 2. Check signatures, otherwise drop
 | ||||
| /// 3. Check content hash, redact if doesn't match
 | ||||
| /// 4. Fetch any missing auth events doing all checks listed here starting at 1. These are not
 | ||||
|  | @ -810,6 +841,7 @@ type AsyncRecursiveResult<'a, T, E> = Pin<Box<dyn Future<Output = StdResult<T, E | |||
| pub fn handle_incoming_pdu<'a>( | ||||
|     origin: &'a ServerName, | ||||
|     event_id: &'a EventId, | ||||
|     room_id: &'a RoomId, | ||||
|     value: BTreeMap<String, CanonicalJsonValue>, | ||||
|     is_timeline_event: bool, | ||||
|     db: &'a Database, | ||||
|  | @ -817,24 +849,6 @@ pub fn handle_incoming_pdu<'a>( | |||
| ) -> AsyncRecursiveResult<'a, Option<Vec<u8>>, String> { | ||||
|     Box::pin(async move { | ||||
|         // TODO: For RoomVersion6 we must check that Raw<..> is canonical do we anywhere?: https://matrix.org/docs/spec/rooms/v6#canonical-json
 | ||||
| 
 | ||||
|         // 0. Skip the PDU if we already have it as a timeline event
 | ||||
|         if let Ok(Some(pdu_id)) = db.rooms.get_pdu_id(&event_id) { | ||||
|             return Ok(Some(pdu_id.to_vec())); | ||||
|         } | ||||
| 
 | ||||
|         // 1. Check the server is in the room
 | ||||
|         let room_id = match value | ||||
|             .get("room_id") | ||||
|             .and_then(|id| RoomId::try_from(id.as_str()?).ok()) | ||||
|         { | ||||
|             Some(id) => id, | ||||
|             None => { | ||||
|                 // Event is invalid
 | ||||
|                 return Err("Event needs a valid RoomId.".to_string()); | ||||
|             } | ||||
|         }; | ||||
| 
 | ||||
|         match db.rooms.exists(&room_id) { | ||||
|             Ok(true) => {} | ||||
|             _ => { | ||||
|  | @ -842,6 +856,11 @@ pub fn handle_incoming_pdu<'a>( | |||
|             } | ||||
|         } | ||||
| 
 | ||||
|         // 1. Skip the PDU if we already have it as a timeline event
 | ||||
|         if let Ok(Some(pdu_id)) = db.rooms.get_pdu_id(&event_id) { | ||||
|             return Ok(Some(pdu_id.to_vec())); | ||||
|         } | ||||
| 
 | ||||
|         // We go through all the signatures we see on the value and fetch the corresponding signing
 | ||||
|         // keys
 | ||||
|         fetch_required_signing_keys(&value, &pub_key_map, db) | ||||
|  | @ -901,7 +920,7 @@ pub fn handle_incoming_pdu<'a>( | |||
|         // 5. Reject "due to auth events" if can't get all the auth events or some of the auth events are also rejected "due to auth events"
 | ||||
|         // EDIT: Step 5 is not applied anymore because it failed too often
 | ||||
|         debug!("Fetching auth events for {}", incoming_pdu.event_id); | ||||
|         fetch_and_handle_events(db, origin, &incoming_pdu.auth_events, pub_key_map) | ||||
|         fetch_and_handle_events(db, origin, &incoming_pdu.auth_events, &room_id, pub_key_map) | ||||
|             .await | ||||
|             .map_err(|e| e.to_string())?; | ||||
| 
 | ||||
|  | @ -1002,13 +1021,13 @@ pub fn handle_incoming_pdu<'a>( | |||
| 
 | ||||
|         if incoming_pdu.prev_events.len() == 1 { | ||||
|             let prev_event = &incoming_pdu.prev_events[0]; | ||||
|             let state_vec = db | ||||
|             let state = db | ||||
|                 .rooms | ||||
|                 .pdu_shortstatehash(prev_event) | ||||
|                 .map_err(|_| "Failed talking to db".to_owned())? | ||||
|                 .map(|shortstatehash| db.rooms.state_full_ids(shortstatehash).ok()) | ||||
|                 .flatten(); | ||||
|             if let Some(mut state_vec) = state_vec { | ||||
|             if let Some(mut state) = state { | ||||
|                 if db | ||||
|                     .rooms | ||||
|                     .get_pdu(prev_event) | ||||
|  | @ -1018,10 +1037,16 @@ pub fn handle_incoming_pdu<'a>( | |||
|                     .state_key | ||||
|                     .is_some() | ||||
|                 { | ||||
|                     state_vec.push(prev_event.clone()); | ||||
|                     state.insert(prev_event.clone()); | ||||
|                 } | ||||
|                 state_at_incoming_event = Some( | ||||
|                     fetch_and_handle_events(db, origin, &state_vec, pub_key_map) | ||||
|                     fetch_and_handle_events( | ||||
|                         db, | ||||
|                         origin, | ||||
|                         &state.into_iter().collect::<Vec<_>>(), | ||||
|                         &room_id, | ||||
|                         pub_key_map, | ||||
|                     ) | ||||
|                     .await | ||||
|                     .map_err(|_| "Failed to fetch state events locally".to_owned())? | ||||
|                     .into_iter() | ||||
|  | @ -1059,8 +1084,14 @@ pub fn handle_incoming_pdu<'a>( | |||
|             { | ||||
|                 Ok(res) => { | ||||
|                     debug!("Fetching state events at event."); | ||||
|                     let state_vec = | ||||
|                         match fetch_and_handle_events(&db, origin, &res.pdu_ids, pub_key_map).await | ||||
|                     let state_vec = match fetch_and_handle_events( | ||||
|                         &db, | ||||
|                         origin, | ||||
|                         &res.pdu_ids, | ||||
|                         &room_id, | ||||
|                         pub_key_map, | ||||
|                     ) | ||||
|                     .await | ||||
|                     { | ||||
|                         Ok(state) => state, | ||||
|                         Err(_) => return Err("Failed to fetch state events.".to_owned()), | ||||
|  | @ -1090,7 +1121,13 @@ pub fn handle_incoming_pdu<'a>( | |||
|                     } | ||||
| 
 | ||||
|                     debug!("Fetching auth chain events at event."); | ||||
|                     match fetch_and_handle_events(&db, origin, &res.auth_chain_ids, pub_key_map) | ||||
|                     match fetch_and_handle_events( | ||||
|                         &db, | ||||
|                         origin, | ||||
|                         &res.auth_chain_ids, | ||||
|                         &room_id, | ||||
|                         pub_key_map, | ||||
|                     ) | ||||
|                     .await | ||||
|                     { | ||||
|                         Ok(state) => state, | ||||
|  | @ -1313,6 +1350,7 @@ pub(crate) fn fetch_and_handle_events<'a>( | |||
|     db: &'a Database, | ||||
|     origin: &'a ServerName, | ||||
|     events: &'a [EventId], | ||||
|     room_id: &'a RoomId, | ||||
|     pub_key_map: &'a RwLock<BTreeMap<String, BTreeMap<String, String>>>, | ||||
| ) -> AsyncRecursiveResult<'a, Vec<Arc<PduEvent>>, Error> { | ||||
|     Box::pin(async move { | ||||
|  | @ -1366,6 +1404,7 @@ pub(crate) fn fetch_and_handle_events<'a>( | |||
|                             match handle_incoming_pdu( | ||||
|                                 origin, | ||||
|                                 &event_id, | ||||
|                                 &room_id, | ||||
|                                 value.clone(), | ||||
|                                 false, | ||||
|                                 db, | ||||
|  | @ -1854,7 +1893,11 @@ pub fn get_room_state_ids_route( | |||
|             "Pdu state not found.", | ||||
|         ))?; | ||||
| 
 | ||||
|     let pdu_ids = db.rooms.state_full_ids(shortstatehash)?; | ||||
|     let pdu_ids = db | ||||
|         .rooms | ||||
|         .state_full_ids(shortstatehash)? | ||||
|         .into_iter() | ||||
|         .collect(); | ||||
| 
 | ||||
|     let mut auth_chain_ids = BTreeSet::<EventId>::new(); | ||||
|     let mut todo = BTreeSet::new(); | ||||
|  | @ -2100,7 +2143,24 @@ pub async fn create_join_event_route( | |||
|     ) | ||||
|     .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Origin field is invalid."))?; | ||||
| 
 | ||||
|     let pdu_id = handle_incoming_pdu(&origin, &event_id, value, true, &db, &pub_key_map) | ||||
|     let mutex = Arc::clone( | ||||
|         db.globals | ||||
|             .roomid_mutex | ||||
|             .write() | ||||
|             .unwrap() | ||||
|             .entry(body.room_id.clone()) | ||||
|             .or_default(), | ||||
|     ); | ||||
|     let mutex_lock = mutex.lock().await; | ||||
|     let pdu_id = handle_incoming_pdu( | ||||
|         &origin, | ||||
|         &event_id, | ||||
|         &body.room_id, | ||||
|         value, | ||||
|         true, | ||||
|         &db, | ||||
|         &pub_key_map, | ||||
|     ) | ||||
|     .await | ||||
|     .map_err(|_| { | ||||
|         Error::BadRequest( | ||||
|  | @ -2112,6 +2172,7 @@ pub async fn create_join_event_route( | |||
|         ErrorKind::InvalidParam, | ||||
|         "Could not accept incoming PDU as timeline event.", | ||||
|     ))?; | ||||
|     drop(mutex_lock); | ||||
| 
 | ||||
|     let state_ids = db.rooms.state_full_ids(shortstatehash)?; | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in a new issue