diff --git a/src/client_server/search.rs b/src/client_server/search.rs index 1107555..dec1ec9 100644 --- a/src/client_server/search.rs +++ b/src/client_server/search.rs @@ -19,16 +19,9 @@ pub fn search_events_route( let sender_id = body.sender_id.as_ref().expect("user is authenticated"); let search_criteria = body.search_categories.room_events.as_ref().unwrap(); - let filter = search_criteria - .filter - .as_ref() - .unwrap(); + let filter = search_criteria.filter.as_ref().unwrap(); - let room_id = filter.rooms - .as_ref() - .unwrap() - .first() - .unwrap(); + let room_id = filter.rooms.as_ref().unwrap().first().unwrap(); let limit = filter.limit.map_or(10, |l| u64::from(l) as usize); diff --git a/src/client_server/sync.rs b/src/client_server/sync.rs index 201e8bc..ecc7144 100644 --- a/src/client_server/sync.rs +++ b/src/client_server/sync.rs @@ -2,14 +2,15 @@ use super::State; use crate::{ConduitResult, Database, Error, Ruma}; use ruma::{ api::client::r0::sync::sync_events, - events::{AnySyncEphemeralRoomEvent, EventType}, - Raw, + events::{room::member::MembershipState, AnySyncEphemeralRoomEvent, EventType}, + Raw, RoomId, UserId, }; #[cfg(feature = "conduit_bin")] use rocket::{get, tokio}; use std::{ collections::{hash_map, BTreeMap, HashMap, HashSet}, + convert::TryFrom, time::Duration, }; @@ -40,7 +41,9 @@ pub async fn sync_events_route( .unwrap_or(0); let mut presence_updates = HashMap::new(); + let mut left_encrypted_users = HashSet::new(); // Users that have left any encrypted rooms the sender was in let mut device_list_updates = HashSet::new(); + let mut device_list_left = HashSet::new(); // Look for device list updates of this account device_list_updates.extend( @@ -67,6 +70,8 @@ pub async fn sync_events_route( .rev() .collect::>(); + let send_notification_counts = !timeline_pdus.is_empty(); + // They /sync response doesn't always return all messages, so we say the output is // limited unless there are events in non_timeline_pdus let mut limited = false; @@ -79,32 +84,86 @@ pub async fn sync_events_route( limited = true; } + let encrypted_room = db + .rooms + .room_state_get(&room_id, &EventType::RoomEncryption, "")? + .is_some(); + + // TODO: optimize this? let mut send_member_count = false; let mut joined_since_last_sync = false; - let mut send_notification_counts = false; - for pdu in db.rooms.pdus_since(&sender_id, &room_id, since)? { - let pdu = pdu?; - send_notification_counts = true; + let mut new_encrypted_room = false; + for (state_key, pdu) in db + .rooms + .pdus_since(&sender_id, &room_id, since)? + .filter_map(|r| r.ok()) + .filter_map(|pdu| Some((pdu.state_key.clone()?, pdu))) + { if pdu.kind == EventType::RoomMember { send_member_count = true; - if !joined_since_last_sync && pdu.state_key == Some(sender_id.to_string()) { - let content = serde_json::from_value::< - Raw, - >(pdu.content.clone()) - .expect("Raw::from_value always works") - .deserialize() - .map_err(|_| Error::bad_database("Invalid PDU in database."))?; - if content.membership == ruma::events::room::member::MembershipState::Join { - joined_since_last_sync = true; - // Both send_member_count and joined_since_last_sync are set. There's - // nothing more to do - break; + + let content = serde_json::from_value::< + Raw, + >(pdu.content.clone()) + .expect("Raw::from_value always works") + .deserialize() + .map_err(|_| Error::bad_database("Invalid PDU in database."))?; + + if pdu.state_key == Some(sender_id.to_string()) + && content.membership == MembershipState::Join + { + joined_since_last_sync = true; + } else if encrypted_room && content.membership == MembershipState::Join { + // A new user joined an encrypted room + let user_id = UserId::try_from(state_key) + .map_err(|_| Error::bad_database("Invalid UserId in member PDU."))?; + // Add encryption update if we didn't share an encrypted room already + if !share_encrypted_room(&db, &sender_id, &user_id, &room_id) { + device_list_updates.insert(user_id); } + } else if encrypted_room && content.membership == MembershipState::Leave { + // Write down users that have left encrypted rooms we are in + left_encrypted_users.insert( + UserId::try_from(state_key) + .map_err(|_| Error::bad_database("Invalid UserId in member PDU."))?, + ); } + } else if pdu.kind == EventType::RoomEncryption { + new_encrypted_room = true; } } - let members = db.rooms.room_state_type(&room_id, &EventType::RoomMember)?; + if joined_since_last_sync && encrypted_room || new_encrypted_room { + // If the user is in a new encrypted room, give them all joined users + device_list_updates.extend( + db.rooms + .room_members(&room_id) + .filter_map(|user_id| { + Some( + UserId::try_from(user_id.ok()?.clone()) + .map_err(|_| { + Error::bad_database("Invalid member event state key in db.") + }) + .ok()?, + ) + }) + .filter(|user_id| { + // Don't send key updates from the sender to the sender + sender_id != user_id + }) + .filter(|user_id| { + // Only send keys if the sender doesn't share an encrypted room with the target already + !share_encrypted_room(&db, sender_id, user_id, &room_id) + }), + ); + } + + // Look for device list updates in this room + device_list_updates.extend( + db.users + .keys_changed(&room_id.to_string(), since, None) + .filter_map(|r| r.ok()), + ); let (joined_member_count, invited_member_count, heroes) = if send_member_count { let joined_member_count = db.rooms.room_members(&room_id).count(); @@ -131,35 +190,17 @@ pub async fn sync_events_route( .map_err(|_| Error::bad_database("Invalid member event in database."))?; if let Some(state_key) = &pdu.state_key { - let current_content = serde_json::from_value::< - Raw, - >( - members - .get(state_key) - .ok_or_else(|| { - Error::bad_database( - "A user that joined once has no member event anymore.", - ) - })? - .content - .clone(), - ) - .expect("Raw::from_value always works") - .deserialize() - .map_err(|_| { - Error::bad_database("Invalid member event in database.") + let user_id = UserId::try_from(state_key.clone()).map_err(|_| { + Error::bad_database("Invalid UserId in member PDU.") })?; // The membership was and still is invite or join if matches!( content.membership, - ruma::events::room::member::MembershipState::Join - | ruma::events::room::member::MembershipState::Invite - ) && matches!( - current_content.membership, - ruma::events::room::member::MembershipState::Join - | ruma::events::room::member::MembershipState::Invite - ) { + MembershipState::Join | MembershipState::Invite + ) && (db.rooms.is_joined(&user_id, &room_id)? + || db.rooms.is_invited(&user_id, &room_id)?) + { Ok::<_, Error>(Some(state_key.clone())) } else { Ok(None) @@ -295,13 +336,6 @@ pub async fn sync_events_route( joined_rooms.insert(room_id.clone(), joined_room); } - // Look for device list updates in this room - device_list_updates.extend( - db.users - .keys_changed(&room_id.to_string(), since, None) - .filter_map(|r| r.ok()), - ); - // Take presence updates from this room for (user_id, presence) in db.rooms @@ -392,18 +426,17 @@ pub async fn sync_events_route( let mut invited_since_last_sync = false; for pdu in db.rooms.pdus_since(&sender_id, &room_id, since)? { let pdu = pdu?; - if pdu.kind == EventType::RoomMember { - if pdu.state_key == Some(sender_id.to_string()) { - let content = serde_json::from_value::< - Raw, - >(pdu.content.clone()) - .expect("Raw::from_value always works") - .deserialize() - .map_err(|_| Error::bad_database("Invalid PDU in database."))?; - if content.membership == ruma::events::room::member::MembershipState::Invite { - invited_since_last_sync = true; - break; - } + if pdu.kind == EventType::RoomMember && pdu.state_key == Some(sender_id.to_string()) { + let content = serde_json::from_value::< + Raw, + >(pdu.content.clone()) + .expect("Raw::from_value always works") + .deserialize() + .map_err(|_| Error::bad_database("Invalid PDU in database."))?; + + if content.membership == MembershipState::Invite { + invited_since_last_sync = true; + break; } } } @@ -428,6 +461,28 @@ pub async fn sync_events_route( } } + // TODO: mark users as left when WE left an encrypted room they were in + for user_id in left_encrypted_users { + // If the user doesn't share an encrypted room with the target anymore, we need to tell + // them + if db + .rooms + .get_shared_rooms(vec![sender_id.clone(), user_id.clone()]) + .filter_map(|r| r.ok()) + .filter_map(|other_room_id| { + Some( + db.rooms + .room_state_get(&other_room_id, &EventType::RoomEncryption, "") + .ok()? + .is_some(), + ) + }) + .all(|encrypted| !encrypted) + { + device_list_left.insert(user_id); + } + } + // Remove all to-device events the device received *last time* db.users .remove_to_device_events(sender_id, device_id, since)?; @@ -459,7 +514,7 @@ pub async fn sync_events_route( }, device_lists: sync_events::DeviceLists { changed: device_list_updates.into_iter().collect(), - left: Vec::new(), // TODO + left: device_list_left.into_iter().collect(), }, device_one_time_keys_count: if db.users.last_one_time_keys_update(sender_id)? > since { db.users.count_one_time_keys(sender_id, device_id)? @@ -495,3 +550,24 @@ pub async fn sync_events_route( Ok(response.into()) } + +fn share_encrypted_room( + db: &Database, + sender_id: &UserId, + user_id: &UserId, + ignore_room: &RoomId, +) -> bool { + db.rooms + .get_shared_rooms(vec![sender_id.clone(), user_id.clone()]) + .filter_map(|r| r.ok()) + .filter(|room_id| room_id != ignore_room) + .filter_map(|other_room_id| { + Some( + db.rooms + .room_state_get(&other_room_id, &EventType::RoomEncryption, "") + .ok()? + .is_some(), + ) + }) + .any(|encrypted| encrypted) +} diff --git a/src/database/rooms.rs b/src/database/rooms.rs index 294531e..767f581 100644 --- a/src/database/rooms.rs +++ b/src/database/rooms.rs @@ -945,11 +945,11 @@ impl Rooms { }) } - pub fn search_pdus( - &self, + pub fn search_pdus<'a>( + &'a self, room_id: &RoomId, search_string: &str, - ) -> Result<(impl Iterator, Vec)> { + ) -> Result<(impl Iterator + 'a, Vec)> { let mut prefix = room_id.to_string().as_bytes().to_vec(); prefix.push(0xff); @@ -958,7 +958,7 @@ impl Rooms { .map(str::to_lowercase) .collect::>(); - let mut iterators = words.iter().map(|word| { + let iterators = words.clone().into_iter().map(move |word| { let mut prefix2 = prefix.clone(); prefix2.extend_from_slice(word.as_bytes()); prefix2.push(0xff); @@ -973,50 +973,56 @@ impl Rooms { .filter(|(_, &b)| b == 0xff) .nth(1) .ok_or_else(|| Error::bad_database("Invalid tokenid in db."))? - .0 + 1; // +1 because the pdu id starts AFTER the separator + .0 + + 1; // +1 because the pdu id starts AFTER the separator - let pdu_id = - key.subslice(pduid_index, key.len() - pduid_index); + let pdu_id = key.subslice(pduid_index, key.len() - pduid_index); Ok::<_, Error>(pdu_id) }) .filter_map(|r| r.ok()) - .peekable() }); - let first_iterator = match iterators.next() { - Some(i) => i, - None => { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "search_term needs to contain at least one word.", - )) - } - }; + Ok((utils::common_elements(iterators).unwrap(), words)) + } - let mut other_iterators = iterators.collect::>(); + pub fn get_shared_rooms<'a>( + &'a self, + users: Vec, + ) -> impl Iterator> + 'a { + let iterators = users.into_iter().map(move |user_id| { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xff); - Ok(( - first_iterator.filter(move |target| { - other_iterators - .iter_mut() - .map(|it| { - while let Some(element) = it.peek() { - if element > target { - return false; - } else if element == target { - return true; - } else { - it.next(); - } - } + self.userroomid_joined + .scan_prefix(&prefix) + .keys() + .filter_map(|r| r.ok()) + .map(|key| { + let roomid_index = key + .iter() + .enumerate() + .filter(|(_, &b)| b == 0xff) + .nth(0) + .ok_or_else(|| Error::bad_database("Invalid userroomid_joined in db."))? + .0 + + 1; // +1 because the room id starts AFTER the separator - false - }) - .all(|b| b) - }), - words, - )) + let room_id = key.subslice(roomid_index, key.len() - roomid_index); + + Ok::<_, Error>(room_id) + }) + .filter_map(|r| r.ok()) + }); + + utils::common_elements(iterators) + .expect("users is not empty") + .map(|bytes| { + RoomId::try_from(utils::string_from_bytes(&*bytes).map_err(|_| { + Error::bad_database("Invalid RoomId bytes in userroomid_joined") + })?) + .map_err(|_| Error::bad_database("Invalid RoomId in userroomid_joined.")) + }) } /// Returns an iterator over all joined members of a room. diff --git a/src/database/users.rs b/src/database/users.rs index 2500b4c..1b6a681 100644 --- a/src/database/users.rs +++ b/src/database/users.rs @@ -408,19 +408,7 @@ impl Users { &*serde_json::to_string(&device_keys).expect("DeviceKeys::to_string always works"), )?; - let count = globals.next_count()?.to_be_bytes(); - for room_id in rooms.rooms_joined(&user_id) { - let mut key = room_id?.to_string().as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(&count); - - self.keychangeid_userid.insert(key, &*user_id.to_string())?; - } - - let mut key = user_id.to_string().as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(&count); - self.keychangeid_userid.insert(key, &*user_id.to_string())?; + self.mark_device_key_update(user_id, rooms, globals)?; Ok(()) } @@ -520,19 +508,7 @@ impl Users { .insert(&*user_id.to_string(), user_signing_key_key)?; } - let count = globals.next_count()?.to_be_bytes(); - for room_id in rooms.rooms_joined(&user_id) { - let mut key = room_id?.to_string().as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(&count); - - self.keychangeid_userid.insert(key, &*user_id.to_string())?; - } - - let mut key = user_id.to_string().as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(&count); - self.keychangeid_userid.insert(key, &*user_id.to_string())?; + self.mark_device_key_update(user_id, rooms, globals)?; Ok(()) } @@ -576,21 +552,7 @@ impl Users { )?; // TODO: Should we notify about this change? - let count = globals.next_count()?.to_be_bytes(); - for room_id in rooms.rooms_joined(&target_id) { - let mut key = room_id?.to_string().as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(&count); - - self.keychangeid_userid - .insert(key, &*target_id.to_string())?; - } - - let mut key = target_id.to_string().as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(&count); - self.keychangeid_userid - .insert(key, &*target_id.to_string())?; + self.mark_device_key_update(target_id, rooms, globals)?; Ok(()) } @@ -628,6 +590,37 @@ impl Users { }) } + fn mark_device_key_update( + &self, + user_id: &UserId, + rooms: &super::rooms::Rooms, + globals: &super::globals::Globals, + ) -> Result<()> { + let count = globals.next_count()?.to_be_bytes(); + for room_id in rooms.rooms_joined(&user_id).filter_map(|r| r.ok()) { + // Don't send key updates to unencrypted rooms + if rooms + .room_state_get(&room_id, &EventType::RoomEncryption, "")? + .is_none() + { + return Ok(()); + } + + let mut key = room_id.to_string().as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(&count); + + self.keychangeid_userid.insert(key, &*user_id.to_string())?; + } + + let mut key = user_id.to_string().as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(&count); + self.keychangeid_userid.insert(key, &*user_id.to_string())?; + + Ok(()) + } + pub fn get_device_keys( &self, user_id: &UserId, diff --git a/src/utils.rs b/src/utils.rs index 0ab3bfa..473c18f 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,5 +1,6 @@ use argon2::{Config, Variant}; use rand::prelude::*; +use sled::IVec; use std::{ convert::TryInto, time::{SystemTime, UNIX_EPOCH}, @@ -59,3 +60,29 @@ pub fn calculate_hash(password: &str) -> Result { let salt = random_string(32); argon2::hash_encoded(password.as_bytes(), salt.as_bytes(), &hashing_config) } + +pub fn common_elements( + mut iterators: impl Iterator>, +) -> Option> { + let first_iterator = iterators.next()?; + let mut other_iterators = iterators.map(|i| i.peekable()).collect::>(); + + Some(first_iterator.filter(move |target| { + other_iterators + .iter_mut() + .map(|it| { + while let Some(element) = it.peek() { + if element > target { + return false; + } else if element == target { + return true; + } else { + it.next(); + } + } + + false + }) + .all(|b| b) + })) +}