finish implementing better state store
This commit is contained in:
		
							parent
							
								
									31f60ad6fd
								
							
						
					
					
						commit
						3eabaa2a95
					
				
					 10 changed files with 645 additions and 526 deletions
				
			
		|  | @ -249,6 +249,8 @@ pub async fn register_route( | |||
| 
 | ||||
|         let room_id = RoomId::new(db.globals.server_name()); | ||||
| 
 | ||||
|         db.rooms.get_or_create_shortroomid(&room_id, &db.globals)?; | ||||
| 
 | ||||
|         let mutex_state = Arc::clone( | ||||
|             db.globals | ||||
|                 .roomid_mutex_state | ||||
|  |  | |||
|  | @ -44,7 +44,7 @@ pub async fn get_context_route( | |||
| 
 | ||||
|     let events_before = db | ||||
|         .rooms | ||||
|         .pdus_until(&sender_user, &body.room_id, base_token) | ||||
|         .pdus_until(&sender_user, &body.room_id, base_token)? | ||||
|         .take( | ||||
|             u32::try_from(body.limit).map_err(|_| { | ||||
|                 Error::BadRequest(ErrorKind::InvalidParam, "Limit value is invalid.") | ||||
|  | @ -66,7 +66,7 @@ pub async fn get_context_route( | |||
| 
 | ||||
|     let events_after = db | ||||
|         .rooms | ||||
|         .pdus_after(&sender_user, &body.room_id, base_token) | ||||
|         .pdus_after(&sender_user, &body.room_id, base_token)? | ||||
|         .take( | ||||
|             u32::try_from(body.limit).map_err(|_| { | ||||
|                 Error::BadRequest(ErrorKind::InvalidParam, "Limit value is invalid.") | ||||
|  |  | |||
|  | @ -609,6 +609,8 @@ async fn join_room_by_id_helper( | |||
|             ) | ||||
|             .await?; | ||||
| 
 | ||||
|         db.rooms.get_or_create_shortroomid(&room_id, &db.globals)?; | ||||
| 
 | ||||
|         let pdu = PduEvent::from_id_val(&event_id, join_event.clone()) | ||||
|             .map_err(|_| Error::BadServerResponse("Invalid join event PDU."))?; | ||||
| 
 | ||||
|  |  | |||
|  | @ -128,7 +128,7 @@ pub async fn get_message_events_route( | |||
|         get_message_events::Direction::Forward => { | ||||
|             let events_after = db | ||||
|                 .rooms | ||||
|                 .pdus_after(&sender_user, &body.room_id, from) | ||||
|                 .pdus_after(&sender_user, &body.room_id, from)? | ||||
|                 .take(limit) | ||||
|                 .filter_map(|r| r.ok()) // Filter out buggy events
 | ||||
|                 .filter_map(|(pdu_id, pdu)| { | ||||
|  | @ -158,7 +158,7 @@ pub async fn get_message_events_route( | |||
|         get_message_events::Direction::Backward => { | ||||
|             let events_before = db | ||||
|                 .rooms | ||||
|                 .pdus_until(&sender_user, &body.room_id, from) | ||||
|                 .pdus_until(&sender_user, &body.room_id, from)? | ||||
|                 .take(limit) | ||||
|                 .filter_map(|r| r.ok()) // Filter out buggy events
 | ||||
|                 .filter_map(|(pdu_id, pdu)| { | ||||
|  |  | |||
|  | @ -33,6 +33,8 @@ pub async fn create_room_route( | |||
| 
 | ||||
|     let room_id = RoomId::new(db.globals.server_name()); | ||||
| 
 | ||||
|     db.rooms.get_or_create_shortroomid(&room_id, &db.globals)?; | ||||
| 
 | ||||
|     let mutex_state = Arc::clone( | ||||
|         db.globals | ||||
|             .roomid_mutex_state | ||||
|  | @ -173,7 +175,6 @@ pub async fn create_room_route( | |||
|     )?; | ||||
| 
 | ||||
|     // 4. Canonical room alias
 | ||||
| 
 | ||||
|     if let Some(room_alias_id) = &alias { | ||||
|         db.rooms.build_and_append_pdu( | ||||
|             PduBuilder { | ||||
|  | @ -193,7 +194,7 @@ pub async fn create_room_route( | |||
|             &room_id, | ||||
|             &db, | ||||
|             &state_lock, | ||||
|         ); | ||||
|         )?; | ||||
|     } | ||||
| 
 | ||||
|     // 5. Events set by preset
 | ||||
|  |  | |||
|  | @ -205,7 +205,7 @@ async fn sync_helper( | |||
| 
 | ||||
|         let mut non_timeline_pdus = db | ||||
|             .rooms | ||||
|             .pdus_until(&sender_user, &room_id, u64::MAX) | ||||
|             .pdus_until(&sender_user, &room_id, u64::MAX)? | ||||
|             .filter_map(|r| { | ||||
|                 // Filter out buggy events
 | ||||
|                 if r.is_err() { | ||||
|  | @ -248,13 +248,13 @@ async fn sync_helper( | |||
| 
 | ||||
|         let first_pdu_before_since = db | ||||
|             .rooms | ||||
|             .pdus_until(&sender_user, &room_id, since) | ||||
|             .pdus_until(&sender_user, &room_id, since)? | ||||
|             .next() | ||||
|             .transpose()?; | ||||
| 
 | ||||
|         let pdus_after_since = db | ||||
|             .rooms | ||||
|             .pdus_after(&sender_user, &room_id, since) | ||||
|             .pdus_after(&sender_user, &room_id, since)? | ||||
|             .next() | ||||
|             .is_some(); | ||||
| 
 | ||||
|  | @ -286,7 +286,7 @@ async fn sync_helper( | |||
| 
 | ||||
|                 for hero in db | ||||
|                     .rooms | ||||
|                     .all_pdus(&sender_user, &room_id) | ||||
|                     .all_pdus(&sender_user, &room_id)? | ||||
|                     .filter_map(|pdu| pdu.ok()) // Ignore all broken pdus
 | ||||
|                     .filter(|(_, pdu)| pdu.kind == EventType::RoomMember) | ||||
|                     .map(|(_, pdu)| { | ||||
|  | @ -328,11 +328,11 @@ async fn sync_helper( | |||
|                 } | ||||
|             } | ||||
| 
 | ||||
|             ( | ||||
|             Ok::<_, Error>(( | ||||
|                 Some(joined_member_count), | ||||
|                 Some(invited_member_count), | ||||
|                 heroes, | ||||
|             ) | ||||
|             )) | ||||
|         }; | ||||
| 
 | ||||
|         let ( | ||||
|  | @ -343,7 +343,7 @@ async fn sync_helper( | |||
|             state_events, | ||||
|         ) = if since_shortstatehash.is_none() { | ||||
|             // Probably since = 0, we will do an initial sync
 | ||||
|             let (joined_member_count, invited_member_count, heroes) = calculate_counts(); | ||||
|             let (joined_member_count, invited_member_count, heroes) = calculate_counts()?; | ||||
| 
 | ||||
|             let current_state_ids = db.rooms.state_full_ids(current_shortstatehash)?; | ||||
|             let state_events = current_state_ids | ||||
|  | @ -510,7 +510,7 @@ async fn sync_helper( | |||
|             } | ||||
| 
 | ||||
|             let (joined_member_count, invited_member_count, heroes) = if send_member_count { | ||||
|                 calculate_counts() | ||||
|                 calculate_counts()? | ||||
|             } else { | ||||
|                 (None, None, Vec::new()) | ||||
|             }; | ||||
|  |  | |||
							
								
								
									
										373
									
								
								src/database.rs
									
									
									
									
									
								
							
							
						
						
									
										373
									
								
								src/database.rs
									
									
									
									
									
								
							|  | @ -28,7 +28,7 @@ use ruma::{DeviceId, EventId, RoomId, ServerName, UserId}; | |||
| use serde::{de::IgnoredAny, Deserialize}; | ||||
| use std::{ | ||||
|     collections::{BTreeMap, HashMap, HashSet}, | ||||
|     convert::TryFrom, | ||||
|     convert::{TryFrom, TryInto}, | ||||
|     fs::{self, remove_dir_all}, | ||||
|     io::Write, | ||||
|     mem::size_of, | ||||
|  | @ -266,7 +266,6 @@ impl Database { | |||
|                 shortroomid_roomid: builder.open_tree("shortroomid_roomid")?, | ||||
|                 roomid_shortroomid: builder.open_tree("roomid_shortroomid")?, | ||||
| 
 | ||||
|                 stateid_shorteventid: builder.open_tree("stateid_shorteventid")?, | ||||
|                 shortstatehash_statediff: builder.open_tree("shortstatehash_statediff")?, | ||||
|                 eventid_shorteventid: builder.open_tree("eventid_shorteventid")?, | ||||
|                 shorteventid_eventid: builder.open_tree("shorteventid_eventid")?, | ||||
|  | @ -431,7 +430,6 @@ impl Database { | |||
|             } | ||||
| 
 | ||||
|             if db.globals.database_version()? < 6 { | ||||
|                 // TODO update to 6
 | ||||
|                 // Set room member count
 | ||||
|                 for (roomid, _) in db.rooms.roomid_shortstatehash.iter() { | ||||
|                     let room_id = | ||||
|  | @ -445,263 +443,98 @@ impl Database { | |||
|                 println!("Migration: 5 -> 6 finished"); | ||||
|             } | ||||
| 
 | ||||
|             fn load_shortstatehash_info( | ||||
|                 shortstatehash: &[u8], | ||||
|                 db: &Database, | ||||
|                 lru: &mut LruCache< | ||||
|                     Vec<u8>, | ||||
|                     Vec<( | ||||
|                         Vec<u8>, | ||||
|                         HashSet<Vec<u8>>, | ||||
|                         HashSet<Vec<u8>>, | ||||
|                         HashSet<Vec<u8>>, | ||||
|                     )>, | ||||
|                 >, | ||||
|             ) -> Result< | ||||
|                 Vec<( | ||||
|                     Vec<u8>,          // sstatehash
 | ||||
|                     HashSet<Vec<u8>>, // full state
 | ||||
|                     HashSet<Vec<u8>>, // added
 | ||||
|                     HashSet<Vec<u8>>, // removed
 | ||||
|                 )>, | ||||
|             > { | ||||
|                 if let Some(result) = lru.get_mut(shortstatehash) { | ||||
|                     return Ok(result.clone()); | ||||
|                 } | ||||
| 
 | ||||
|                 let value = db | ||||
|                     .rooms | ||||
|                     .shortstatehash_statediff | ||||
|                     .get(shortstatehash)? | ||||
|                     .ok_or_else(|| Error::bad_database("State hash does not exist"))?; | ||||
|                 let parent = value[0..size_of::<u64>()].to_vec(); | ||||
| 
 | ||||
|                 let mut add_mode = true; | ||||
|                 let mut added = HashSet::new(); | ||||
|                 let mut removed = HashSet::new(); | ||||
| 
 | ||||
|                 let mut i = size_of::<u64>(); | ||||
|                 while let Some(v) = value.get(i..i + 2 * size_of::<u64>()) { | ||||
|                     if add_mode && v.starts_with(&0_u64.to_be_bytes()) { | ||||
|                         add_mode = false; | ||||
|                         i += size_of::<u64>(); | ||||
|                         continue; | ||||
|                     } | ||||
|                     if add_mode { | ||||
|                         added.insert(v.to_vec()); | ||||
|                     } else { | ||||
|                         removed.insert(v.to_vec()); | ||||
|                     } | ||||
|                     i += 2 * size_of::<u64>(); | ||||
|                 } | ||||
| 
 | ||||
|                 if parent != 0_u64.to_be_bytes() { | ||||
|                     let mut response = load_shortstatehash_info(&parent, db, lru)?; | ||||
|                     let mut state = response.last().unwrap().1.clone(); | ||||
|                     state.extend(added.iter().cloned()); | ||||
|                     for r in &removed { | ||||
|                         state.remove(r); | ||||
|                     } | ||||
| 
 | ||||
|                     response.push((shortstatehash.to_vec(), state, added, removed)); | ||||
| 
 | ||||
|                     lru.insert(shortstatehash.to_vec(), response.clone()); | ||||
|                     Ok(response) | ||||
|                 } else { | ||||
|                     let mut response = Vec::new(); | ||||
|                     response.push((shortstatehash.to_vec(), added.clone(), added, removed)); | ||||
|                     lru.insert(shortstatehash.to_vec(), response.clone()); | ||||
|                     Ok(response) | ||||
|                 } | ||||
|             } | ||||
| 
 | ||||
|             fn update_shortstatehash_level( | ||||
|                 current_shortstatehash: &[u8], | ||||
|                 statediffnew: HashSet<Vec<u8>>, | ||||
|                 statediffremoved: HashSet<Vec<u8>>, | ||||
|                 diff_to_sibling: usize, | ||||
|                 mut parent_states: Vec<( | ||||
|                     Vec<u8>,          // sstatehash
 | ||||
|                     HashSet<Vec<u8>>, // full state
 | ||||
|                     HashSet<Vec<u8>>, // added
 | ||||
|                     HashSet<Vec<u8>>, // removed
 | ||||
|                 )>, | ||||
|                 db: &Database, | ||||
|             ) -> Result<()> { | ||||
|                 let diffsum = statediffnew.len() + statediffremoved.len(); | ||||
| 
 | ||||
|                 if parent_states.len() > 3 { | ||||
|                     // Number of layers
 | ||||
|                     // To many layers, we have to go deeper
 | ||||
|                     let parent = parent_states.pop().unwrap(); | ||||
| 
 | ||||
|                     let mut parent_new = parent.2; | ||||
|                     let mut parent_removed = parent.3; | ||||
| 
 | ||||
|                     for removed in statediffremoved { | ||||
|                         if !parent_new.remove(&removed) { | ||||
|                             parent_removed.insert(removed); | ||||
|                         } | ||||
|                     } | ||||
|                     parent_new.extend(statediffnew); | ||||
| 
 | ||||
|                     update_shortstatehash_level( | ||||
|                         current_shortstatehash, | ||||
|                         parent_new, | ||||
|                         parent_removed, | ||||
|                         diffsum, | ||||
|                         parent_states, | ||||
|                         db, | ||||
|                     )?; | ||||
| 
 | ||||
|                     return Ok(()); | ||||
|                 } | ||||
| 
 | ||||
|                 if parent_states.len() == 0 { | ||||
|                     // There is no parent layer, create a new state
 | ||||
|                     let mut value = 0_u64.to_be_bytes().to_vec(); // 0 means no parent
 | ||||
|                     for new in &statediffnew { | ||||
|                         value.extend_from_slice(&new); | ||||
|                     } | ||||
| 
 | ||||
|                     if !statediffremoved.is_empty() { | ||||
|                         warn!("Tried to create new state with removals"); | ||||
|                     } | ||||
| 
 | ||||
|                     db.rooms | ||||
|                         .shortstatehash_statediff | ||||
|                         .insert(¤t_shortstatehash, &value)?; | ||||
| 
 | ||||
|                     return Ok(()); | ||||
|                 }; | ||||
| 
 | ||||
|                 // Else we have two options.
 | ||||
|                 // 1. We add the current diff on top of the parent layer.
 | ||||
|                 // 2. We replace a layer above
 | ||||
| 
 | ||||
|                 let parent = parent_states.pop().unwrap(); | ||||
|                 let parent_diff = parent.2.len() + parent.3.len(); | ||||
| 
 | ||||
|                 if diffsum * diffsum >= 2 * diff_to_sibling * parent_diff { | ||||
|                     // Diff too big, we replace above layer(s)
 | ||||
|                     let mut parent_new = parent.2; | ||||
|                     let mut parent_removed = parent.3; | ||||
| 
 | ||||
|                     for removed in statediffremoved { | ||||
|                         if !parent_new.remove(&removed) { | ||||
|                             parent_removed.insert(removed); | ||||
|                         } | ||||
|                     } | ||||
| 
 | ||||
|                     parent_new.extend(statediffnew); | ||||
|                     update_shortstatehash_level( | ||||
|                         current_shortstatehash, | ||||
|                         parent_new, | ||||
|                         parent_removed, | ||||
|                         diffsum, | ||||
|                         parent_states, | ||||
|                         db, | ||||
|                     )?; | ||||
|                 } else { | ||||
|                     // Diff small enough, we add diff as layer on top of parent
 | ||||
|                     let mut value = parent.0.clone(); | ||||
|                     for new in &statediffnew { | ||||
|                         value.extend_from_slice(&new); | ||||
|                     } | ||||
| 
 | ||||
|                     if !statediffremoved.is_empty() { | ||||
|                         value.extend_from_slice(&0_u64.to_be_bytes()); | ||||
|                         for removed in &statediffremoved { | ||||
|                             value.extend_from_slice(&removed); | ||||
|                         } | ||||
|                     } | ||||
| 
 | ||||
|                     db.rooms | ||||
|                         .shortstatehash_statediff | ||||
|                         .insert(¤t_shortstatehash, &value)?; | ||||
|                 } | ||||
| 
 | ||||
|                 Ok(()) | ||||
|             } | ||||
| 
 | ||||
|             if db.globals.database_version()? < 7 { | ||||
|                 // Upgrade state store
 | ||||
|                 let mut lru = LruCache::new(1000); | ||||
|                 let mut last_roomstates: HashMap<RoomId, Vec<u8>> = HashMap::new(); | ||||
|                 let mut current_sstatehash: Vec<u8> = Vec::new(); | ||||
|                 let mut last_roomstates: HashMap<RoomId, u64> = HashMap::new(); | ||||
|                 let mut current_sstatehash: Option<u64> = None; | ||||
|                 let mut current_room = None; | ||||
|                 let mut current_state = HashSet::new(); | ||||
|                 let mut counter = 0; | ||||
| 
 | ||||
|                 let mut handle_state = | ||||
|                     |current_sstatehash: u64, | ||||
|                      current_room: &RoomId, | ||||
|                      current_state: HashSet<_>, | ||||
|                      last_roomstates: &mut HashMap<_, _>| { | ||||
|                         counter += 1; | ||||
|                         println!("counter: {}", counter); | ||||
|                         let last_roomsstatehash = last_roomstates.get(current_room); | ||||
| 
 | ||||
|                         let states_parents = last_roomsstatehash.map_or_else( | ||||
|                             || Ok(Vec::new()), | ||||
|                             |&last_roomsstatehash| { | ||||
|                                 db.rooms.load_shortstatehash_info(dbg!(last_roomsstatehash)) | ||||
|                             }, | ||||
|                         )?; | ||||
| 
 | ||||
|                         let (statediffnew, statediffremoved) = | ||||
|                             if let Some(parent_stateinfo) = states_parents.last() { | ||||
|                                 let statediffnew = current_state | ||||
|                                     .difference(&parent_stateinfo.1) | ||||
|                                     .cloned() | ||||
|                                     .collect::<HashSet<_>>(); | ||||
| 
 | ||||
|                                 let statediffremoved = parent_stateinfo | ||||
|                                     .1 | ||||
|                                     .difference(¤t_state) | ||||
|                                     .cloned() | ||||
|                                     .collect::<HashSet<_>>(); | ||||
| 
 | ||||
|                                 (statediffnew, statediffremoved) | ||||
|                             } else { | ||||
|                                 (current_state, HashSet::new()) | ||||
|                             }; | ||||
| 
 | ||||
|                         db.rooms.save_state_from_diff( | ||||
|                             dbg!(current_sstatehash), | ||||
|                             statediffnew, | ||||
|                             statediffremoved, | ||||
|                             2, // every state change is 2 event changes on average
 | ||||
|                             states_parents, | ||||
|                         )?; | ||||
| 
 | ||||
|                         /* | ||||
|                         let mut tmp = db.rooms.load_shortstatehash_info(¤t_sstatehash, &db)?; | ||||
|                         let state = tmp.pop().unwrap(); | ||||
|                         println!( | ||||
|                             "{}\t{}{:?}: {:?} + {:?} - {:?}", | ||||
|                             current_room, | ||||
|                             "  ".repeat(tmp.len()), | ||||
|                             utils::u64_from_bytes(¤t_sstatehash).unwrap(), | ||||
|                             tmp.last().map(|b| utils::u64_from_bytes(&b.0).unwrap()), | ||||
|                             state | ||||
|                                 .2 | ||||
|                                 .iter() | ||||
|                                 .map(|b| utils::u64_from_bytes(&b[size_of::<u64>()..]).unwrap()) | ||||
|                                 .collect::<Vec<_>>(), | ||||
|                             state | ||||
|                                 .3 | ||||
|                                 .iter() | ||||
|                                 .map(|b| utils::u64_from_bytes(&b[size_of::<u64>()..]).unwrap()) | ||||
|                                 .collect::<Vec<_>>() | ||||
|                         ); | ||||
|                         */ | ||||
| 
 | ||||
|                         Ok::<_, Error>(()) | ||||
|                     }; | ||||
| 
 | ||||
|                 for (k, seventid) in db._db.open_tree("stateid_shorteventid")?.iter() { | ||||
|                     let sstatehash = k[0..size_of::<u64>()].to_vec(); | ||||
|                     let sstatehash = utils::u64_from_bytes(&k[0..size_of::<u64>()]) | ||||
|                         .expect("number of bytes is correct"); | ||||
|                     let sstatekey = k[size_of::<u64>()..].to_vec(); | ||||
|                     if sstatehash != current_sstatehash { | ||||
|                         if !current_sstatehash.is_empty() { | ||||
|                             counter += 1; | ||||
|                             println!("counter: {}", counter); | ||||
|                             let current_room = current_room.as_ref().unwrap(); | ||||
|                             let last_roomsstatehash = last_roomstates.get(¤t_room); | ||||
| 
 | ||||
|                             let states_parents = last_roomsstatehash.map_or_else( | ||||
|                                 || Ok(Vec::new()), | ||||
|                                 |last_roomsstatehash| { | ||||
|                                     load_shortstatehash_info(&last_roomsstatehash, &db, &mut lru) | ||||
|                                 }, | ||||
|                     if Some(sstatehash) != current_sstatehash { | ||||
|                         if let Some(current_sstatehash) = current_sstatehash { | ||||
|                             handle_state( | ||||
|                                 current_sstatehash, | ||||
|                                 current_room.as_ref().unwrap(), | ||||
|                                 current_state, | ||||
|                                 &mut last_roomstates, | ||||
|                             )?; | ||||
| 
 | ||||
|                             let (statediffnew, statediffremoved) = | ||||
|                                 if let Some(parent_stateinfo) = states_parents.last() { | ||||
|                                     let statediffnew = current_state | ||||
|                                         .difference(&parent_stateinfo.1) | ||||
|                                         .cloned() | ||||
|                                         .collect::<HashSet<_>>(); | ||||
| 
 | ||||
|                                     let statediffremoved = parent_stateinfo | ||||
|                                         .1 | ||||
|                                         .difference(¤t_state) | ||||
|                                         .cloned() | ||||
|                                         .collect::<HashSet<_>>(); | ||||
| 
 | ||||
|                                     (statediffnew, statediffremoved) | ||||
|                                 } else { | ||||
|                                     (current_state, HashSet::new()) | ||||
|                                 }; | ||||
| 
 | ||||
|                             update_shortstatehash_level( | ||||
|                                 ¤t_sstatehash, | ||||
|                                 statediffnew, | ||||
|                                 statediffremoved, | ||||
|                                 2, // every state change is 2 event changes on average
 | ||||
|                                 states_parents, | ||||
|                                 &db, | ||||
|                             )?; | ||||
| 
 | ||||
|                             /* | ||||
|                             let mut tmp = load_shortstatehash_info(¤t_sstatehash, &db)?; | ||||
|                             let state = tmp.pop().unwrap(); | ||||
|                             println!( | ||||
|                                 "{}\t{}{:?}: {:?} + {:?} - {:?}", | ||||
|                                 current_room, | ||||
|                                 "  ".repeat(tmp.len()), | ||||
|                                 utils::u64_from_bytes(¤t_sstatehash).unwrap(), | ||||
|                                 tmp.last().map(|b| utils::u64_from_bytes(&b.0).unwrap()), | ||||
|                                 state | ||||
|                                     .2 | ||||
|                                     .iter() | ||||
|                                     .map(|b| utils::u64_from_bytes(&b[size_of::<u64>()..]).unwrap()) | ||||
|                                     .collect::<Vec<_>>(), | ||||
|                                 state | ||||
|                                     .3 | ||||
|                                     .iter() | ||||
|                                     .map(|b| utils::u64_from_bytes(&b[size_of::<u64>()..]).unwrap()) | ||||
|                                     .collect::<Vec<_>>() | ||||
|                             ); | ||||
|                             */ | ||||
| 
 | ||||
|                             last_roomstates.insert(current_room.clone(), current_sstatehash); | ||||
|                             last_roomstates | ||||
|                                 .insert(current_room.clone().unwrap(), current_sstatehash); | ||||
|                         } | ||||
|                         current_state = HashSet::new(); | ||||
|                         current_sstatehash = sstatehash; | ||||
|                         current_sstatehash = Some(sstatehash); | ||||
| 
 | ||||
|                         let event_id = db | ||||
|                             .rooms | ||||
|  | @ -721,7 +554,16 @@ impl Database { | |||
| 
 | ||||
|                     let mut val = sstatekey; | ||||
|                     val.extend_from_slice(&seventid); | ||||
|                     current_state.insert(val); | ||||
|                     current_state.insert(val.try_into().expect("size is correct")); | ||||
|                 } | ||||
| 
 | ||||
|                 if let Some(current_sstatehash) = current_sstatehash { | ||||
|                     handle_state( | ||||
|                         current_sstatehash, | ||||
|                         current_room.as_ref().unwrap(), | ||||
|                         current_state, | ||||
|                         &mut last_roomstates, | ||||
|                     )?; | ||||
|                 } | ||||
| 
 | ||||
|                 db.globals.bump_database_version(7)?; | ||||
|  | @ -761,11 +603,28 @@ impl Database { | |||
| 
 | ||||
|                 db.rooms.pduid_pdu.insert_batch(&mut batch)?; | ||||
| 
 | ||||
|                 for (key, _) in db.rooms.pduid_pdu.iter() { | ||||
|                     if key.starts_with(b"!") { | ||||
|                         db.rooms.pduid_pdu.remove(&key); | ||||
|                 let mut batch2 = db.rooms.eventid_pduid.iter().filter_map(|(k, value)| { | ||||
|                     if !value.starts_with(b"!") { | ||||
|                         return None; | ||||
|                     } | ||||
|                 } | ||||
|                     let mut parts = value.splitn(2, |&b| b == 0xff); | ||||
|                     let room_id = parts.next().unwrap(); | ||||
|                     let count = parts.next().unwrap(); | ||||
| 
 | ||||
|                     let short_room_id = db | ||||
|                         .rooms | ||||
|                         .roomid_shortroomid | ||||
|                         .get(&room_id) | ||||
|                         .unwrap() | ||||
|                         .expect("shortroomid should exist"); | ||||
| 
 | ||||
|                     let mut new_value = short_room_id; | ||||
|                     new_value.extend_from_slice(count); | ||||
| 
 | ||||
|                     Some((k, new_value)) | ||||
|                 }); | ||||
| 
 | ||||
|                 db.rooms.eventid_pduid.insert_batch(&mut batch2)?; | ||||
| 
 | ||||
|                 db.globals.bump_database_version(8)?; | ||||
| 
 | ||||
|  | @ -803,7 +662,7 @@ impl Database { | |||
| 
 | ||||
|                 for (key, _) in db.rooms.tokenids.iter() { | ||||
|                     if key.starts_with(b"!") { | ||||
|                         db.rooms.pduid_pdu.remove(&key)?; | ||||
|                         db.rooms.tokenids.remove(&key)?; | ||||
|                     } | ||||
|                 } | ||||
| 
 | ||||
|  | @ -811,8 +670,6 @@ impl Database { | |||
| 
 | ||||
|                 println!("Migration: 8 -> 9 finished"); | ||||
|             } | ||||
| 
 | ||||
|             panic!(); | ||||
|         } | ||||
| 
 | ||||
|         let guard = db.read().await; | ||||
|  |  | |||
|  | @ -9,13 +9,13 @@ use std::{ | |||
|     path::{Path, PathBuf}, | ||||
|     pin::Pin, | ||||
|     sync::Arc, | ||||
|     time::{Duration, Instant}, | ||||
| }; | ||||
| use tokio::sync::oneshot::Sender; | ||||
| use tracing::debug; | ||||
| 
 | ||||
| thread_local! { | ||||
|     static READ_CONNECTION: RefCell<Option<&'static Connection>> = RefCell::new(None); | ||||
|     static READ_CONNECTION_ITERATOR: RefCell<Option<&'static Connection>> = RefCell::new(None); | ||||
| } | ||||
| 
 | ||||
| struct PreparedStatementIterator<'a> { | ||||
|  | @ -77,6 +77,21 @@ impl Engine { | |||
|         }) | ||||
|     } | ||||
| 
 | ||||
|     fn read_lock_iterator(&self) -> &'static Connection { | ||||
|         READ_CONNECTION_ITERATOR.with(|cell| { | ||||
|             let connection = &mut cell.borrow_mut(); | ||||
| 
 | ||||
|             if (*connection).is_none() { | ||||
|                 let c = Box::leak(Box::new( | ||||
|                     Self::prepare_conn(&self.path, self.cache_size_per_thread).unwrap(), | ||||
|                 )); | ||||
|                 **connection = Some(c); | ||||
|             } | ||||
| 
 | ||||
|             connection.unwrap() | ||||
|         }) | ||||
|     } | ||||
| 
 | ||||
|     pub fn flush_wal(self: &Arc<Self>) -> Result<()> { | ||||
|         self.write_lock() | ||||
|             .pragma_update(Some(Main), "wal_checkpoint", &"TRUNCATE")?; | ||||
|  | @ -151,6 +166,34 @@ impl SqliteTable { | |||
|         )?; | ||||
|         Ok(()) | ||||
|     } | ||||
| 
 | ||||
|     pub fn iter_with_guard<'a>( | ||||
|         &'a self, | ||||
|         guard: &'a Connection, | ||||
|     ) -> Box<dyn Iterator<Item = TupleOfBytes> + 'a> { | ||||
|         let statement = Box::leak(Box::new( | ||||
|             guard | ||||
|                 .prepare(&format!( | ||||
|                     "SELECT key, value FROM {} ORDER BY key ASC", | ||||
|                     &self.name | ||||
|                 )) | ||||
|                 .unwrap(), | ||||
|         )); | ||||
| 
 | ||||
|         let statement_ref = NonAliasingBox(statement); | ||||
| 
 | ||||
|         let iterator = Box::new( | ||||
|             statement | ||||
|                 .query_map([], |row| Ok((row.get_unwrap(0), row.get_unwrap(1)))) | ||||
|                 .unwrap() | ||||
|                 .map(|r| r.unwrap()), | ||||
|         ); | ||||
| 
 | ||||
|         Box::new(PreparedStatementIterator { | ||||
|             iterator, | ||||
|             statement_ref, | ||||
|         }) | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl Tree for SqliteTable { | ||||
|  | @ -219,30 +262,9 @@ impl Tree for SqliteTable { | |||
| 
 | ||||
|     #[tracing::instrument(skip(self))] | ||||
|     fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = TupleOfBytes> + 'a> { | ||||
|         let guard = self.engine.read_lock(); | ||||
|         let guard = self.engine.read_lock_iterator(); | ||||
| 
 | ||||
|         let statement = Box::leak(Box::new( | ||||
|             guard | ||||
|                 .prepare(&format!( | ||||
|                     "SELECT key, value FROM {} ORDER BY key ASC", | ||||
|                     &self.name | ||||
|                 )) | ||||
|                 .unwrap(), | ||||
|         )); | ||||
| 
 | ||||
|         let statement_ref = NonAliasingBox(statement); | ||||
| 
 | ||||
|         let iterator = Box::new( | ||||
|             statement | ||||
|                 .query_map([], |row| Ok((row.get_unwrap(0), row.get_unwrap(1)))) | ||||
|                 .unwrap() | ||||
|                 .map(|r| r.unwrap()), | ||||
|         ); | ||||
| 
 | ||||
|         Box::new(PreparedStatementIterator { | ||||
|             iterator, | ||||
|             statement_ref, | ||||
|         }) | ||||
|         self.iter_with_guard(&guard) | ||||
|     } | ||||
| 
 | ||||
|     #[tracing::instrument(skip(self, from, backwards))] | ||||
|  | @ -251,7 +273,7 @@ impl Tree for SqliteTable { | |||
|         from: &[u8], | ||||
|         backwards: bool, | ||||
|     ) -> Box<dyn Iterator<Item = TupleOfBytes> + 'a> { | ||||
|         let guard = self.engine.read_lock(); | ||||
|         let guard = self.engine.read_lock_iterator(); | ||||
|         let from = from.to_vec(); // TODO change interface?
 | ||||
| 
 | ||||
|         if backwards { | ||||
|  |  | |||
|  | @ -24,7 +24,7 @@ use ruma::{ | |||
| use std::{ | ||||
|     collections::{BTreeMap, BTreeSet, HashMap, HashSet}, | ||||
|     convert::{TryFrom, TryInto}, | ||||
|     mem, | ||||
|     mem::size_of, | ||||
|     sync::{Arc, Mutex}, | ||||
| }; | ||||
| use tokio::sync::MutexGuard; | ||||
|  | @ -37,10 +37,11 @@ use super::{abstraction::Tree, admin::AdminCommand, pusher}; | |||
| /// This is created when a state group is added to the database by
 | ||||
| /// hashing the entire state.
 | ||||
| pub type StateHashId = Vec<u8>; | ||||
| pub type CompressedStateEvent = [u8; 2 * size_of::<u64>()]; | ||||
| 
 | ||||
| pub struct Rooms { | ||||
|     pub edus: edus::RoomEdus, | ||||
|     pub(super) pduid_pdu: Arc<dyn Tree>, // PduId = RoomId + Count
 | ||||
|     pub(super) pduid_pdu: Arc<dyn Tree>, // PduId = ShortRoomId + Count
 | ||||
|     pub(super) eventid_pduid: Arc<dyn Tree>, | ||||
|     pub(super) roomid_pduleaves: Arc<dyn Tree>, | ||||
|     pub(super) alias_roomid: Arc<dyn Tree>, | ||||
|  | @ -79,9 +80,6 @@ pub struct Rooms { | |||
|     pub(super) eventid_shorteventid: Arc<dyn Tree>, | ||||
| 
 | ||||
|     pub(super) statehash_shortstatehash: Arc<dyn Tree>, | ||||
|     /// ShortStateHash = Count
 | ||||
|     /// StateId = ShortStateHash
 | ||||
|     pub(super) stateid_shorteventid: Arc<dyn Tree>, | ||||
|     pub(super) shortstatehash_statediff: Arc<dyn Tree>, // StateDiff = parent (or 0) + (shortstatekey+shorteventid++) + 0_u64 + (shortstatekey+shorteventid--)
 | ||||
| 
 | ||||
|     /// RoomId + EventId -> outlier PDU.
 | ||||
|  | @ -100,29 +98,30 @@ impl Rooms { | |||
|     /// Builds a StateMap by iterating over all keys that start
 | ||||
|     /// with state_hash, this gives the full state for the given state_hash.
 | ||||
|     pub fn state_full_ids(&self, shortstatehash: u64) -> Result<BTreeSet<EventId>> { | ||||
|         Ok(self | ||||
|             .stateid_shorteventid | ||||
|             .scan_prefix(shortstatehash.to_be_bytes().to_vec()) | ||||
|             .map(|(_, bytes)| { | ||||
|                 self.get_eventid_from_short(utils::u64_from_bytes(&bytes).unwrap()) | ||||
|                     .ok() | ||||
|             }) | ||||
|             .flatten() | ||||
|             .collect()) | ||||
|         let full_state = self | ||||
|             .load_shortstatehash_info(shortstatehash)? | ||||
|             .pop() | ||||
|             .expect("there is always one layer") | ||||
|             .1; | ||||
|         full_state | ||||
|             .into_iter() | ||||
|             .map(|compressed| self.parse_compressed_state_event(compressed)) | ||||
|             .collect() | ||||
|     } | ||||
| 
 | ||||
|     pub fn state_full( | ||||
|         &self, | ||||
|         shortstatehash: u64, | ||||
|     ) -> Result<HashMap<(EventType, String), Arc<PduEvent>>> { | ||||
|         let state = self | ||||
|             .stateid_shorteventid | ||||
|             .scan_prefix(shortstatehash.to_be_bytes().to_vec()) | ||||
|             .map(|(_, bytes)| { | ||||
|                 self.get_eventid_from_short(utils::u64_from_bytes(&bytes).unwrap()) | ||||
|                     .ok() | ||||
|             }) | ||||
|             .flatten() | ||||
|         let full_state = self | ||||
|             .load_shortstatehash_info(shortstatehash)? | ||||
|             .pop() | ||||
|             .expect("there is always one layer") | ||||
|             .1; | ||||
|         Ok(full_state | ||||
|             .into_iter() | ||||
|             .map(|compressed| self.parse_compressed_state_event(compressed)) | ||||
|             .filter_map(|r| r.ok()) | ||||
|             .map(|eventid| self.get_pdu(&eventid)) | ||||
|             .filter_map(|r| r.ok().flatten()) | ||||
|             .map(|pdu| { | ||||
|  | @ -138,9 +137,7 @@ impl Rooms { | |||
|                 )) | ||||
|             }) | ||||
|             .filter_map(|r| r.ok()) | ||||
|             .collect(); | ||||
| 
 | ||||
|         Ok(state) | ||||
|             .collect()) | ||||
|     } | ||||
| 
 | ||||
|     /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`).
 | ||||
|  | @ -151,27 +148,19 @@ impl Rooms { | |||
|         event_type: &EventType, | ||||
|         state_key: &str, | ||||
|     ) -> Result<Option<EventId>> { | ||||
|         let mut key = event_type.as_ref().as_bytes().to_vec(); | ||||
|         key.push(0xff); | ||||
|         key.extend_from_slice(&state_key.as_bytes()); | ||||
| 
 | ||||
|         let shortstatekey = self.statekey_shortstatekey.get(&key)?; | ||||
| 
 | ||||
|         if let Some(shortstatekey) = shortstatekey { | ||||
|             let mut stateid = shortstatehash.to_be_bytes().to_vec(); | ||||
|             stateid.extend_from_slice(&shortstatekey); | ||||
| 
 | ||||
|             Ok(self | ||||
|                 .stateid_shorteventid | ||||
|                 .get(&stateid)? | ||||
|                 .map(|bytes| { | ||||
|                     self.get_eventid_from_short(utils::u64_from_bytes(&bytes).unwrap()) | ||||
|                         .ok() | ||||
|                 }) | ||||
|                 .flatten()) | ||||
|         } else { | ||||
|             Ok(None) | ||||
|         } | ||||
|         let shortstatekey = match self.get_shortstatekey(event_type, state_key)? { | ||||
|             Some(s) => s, | ||||
|             None => return Ok(None), | ||||
|         }; | ||||
|         let full_state = self | ||||
|             .load_shortstatehash_info(shortstatehash)? | ||||
|             .pop() | ||||
|             .expect("there is always one layer") | ||||
|             .1; | ||||
|         Ok(full_state | ||||
|             .into_iter() | ||||
|             .find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes())) | ||||
|             .and_then(|compressed| self.parse_compressed_state_event(compressed).ok())) | ||||
|     } | ||||
| 
 | ||||
|     /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`).
 | ||||
|  | @ -260,8 +249,7 @@ impl Rooms { | |||
| 
 | ||||
|     /// Checks if a room exists.
 | ||||
|     pub fn exists(&self, room_id: &RoomId) -> Result<bool> { | ||||
|         let mut prefix = room_id.as_bytes().to_vec(); | ||||
|         prefix.push(0xff); | ||||
|         let prefix = self.get_shortroomid(room_id)?.to_be_bytes().to_vec(); | ||||
| 
 | ||||
|         // Look for PDUs in that room.
 | ||||
|         Ok(self | ||||
|  | @ -274,8 +262,7 @@ impl Rooms { | |||
| 
 | ||||
|     /// Checks if a room exists.
 | ||||
|     pub fn first_pdu_in_room(&self, room_id: &RoomId) -> Result<Option<Arc<PduEvent>>> { | ||||
|         let mut prefix = room_id.as_bytes().to_vec(); | ||||
|         prefix.push(0xff); | ||||
|         let prefix = self.get_shortroomid(room_id)?.to_be_bytes().to_vec(); | ||||
| 
 | ||||
|         // Look for PDUs in that room.
 | ||||
|         self.pduid_pdu | ||||
|  | @ -292,74 +279,78 @@ impl Rooms { | |||
| 
 | ||||
|     /// Force the creation of a new StateHash and insert it into the db.
 | ||||
|     ///
 | ||||
|     /// Whatever `state` is supplied to `force_state` __is__ the current room state snapshot.
 | ||||
|     /// Whatever `state` is supplied to `force_state` becomes the new current room state snapshot.
 | ||||
|     pub fn force_state( | ||||
|         &self, | ||||
|         room_id: &RoomId, | ||||
|         state: HashMap<(EventType, String), EventId>, | ||||
|         new_state: HashMap<(EventType, String), EventId>, | ||||
|         db: &Database, | ||||
|     ) -> Result<()> { | ||||
|         let previous_shortstatehash = self.current_shortstatehash(&room_id)?; | ||||
| 
 | ||||
|         let new_state_ids_compressed = new_state | ||||
|             .iter() | ||||
|             .filter_map(|((event_type, state_key), event_id)| { | ||||
|                 let shortstatekey = self | ||||
|                     .get_or_create_shortstatekey(event_type, state_key, &db.globals) | ||||
|                     .ok()?; | ||||
|                 Some( | ||||
|                     self.compress_state_event(shortstatekey, event_id, &db.globals) | ||||
|                         .ok()?, | ||||
|                 ) | ||||
|             }) | ||||
|             .collect::<HashSet<_>>(); | ||||
| 
 | ||||
|         let state_hash = self.calculate_hash( | ||||
|             &state | ||||
|             &new_state | ||||
|                 .values() | ||||
|                 .map(|event_id| event_id.as_bytes()) | ||||
|                 .collect::<Vec<_>>(), | ||||
|         ); | ||||
| 
 | ||||
|         let (shortstatehash, already_existed) = | ||||
|         let (new_shortstatehash, already_existed) = | ||||
|             self.get_or_create_shortstatehash(&state_hash, &db.globals)?; | ||||
| 
 | ||||
|         let new_state = if !already_existed { | ||||
|             let mut new_state = HashSet::new(); | ||||
|         if Some(new_shortstatehash) == previous_shortstatehash { | ||||
|             return Ok(()); | ||||
|         } | ||||
| 
 | ||||
|             let batch = state | ||||
|                 .iter() | ||||
|                 .filter_map(|((event_type, state_key), eventid)| { | ||||
|                     new_state.insert(eventid.clone()); | ||||
|         let states_parents = previous_shortstatehash | ||||
|             .map_or_else(|| Ok(Vec::new()), |p| self.load_shortstatehash_info(p))?; | ||||
| 
 | ||||
|                     let mut statekey = event_type.as_ref().as_bytes().to_vec(); | ||||
|                     statekey.push(0xff); | ||||
|                     statekey.extend_from_slice(&state_key.as_bytes()); | ||||
|         let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() | ||||
|         { | ||||
|             let statediffnew = new_state_ids_compressed | ||||
|                 .difference(&parent_stateinfo.1) | ||||
|                 .cloned() | ||||
|                 .collect::<HashSet<_>>(); | ||||
| 
 | ||||
|                     let shortstatekey = match self.statekey_shortstatekey.get(&statekey).ok()? { | ||||
|                         Some(shortstatekey) => shortstatekey.to_vec(), | ||||
|                         None => { | ||||
|                             let shortstatekey = db.globals.next_count().ok()?; | ||||
|                             self.statekey_shortstatekey | ||||
|                                 .insert(&statekey, &shortstatekey.to_be_bytes()) | ||||
|                                 .ok()?; | ||||
|                             shortstatekey.to_be_bytes().to_vec() | ||||
|                         } | ||||
|                     }; | ||||
|             let statediffremoved = parent_stateinfo | ||||
|                 .1 | ||||
|                 .difference(&new_state_ids_compressed) | ||||
|                 .cloned() | ||||
|                 .collect::<HashSet<_>>(); | ||||
| 
 | ||||
|                     let shorteventid = self | ||||
|                         .get_or_create_shorteventid(&eventid, &db.globals) | ||||
|                         .ok()?; | ||||
| 
 | ||||
|                     let mut state_id = shortstatehash.to_be_bytes().to_vec(); | ||||
|                     state_id.extend_from_slice(&shortstatekey); | ||||
| 
 | ||||
|                     Some((state_id, shorteventid.to_be_bytes().to_vec())) | ||||
|                 }) | ||||
|                 .collect::<Vec<_>>(); | ||||
| 
 | ||||
|             self.stateid_shorteventid | ||||
|                 .insert_batch(&mut batch.into_iter())?; | ||||
| 
 | ||||
|             new_state | ||||
|             (statediffnew, statediffremoved) | ||||
|         } else { | ||||
|             self.state_full_ids(shortstatehash)?.into_iter().collect() | ||||
|             (new_state_ids_compressed, HashSet::new()) | ||||
|         }; | ||||
| 
 | ||||
|         let old_state = self | ||||
|             .current_shortstatehash(&room_id)? | ||||
|             .map(|s| self.state_full_ids(s)) | ||||
|             .transpose()? | ||||
|             .map(|vec| vec.into_iter().collect::<HashSet<_>>()) | ||||
|             .unwrap_or_default(); | ||||
|         if !already_existed { | ||||
|             self.save_state_from_diff( | ||||
|                 new_shortstatehash, | ||||
|                 statediffnew.clone(), | ||||
|                 statediffremoved.clone(), | ||||
|                 2, // every state change is 2 event changes on average
 | ||||
|                 states_parents, | ||||
|             )?; | ||||
|         }; | ||||
| 
 | ||||
|         for event_id in new_state.difference(&old_state) { | ||||
|             if let Some(pdu) = self.get_pdu_json(event_id)? { | ||||
|         for event_id in statediffnew | ||||
|             .into_iter() | ||||
|             .filter_map(|new| self.parse_compressed_state_event(new).ok()) | ||||
|         { | ||||
|             if let Some(pdu) = self.get_pdu_json(&event_id)? { | ||||
|                 if pdu.get("type").and_then(|val| val.as_str()) == Some("m.room.member") { | ||||
|                     if let Ok(pdu) = serde_json::from_value::<PduEvent>( | ||||
|                         serde_json::to_value(&pdu).expect("CanonicalJsonObj is a valid JsonValue"), | ||||
|  | @ -392,7 +383,206 @@ impl Rooms { | |||
|         } | ||||
| 
 | ||||
|         self.roomid_shortstatehash | ||||
|             .insert(room_id.as_bytes(), &shortstatehash.to_be_bytes())?; | ||||
|             .insert(room_id.as_bytes(), &new_shortstatehash.to_be_bytes())?; | ||||
| 
 | ||||
|         Ok(()) | ||||
|     } | ||||
| 
 | ||||
|     /// Returns a stack with info on shortstatehash, full state, added diff and removed diff for the selected shortstatehash and each parent layer.
 | ||||
|     pub fn load_shortstatehash_info( | ||||
|         &self, | ||||
|         shortstatehash: u64, | ||||
|     ) -> Result< | ||||
|         Vec<( | ||||
|             u64,                           // sstatehash
 | ||||
|             HashSet<CompressedStateEvent>, // full state
 | ||||
|             HashSet<CompressedStateEvent>, // added
 | ||||
|             HashSet<CompressedStateEvent>, // removed
 | ||||
|         )>, | ||||
|     > { | ||||
|         let value = self | ||||
|             .shortstatehash_statediff | ||||
|             .get(&shortstatehash.to_be_bytes())? | ||||
|             .ok_or_else(|| Error::bad_database("State hash does not exist"))?; | ||||
|         let parent = | ||||
|             utils::u64_from_bytes(&value[0..size_of::<u64>()]).expect("bytes have right length"); | ||||
| 
 | ||||
|         let mut add_mode = true; | ||||
|         let mut added = HashSet::new(); | ||||
|         let mut removed = HashSet::new(); | ||||
| 
 | ||||
|         let mut i = size_of::<u64>(); | ||||
|         while let Some(v) = value.get(i..i + 2 * size_of::<u64>()) { | ||||
|             if add_mode && v.starts_with(&0_u64.to_be_bytes()) { | ||||
|                 add_mode = false; | ||||
|                 i += size_of::<u64>(); | ||||
|                 continue; | ||||
|             } | ||||
|             if add_mode { | ||||
|                 added.insert(v.try_into().expect("we checked the size above")); | ||||
|             } else { | ||||
|                 removed.insert(v.try_into().expect("we checked the size above")); | ||||
|             } | ||||
|             i += 2 * size_of::<u64>(); | ||||
|         } | ||||
| 
 | ||||
|         if parent != 0_u64 { | ||||
|             let mut response = self.load_shortstatehash_info(parent)?; | ||||
|             let mut state = response.last().unwrap().1.clone(); | ||||
|             state.extend(added.iter().cloned()); | ||||
|             for r in &removed { | ||||
|                 state.remove(r); | ||||
|             } | ||||
| 
 | ||||
|             response.push((shortstatehash, state, added, removed)); | ||||
| 
 | ||||
|             Ok(response) | ||||
|         } else { | ||||
|             let mut response = Vec::new(); | ||||
|             response.push((shortstatehash, added.clone(), added, removed)); | ||||
|             Ok(response) | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     pub fn compress_state_event( | ||||
|         &self, | ||||
|         shortstatekey: u64, | ||||
|         event_id: &EventId, | ||||
|         globals: &super::globals::Globals, | ||||
|     ) -> Result<CompressedStateEvent> { | ||||
|         let mut v = shortstatekey.to_be_bytes().to_vec(); | ||||
|         v.extend_from_slice( | ||||
|             &self | ||||
|                 .get_or_create_shorteventid(event_id, globals)? | ||||
|                 .to_be_bytes(), | ||||
|         ); | ||||
|         Ok(v.try_into().expect("we checked the size above")) | ||||
|     } | ||||
| 
 | ||||
|     pub fn parse_compressed_state_event( | ||||
|         &self, | ||||
|         compressed_event: CompressedStateEvent, | ||||
|     ) -> Result<EventId> { | ||||
|         self.get_eventid_from_short( | ||||
|             utils::u64_from_bytes(&compressed_event[size_of::<u64>()..]) | ||||
|                 .expect("bytes have right length"), | ||||
|         ) | ||||
|     } | ||||
| 
 | ||||
|     /// Creates a new shortstatehash that often is just a diff to an already existing
 | ||||
|     /// shortstatehash and therefore very efficient.
 | ||||
|     ///
 | ||||
|     /// There are multiple layers of diffs. The bottom layer 0 always contains the full state. Layer
 | ||||
|     /// 1 contains diffs to states of layer 0, layer 2 diffs to layer 1 and so on. If layer n > 0
 | ||||
|     /// grows too big, it will be combined with layer n-1 to create a new diff on layer n-1 that's
 | ||||
|     /// based on layer n-2. If that layer is also too big, it will recursively fix above layers too.
 | ||||
|     ///
 | ||||
|     /// * `shortstatehash` - Shortstatehash of this state
 | ||||
|     /// * `statediffnew` - Added to base. Each vec is shortstatekey+shorteventid
 | ||||
|     /// * `statediffremoved` - Removed from base. Each vec is shortstatekey+shorteventid
 | ||||
|     /// * `diff_to_sibling` - Approximately how much the diff grows each time for this layer
 | ||||
|     /// * `parent_states` - A stack with info on shortstatehash, full state, added diff and removed diff for each parent layer
 | ||||
|     pub fn save_state_from_diff( | ||||
|         &self, | ||||
|         shortstatehash: u64, | ||||
|         statediffnew: HashSet<CompressedStateEvent>, | ||||
|         statediffremoved: HashSet<CompressedStateEvent>, | ||||
|         diff_to_sibling: usize, | ||||
|         mut parent_states: Vec<( | ||||
|             u64,                           // sstatehash
 | ||||
|             HashSet<CompressedStateEvent>, // full state
 | ||||
|             HashSet<CompressedStateEvent>, // added
 | ||||
|             HashSet<CompressedStateEvent>, // removed
 | ||||
|         )>, | ||||
|     ) -> Result<()> { | ||||
|         let diffsum = statediffnew.len() + statediffremoved.len(); | ||||
| 
 | ||||
|         if parent_states.len() > 3 { | ||||
|             // Number of layers
 | ||||
|             // To many layers, we have to go deeper
 | ||||
|             let parent = parent_states.pop().unwrap(); | ||||
| 
 | ||||
|             let mut parent_new = parent.2; | ||||
|             let mut parent_removed = parent.3; | ||||
| 
 | ||||
|             for removed in statediffremoved { | ||||
|                 if !parent_new.remove(&removed) { | ||||
|                     parent_removed.insert(removed); | ||||
|                 } | ||||
|             } | ||||
|             parent_new.extend(statediffnew); | ||||
| 
 | ||||
|             self.save_state_from_diff( | ||||
|                 shortstatehash, | ||||
|                 parent_new, | ||||
|                 parent_removed, | ||||
|                 diffsum, | ||||
|                 parent_states, | ||||
|             )?; | ||||
| 
 | ||||
|             return Ok(()); | ||||
|         } | ||||
| 
 | ||||
|         if parent_states.len() == 0 { | ||||
|             // There is no parent layer, create a new state
 | ||||
|             let mut value = 0_u64.to_be_bytes().to_vec(); // 0 means no parent
 | ||||
|             for new in &statediffnew { | ||||
|                 value.extend_from_slice(&new[..]); | ||||
|             } | ||||
| 
 | ||||
|             if !statediffremoved.is_empty() { | ||||
|                 warn!("Tried to create new state with removals"); | ||||
|             } | ||||
| 
 | ||||
|             self.shortstatehash_statediff | ||||
|                 .insert(&shortstatehash.to_be_bytes(), &value)?; | ||||
| 
 | ||||
|             return Ok(()); | ||||
|         }; | ||||
| 
 | ||||
|         // Else we have two options.
 | ||||
|         // 1. We add the current diff on top of the parent layer.
 | ||||
|         // 2. We replace a layer above
 | ||||
| 
 | ||||
|         let parent = parent_states.pop().unwrap(); | ||||
|         let parent_diff = parent.2.len() + parent.3.len(); | ||||
| 
 | ||||
|         if diffsum * diffsum >= 2 * diff_to_sibling * parent_diff { | ||||
|             // Diff too big, we replace above layer(s)
 | ||||
|             let mut parent_new = parent.2; | ||||
|             let mut parent_removed = parent.3; | ||||
| 
 | ||||
|             for removed in statediffremoved { | ||||
|                 if !parent_new.remove(&removed) { | ||||
|                     parent_removed.insert(removed); | ||||
|                 } | ||||
|             } | ||||
| 
 | ||||
|             parent_new.extend(statediffnew); | ||||
|             self.save_state_from_diff( | ||||
|                 shortstatehash, | ||||
|                 parent_new, | ||||
|                 parent_removed, | ||||
|                 diffsum, | ||||
|                 parent_states, | ||||
|             )?; | ||||
|         } else { | ||||
|             // Diff small enough, we add diff as layer on top of parent
 | ||||
|             let mut value = parent.0.to_be_bytes().to_vec(); | ||||
|             for new in &statediffnew { | ||||
|                 value.extend_from_slice(&new[..]); | ||||
|             } | ||||
| 
 | ||||
|             if !statediffremoved.is_empty() { | ||||
|                 value.extend_from_slice(&0_u64.to_be_bytes()); | ||||
|                 for removed in &statediffremoved { | ||||
|                     value.extend_from_slice(&removed[..]); | ||||
|                 } | ||||
|             } | ||||
| 
 | ||||
|             self.shortstatehash_statediff | ||||
|                 .insert(&shortstatehash.to_be_bytes(), &value)?; | ||||
|         } | ||||
| 
 | ||||
|         Ok(()) | ||||
|     } | ||||
|  | @ -418,7 +608,6 @@ impl Rooms { | |||
|         }) | ||||
|     } | ||||
| 
 | ||||
|     /// Returns (shortstatehash, already_existed)
 | ||||
|     pub fn get_or_create_shorteventid( | ||||
|         &self, | ||||
|         event_id: &EventId, | ||||
|  | @ -438,6 +627,71 @@ impl Rooms { | |||
|         }) | ||||
|     } | ||||
| 
 | ||||
|     pub fn get_shortroomid(&self, room_id: &RoomId) -> Result<u64> { | ||||
|         let bytes = self | ||||
|             .roomid_shortroomid | ||||
|             .get(&room_id.as_bytes())? | ||||
|             .expect("every room has a shortroomid"); | ||||
|         utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid shortroomid in db.")) | ||||
|     } | ||||
| 
 | ||||
|     pub fn get_shortstatekey( | ||||
|         &self, | ||||
|         event_type: &EventType, | ||||
|         state_key: &str, | ||||
|     ) -> Result<Option<u64>> { | ||||
|         let mut statekey = event_type.as_ref().as_bytes().to_vec(); | ||||
|         statekey.push(0xff); | ||||
|         statekey.extend_from_slice(&state_key.as_bytes()); | ||||
| 
 | ||||
|         self.statekey_shortstatekey | ||||
|             .get(&statekey)? | ||||
|             .map(|shortstatekey| { | ||||
|                 utils::u64_from_bytes(&shortstatekey) | ||||
|                     .map_err(|_| Error::bad_database("Invalid shortstatekey in db.")) | ||||
|             }) | ||||
|             .transpose() | ||||
|     } | ||||
| 
 | ||||
|     pub fn get_or_create_shortroomid( | ||||
|         &self, | ||||
|         room_id: &RoomId, | ||||
|         globals: &super::globals::Globals, | ||||
|     ) -> Result<u64> { | ||||
|         Ok(match self.roomid_shortroomid.get(&room_id.as_bytes())? { | ||||
|             Some(short) => utils::u64_from_bytes(&short) | ||||
|                 .map_err(|_| Error::bad_database("Invalid shortroomid in db."))?, | ||||
|             None => { | ||||
|                 let short = globals.next_count()?; | ||||
|                 self.roomid_shortroomid | ||||
|                     .insert(&room_id.as_bytes(), &short.to_be_bytes())?; | ||||
|                 short | ||||
|             } | ||||
|         }) | ||||
|     } | ||||
| 
 | ||||
|     pub fn get_or_create_shortstatekey( | ||||
|         &self, | ||||
|         event_type: &EventType, | ||||
|         state_key: &str, | ||||
|         globals: &super::globals::Globals, | ||||
|     ) -> Result<u64> { | ||||
|         let mut statekey = event_type.as_ref().as_bytes().to_vec(); | ||||
|         statekey.push(0xff); | ||||
|         statekey.extend_from_slice(&state_key.as_bytes()); | ||||
| 
 | ||||
|         Ok(match self.statekey_shortstatekey.get(&statekey)? { | ||||
|             Some(shortstatekey) => utils::u64_from_bytes(&shortstatekey) | ||||
|                 .map_err(|_| Error::bad_database("Invalid shortstatekey in db."))?, | ||||
|             None => { | ||||
|                 let shortstatekey = globals.next_count()?; | ||||
|                 self.statekey_shortstatekey | ||||
|                     .insert(&statekey, &shortstatekey.to_be_bytes())?; | ||||
|                 shortstatekey | ||||
|             } | ||||
|         }) | ||||
|     } | ||||
| 
 | ||||
|     pub fn get_eventid_from_short(&self, shorteventid: u64) -> Result<EventId> { | ||||
|         if let Some(id) = self | ||||
|             .shorteventid_cache | ||||
|  | @ -514,7 +768,7 @@ impl Rooms { | |||
|     #[tracing::instrument(skip(self))] | ||||
|     pub fn pdu_count(&self, pdu_id: &[u8]) -> Result<u64> { | ||||
|         Ok( | ||||
|             utils::u64_from_bytes(&pdu_id[pdu_id.len() - mem::size_of::<u64>()..pdu_id.len()]) | ||||
|             utils::u64_from_bytes(&pdu_id[pdu_id.len() - size_of::<u64>()..]) | ||||
|                 .map_err(|_| Error::bad_database("PDU has invalid count bytes."))?, | ||||
|         ) | ||||
|     } | ||||
|  | @ -527,8 +781,7 @@ impl Rooms { | |||
|     } | ||||
| 
 | ||||
|     pub fn latest_pdu_count(&self, room_id: &RoomId) -> Result<u64> { | ||||
|         let mut prefix = room_id.as_bytes().to_vec(); | ||||
|         prefix.push(0xff); | ||||
|         let prefix = self.get_shortroomid(room_id)?.to_be_bytes().to_vec(); | ||||
| 
 | ||||
|         let mut last_possible_key = prefix.clone(); | ||||
|         last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); | ||||
|  | @ -758,6 +1011,8 @@ impl Rooms { | |||
|     ///
 | ||||
|     /// By this point the incoming event should be fully authenticated, no auth happens
 | ||||
|     /// in `append_pdu`.
 | ||||
|     ///
 | ||||
|     /// Returns pdu id
 | ||||
|     #[tracing::instrument(skip(self, pdu, pdu_json, leaves, db))] | ||||
|     pub fn append_pdu( | ||||
|         &self, | ||||
|  | @ -766,7 +1021,8 @@ impl Rooms { | |||
|         leaves: &[EventId], | ||||
|         db: &Database, | ||||
|     ) -> Result<Vec<u8>> { | ||||
|         // returns pdu id
 | ||||
|         let shortroomid = self.get_shortroomid(&pdu.room_id)?; | ||||
| 
 | ||||
|         // Make unsigned fields correct. This is not properly documented in the spec, but state
 | ||||
|         // events need to have previous content in the unsigned field, so clients can easily
 | ||||
|         // interpret things like membership changes
 | ||||
|  | @ -821,8 +1077,7 @@ impl Rooms { | |||
|         self.reset_notification_counts(&pdu.sender, &pdu.room_id)?; | ||||
| 
 | ||||
|         let count2 = db.globals.next_count()?; | ||||
|         let mut pdu_id = pdu.room_id.as_bytes().to_vec(); | ||||
|         pdu_id.push(0xff); | ||||
|         let mut pdu_id = shortroomid.to_be_bytes().to_vec(); | ||||
|         pdu_id.extend_from_slice(&count2.to_be_bytes()); | ||||
| 
 | ||||
|         // There's a brief moment of time here where the count is updated but the pdu does not
 | ||||
|  | @ -968,8 +1223,7 @@ impl Rooms { | |||
|                         .filter(|word| word.len() <= 50) | ||||
|                         .map(str::to_lowercase) | ||||
|                         .map(|word| { | ||||
|                             let mut key = pdu.room_id.as_bytes().to_vec(); | ||||
|                             key.push(0xff); | ||||
|                             let mut key = shortroomid.to_be_bytes().to_vec(); | ||||
|                             key.extend_from_slice(word.as_bytes()); | ||||
|                             key.push(0xff); | ||||
|                             key.extend_from_slice(&pdu_id); | ||||
|  | @ -1152,11 +1406,27 @@ impl Rooms { | |||
|     pub fn set_event_state( | ||||
|         &self, | ||||
|         event_id: &EventId, | ||||
|         room_id: &RoomId, | ||||
|         state: &StateMap<Arc<PduEvent>>, | ||||
|         globals: &super::globals::Globals, | ||||
|     ) -> Result<()> { | ||||
|         let shorteventid = self.get_or_create_shorteventid(&event_id, globals)?; | ||||
| 
 | ||||
|         let previous_shortstatehash = self.current_shortstatehash(&room_id)?; | ||||
| 
 | ||||
|         let state_ids_compressed = state | ||||
|             .iter() | ||||
|             .filter_map(|((event_type, state_key), pdu)| { | ||||
|                 let shortstatekey = self | ||||
|                     .get_or_create_shortstatekey(event_type, state_key, globals) | ||||
|                     .ok()?; | ||||
|                 Some( | ||||
|                     self.compress_state_event(shortstatekey, &pdu.event_id, globals) | ||||
|                         .ok()?, | ||||
|                 ) | ||||
|             }) | ||||
|             .collect::<HashSet<_>>(); | ||||
| 
 | ||||
|         let state_hash = self.calculate_hash( | ||||
|             &state | ||||
|                 .values() | ||||
|  | @ -1168,37 +1438,33 @@ impl Rooms { | |||
|             self.get_or_create_shortstatehash(&state_hash, globals)?; | ||||
| 
 | ||||
|         if !already_existed { | ||||
|             let batch = state | ||||
|                 .iter() | ||||
|                 .filter_map(|((event_type, state_key), pdu)| { | ||||
|                     let mut statekey = event_type.as_ref().as_bytes().to_vec(); | ||||
|                     statekey.push(0xff); | ||||
|                     statekey.extend_from_slice(&state_key.as_bytes()); | ||||
|             let states_parents = previous_shortstatehash | ||||
|                 .map_or_else(|| Ok(Vec::new()), |p| self.load_shortstatehash_info(p))?; | ||||
| 
 | ||||
|                     let shortstatekey = match self.statekey_shortstatekey.get(&statekey).ok()? { | ||||
|                         Some(shortstatekey) => shortstatekey.to_vec(), | ||||
|                         None => { | ||||
|                             let shortstatekey = globals.next_count().ok()?; | ||||
|                             self.statekey_shortstatekey | ||||
|                                 .insert(&statekey, &shortstatekey.to_be_bytes()) | ||||
|                                 .ok()?; | ||||
|                             shortstatekey.to_be_bytes().to_vec() | ||||
|                         } | ||||
|                     }; | ||||
|             let (statediffnew, statediffremoved) = | ||||
|                 if let Some(parent_stateinfo) = states_parents.last() { | ||||
|                     let statediffnew = state_ids_compressed | ||||
|                         .difference(&parent_stateinfo.1) | ||||
|                         .cloned() | ||||
|                         .collect::<HashSet<_>>(); | ||||
| 
 | ||||
|                     let shorteventid = self | ||||
|                         .get_or_create_shorteventid(&pdu.event_id, globals) | ||||
|                         .ok()?; | ||||
|                     let statediffremoved = parent_stateinfo | ||||
|                         .1 | ||||
|                         .difference(&state_ids_compressed) | ||||
|                         .cloned() | ||||
|                         .collect::<HashSet<_>>(); | ||||
| 
 | ||||
|                     let mut state_id = shortstatehash.to_be_bytes().to_vec(); | ||||
|                     state_id.extend_from_slice(&shortstatekey); | ||||
| 
 | ||||
|                     Some((state_id, shorteventid.to_be_bytes().to_vec())) | ||||
|                 }) | ||||
|                 .collect::<Vec<_>>(); | ||||
| 
 | ||||
|             self.stateid_shorteventid | ||||
|                 .insert_batch(&mut batch.into_iter())?; | ||||
|                     (statediffnew, statediffremoved) | ||||
|                 } else { | ||||
|                     (state_ids_compressed, HashSet::new()) | ||||
|                 }; | ||||
|             self.save_state_from_diff( | ||||
|                 shortstatehash, | ||||
|                 statediffnew.clone(), | ||||
|                 statediffremoved.clone(), | ||||
|                 1_000_000, // high number because no state will be based on this one
 | ||||
|                 states_parents, | ||||
|             )?; | ||||
|         } | ||||
| 
 | ||||
|         self.shorteventid_shortstatehash | ||||
|  | @ -1219,82 +1485,52 @@ impl Rooms { | |||
|     ) -> Result<u64> { | ||||
|         let shorteventid = self.get_or_create_shorteventid(&new_pdu.event_id, globals)?; | ||||
| 
 | ||||
|         let old_state = if let Some(old_shortstatehash) = | ||||
|             self.roomid_shortstatehash.get(new_pdu.room_id.as_bytes())? | ||||
|         { | ||||
|             // Store state for event. The state does not include the event itself.
 | ||||
|             // Instead it's the state before the pdu, so the room's old state.
 | ||||
|         let previous_shortstatehash = self.current_shortstatehash(&new_pdu.room_id)?; | ||||
| 
 | ||||
|         if let Some(p) = previous_shortstatehash { | ||||
|             self.shorteventid_shortstatehash | ||||
|                 .insert(&shorteventid.to_be_bytes(), &old_shortstatehash)?; | ||||
| 
 | ||||
|             if new_pdu.state_key.is_none() { | ||||
|                 return utils::u64_from_bytes(&old_shortstatehash).map_err(|_| { | ||||
|                     Error::bad_database("Invalid shortstatehash in roomid_shortstatehash.") | ||||
|                 }); | ||||
|             } | ||||
| 
 | ||||
|             self.stateid_shorteventid | ||||
|                 .scan_prefix(old_shortstatehash.clone()) | ||||
|                 // Chop the old_shortstatehash out leaving behind the short state key
 | ||||
|                 .map(|(k, v)| (k[old_shortstatehash.len()..].to_vec(), v)) | ||||
|                 .collect::<HashMap<Vec<u8>, Vec<u8>>>() | ||||
|         } else { | ||||
|             HashMap::new() | ||||
|         }; | ||||
|                 .insert(&shorteventid.to_be_bytes(), &p.to_be_bytes())?; | ||||
|         } | ||||
| 
 | ||||
|         if let Some(state_key) = &new_pdu.state_key { | ||||
|             let mut new_state: HashMap<Vec<u8>, Vec<u8>> = old_state; | ||||
|             let states_parents = previous_shortstatehash | ||||
|                 .map_or_else(|| Ok(Vec::new()), |p| self.load_shortstatehash_info(p))?; | ||||
| 
 | ||||
|             let mut new_state_key = new_pdu.kind.as_ref().as_bytes().to_vec(); | ||||
|             new_state_key.push(0xff); | ||||
|             new_state_key.extend_from_slice(state_key.as_bytes()); | ||||
|             let shortstatekey = | ||||
|                 self.get_or_create_shortstatekey(&new_pdu.kind, &state_key, globals)?; | ||||
| 
 | ||||
|             let shortstatekey = match self.statekey_shortstatekey.get(&new_state_key)? { | ||||
|                 Some(shortstatekey) => shortstatekey.to_vec(), | ||||
|                 None => { | ||||
|                     let shortstatekey = globals.next_count()?; | ||||
|                     self.statekey_shortstatekey | ||||
|                         .insert(&new_state_key, &shortstatekey.to_be_bytes())?; | ||||
|                     shortstatekey.to_be_bytes().to_vec() | ||||
|                 } | ||||
|             }; | ||||
|             let replaces = states_parents | ||||
|                 .last() | ||||
|                 .map(|info| { | ||||
|                     info.1 | ||||
|                         .iter() | ||||
|                         .find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes())) | ||||
|                 }) | ||||
|                 .unwrap_or_default(); | ||||
| 
 | ||||
|             new_state.insert(shortstatekey, shorteventid.to_be_bytes().to_vec()); | ||||
|             // TODO: statehash with deterministic inputs
 | ||||
|             let shortstatehash = globals.next_count()?; | ||||
| 
 | ||||
|             let new_state_hash = self.calculate_hash( | ||||
|                 &new_state | ||||
|                     .values() | ||||
|                     .map(|event_id| &**event_id) | ||||
|                     .collect::<Vec<_>>(), | ||||
|             ); | ||||
|             let mut statediffnew = HashSet::new(); | ||||
|             let new = self.compress_state_event(shortstatekey, &new_pdu.event_id, globals)?; | ||||
|             statediffnew.insert(new); | ||||
| 
 | ||||
|             let shortstatehash = match self.statehash_shortstatehash.get(&new_state_hash)? { | ||||
|                 Some(shortstatehash) => { | ||||
|                     warn!("state hash already existed?!"); | ||||
|                     utils::u64_from_bytes(&shortstatehash) | ||||
|                         .map_err(|_| Error::bad_database("PDU has invalid count bytes."))? | ||||
|                 } | ||||
|                 None => { | ||||
|                     let shortstatehash = globals.next_count()?; | ||||
|                     self.statehash_shortstatehash | ||||
|                         .insert(&new_state_hash, &shortstatehash.to_be_bytes())?; | ||||
|                     shortstatehash | ||||
|                 } | ||||
|             }; | ||||
|             let mut statediffremoved = HashSet::new(); | ||||
|             if let Some(replaces) = replaces { | ||||
|                 statediffremoved.insert(replaces.clone()); | ||||
|             } | ||||
| 
 | ||||
|             let mut batch = new_state.into_iter().map(|(shortstatekey, shorteventid)| { | ||||
|                 let mut state_id = shortstatehash.to_be_bytes().to_vec(); | ||||
|                 state_id.extend_from_slice(&shortstatekey); | ||||
|                 (state_id, shorteventid) | ||||
|             }); | ||||
| 
 | ||||
|             self.stateid_shorteventid.insert_batch(&mut batch)?; | ||||
|             self.save_state_from_diff( | ||||
|                 shortstatehash, | ||||
|                 statediffnew, | ||||
|                 statediffremoved, | ||||
|                 2, | ||||
|                 states_parents, | ||||
|             )?; | ||||
| 
 | ||||
|             Ok(shortstatehash) | ||||
|         } else { | ||||
|             Err(Error::bad_database( | ||||
|                 "Tried to insert non-state event into room without a state.", | ||||
|             )) | ||||
|             Ok(previous_shortstatehash.expect("first event in room must be a state event")) | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|  | @ -1597,7 +1833,7 @@ impl Rooms { | |||
|         &'a self, | ||||
|         user_id: &UserId, | ||||
|         room_id: &RoomId, | ||||
|     ) -> impl Iterator<Item = Result<(Vec<u8>, PduEvent)>> + 'a { | ||||
|     ) -> Result<impl Iterator<Item = Result<(Vec<u8>, PduEvent)>> + 'a> { | ||||
|         self.pdus_since(user_id, room_id, 0) | ||||
|     } | ||||
| 
 | ||||
|  | @ -1609,16 +1845,17 @@ impl Rooms { | |||
|         user_id: &UserId, | ||||
|         room_id: &RoomId, | ||||
|         since: u64, | ||||
|     ) -> impl Iterator<Item = Result<(Vec<u8>, PduEvent)>> + 'a { | ||||
|         let mut prefix = room_id.as_bytes().to_vec(); | ||||
|         prefix.push(0xff); | ||||
|     ) -> Result<impl Iterator<Item = Result<(Vec<u8>, PduEvent)>> + 'a> { | ||||
|         let prefix = self.get_shortroomid(room_id)?.to_be_bytes().to_vec(); | ||||
| 
 | ||||
|         // Skip the first pdu if it's exactly at since, because we sent that last time
 | ||||
|         let mut first_pdu_id = prefix.clone(); | ||||
|         first_pdu_id.extend_from_slice(&(since + 1).to_be_bytes()); | ||||
| 
 | ||||
|         let user_id = user_id.clone(); | ||||
|         self.pduid_pdu | ||||
| 
 | ||||
|         Ok(self | ||||
|             .pduid_pdu | ||||
|             .iter_from(&first_pdu_id, false) | ||||
|             .take_while(move |(k, _)| k.starts_with(&prefix)) | ||||
|             .map(move |(pdu_id, v)| { | ||||
|  | @ -1628,7 +1865,7 @@ impl Rooms { | |||
|                     pdu.unsigned.remove("transaction_id"); | ||||
|                 } | ||||
|                 Ok((pdu_id, pdu)) | ||||
|             }) | ||||
|             })) | ||||
|     } | ||||
| 
 | ||||
|     /// Returns an iterator over all events and their tokens in a room that happened before the
 | ||||
|  | @ -1639,10 +1876,9 @@ impl Rooms { | |||
|         user_id: &UserId, | ||||
|         room_id: &RoomId, | ||||
|         until: u64, | ||||
|     ) -> impl Iterator<Item = Result<(Vec<u8>, PduEvent)>> + 'a { | ||||
|     ) -> Result<impl Iterator<Item = Result<(Vec<u8>, PduEvent)>> + 'a> { | ||||
|         // Create the first part of the full pdu id
 | ||||
|         let mut prefix = room_id.as_bytes().to_vec(); | ||||
|         prefix.push(0xff); | ||||
|         let prefix = self.get_shortroomid(room_id)?.to_be_bytes().to_vec(); | ||||
| 
 | ||||
|         let mut current = prefix.clone(); | ||||
|         current.extend_from_slice(&(until.saturating_sub(1)).to_be_bytes()); // -1 because we don't want event at `until`
 | ||||
|  | @ -1650,7 +1886,9 @@ impl Rooms { | |||
|         let current: &[u8] = ¤t; | ||||
| 
 | ||||
|         let user_id = user_id.clone(); | ||||
|         self.pduid_pdu | ||||
| 
 | ||||
|         Ok(self | ||||
|             .pduid_pdu | ||||
|             .iter_from(current, true) | ||||
|             .take_while(move |(k, _)| k.starts_with(&prefix)) | ||||
|             .map(move |(pdu_id, v)| { | ||||
|  | @ -1660,7 +1898,7 @@ impl Rooms { | |||
|                     pdu.unsigned.remove("transaction_id"); | ||||
|                 } | ||||
|                 Ok((pdu_id, pdu)) | ||||
|             }) | ||||
|             })) | ||||
|     } | ||||
| 
 | ||||
|     /// Returns an iterator over all events and their token in a room that happened after the event
 | ||||
|  | @ -1671,10 +1909,9 @@ impl Rooms { | |||
|         user_id: &UserId, | ||||
|         room_id: &RoomId, | ||||
|         from: u64, | ||||
|     ) -> impl Iterator<Item = Result<(Vec<u8>, PduEvent)>> + 'a { | ||||
|     ) -> Result<impl Iterator<Item = Result<(Vec<u8>, PduEvent)>> + 'a> { | ||||
|         // Create the first part of the full pdu id
 | ||||
|         let mut prefix = room_id.as_bytes().to_vec(); | ||||
|         prefix.push(0xff); | ||||
|         let prefix = self.get_shortroomid(room_id)?.to_be_bytes().to_vec(); | ||||
| 
 | ||||
|         let mut current = prefix.clone(); | ||||
|         current.extend_from_slice(&(from + 1).to_be_bytes()); // +1 so we don't send the base event
 | ||||
|  | @ -1682,7 +1919,9 @@ impl Rooms { | |||
|         let current: &[u8] = ¤t; | ||||
| 
 | ||||
|         let user_id = user_id.clone(); | ||||
|         self.pduid_pdu | ||||
| 
 | ||||
|         Ok(self | ||||
|             .pduid_pdu | ||||
|             .iter_from(current, false) | ||||
|             .take_while(move |(k, _)| k.starts_with(&prefix)) | ||||
|             .map(move |(pdu_id, v)| { | ||||
|  | @ -1692,7 +1931,7 @@ impl Rooms { | |||
|                     pdu.unsigned.remove("transaction_id"); | ||||
|                 } | ||||
|                 Ok((pdu_id, pdu)) | ||||
|             }) | ||||
|             })) | ||||
|     } | ||||
| 
 | ||||
|     /// Replace a PDU with the redacted form.
 | ||||
|  | @ -2223,8 +2462,8 @@ impl Rooms { | |||
|         room_id: &RoomId, | ||||
|         search_string: &str, | ||||
|     ) -> Result<(impl Iterator<Item = Vec<u8>> + 'a, Vec<String>)> { | ||||
|         let mut prefix = room_id.as_bytes().to_vec(); | ||||
|         prefix.push(0xff); | ||||
|         let prefix = self.get_shortroomid(room_id)?.to_be_bytes().to_vec(); | ||||
|         let prefix_clone = prefix.clone(); | ||||
| 
 | ||||
|         let words = search_string | ||||
|             .split_terminator(|c: char| !c.is_alphanumeric()) | ||||
|  | @ -2243,16 +2482,7 @@ impl Rooms { | |||
|                 .iter_from(&last_possible_id, true) // Newest pdus first
 | ||||
|                 .take_while(move |(k, _)| k.starts_with(&prefix2)) | ||||
|                 .map(|(key, _)| { | ||||
|                     let pduid_index = key | ||||
|                         .iter() | ||||
|                         .enumerate() | ||||
|                         .filter(|(_, &b)| b == 0xff) | ||||
|                         .nth(1) | ||||
|                         .ok_or_else(|| Error::bad_database("Invalid tokenid in db."))? | ||||
|                         .0 | ||||
|                         + 1; // +1 because the pdu id starts AFTER the separator
 | ||||
| 
 | ||||
|                     let pdu_id = key[pduid_index..].to_vec(); | ||||
|                     let pdu_id = key[key.len() - size_of::<u64>()..].to_vec(); | ||||
| 
 | ||||
|                     Ok::<_, Error>(pdu_id) | ||||
|                 }) | ||||
|  | @ -2264,7 +2494,12 @@ impl Rooms { | |||
|                 // We compare b with a because we reversed the iterator earlier
 | ||||
|                 b.cmp(a) | ||||
|             }) | ||||
|             .unwrap(), | ||||
|             .unwrap() | ||||
|             .map(move |id| { | ||||
|                 let mut pduid = prefix_clone.clone(); | ||||
|                 pduid.extend_from_slice(&id); | ||||
|                 pduid | ||||
|             }), | ||||
|             words, | ||||
|         )) | ||||
|     } | ||||
|  |  | |||
|  | @ -1704,7 +1704,7 @@ fn append_incoming_pdu( | |||
|     // We append to state before appending the pdu, so we don't have a moment in time with the
 | ||||
|     // pdu without it's state. This is okay because append_pdu can't fail.
 | ||||
|     db.rooms | ||||
|         .set_event_state(&pdu.event_id, state, &db.globals)?; | ||||
|         .set_event_state(&pdu.event_id, &pdu.room_id, state, &db.globals)?; | ||||
| 
 | ||||
|     let pdu_id = db.rooms.append_pdu( | ||||
|         pdu, | ||||
|  |  | |||
		Loading…
	
		Reference in a new issue