improvement: cache actual destination
parent
45086b54b3
commit
d62f17a91a
|
@ -76,7 +76,7 @@ impl Database {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Load an existing database or create a new one.
|
/// Load an existing database or create a new one.
|
||||||
pub fn load_or_create(config: Config) -> Result<Self> {
|
pub async fn load_or_create(config: Config) -> Result<Self> {
|
||||||
let path = config
|
let path = config
|
||||||
.database_path
|
.database_path
|
||||||
.clone()
|
.clone()
|
||||||
|
@ -106,7 +106,7 @@ impl Database {
|
||||||
let (admin_sender, admin_receiver) = mpsc::unbounded();
|
let (admin_sender, admin_receiver) = mpsc::unbounded();
|
||||||
|
|
||||||
let db = Self {
|
let db = Self {
|
||||||
globals: globals::Globals::load(db.open_tree("global")?, config)?,
|
globals: globals::Globals::load(db.open_tree("global")?, config).await?,
|
||||||
users: users::Users {
|
users: users::Users {
|
||||||
userid_password: db.open_tree("userid_password")?,
|
userid_password: db.open_tree("userid_password")?,
|
||||||
userid_displayname: db.open_tree("userid_displayname")?,
|
userid_displayname: db.open_tree("userid_displayname")?,
|
||||||
|
|
|
@ -1,20 +1,25 @@
|
||||||
use crate::{database::Config, utils, Error, Result};
|
use crate::{database::Config, utils, Error, Result};
|
||||||
|
use trust_dns_resolver::TokioAsyncResolver;
|
||||||
|
use std::collections::HashMap;
|
||||||
use log::error;
|
use log::error;
|
||||||
use ruma::ServerName;
|
use ruma::ServerName;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use std::sync::RwLock;
|
||||||
|
|
||||||
pub const COUNTER: &str = "c";
|
pub const COUNTER: &str = "c";
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct Globals {
|
pub struct Globals {
|
||||||
pub(super) globals: sled::Tree,
|
pub(super) globals: sled::Tree,
|
||||||
|
config: Config,
|
||||||
keypair: Arc<ruma::signatures::Ed25519KeyPair>,
|
keypair: Arc<ruma::signatures::Ed25519KeyPair>,
|
||||||
reqwest_client: reqwest::Client,
|
reqwest_client: reqwest::Client,
|
||||||
config: Config,
|
pub actual_destination_cache: Arc<RwLock<HashMap<Box<ServerName>, (String, Option<String>)>>>, // actual_destination, host
|
||||||
|
dns_resolver: TokioAsyncResolver,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Globals {
|
impl Globals {
|
||||||
pub fn load(globals: sled::Tree, config: Config) -> Result<Self> {
|
pub async fn load(globals: sled::Tree, config: Config) -> Result<Self> {
|
||||||
let bytes = &*globals
|
let bytes = &*globals
|
||||||
.update_and_fetch("keypair", utils::generate_keypair)?
|
.update_and_fetch("keypair", utils::generate_keypair)?
|
||||||
.expect("utils::generate_keypair always returns Some");
|
.expect("utils::generate_keypair always returns Some");
|
||||||
|
@ -51,9 +56,13 @@ impl Globals {
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
globals,
|
globals,
|
||||||
|
config,
|
||||||
keypair: Arc::new(keypair),
|
keypair: Arc::new(keypair),
|
||||||
reqwest_client: reqwest::Client::new(),
|
reqwest_client: reqwest::Client::new(),
|
||||||
config,
|
dns_resolver: TokioAsyncResolver::tokio_from_system_conf().await.map_err(|_| {
|
||||||
|
Error::bad_config("Failed to set up trust dns resolver with system config.")
|
||||||
|
})?,
|
||||||
|
actual_destination_cache: Arc::new(RwLock::new(HashMap::new())),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -103,4 +112,8 @@ impl Globals {
|
||||||
pub fn federation_enabled(&self) -> bool {
|
pub fn federation_enabled(&self) -> bool {
|
||||||
self.config.federation_enabled
|
self.config.federation_enabled
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn dns_resolver(&self) -> &TokioAsyncResolver {
|
||||||
|
&self.dns_resolver
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -136,6 +136,7 @@ fn setup_rocket() -> rocket::Rocket {
|
||||||
.attach(AdHoc::on_attach("Config", |rocket| async {
|
.attach(AdHoc::on_attach("Config", |rocket| async {
|
||||||
let data =
|
let data =
|
||||||
Database::load_or_create(rocket.figment().extract().expect("config is valid"))
|
Database::load_or_create(rocket.figment().extract().expect("config is valid"))
|
||||||
|
.await
|
||||||
.expect("config is valid");
|
.expect("config is valid");
|
||||||
|
|
||||||
data.sending.start_handler(&data.globals, &data.rooms);
|
data.sending.start_handler(&data.globals, &data.rooms);
|
||||||
|
|
|
@ -30,7 +30,6 @@ use std::{
|
||||||
sync::Arc,
|
sync::Arc,
|
||||||
time::{Duration, SystemTime},
|
time::{Duration, SystemTime},
|
||||||
};
|
};
|
||||||
use trust_dns_resolver::AsyncResolver;
|
|
||||||
|
|
||||||
pub async fn request_well_known(
|
pub async fn request_well_known(
|
||||||
globals: &crate::database::globals::Globals,
|
globals: &crate::database::globals::Globals,
|
||||||
|
@ -66,35 +65,25 @@ where
|
||||||
return Err(Error::bad_config("Federation is disabled."));
|
return Err(Error::bad_config("Federation is disabled."));
|
||||||
}
|
}
|
||||||
|
|
||||||
let resolver = AsyncResolver::tokio_from_system_conf().await.map_err(|_| {
|
let maybe_result = globals
|
||||||
Error::bad_config("Failed to set up trust dns resolver with system config.")
|
.actual_destination_cache
|
||||||
})?;
|
.read()
|
||||||
|
.unwrap()
|
||||||
|
.get(&destination)
|
||||||
|
.cloned();
|
||||||
|
|
||||||
let mut host = None;
|
let (actual_destination, host) = if let Some(result) = maybe_result {
|
||||||
|
println!("Loaded {} -> {:?}", destination, result);
|
||||||
let actual_destination = "https://".to_owned()
|
result
|
||||||
+ &if let Some(mut delegated_hostname) =
|
|
||||||
request_well_known(globals, &destination.as_str()).await
|
|
||||||
{
|
|
||||||
if let Ok(Some(srv)) = resolver
|
|
||||||
.srv_lookup(format!("_matrix._tcp.{}", delegated_hostname))
|
|
||||||
.await
|
|
||||||
.map(|srv| srv.iter().next().map(|result| result.target().to_string()))
|
|
||||||
{
|
|
||||||
host = Some(delegated_hostname);
|
|
||||||
srv.trim_end_matches('.').to_owned()
|
|
||||||
} else {
|
} else {
|
||||||
if delegated_hostname.find(':').is_none() {
|
let result = find_actual_destination(globals, &destination).await;
|
||||||
delegated_hostname += ":8448";
|
globals
|
||||||
}
|
.actual_destination_cache
|
||||||
delegated_hostname
|
.write()
|
||||||
}
|
.unwrap()
|
||||||
} else {
|
.insert(destination.clone(), result.clone());
|
||||||
let mut destination = destination.as_str().to_owned();
|
println!("Saving {} -> {:?}", destination, result);
|
||||||
if destination.find(':').is_none() {
|
result
|
||||||
destination += ":8448";
|
|
||||||
}
|
|
||||||
destination
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut http_request = request
|
let mut http_request = request
|
||||||
|
@ -232,6 +221,42 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns: actual_destination, host header
|
||||||
|
async fn find_actual_destination(
|
||||||
|
globals: &crate::database::globals::Globals,
|
||||||
|
destination: &Box<ServerName>,
|
||||||
|
) -> (String, Option<String>) {
|
||||||
|
let mut host = None;
|
||||||
|
|
||||||
|
let actual_destination = "https://".to_owned()
|
||||||
|
+ &if let Some(mut delegated_hostname) =
|
||||||
|
request_well_known(globals, destination.as_str()).await
|
||||||
|
{
|
||||||
|
if let Ok(Some(srv)) = globals
|
||||||
|
.dns_resolver()
|
||||||
|
.srv_lookup(format!("_matrix._tcp.{}", delegated_hostname))
|
||||||
|
.await
|
||||||
|
.map(|srv| srv.iter().next().map(|result| result.target().to_string()))
|
||||||
|
{
|
||||||
|
host = Some(delegated_hostname);
|
||||||
|
srv.trim_end_matches('.').to_owned()
|
||||||
|
} else {
|
||||||
|
if delegated_hostname.find(':').is_none() {
|
||||||
|
delegated_hostname += ":8448";
|
||||||
|
}
|
||||||
|
delegated_hostname
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
let mut destination = destination.as_str().to_owned();
|
||||||
|
if destination.find(':').is_none() {
|
||||||
|
destination += ":8448";
|
||||||
|
}
|
||||||
|
destination
|
||||||
|
};
|
||||||
|
|
||||||
|
(actual_destination, host)
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg_attr(feature = "conduit_bin", get("/_matrix/federation/v1/version"))]
|
#[cfg_attr(feature = "conduit_bin", get("/_matrix/federation/v1/version"))]
|
||||||
pub fn get_server_version_route(
|
pub fn get_server_version_route(
|
||||||
db: State<'_, Database>,
|
db: State<'_, Database>,
|
||||||
|
|
Loading…
Reference in New Issue