diff --git a/src/client_server/keys.rs b/src/client_server/keys.rs index 1ae9f80..418e41a 100644 --- a/src/client_server/keys.rs +++ b/src/client_server/keys.rs @@ -302,6 +302,7 @@ pub async fn get_keys_helper bool>( .entry(user_id.server_name()) .or_insert_with(Vec::new) .push((user_id, device_ids)); + continue; } if device_ids.is_empty() { @@ -364,20 +365,29 @@ pub async fn get_keys_helper bool>( let mut failures = BTreeMap::new(); for (server, vec) in get_over_federation { - let mut device_keys = BTreeMap::new(); + let mut device_keys_input_fed = BTreeMap::new(); for (user_id, keys) in vec { - device_keys.insert(user_id.clone(), keys.clone()); + device_keys_input_fed.insert(user_id.clone(), keys.clone()); } - if let Err(_e) = db + match db .sending .send_federation_request( &db.globals, server, - federation::keys::get_keys::v1::Request { device_keys }, + federation::keys::get_keys::v1::Request { + device_keys: device_keys_input_fed, + }, ) .await { - failures.insert(server.to_string(), json!({})); + Ok(response) => { + master_keys.extend(response.master_keys); + self_signing_keys.extend(response.self_signing_keys); + device_keys.extend(response.device_keys); + } + Err(_e) => { + failures.insert(server.to_string(), json!({})); + } } } diff --git a/src/client_server/to_device.rs b/src/client_server/to_device.rs index 7896af9..e3fd780 100644 --- a/src/client_server/to_device.rs +++ b/src/client_server/to_device.rs @@ -1,6 +1,12 @@ +use std::collections::BTreeMap; + use crate::{database::DatabaseGuard, ConduitResult, Error, Ruma}; use ruma::{ - api::client::{error::ErrorKind, r0::to_device::send_event_to_device}, + api::{ + client::{error::ErrorKind, r0::to_device::send_event_to_device}, + federation::{self, transactions::edu::DirectDeviceContent}, + }, + events::EventType, to_device::DeviceIdOrAllDevices, }; @@ -33,6 +39,28 @@ pub async fn send_event_to_device_route( for (target_user_id, map) in &body.messages { for (target_device_id_maybe, event) in map { + if target_user_id.server_name() != db.globals.server_name() { + let mut map = BTreeMap::new(); + map.insert(target_device_id_maybe.clone(), event.clone()); + let mut messages = BTreeMap::new(); + messages.insert(target_user_id.clone(), map); + + db.sending.send_reliable_edu( + target_user_id.server_name(), + &serde_json::to_vec(&federation::transactions::edu::Edu::DirectToDevice( + DirectDeviceContent { + sender: sender_user.clone(), + ev_type: EventType::from(&body.event_type), + message_id: body.txn_id.clone(), + messages, + }, + )) + .expect("DirectToDevice EDU can be serialized"), + )?; + + continue; + } + match target_device_id_maybe { DeviceIdOrAllDevices::DeviceId(target_device_id) => db.users.add_to_device_event( sender_user, diff --git a/src/database/sending.rs b/src/database/sending.rs index 7c9cf64..8dfcbee 100644 --- a/src/database/sending.rs +++ b/src/database/sending.rs @@ -164,9 +164,10 @@ impl Sending { // Find events that have been added since starting the last request let new_events = guard.sending.servernamepduids .scan_prefix(prefix.clone()) - .map(|(k, _)| { - SendingEventType::Pdu(k[prefix.len()..].to_vec()) + .filter_map(|(k, _)| { + Self::parse_servercurrentevent(&k).ok() }) + .map(|(_, event)| event) .take(30) .collect::>(); @@ -290,7 +291,14 @@ impl Sending { if let OutgoingKind::Normal(server_name) = outgoing_kind { if let Ok((select_edus, last_count)) = Self::select_edus(db, server_name) { - events.extend_from_slice(&select_edus); + for edu in &select_edus { + let mut full_key = vec![b'*']; + full_key.extend_from_slice(&edu); + db.sending.servercurrentevents.insert(&full_key, &[])?; + } + + events.extend(select_edus.into_iter().map(SendingEventType::Edu)); + db.sending .servername_educount .insert(server_name.as_bytes(), &last_count.to_be_bytes())?; @@ -301,7 +309,7 @@ impl Sending { Ok(Some(events)) } - pub fn select_edus(db: &Database, server: &ServerName) -> Result<(Vec, u64)> { + pub fn select_edus(db: &Database, server: &ServerName) -> Result<(Vec>, u64)> { // u64: count of last edu let since = db .sending @@ -366,9 +374,7 @@ impl Sending { } }; - events.push(SendingEventType::Edu( - serde_json::to_vec(&federation_event).expect("json can be serialized"), - )); + events.push(serde_json::to_vec(&federation_event).expect("json can be serialized")); if events.len() >= 20 { break 'outer; @@ -402,6 +408,18 @@ impl Sending { Ok(()) } + #[tracing::instrument(skip(self))] + pub fn send_reliable_edu(&self, server: &ServerName, serialized: &[u8]) -> Result<()> { + let mut key = server.as_bytes().to_vec(); + key.push(0xff); + key.push(b'*'); + key.extend_from_slice(serialized); + self.servernamepduids.insert(&key, b"")?; + self.sender.unbounded_send(key).unwrap(); + + Ok(()) + } + #[tracing::instrument(skip(self))] pub fn send_pdu_appservice(&self, appservice_id: &str, pdu_id: &[u8]) -> Result<()> { let mut key = b"+".to_vec();