From 8f27e6123b8d142ec647808e659859c9bc37a131 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timo=20K=C3=B6sters?= Date: Mon, 17 May 2021 10:25:27 +0200 Subject: [PATCH] feat: send read receipts over federation currently they will only be sent if a PDU has to be sent as well --- src/client_server/sync.rs | 1 + src/database.rs | 29 +++++++++- src/database/globals.rs | 12 +++++ src/database/rooms.rs | 28 ++++++++++ src/database/rooms/edus.rs | 29 ++++++++-- src/database/sending.rs | 108 +++++++++++++++++++++++++++++++++++-- 6 files changed, 197 insertions(+), 10 deletions(-) diff --git a/src/client_server/sync.rs b/src/client_server/sync.rs index 66a1e13..fe6f692 100644 --- a/src/client_server/sync.rs +++ b/src/client_server/sync.rs @@ -406,6 +406,7 @@ pub async fn sync_events_route( .edus .readreceipts_since(&room_id, since)? .filter_map(|r| r.ok()) // Filter out buggy events + .map(|(_, _, v)| v) .collect::>(); if db.rooms.edus.last_typing_update(&room_id, &db.globals)? > since { diff --git a/src/database.rs b/src/database.rs index 62b3a40..6b68b9e 100644 --- a/src/database.rs +++ b/src/database.rs @@ -14,7 +14,7 @@ pub mod users; use crate::{Error, Result}; use directories::ProjectDirs; use futures::StreamExt; -use log::info; +use log::{error, info}; use rocket::futures::{self, channel::mpsc}; use ruma::{DeviceId, ServerName, UserId}; use serde::Deserialize; @@ -160,6 +160,7 @@ impl Database { tokenids: db.open_tree("tokenids")?, roomserverids: db.open_tree("roomserverids")?, + serverroomids: db.open_tree("serverroomids")?, userroomid_joined: db.open_tree("userroomid_joined")?, roomuserid_joined: db.open_tree("roomuserid_joined")?, roomuseroncejoinedids: db.open_tree("roomuseroncejoinedids")?, @@ -197,6 +198,7 @@ impl Database { userdevicetxnid_response: db.open_tree("userdevicetxnid_response")?, }, sending: sending::Sending { + servername_educount: db.open_tree("servername_educount")?, servernamepduids: db.open_tree("servernamepduids")?, servercurrentevents: db.open_tree("servercurrentevents")?, maximum_requests: Arc::new(Semaphore::new(config.max_concurrent_requests as usize)), @@ -217,6 +219,31 @@ impl Database { _db: db, }; + // MIGRATIONS + if db.globals.database_version()? < 1 { + for roomserverid in db.rooms.roomserverids.iter().keys() { + let roomserverid = roomserverid?; + let mut parts = roomserverid.split(|&b| b == 0xff); + let room_id = parts.next().expect("split always returns one element"); + let servername = match parts.next() { + Some(s) => s, + None => { + error!("Migration: Invalid roomserverid in db."); + continue; + } + }; + let mut serverroomid = servername.to_vec(); + serverroomid.push(0xff); + serverroomid.extend_from_slice(room_id); + + db.rooms.serverroomids.insert(serverroomid, &[])?; + } + + db.globals.bump_database_version(1)?; + + info!("Migration: 0 -> 1 finished"); + } + // This data is probably outdated db.rooms.edus.presenceid_presence.clear()?; diff --git a/src/database/globals.rs b/src/database/globals.rs index 04f8d29..c1eafe0 100644 --- a/src/database/globals.rs +++ b/src/database/globals.rs @@ -258,4 +258,16 @@ impl Globals { } Ok(response) } + + pub fn database_version(&self) -> Result { + self.globals.get("version")?.map_or(Ok(0), |version| { + utils::u64_from_bytes(&version) + .map_err(|_| Error::bad_database("Database version id is invalid.")) + }) + } + + pub fn bump_database_version(&self, new_version: u64) -> Result<()> { + self.globals.insert("version", &new_version.to_be_bytes())?; + Ok(()) + } } diff --git a/src/database/rooms.rs b/src/database/rooms.rs index c359997..48e6e11 100644 --- a/src/database/rooms.rs +++ b/src/database/rooms.rs @@ -50,6 +50,8 @@ pub struct Rooms { /// Participating servers in a room. pub(super) roomserverids: sled::Tree, // RoomServerId = RoomId + ServerName + pub(super) serverroomids: sled::Tree, // ServerRoomId = ServerName + RoomId + pub(super) userroomid_joined: sled::Tree, pub(super) roomuserid_joined: sled::Tree, pub(super) roomuseroncejoinedids: sled::Tree, @@ -1597,6 +1599,10 @@ impl Rooms { roomserver_id.push(0xff); roomserver_id.extend_from_slice(user_id.server_name().as_bytes()); + let mut serverroom_id = user_id.server_name().as_bytes().to_vec(); + serverroom_id.push(0xff); + serverroom_id.extend_from_slice(room_id.as_bytes()); + let mut userroom_id = user_id.as_bytes().to_vec(); userroom_id.push(0xff); userroom_id.extend_from_slice(room_id.as_bytes()); @@ -1700,6 +1706,7 @@ impl Rooms { } self.roomserverids.insert(&roomserver_id, &[])?; + self.serverroomids.insert(&serverroom_id, &[])?; self.userroomid_joined.insert(&userroom_id, &[])?; self.roomuserid_joined.insert(&roomuser_id, &[])?; self.userroomid_invitestate.remove(&userroom_id)?; @@ -1725,6 +1732,7 @@ impl Rooms { } self.roomserverids.insert(&roomserver_id, &[])?; + self.serverroomids.insert(&serverroom_id, &[])?; self.userroomid_invitestate.insert( &userroom_id, serde_json::to_vec(&last_state.unwrap_or_default()) @@ -1745,6 +1753,7 @@ impl Rooms { .all(|u| u.server_name() != user_id.server_name()) { self.roomserverids.remove(&roomserver_id)?; + self.serverroomids.remove(&serverroom_id)?; } self.userroomid_leftstate.insert( &userroom_id, @@ -2152,6 +2161,25 @@ impl Rooms { }) } + /// Returns an iterator of all rooms a server participates in (as far as we know). + pub fn server_rooms(&self, server: &ServerName) -> impl Iterator> { + let mut prefix = server.as_bytes().to_vec(); + prefix.push(0xff); + + self.serverroomids.scan_prefix(prefix).keys().map(|key| { + Ok(RoomId::try_from( + utils::string_from_bytes( + &key? + .rsplit(|&b| b == 0xff) + .next() + .expect("rsplit always returns an element"), + ) + .map_err(|_| Error::bad_database("RoomId in serverroomids is invalid unicode."))?, + ) + .map_err(|_| Error::bad_database("RoomId in serverroomids is invalid."))?) + }) + } + /// Returns an iterator over all joined members of a room. #[tracing::instrument(skip(self))] pub fn room_members(&self, room_id: &RoomId) -> impl Iterator> { diff --git a/src/database/rooms/edus.rs b/src/database/rooms/edus.rs index 3bf2e06..89f2905 100644 --- a/src/database/rooms/edus.rs +++ b/src/database/rooms/edus.rs @@ -76,9 +76,12 @@ impl RoomEdus { &self, room_id: &RoomId, since: u64, - ) -> Result>>> { + ) -> Result< + impl Iterator)>>, + > { let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xff); + let prefix2 = prefix.clone(); let mut first_possible_edu = prefix.clone(); first_possible_edu.extend_from_slice(&(since + 1).to_be_bytes()); // +1 so we don't send the event at since @@ -87,14 +90,30 @@ impl RoomEdus { .readreceiptid_readreceipt .range(&*first_possible_edu..) .filter_map(|r| r.ok()) - .take_while(move |(k, _)| k.starts_with(&prefix)) - .map(|(_, v)| { + .take_while(move |(k, _)| k.starts_with(&prefix2)) + .map(move |(k, v)| { + let count = + utils::u64_from_bytes(&k[prefix.len()..prefix.len() + mem::size_of::()]) + .map_err(|_| Error::bad_database("Invalid readreceiptid count in db."))?; + let user_id = UserId::try_from( + utils::string_from_bytes(&k[prefix.len() + mem::size_of::() + 1..]) + .map_err(|_| { + Error::bad_database("Invalid readreceiptid userid bytes in db.") + })?, + ) + .map_err(|_| Error::bad_database("Invalid readreceiptid userid in db."))?; + let mut json = serde_json::from_slice::(&v).map_err(|_| { Error::bad_database("Read receipt in roomlatestid_roomlatest is invalid json.") })?; json.remove("room_id"); - Ok(Raw::from_json( - serde_json::value::to_raw_value(&json).expect("json is valid raw value"), + + Ok(( + user_id, + count, + Raw::from_json( + serde_json::value::to_raw_value(&json).expect("json is valid raw value"), + ), )) })) } diff --git a/src/database/sending.rs b/src/database/sending.rs index e530396..199bd05 100644 --- a/src/database/sending.rs +++ b/src/database/sending.rs @@ -1,5 +1,5 @@ use std::{ - collections::HashMap, + collections::{BTreeMap, HashMap}, convert::{TryFrom, TryInto}, fmt::Debug, sync::Arc, @@ -14,8 +14,15 @@ use log::{error, warn}; use ring::digest; use rocket::futures::stream::{FuturesUnordered, StreamExt}; use ruma::{ - api::{appservice, federation, OutgoingRequest}, - events::{push_rules, EventType}, + api::{ + appservice, + federation::{ + self, + transactions::edu::{Edu, ReceiptContent, ReceiptData, ReceiptMap}, + }, + OutgoingRequest, + }, + events::{push_rules, AnySyncEphemeralRoomEvent, EventType}, push, ServerName, UInt, UserId, }; use sled::IVec; @@ -64,6 +71,7 @@ pub enum SendingEventType { #[derive(Clone)] pub struct Sending { /// The state for a given state hash. + pub(super) servername_educount: sled::Tree, // EduCount: Count of last EDU sync pub(super) servernamepduids: sled::Tree, // ServernamePduId = (+ / $)SenderKey / ServerName / UserId + PduId pub(super) servercurrentevents: sled::Tree, // ServerCurrentEvents = (+ / $)ServerName / UserId + PduId / (*)EduEvent pub(super) maximum_requests: Arc, @@ -194,7 +202,7 @@ impl Sending { if let sled::Event::Insert { key, .. } = event { if let Ok((outgoing_kind, event)) = Self::parse_servercurrentevent(&key) { - if let Some(events) = Self::select_events(&outgoing_kind, vec![(event, key)], &mut current_transaction_status, &servercurrentevents, &servernamepduids) { + if let Some(events) = Self::select_events(&outgoing_kind, vec![(event, key)], &mut current_transaction_status, &servercurrentevents, &servernamepduids, &db) { futures.push(Self::handle_events(outgoing_kind, events, &db)); } } @@ -211,6 +219,7 @@ impl Sending { current_transaction_status: &mut HashMap, TransactionStatus>, servercurrentevents: &sled::Tree, servernamepduids: &sled::Tree, + db: &Database, ) -> Option> { let mut retry = false; let mut allow = true; @@ -267,11 +276,102 @@ impl Sending { events.push(e); } + + match outgoing_kind { + OutgoingKind::Normal(server_name) => { + if let Ok((select_edus, last_count)) = Self::select_edus(db, server_name) { + events.extend_from_slice(&select_edus); + db.sending + .servername_educount + .insert(server_name.as_bytes(), &last_count.to_be_bytes()) + .unwrap(); + } + } + _ => {} + } } Some(events) } + pub fn select_edus(db: &Database, server: &ServerName) -> Result<(Vec, u64)> { + // u64: count of last edu + let since = db + .sending + .servername_educount + .get(server.as_bytes())? + .map_or(Ok(0), |bytes| { + utils::u64_from_bytes(&bytes) + .map_err(|_| Error::bad_database("Invalid u64 in servername_educount.")) + })?; + let mut events = Vec::new(); + let mut max_edu_count = since; + 'outer: for room_id in db.rooms.server_rooms(server) { + let room_id = room_id?; + for r in db.rooms.edus.readreceipts_since(&room_id, since)? { + let (user_id, count, read_receipt) = r?; + + if count > max_edu_count { + max_edu_count = count; + } + + if user_id.server_name() != db.globals.server_name() { + continue; + } + + let event = + serde_json::from_str::(&read_receipt.json().get()) + .map_err(|_| Error::bad_database("Invalid edu event in read_receipts."))?; + let federation_event = match event { + AnySyncEphemeralRoomEvent::Receipt(r) => { + let mut read = BTreeMap::new(); + + let (event_id, receipt) = r + .content + .0 + .into_iter() + .next() + .expect("we only use one event per read receipt"); + let receipt = receipt + .read + .expect("our read receipts always set this") + .remove(&user_id) + .expect("our read receipts always have the user here"); + + read.insert( + user_id, + ReceiptData { + data: receipt.clone(), + event_ids: vec![event_id.clone()], + }, + ); + + let receipt_map = ReceiptMap { read }; + + let mut receipts = BTreeMap::new(); + receipts.insert(room_id.clone(), receipt_map); + + Edu::Receipt(ReceiptContent { receipts }) + } + _ => { + Error::bad_database("Invalid event type in read_receipts"); + continue; + } + }; + + events.push(SendingEventType::Edu( + serde_json::to_vec(&federation_event).expect("json can be serialized"), + )); + + if events.len() >= 20 { + break 'outer; + } + } + } + + Ok((events, max_edu_count)) + } + #[tracing::instrument(skip(self))] pub fn send_push_pdu(&self, pdu_id: &[u8], senderkey: IVec) -> Result<()> { let mut key = b"$".to_vec();