improvement: device list works better
The only situation that isn't working yet is sending `left` events for users when the sender leaves the room
This commit is contained in:
		
							parent
							
								
									f23fb32e95
								
							
						
					
					
						commit
						4323cf5fec
					
				
					 5 changed files with 246 additions and 151 deletions
				
			
		|  | @ -19,16 +19,9 @@ pub fn search_events_route( | ||||||
|     let sender_id = body.sender_id.as_ref().expect("user is authenticated"); |     let sender_id = body.sender_id.as_ref().expect("user is authenticated"); | ||||||
| 
 | 
 | ||||||
|     let search_criteria = body.search_categories.room_events.as_ref().unwrap(); |     let search_criteria = body.search_categories.room_events.as_ref().unwrap(); | ||||||
|     let filter = search_criteria |     let filter = search_criteria.filter.as_ref().unwrap(); | ||||||
|         .filter |  | ||||||
|         .as_ref() |  | ||||||
|         .unwrap(); |  | ||||||
| 
 | 
 | ||||||
|     let room_id = filter.rooms |     let room_id = filter.rooms.as_ref().unwrap().first().unwrap(); | ||||||
|         .as_ref() |  | ||||||
|         .unwrap() |  | ||||||
|         .first() |  | ||||||
|         .unwrap(); |  | ||||||
| 
 | 
 | ||||||
|     let limit = filter.limit.map_or(10, |l| u64::from(l) as usize); |     let limit = filter.limit.map_or(10, |l| u64::from(l) as usize); | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -2,14 +2,15 @@ use super::State; | ||||||
| use crate::{ConduitResult, Database, Error, Ruma}; | use crate::{ConduitResult, Database, Error, Ruma}; | ||||||
| use ruma::{ | use ruma::{ | ||||||
|     api::client::r0::sync::sync_events, |     api::client::r0::sync::sync_events, | ||||||
|     events::{AnySyncEphemeralRoomEvent, EventType}, |     events::{room::member::MembershipState, AnySyncEphemeralRoomEvent, EventType}, | ||||||
|     Raw, |     Raw, RoomId, UserId, | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| #[cfg(feature = "conduit_bin")] | #[cfg(feature = "conduit_bin")] | ||||||
| use rocket::{get, tokio}; | use rocket::{get, tokio}; | ||||||
| use std::{ | use std::{ | ||||||
|     collections::{hash_map, BTreeMap, HashMap, HashSet}, |     collections::{hash_map, BTreeMap, HashMap, HashSet}, | ||||||
|  |     convert::TryFrom, | ||||||
|     time::Duration, |     time::Duration, | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
|  | @ -40,7 +41,9 @@ pub async fn sync_events_route( | ||||||
|         .unwrap_or(0); |         .unwrap_or(0); | ||||||
| 
 | 
 | ||||||
|     let mut presence_updates = HashMap::new(); |     let mut presence_updates = HashMap::new(); | ||||||
|  |     let mut left_encrypted_users = HashSet::new(); // Users that have left any encrypted rooms the sender was in
 | ||||||
|     let mut device_list_updates = HashSet::new(); |     let mut device_list_updates = HashSet::new(); | ||||||
|  |     let mut device_list_left = HashSet::new(); | ||||||
| 
 | 
 | ||||||
|     // Look for device list updates of this account
 |     // Look for device list updates of this account
 | ||||||
|     device_list_updates.extend( |     device_list_updates.extend( | ||||||
|  | @ -67,6 +70,8 @@ pub async fn sync_events_route( | ||||||
|             .rev() |             .rev() | ||||||
|             .collect::<Vec<_>>(); |             .collect::<Vec<_>>(); | ||||||
| 
 | 
 | ||||||
|  |         let send_notification_counts = !timeline_pdus.is_empty(); | ||||||
|  | 
 | ||||||
|         // They /sync response doesn't always return all messages, so we say the output is
 |         // They /sync response doesn't always return all messages, so we say the output is
 | ||||||
|         // limited unless there are events in non_timeline_pdus
 |         // limited unless there are events in non_timeline_pdus
 | ||||||
|         let mut limited = false; |         let mut limited = false; | ||||||
|  | @ -79,32 +84,86 @@ pub async fn sync_events_route( | ||||||
|             limited = true; |             limited = true; | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|  |         let encrypted_room = db | ||||||
|  |             .rooms | ||||||
|  |             .room_state_get(&room_id, &EventType::RoomEncryption, "")? | ||||||
|  |             .is_some(); | ||||||
|  | 
 | ||||||
|  |         // TODO: optimize this?
 | ||||||
|         let mut send_member_count = false; |         let mut send_member_count = false; | ||||||
|         let mut joined_since_last_sync = false; |         let mut joined_since_last_sync = false; | ||||||
|         let mut send_notification_counts = false; |         let mut new_encrypted_room = false; | ||||||
|         for pdu in db.rooms.pdus_since(&sender_id, &room_id, since)? { |         for (state_key, pdu) in db | ||||||
|             let pdu = pdu?; |             .rooms | ||||||
|             send_notification_counts = true; |             .pdus_since(&sender_id, &room_id, since)? | ||||||
|  |             .filter_map(|r| r.ok()) | ||||||
|  |             .filter_map(|pdu| Some((pdu.state_key.clone()?, pdu))) | ||||||
|  |         { | ||||||
|             if pdu.kind == EventType::RoomMember { |             if pdu.kind == EventType::RoomMember { | ||||||
|                 send_member_count = true; |                 send_member_count = true; | ||||||
|                 if !joined_since_last_sync && pdu.state_key == Some(sender_id.to_string()) { | 
 | ||||||
|                     let content = serde_json::from_value::< |                 let content = serde_json::from_value::< | ||||||
|                         Raw<ruma::events::room::member::MemberEventContent>, |                     Raw<ruma::events::room::member::MemberEventContent>, | ||||||
|                     >(pdu.content.clone()) |                 >(pdu.content.clone()) | ||||||
|                     .expect("Raw::from_value always works") |                 .expect("Raw::from_value always works") | ||||||
|                     .deserialize() |                 .deserialize() | ||||||
|                     .map_err(|_| Error::bad_database("Invalid PDU in database."))?; |                 .map_err(|_| Error::bad_database("Invalid PDU in database."))?; | ||||||
|                     if content.membership == ruma::events::room::member::MembershipState::Join { | 
 | ||||||
|                         joined_since_last_sync = true; |                 if pdu.state_key == Some(sender_id.to_string()) | ||||||
|                         // Both send_member_count and joined_since_last_sync are set. There's
 |                     && content.membership == MembershipState::Join | ||||||
|                         // nothing more to do
 |                 { | ||||||
|                         break; |                     joined_since_last_sync = true; | ||||||
|  |                 } else if encrypted_room && content.membership == MembershipState::Join { | ||||||
|  |                     // A new user joined an encrypted room
 | ||||||
|  |                     let user_id = UserId::try_from(state_key) | ||||||
|  |                         .map_err(|_| Error::bad_database("Invalid UserId in member PDU."))?; | ||||||
|  |                     // Add encryption update if we didn't share an encrypted room already
 | ||||||
|  |                     if !share_encrypted_room(&db, &sender_id, &user_id, &room_id) { | ||||||
|  |                         device_list_updates.insert(user_id); | ||||||
|                     } |                     } | ||||||
|  |                 } else if encrypted_room && content.membership == MembershipState::Leave { | ||||||
|  |                     // Write down users that have left encrypted rooms we are in
 | ||||||
|  |                     left_encrypted_users.insert( | ||||||
|  |                         UserId::try_from(state_key) | ||||||
|  |                             .map_err(|_| Error::bad_database("Invalid UserId in member PDU."))?, | ||||||
|  |                     ); | ||||||
|                 } |                 } | ||||||
|  |             } else if pdu.kind == EventType::RoomEncryption { | ||||||
|  |                 new_encrypted_room = true; | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         let members = db.rooms.room_state_type(&room_id, &EventType::RoomMember)?; |         if joined_since_last_sync && encrypted_room || new_encrypted_room { | ||||||
|  |             // If the user is in a new encrypted room, give them all joined users
 | ||||||
|  |             device_list_updates.extend( | ||||||
|  |                 db.rooms | ||||||
|  |                     .room_members(&room_id) | ||||||
|  |                     .filter_map(|user_id| { | ||||||
|  |                         Some( | ||||||
|  |                             UserId::try_from(user_id.ok()?.clone()) | ||||||
|  |                                 .map_err(|_| { | ||||||
|  |                                     Error::bad_database("Invalid member event state key in db.") | ||||||
|  |                                 }) | ||||||
|  |                                 .ok()?, | ||||||
|  |                         ) | ||||||
|  |                     }) | ||||||
|  |                     .filter(|user_id| { | ||||||
|  |                         // Don't send key updates from the sender to the sender
 | ||||||
|  |                         sender_id != user_id | ||||||
|  |                     }) | ||||||
|  |                     .filter(|user_id| { | ||||||
|  |                         // Only send keys if the sender doesn't share an encrypted room with the target already
 | ||||||
|  |                         !share_encrypted_room(&db, sender_id, user_id, &room_id) | ||||||
|  |                     }), | ||||||
|  |             ); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         // Look for device list updates in this room
 | ||||||
|  |         device_list_updates.extend( | ||||||
|  |             db.users | ||||||
|  |                 .keys_changed(&room_id.to_string(), since, None) | ||||||
|  |                 .filter_map(|r| r.ok()), | ||||||
|  |         ); | ||||||
| 
 | 
 | ||||||
|         let (joined_member_count, invited_member_count, heroes) = if send_member_count { |         let (joined_member_count, invited_member_count, heroes) = if send_member_count { | ||||||
|             let joined_member_count = db.rooms.room_members(&room_id).count(); |             let joined_member_count = db.rooms.room_members(&room_id).count(); | ||||||
|  | @ -131,35 +190,17 @@ pub async fn sync_events_route( | ||||||
|                         .map_err(|_| Error::bad_database("Invalid member event in database."))?; |                         .map_err(|_| Error::bad_database("Invalid member event in database."))?; | ||||||
| 
 | 
 | ||||||
|                         if let Some(state_key) = &pdu.state_key { |                         if let Some(state_key) = &pdu.state_key { | ||||||
|                             let current_content = serde_json::from_value::< |                             let user_id = UserId::try_from(state_key.clone()).map_err(|_| { | ||||||
|                                 Raw<ruma::events::room::member::MemberEventContent>, |                                 Error::bad_database("Invalid UserId in member PDU.") | ||||||
|                             >( |  | ||||||
|                                 members |  | ||||||
|                                     .get(state_key) |  | ||||||
|                                     .ok_or_else(|| { |  | ||||||
|                                         Error::bad_database( |  | ||||||
|                                             "A user that joined once has no member event anymore.", |  | ||||||
|                                         ) |  | ||||||
|                                     })? |  | ||||||
|                                     .content |  | ||||||
|                                     .clone(), |  | ||||||
|                             ) |  | ||||||
|                             .expect("Raw::from_value always works") |  | ||||||
|                             .deserialize() |  | ||||||
|                             .map_err(|_| { |  | ||||||
|                                 Error::bad_database("Invalid member event in database.") |  | ||||||
|                             })?; |                             })?; | ||||||
| 
 | 
 | ||||||
|                             // The membership was and still is invite or join
 |                             // The membership was and still is invite or join
 | ||||||
|                             if matches!( |                             if matches!( | ||||||
|                                 content.membership, |                                 content.membership, | ||||||
|                                 ruma::events::room::member::MembershipState::Join |                                 MembershipState::Join | MembershipState::Invite | ||||||
|                                     | ruma::events::room::member::MembershipState::Invite |                             ) && (db.rooms.is_joined(&user_id, &room_id)? | ||||||
|                             ) && matches!( |                                 || db.rooms.is_invited(&user_id, &room_id)?) | ||||||
|                                 current_content.membership, |                             { | ||||||
|                                 ruma::events::room::member::MembershipState::Join |  | ||||||
|                                     | ruma::events::room::member::MembershipState::Invite |  | ||||||
|                             ) { |  | ||||||
|                                 Ok::<_, Error>(Some(state_key.clone())) |                                 Ok::<_, Error>(Some(state_key.clone())) | ||||||
|                             } else { |                             } else { | ||||||
|                                 Ok(None) |                                 Ok(None) | ||||||
|  | @ -295,13 +336,6 @@ pub async fn sync_events_route( | ||||||
|             joined_rooms.insert(room_id.clone(), joined_room); |             joined_rooms.insert(room_id.clone(), joined_room); | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         // Look for device list updates in this room
 |  | ||||||
|         device_list_updates.extend( |  | ||||||
|             db.users |  | ||||||
|                 .keys_changed(&room_id.to_string(), since, None) |  | ||||||
|                 .filter_map(|r| r.ok()), |  | ||||||
|         ); |  | ||||||
| 
 |  | ||||||
|         // Take presence updates from this room
 |         // Take presence updates from this room
 | ||||||
|         for (user_id, presence) in |         for (user_id, presence) in | ||||||
|             db.rooms |             db.rooms | ||||||
|  | @ -392,18 +426,17 @@ pub async fn sync_events_route( | ||||||
|         let mut invited_since_last_sync = false; |         let mut invited_since_last_sync = false; | ||||||
|         for pdu in db.rooms.pdus_since(&sender_id, &room_id, since)? { |         for pdu in db.rooms.pdus_since(&sender_id, &room_id, since)? { | ||||||
|             let pdu = pdu?; |             let pdu = pdu?; | ||||||
|             if pdu.kind == EventType::RoomMember { |             if pdu.kind == EventType::RoomMember && pdu.state_key == Some(sender_id.to_string()) { | ||||||
|                 if pdu.state_key == Some(sender_id.to_string()) { |                 let content = serde_json::from_value::< | ||||||
|                     let content = serde_json::from_value::< |                     Raw<ruma::events::room::member::MemberEventContent>, | ||||||
|                         Raw<ruma::events::room::member::MemberEventContent>, |                 >(pdu.content.clone()) | ||||||
|                     >(pdu.content.clone()) |                 .expect("Raw::from_value always works") | ||||||
|                     .expect("Raw::from_value always works") |                 .deserialize() | ||||||
|                     .deserialize() |                 .map_err(|_| Error::bad_database("Invalid PDU in database."))?; | ||||||
|                     .map_err(|_| Error::bad_database("Invalid PDU in database."))?; | 
 | ||||||
|                     if content.membership == ruma::events::room::member::MembershipState::Invite { |                 if content.membership == MembershipState::Invite { | ||||||
|                         invited_since_last_sync = true; |                     invited_since_last_sync = true; | ||||||
|                         break; |                     break; | ||||||
|                     } |  | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|  | @ -428,6 +461,28 @@ pub async fn sync_events_route( | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |     // TODO: mark users as left when WE left an encrypted room they were in
 | ||||||
|  |     for user_id in left_encrypted_users { | ||||||
|  |         // If the user doesn't share an encrypted room with the target anymore, we need to tell
 | ||||||
|  |         // them
 | ||||||
|  |         if db | ||||||
|  |             .rooms | ||||||
|  |             .get_shared_rooms(vec![sender_id.clone(), user_id.clone()]) | ||||||
|  |             .filter_map(|r| r.ok()) | ||||||
|  |             .filter_map(|other_room_id| { | ||||||
|  |                 Some( | ||||||
|  |                     db.rooms | ||||||
|  |                         .room_state_get(&other_room_id, &EventType::RoomEncryption, "") | ||||||
|  |                         .ok()? | ||||||
|  |                         .is_some(), | ||||||
|  |                 ) | ||||||
|  |             }) | ||||||
|  |             .all(|encrypted| !encrypted) | ||||||
|  |         { | ||||||
|  |             device_list_left.insert(user_id); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     // Remove all to-device events the device received *last time*
 |     // Remove all to-device events the device received *last time*
 | ||||||
|     db.users |     db.users | ||||||
|         .remove_to_device_events(sender_id, device_id, since)?; |         .remove_to_device_events(sender_id, device_id, since)?; | ||||||
|  | @ -459,7 +514,7 @@ pub async fn sync_events_route( | ||||||
|         }, |         }, | ||||||
|         device_lists: sync_events::DeviceLists { |         device_lists: sync_events::DeviceLists { | ||||||
|             changed: device_list_updates.into_iter().collect(), |             changed: device_list_updates.into_iter().collect(), | ||||||
|             left: Vec::new(), // TODO
 |             left: device_list_left.into_iter().collect(), | ||||||
|         }, |         }, | ||||||
|         device_one_time_keys_count: if db.users.last_one_time_keys_update(sender_id)? > since { |         device_one_time_keys_count: if db.users.last_one_time_keys_update(sender_id)? > since { | ||||||
|             db.users.count_one_time_keys(sender_id, device_id)? |             db.users.count_one_time_keys(sender_id, device_id)? | ||||||
|  | @ -495,3 +550,24 @@ pub async fn sync_events_route( | ||||||
| 
 | 
 | ||||||
|     Ok(response.into()) |     Ok(response.into()) | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | fn share_encrypted_room( | ||||||
|  |     db: &Database, | ||||||
|  |     sender_id: &UserId, | ||||||
|  |     user_id: &UserId, | ||||||
|  |     ignore_room: &RoomId, | ||||||
|  | ) -> bool { | ||||||
|  |     db.rooms | ||||||
|  |         .get_shared_rooms(vec![sender_id.clone(), user_id.clone()]) | ||||||
|  |         .filter_map(|r| r.ok()) | ||||||
|  |         .filter(|room_id| room_id != ignore_room) | ||||||
|  |         .filter_map(|other_room_id| { | ||||||
|  |             Some( | ||||||
|  |                 db.rooms | ||||||
|  |                     .room_state_get(&other_room_id, &EventType::RoomEncryption, "") | ||||||
|  |                     .ok()? | ||||||
|  |                     .is_some(), | ||||||
|  |             ) | ||||||
|  |         }) | ||||||
|  |         .any(|encrypted| encrypted) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | @ -945,11 +945,11 @@ impl Rooms { | ||||||
|         }) |         }) | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     pub fn search_pdus( |     pub fn search_pdus<'a>( | ||||||
|         &self, |         &'a self, | ||||||
|         room_id: &RoomId, |         room_id: &RoomId, | ||||||
|         search_string: &str, |         search_string: &str, | ||||||
|     ) -> Result<(impl Iterator<Item = IVec>, Vec<String>)> { |     ) -> Result<(impl Iterator<Item = IVec> + 'a, Vec<String>)> { | ||||||
|         let mut prefix = room_id.to_string().as_bytes().to_vec(); |         let mut prefix = room_id.to_string().as_bytes().to_vec(); | ||||||
|         prefix.push(0xff); |         prefix.push(0xff); | ||||||
| 
 | 
 | ||||||
|  | @ -958,7 +958,7 @@ impl Rooms { | ||||||
|             .map(str::to_lowercase) |             .map(str::to_lowercase) | ||||||
|             .collect::<Vec<_>>(); |             .collect::<Vec<_>>(); | ||||||
| 
 | 
 | ||||||
|         let mut iterators = words.iter().map(|word| { |         let iterators = words.clone().into_iter().map(move |word| { | ||||||
|             let mut prefix2 = prefix.clone(); |             let mut prefix2 = prefix.clone(); | ||||||
|             prefix2.extend_from_slice(word.as_bytes()); |             prefix2.extend_from_slice(word.as_bytes()); | ||||||
|             prefix2.push(0xff); |             prefix2.push(0xff); | ||||||
|  | @ -973,50 +973,56 @@ impl Rooms { | ||||||
|                         .filter(|(_, &b)| b == 0xff) |                         .filter(|(_, &b)| b == 0xff) | ||||||
|                         .nth(1) |                         .nth(1) | ||||||
|                         .ok_or_else(|| Error::bad_database("Invalid tokenid in db."))? |                         .ok_or_else(|| Error::bad_database("Invalid tokenid in db."))? | ||||||
|                         .0 + 1; // +1 because the pdu id starts AFTER the separator
 |                         .0 | ||||||
|  |                         + 1; // +1 because the pdu id starts AFTER the separator
 | ||||||
| 
 | 
 | ||||||
|                     let pdu_id = |                     let pdu_id = key.subslice(pduid_index, key.len() - pduid_index); | ||||||
|                         key.subslice(pduid_index, key.len() - pduid_index); |  | ||||||
| 
 | 
 | ||||||
|                     Ok::<_, Error>(pdu_id) |                     Ok::<_, Error>(pdu_id) | ||||||
|                 }) |                 }) | ||||||
|                 .filter_map(|r| r.ok()) |                 .filter_map(|r| r.ok()) | ||||||
|                 .peekable() |  | ||||||
|         }); |         }); | ||||||
| 
 | 
 | ||||||
|         let first_iterator = match iterators.next() { |         Ok((utils::common_elements(iterators).unwrap(), words)) | ||||||
|             Some(i) => i, |     } | ||||||
|             None => { |  | ||||||
|                 return Err(Error::BadRequest( |  | ||||||
|                     ErrorKind::InvalidParam, |  | ||||||
|                     "search_term needs to contain at least one word.", |  | ||||||
|                 )) |  | ||||||
|             } |  | ||||||
|         }; |  | ||||||
| 
 | 
 | ||||||
|         let mut other_iterators = iterators.collect::<Vec<_>>(); |     pub fn get_shared_rooms<'a>( | ||||||
|  |         &'a self, | ||||||
|  |         users: Vec<UserId>, | ||||||
|  |     ) -> impl Iterator<Item = Result<RoomId>> + 'a { | ||||||
|  |         let iterators = users.into_iter().map(move |user_id| { | ||||||
|  |             let mut prefix = user_id.as_bytes().to_vec(); | ||||||
|  |             prefix.push(0xff); | ||||||
| 
 | 
 | ||||||
|         Ok(( |             self.userroomid_joined | ||||||
|             first_iterator.filter(move |target| { |                 .scan_prefix(&prefix) | ||||||
|                 other_iterators |                 .keys() | ||||||
|                     .iter_mut() |                 .filter_map(|r| r.ok()) | ||||||
|                     .map(|it| { |                 .map(|key| { | ||||||
|                         while let Some(element) = it.peek() { |                     let roomid_index = key | ||||||
|                             if element > target { |                         .iter() | ||||||
|                                 return false; |                         .enumerate() | ||||||
|                             } else if element == target { |                         .filter(|(_, &b)| b == 0xff) | ||||||
|                                 return true; |                         .nth(0) | ||||||
|                             } else { |                         .ok_or_else(|| Error::bad_database("Invalid userroomid_joined in db."))? | ||||||
|                                 it.next(); |                         .0 | ||||||
|                             } |                         + 1; // +1 because the room id starts AFTER the separator
 | ||||||
|                         } |  | ||||||
| 
 | 
 | ||||||
|                         false |                     let room_id = key.subslice(roomid_index, key.len() - roomid_index); | ||||||
|                     }) | 
 | ||||||
|                     .all(|b| b) |                     Ok::<_, Error>(room_id) | ||||||
|             }), |                 }) | ||||||
|             words, |                 .filter_map(|r| r.ok()) | ||||||
|         )) |         }); | ||||||
|  | 
 | ||||||
|  |         utils::common_elements(iterators) | ||||||
|  |             .expect("users is not empty") | ||||||
|  |             .map(|bytes| { | ||||||
|  |                 RoomId::try_from(utils::string_from_bytes(&*bytes).map_err(|_| { | ||||||
|  |                     Error::bad_database("Invalid RoomId bytes in userroomid_joined") | ||||||
|  |                 })?) | ||||||
|  |                 .map_err(|_| Error::bad_database("Invalid RoomId in userroomid_joined.")) | ||||||
|  |             }) | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     /// Returns an iterator over all joined members of a room.
 |     /// Returns an iterator over all joined members of a room.
 | ||||||
|  |  | ||||||
|  | @ -408,19 +408,7 @@ impl Users { | ||||||
|             &*serde_json::to_string(&device_keys).expect("DeviceKeys::to_string always works"), |             &*serde_json::to_string(&device_keys).expect("DeviceKeys::to_string always works"), | ||||||
|         )?; |         )?; | ||||||
| 
 | 
 | ||||||
|         let count = globals.next_count()?.to_be_bytes(); |         self.mark_device_key_update(user_id, rooms, globals)?; | ||||||
|         for room_id in rooms.rooms_joined(&user_id) { |  | ||||||
|             let mut key = room_id?.to_string().as_bytes().to_vec(); |  | ||||||
|             key.push(0xff); |  | ||||||
|             key.extend_from_slice(&count); |  | ||||||
| 
 |  | ||||||
|             self.keychangeid_userid.insert(key, &*user_id.to_string())?; |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         let mut key = user_id.to_string().as_bytes().to_vec(); |  | ||||||
|         key.push(0xff); |  | ||||||
|         key.extend_from_slice(&count); |  | ||||||
|         self.keychangeid_userid.insert(key, &*user_id.to_string())?; |  | ||||||
| 
 | 
 | ||||||
|         Ok(()) |         Ok(()) | ||||||
|     } |     } | ||||||
|  | @ -520,19 +508,7 @@ impl Users { | ||||||
|                 .insert(&*user_id.to_string(), user_signing_key_key)?; |                 .insert(&*user_id.to_string(), user_signing_key_key)?; | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         let count = globals.next_count()?.to_be_bytes(); |         self.mark_device_key_update(user_id, rooms, globals)?; | ||||||
|         for room_id in rooms.rooms_joined(&user_id) { |  | ||||||
|             let mut key = room_id?.to_string().as_bytes().to_vec(); |  | ||||||
|             key.push(0xff); |  | ||||||
|             key.extend_from_slice(&count); |  | ||||||
| 
 |  | ||||||
|             self.keychangeid_userid.insert(key, &*user_id.to_string())?; |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         let mut key = user_id.to_string().as_bytes().to_vec(); |  | ||||||
|         key.push(0xff); |  | ||||||
|         key.extend_from_slice(&count); |  | ||||||
|         self.keychangeid_userid.insert(key, &*user_id.to_string())?; |  | ||||||
| 
 | 
 | ||||||
|         Ok(()) |         Ok(()) | ||||||
|     } |     } | ||||||
|  | @ -576,21 +552,7 @@ impl Users { | ||||||
|         )?; |         )?; | ||||||
| 
 | 
 | ||||||
|         // TODO: Should we notify about this change?
 |         // TODO: Should we notify about this change?
 | ||||||
|         let count = globals.next_count()?.to_be_bytes(); |         self.mark_device_key_update(target_id, rooms, globals)?; | ||||||
|         for room_id in rooms.rooms_joined(&target_id) { |  | ||||||
|             let mut key = room_id?.to_string().as_bytes().to_vec(); |  | ||||||
|             key.push(0xff); |  | ||||||
|             key.extend_from_slice(&count); |  | ||||||
| 
 |  | ||||||
|             self.keychangeid_userid |  | ||||||
|                 .insert(key, &*target_id.to_string())?; |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         let mut key = target_id.to_string().as_bytes().to_vec(); |  | ||||||
|         key.push(0xff); |  | ||||||
|         key.extend_from_slice(&count); |  | ||||||
|         self.keychangeid_userid |  | ||||||
|             .insert(key, &*target_id.to_string())?; |  | ||||||
| 
 | 
 | ||||||
|         Ok(()) |         Ok(()) | ||||||
|     } |     } | ||||||
|  | @ -628,6 +590,37 @@ impl Users { | ||||||
|             }) |             }) | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |     fn mark_device_key_update( | ||||||
|  |         &self, | ||||||
|  |         user_id: &UserId, | ||||||
|  |         rooms: &super::rooms::Rooms, | ||||||
|  |         globals: &super::globals::Globals, | ||||||
|  |     ) -> Result<()> { | ||||||
|  |         let count = globals.next_count()?.to_be_bytes(); | ||||||
|  |         for room_id in rooms.rooms_joined(&user_id).filter_map(|r| r.ok()) { | ||||||
|  |             // Don't send key updates to unencrypted rooms
 | ||||||
|  |             if rooms | ||||||
|  |                 .room_state_get(&room_id, &EventType::RoomEncryption, "")? | ||||||
|  |                 .is_none() | ||||||
|  |             { | ||||||
|  |                 return Ok(()); | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             let mut key = room_id.to_string().as_bytes().to_vec(); | ||||||
|  |             key.push(0xff); | ||||||
|  |             key.extend_from_slice(&count); | ||||||
|  | 
 | ||||||
|  |             self.keychangeid_userid.insert(key, &*user_id.to_string())?; | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         let mut key = user_id.to_string().as_bytes().to_vec(); | ||||||
|  |         key.push(0xff); | ||||||
|  |         key.extend_from_slice(&count); | ||||||
|  |         self.keychangeid_userid.insert(key, &*user_id.to_string())?; | ||||||
|  | 
 | ||||||
|  |         Ok(()) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     pub fn get_device_keys( |     pub fn get_device_keys( | ||||||
|         &self, |         &self, | ||||||
|         user_id: &UserId, |         user_id: &UserId, | ||||||
|  |  | ||||||
							
								
								
									
										27
									
								
								src/utils.rs
									
									
									
									
									
								
							
							
						
						
									
										27
									
								
								src/utils.rs
									
									
									
									
									
								
							|  | @ -1,5 +1,6 @@ | ||||||
| use argon2::{Config, Variant}; | use argon2::{Config, Variant}; | ||||||
| use rand::prelude::*; | use rand::prelude::*; | ||||||
|  | use sled::IVec; | ||||||
| use std::{ | use std::{ | ||||||
|     convert::TryInto, |     convert::TryInto, | ||||||
|     time::{SystemTime, UNIX_EPOCH}, |     time::{SystemTime, UNIX_EPOCH}, | ||||||
|  | @ -59,3 +60,29 @@ pub fn calculate_hash(password: &str) -> Result<String, argon2::Error> { | ||||||
|     let salt = random_string(32); |     let salt = random_string(32); | ||||||
|     argon2::hash_encoded(password.as_bytes(), salt.as_bytes(), &hashing_config) |     argon2::hash_encoded(password.as_bytes(), salt.as_bytes(), &hashing_config) | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | pub fn common_elements( | ||||||
|  |     mut iterators: impl Iterator<Item = impl Iterator<Item = IVec>>, | ||||||
|  | ) -> Option<impl Iterator<Item = IVec>> { | ||||||
|  |     let first_iterator = iterators.next()?; | ||||||
|  |     let mut other_iterators = iterators.map(|i| i.peekable()).collect::<Vec<_>>(); | ||||||
|  | 
 | ||||||
|  |     Some(first_iterator.filter(move |target| { | ||||||
|  |         other_iterators | ||||||
|  |             .iter_mut() | ||||||
|  |             .map(|it| { | ||||||
|  |                 while let Some(element) = it.peek() { | ||||||
|  |                     if element > target { | ||||||
|  |                         return false; | ||||||
|  |                     } else if element == target { | ||||||
|  |                         return true; | ||||||
|  |                     } else { | ||||||
|  |                         it.next(); | ||||||
|  |                     } | ||||||
|  |                 } | ||||||
|  | 
 | ||||||
|  |                 false | ||||||
|  |             }) | ||||||
|  |             .all(|b| b) | ||||||
|  |     })) | ||||||
|  | } | ||||||
|  |  | ||||||
		Loading…
	
		Reference in a new issue