diff --git a/src/database.rs b/src/database.rs index 883ef85..3b0bd6f 100644 --- a/src/database.rs +++ b/src/database.rs @@ -149,7 +149,8 @@ impl Database { userdevicetxnid_response: db.open_tree("userdevicetxnid_response")?, }, sending: sending::Sending { - serverpduids: db.open_tree("serverpduids")?, + servernamepduids: db.open_tree("servernamepduids")?, + servercurrentpdus: db.open_tree("servercurrentpdus")?, }, _db: db, }) diff --git a/src/database/rooms.rs b/src/database/rooms.rs index 35c3eac..1cc20a4 100644 --- a/src/database/rooms.rs +++ b/src/database/rooms.rs @@ -367,7 +367,7 @@ impl Rooms { } /// Returns the pdu. - pub fn get_pdu_json_from_id(&self, pdu_id: &IVec) -> Result> { + pub fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result> { self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| { Ok(Some( serde_json::from_slice(&pdu) diff --git a/src/database/sending.rs b/src/database/sending.rs index 24a783b..33ee530 100644 --- a/src/database/sending.rs +++ b/src/database/sending.rs @@ -1,8 +1,8 @@ -use std::{collections::HashSet, convert::TryFrom, time::SystemTime}; +use std::{collections::HashMap, convert::TryFrom, time::SystemTime}; use crate::{server_server, utils, Error, PduEvent, Result}; use federation::transactions::send_transaction_message; -use log::warn; +use log::debug; use rocket::futures::stream::{FuturesUnordered, StreamExt}; use ruma::{api::federation, ServerName}; use sled::IVec; @@ -10,54 +10,145 @@ use tokio::select; pub struct Sending { /// The state for a given state hash. - pub(super) serverpduids: sled::Tree, // ServerPduId = ServerName + PduId + pub(super) servernamepduids: sled::Tree, // ServernamePduId = ServerName + PduId + pub(super) servercurrentpdus: sled::Tree, // ServerCurrentPdus = ServerName + PduId (pduid can be empty for reservation) } impl Sending { pub fn start_handler(&self, globals: &super::globals::Globals, rooms: &super::rooms::Rooms) { - let serverpduids = self.serverpduids.clone(); + let servernamepduids = self.servernamepduids.clone(); + let servercurrentpdus = self.servercurrentpdus.clone(); let rooms = rooms.clone(); let globals = globals.clone(); tokio::spawn(async move { let mut futures = FuturesUnordered::new(); - let mut waiting_servers = HashSet::new(); - let mut subscriber = serverpduids.watch_prefix(b""); + // Retry requests we could not finish yet + let mut current_transactions = HashMap::new(); + + for (server, pdu) in servercurrentpdus + .iter() + .filter_map(|r| r.ok()) + .map(|(key, _)| { + let mut parts = key.splitn(2, |&b| b == 0xff); + let server = parts.next().expect("splitn always returns one element"); + let pdu = parts.next().ok_or_else(|| { + Error::bad_database("Invalid bytes in servercurrentpdus.") + })?; + + Ok::<_, Error>(( + Box::::try_from(utils::string_from_bytes(&server).map_err( + |_| { + Error::bad_database( + "Invalid server bytes in server_currenttransaction", + ) + }, + )?) + .map_err(|_| { + Error::bad_database( + "Invalid server string in server_currenttransaction", + ) + })?, + IVec::from(pdu), + )) + }) + .filter_map(|r| r.ok()) + { + if !pdu.is_empty() { + current_transactions + .entry(server) + .or_insert_with(Vec::new) + .push(pdu); + } + } + + for (server, pdus) in current_transactions { + futures.push(Self::handle_event(server, pdus, &globals, &rooms)); + } + + let mut subscriber = servernamepduids.watch_prefix(b""); loop { select! { Some(server) = futures.next() => { - warn!("response: {:?}", &server); - warn!("futures left: {}", &futures.len()); + debug!("response: {:?}", &server); match server { Ok((server, _response)) => { - waiting_servers.remove(&server) + let mut prefix = server.as_bytes().to_vec(); + prefix.push(0xff); + + for key in servercurrentpdus + .scan_prefix(&prefix) + .keys() + .filter_map(|r| r.ok()) + { + // Don't remove reservation yet + if prefix.len() != key.len() { + servercurrentpdus.remove(key).unwrap(); + } + } + + // Find events that have been added since starting the last request + let new_pdus = servernamepduids + .scan_prefix(&prefix) + .keys() + .filter_map(|r| r.ok()) + .map(|k| { + k.subslice(prefix.len(), k.len() - prefix.len()) + }).collect::>(); + + if !new_pdus.is_empty() { + for pdu_id in &new_pdus { + let mut current_key = prefix.clone(); + current_key.extend_from_slice(pdu_id); + servercurrentpdus.insert(¤t_key, &[]).unwrap(); + servernamepduids.remove(¤t_key).unwrap(); + } + + futures.push(Self::handle_event(server, new_pdus, &globals, &rooms)); + } else { + servercurrentpdus.remove(&prefix).unwrap(); + } } - Err((server, _e)) => { - waiting_servers.remove(&server) + Err((_server, _e)) => { + // TODO: exponential backoff } }; }, Some(event) = &mut subscriber => { if let sled::Event::Insert { key, .. } = event { - let serverpduid = key.clone(); - let mut parts = serverpduid.splitn(2, |&b| b == 0xff); + let servernamepduid = key.clone(); + let mut parts = servernamepduid.splitn(2, |&b| b == 0xff); if let Some((server, pdu_id)) = utils::string_from_bytes( parts .next() .expect("splitn will always return 1 or more elements"), ) - .map_err(|_| Error::bad_database("ServerName in serverpduid bytes are invalid.")) + .map_err(|_| Error::bad_database("ServerName in servernamepduid bytes are invalid.")) .and_then(|server_str|Box::::try_from(server_str) - .map_err(|_| Error::bad_database("ServerName in serverpduid is invalid."))) + .map_err(|_| Error::bad_database("ServerName in servernamepduid is invalid."))) .ok() - .filter(|server| waiting_servers.insert(server.clone())) .and_then(|server| parts - .next() - .ok_or_else(|| Error::bad_database("Invalid serverpduid in db.")).ok().map(|pdu_id| (server, pdu_id))) + .next() + .ok_or_else(|| Error::bad_database("Invalid servernamepduid in db.")) + .ok() + .map(|pdu_id| (server, pdu_id)) + ) + // TODO: exponential backoff + .filter(|(server, _)| { + let mut prefix = server.to_string().as_bytes().to_vec(); + prefix.push(0xff); + + servercurrentpdus + .compare_and_swap(prefix, Option::<&[u8]>::None, Some(&[])) // Try to reserve + == Ok(Ok(())) + }) { - futures.push(Self::handle_event(server, pdu_id.into(), &globals, &rooms)); + servercurrentpdus.insert(&key, &[]).unwrap(); + servernamepduids.remove(&key).unwrap(); + + futures.push(Self::handle_event(server, vec![pdu_id.into()], &globals, &rooms)); } } } @@ -70,38 +161,44 @@ impl Sending { let mut key = server.as_bytes().to_vec(); key.push(0xff); key.extend_from_slice(pdu_id); - self.serverpduids.insert(key, b"")?; + self.servernamepduids.insert(key, b"")?; Ok(()) } async fn handle_event( server: Box, - pdu_id: IVec, + pdu_ids: Vec, globals: &super::globals::Globals, rooms: &super::rooms::Rooms, ) -> std::result::Result< (Box, send_transaction_message::v1::Response), (Box, Error), > { - let pdu_json = PduEvent::convert_to_outgoing_federation_event( - rooms - .get_pdu_json_from_id(&pdu_id) - .map_err(|e| (server.clone(), e))? - .ok_or_else(|| { - ( - server.clone(), - Error::bad_database("Event in serverpduids not found in db."), - ) - })?, - ); + let pdu_jsons = pdu_ids + .iter() + .map(|pdu_id| { + Ok::<_, (Box, Error)>(PduEvent::convert_to_outgoing_federation_event( + rooms + .get_pdu_json_from_id(pdu_id) + .map_err(|e| (server.clone(), e))? + .ok_or_else(|| { + ( + server.clone(), + Error::bad_database("Event in servernamepduids not found in db."), + ) + })?, + )) + }) + .filter_map(|r| r.ok()) + .collect::>(); server_server::send_request( &globals, server.clone(), send_transaction_message::v1::Request { origin: globals.server_name(), - pdus: &[pdu_json], + pdus: &pdu_jsons, edus: &[], origin_server_ts: SystemTime::now(), transaction_id: &utils::random_string(16), diff --git a/src/server_server.rs b/src/server_server.rs index 184f333..ccb1399 100644 --- a/src/server_server.rs +++ b/src/server_server.rs @@ -186,7 +186,10 @@ where let body = reqwest_response .bytes() .await - .unwrap() + .unwrap_or_else(|e| { + warn!("server error: {}", e); + Vec::new().into() + }) // TODO: handle timeout .into_iter() .collect();