diff --git a/src/database.rs b/src/database.rs index 406ce77..4905070 100644 --- a/src/database.rs +++ b/src/database.rs @@ -76,7 +76,7 @@ impl Database { } /// Load an existing database or create a new one. - pub fn load_or_create(config: Config) -> Result { + pub async fn load_or_create(config: Config) -> Result { let path = config .database_path .clone() @@ -106,7 +106,7 @@ impl Database { let (admin_sender, admin_receiver) = mpsc::unbounded(); let db = Self { - globals: globals::Globals::load(db.open_tree("global")?, config)?, + globals: globals::Globals::load(db.open_tree("global")?, config).await?, users: users::Users { userid_password: db.open_tree("userid_password")?, userid_displayname: db.open_tree("userid_displayname")?, diff --git a/src/database/globals.rs b/src/database/globals.rs index 403fadd..1221609 100644 --- a/src/database/globals.rs +++ b/src/database/globals.rs @@ -1,20 +1,25 @@ use crate::{database::Config, utils, Error, Result}; +use trust_dns_resolver::TokioAsyncResolver; +use std::collections::HashMap; use log::error; use ruma::ServerName; use std::sync::Arc; +use std::sync::RwLock; pub const COUNTER: &str = "c"; #[derive(Clone)] pub struct Globals { pub(super) globals: sled::Tree, + config: Config, keypair: Arc, reqwest_client: reqwest::Client, - config: Config, + pub actual_destination_cache: Arc, (String, Option)>>>, // actual_destination, host + dns_resolver: TokioAsyncResolver, } impl Globals { - pub fn load(globals: sled::Tree, config: Config) -> Result { + pub async fn load(globals: sled::Tree, config: Config) -> Result { let bytes = &*globals .update_and_fetch("keypair", utils::generate_keypair)? .expect("utils::generate_keypair always returns Some"); @@ -51,9 +56,13 @@ impl Globals { Ok(Self { globals, + config, keypair: Arc::new(keypair), reqwest_client: reqwest::Client::new(), - config, + dns_resolver: TokioAsyncResolver::tokio_from_system_conf().await.map_err(|_| { + Error::bad_config("Failed to set up trust dns resolver with system config.") + })?, + actual_destination_cache: Arc::new(RwLock::new(HashMap::new())), }) } @@ -103,4 +112,8 @@ impl Globals { pub fn federation_enabled(&self) -> bool { self.config.federation_enabled } + + pub fn dns_resolver(&self) -> &TokioAsyncResolver { + &self.dns_resolver + } } diff --git a/src/main.rs b/src/main.rs index 75b74cc..58d3427 100644 --- a/src/main.rs +++ b/src/main.rs @@ -136,6 +136,7 @@ fn setup_rocket() -> rocket::Rocket { .attach(AdHoc::on_attach("Config", |rocket| async { let data = Database::load_or_create(rocket.figment().extract().expect("config is valid")) + .await .expect("config is valid"); data.sending.start_handler(&data.globals, &data.rooms); diff --git a/src/server_server.rs b/src/server_server.rs index da046d3..58dd872 100644 --- a/src/server_server.rs +++ b/src/server_server.rs @@ -30,7 +30,6 @@ use std::{ sync::Arc, time::{Duration, SystemTime}, }; -use trust_dns_resolver::AsyncResolver; pub async fn request_well_known( globals: &crate::database::globals::Globals, @@ -66,36 +65,26 @@ where return Err(Error::bad_config("Federation is disabled.")); } - let resolver = AsyncResolver::tokio_from_system_conf().await.map_err(|_| { - Error::bad_config("Failed to set up trust dns resolver with system config.") - })?; + let maybe_result = globals + .actual_destination_cache + .read() + .unwrap() + .get(&destination) + .cloned(); - let mut host = None; - - let actual_destination = "https://".to_owned() - + &if let Some(mut delegated_hostname) = - request_well_known(globals, &destination.as_str()).await - { - if let Ok(Some(srv)) = resolver - .srv_lookup(format!("_matrix._tcp.{}", delegated_hostname)) - .await - .map(|srv| srv.iter().next().map(|result| result.target().to_string())) - { - host = Some(delegated_hostname); - srv.trim_end_matches('.').to_owned() - } else { - if delegated_hostname.find(':').is_none() { - delegated_hostname += ":8448"; - } - delegated_hostname - } - } else { - let mut destination = destination.as_str().to_owned(); - if destination.find(':').is_none() { - destination += ":8448"; - } - destination - }; + let (actual_destination, host) = if let Some(result) = maybe_result { + println!("Loaded {} -> {:?}", destination, result); + result + } else { + let result = find_actual_destination(globals, &destination).await; + globals + .actual_destination_cache + .write() + .unwrap() + .insert(destination.clone(), result.clone()); + println!("Saving {} -> {:?}", destination, result); + result + }; let mut http_request = request .try_into_http_request(&actual_destination, Some("")) @@ -232,6 +221,42 @@ where } } +/// Returns: actual_destination, host header +async fn find_actual_destination( + globals: &crate::database::globals::Globals, + destination: &Box, +) -> (String, Option) { + let mut host = None; + + let actual_destination = "https://".to_owned() + + &if let Some(mut delegated_hostname) = + request_well_known(globals, destination.as_str()).await + { + if let Ok(Some(srv)) = globals + .dns_resolver() + .srv_lookup(format!("_matrix._tcp.{}", delegated_hostname)) + .await + .map(|srv| srv.iter().next().map(|result| result.target().to_string())) + { + host = Some(delegated_hostname); + srv.trim_end_matches('.').to_owned() + } else { + if delegated_hostname.find(':').is_none() { + delegated_hostname += ":8448"; + } + delegated_hostname + } + } else { + let mut destination = destination.as_str().to_owned(); + if destination.find(':').is_none() { + destination += ":8448"; + } + destination + }; + + (actual_destination, host) +} + #[cfg_attr(feature = "conduit_bin", get("/_matrix/federation/v1/version"))] pub fn get_server_version_route( db: State<'_, Database>,