improvement: cache actual destination
This commit is contained in:
		
							parent
							
								
									45086b54b3
								
							
						
					
					
						commit
						d62f17a91a
					
				
					 4 changed files with 74 additions and 35 deletions
				
			
		|  | @ -76,7 +76,7 @@ impl Database { | |||
|     } | ||||
| 
 | ||||
|     /// Load an existing database or create a new one.
 | ||||
|     pub fn load_or_create(config: Config) -> Result<Self> { | ||||
|     pub async fn load_or_create(config: Config) -> Result<Self> { | ||||
|         let path = config | ||||
|             .database_path | ||||
|             .clone() | ||||
|  | @ -106,7 +106,7 @@ impl Database { | |||
|         let (admin_sender, admin_receiver) = mpsc::unbounded(); | ||||
| 
 | ||||
|         let db = Self { | ||||
|             globals: globals::Globals::load(db.open_tree("global")?, config)?, | ||||
|             globals: globals::Globals::load(db.open_tree("global")?, config).await?, | ||||
|             users: users::Users { | ||||
|                 userid_password: db.open_tree("userid_password")?, | ||||
|                 userid_displayname: db.open_tree("userid_displayname")?, | ||||
|  |  | |||
|  | @ -1,20 +1,25 @@ | |||
| use crate::{database::Config, utils, Error, Result}; | ||||
| use trust_dns_resolver::TokioAsyncResolver; | ||||
| use std::collections::HashMap; | ||||
| use log::error; | ||||
| use ruma::ServerName; | ||||
| use std::sync::Arc; | ||||
| use std::sync::RwLock; | ||||
| 
 | ||||
| pub const COUNTER: &str = "c"; | ||||
| 
 | ||||
| #[derive(Clone)] | ||||
| pub struct Globals { | ||||
|     pub(super) globals: sled::Tree, | ||||
|     config: Config, | ||||
|     keypair: Arc<ruma::signatures::Ed25519KeyPair>, | ||||
|     reqwest_client: reqwest::Client, | ||||
|     config: Config, | ||||
|     pub actual_destination_cache: Arc<RwLock<HashMap<Box<ServerName>, (String, Option<String>)>>>, // actual_destination, host
 | ||||
|     dns_resolver: TokioAsyncResolver, | ||||
| } | ||||
| 
 | ||||
| impl Globals { | ||||
|     pub fn load(globals: sled::Tree, config: Config) -> Result<Self> { | ||||
|     pub async fn load(globals: sled::Tree, config: Config) -> Result<Self> { | ||||
|         let bytes = &*globals | ||||
|             .update_and_fetch("keypair", utils::generate_keypair)? | ||||
|             .expect("utils::generate_keypair always returns Some"); | ||||
|  | @ -51,9 +56,13 @@ impl Globals { | |||
| 
 | ||||
|         Ok(Self { | ||||
|             globals, | ||||
|             config, | ||||
|             keypair: Arc::new(keypair), | ||||
|             reqwest_client: reqwest::Client::new(), | ||||
|             config, | ||||
|             dns_resolver: TokioAsyncResolver::tokio_from_system_conf().await.map_err(|_| { | ||||
|                 Error::bad_config("Failed to set up trust dns resolver with system config.") | ||||
|             })?, | ||||
|             actual_destination_cache: Arc::new(RwLock::new(HashMap::new())), | ||||
|         }) | ||||
|     } | ||||
| 
 | ||||
|  | @ -103,4 +112,8 @@ impl Globals { | |||
|     pub fn federation_enabled(&self) -> bool { | ||||
|         self.config.federation_enabled | ||||
|     } | ||||
| 
 | ||||
|     pub fn dns_resolver(&self) -> &TokioAsyncResolver { | ||||
|         &self.dns_resolver | ||||
|     } | ||||
| } | ||||
|  |  | |||
|  | @ -136,6 +136,7 @@ fn setup_rocket() -> rocket::Rocket { | |||
|         .attach(AdHoc::on_attach("Config", |rocket| async { | ||||
|             let data = | ||||
|                 Database::load_or_create(rocket.figment().extract().expect("config is valid")) | ||||
|                     .await | ||||
|                     .expect("config is valid"); | ||||
| 
 | ||||
|             data.sending.start_handler(&data.globals, &data.rooms); | ||||
|  |  | |||
|  | @ -30,7 +30,6 @@ use std::{ | |||
|     sync::Arc, | ||||
|     time::{Duration, SystemTime}, | ||||
| }; | ||||
| use trust_dns_resolver::AsyncResolver; | ||||
| 
 | ||||
| pub async fn request_well_known( | ||||
|     globals: &crate::database::globals::Globals, | ||||
|  | @ -66,36 +65,26 @@ where | |||
|         return Err(Error::bad_config("Federation is disabled.")); | ||||
|     } | ||||
| 
 | ||||
|     let resolver = AsyncResolver::tokio_from_system_conf().await.map_err(|_| { | ||||
|         Error::bad_config("Failed to set up trust dns resolver with system config.") | ||||
|     })?; | ||||
|     let maybe_result = globals | ||||
|         .actual_destination_cache | ||||
|         .read() | ||||
|         .unwrap() | ||||
|         .get(&destination) | ||||
|         .cloned(); | ||||
| 
 | ||||
|     let mut host = None; | ||||
| 
 | ||||
|     let actual_destination = "https://".to_owned() | ||||
|         + &if let Some(mut delegated_hostname) = | ||||
|             request_well_known(globals, &destination.as_str()).await | ||||
|         { | ||||
|             if let Ok(Some(srv)) = resolver | ||||
|                 .srv_lookup(format!("_matrix._tcp.{}", delegated_hostname)) | ||||
|                 .await | ||||
|                 .map(|srv| srv.iter().next().map(|result| result.target().to_string())) | ||||
|             { | ||||
|                 host = Some(delegated_hostname); | ||||
|                 srv.trim_end_matches('.').to_owned() | ||||
|             } else { | ||||
|                 if delegated_hostname.find(':').is_none() { | ||||
|                     delegated_hostname += ":8448"; | ||||
|                 } | ||||
|                 delegated_hostname | ||||
|             } | ||||
|         } else { | ||||
|             let mut destination = destination.as_str().to_owned(); | ||||
|             if destination.find(':').is_none() { | ||||
|                 destination += ":8448"; | ||||
|             } | ||||
|             destination | ||||
|         }; | ||||
|     let (actual_destination, host) = if let Some(result) = maybe_result { | ||||
|         println!("Loaded {} -> {:?}", destination, result); | ||||
|         result | ||||
|     } else { | ||||
|         let result = find_actual_destination(globals, &destination).await; | ||||
|         globals | ||||
|             .actual_destination_cache | ||||
|             .write() | ||||
|             .unwrap() | ||||
|             .insert(destination.clone(), result.clone()); | ||||
|         println!("Saving {} -> {:?}", destination, result); | ||||
|         result | ||||
|     }; | ||||
| 
 | ||||
|     let mut http_request = request | ||||
|         .try_into_http_request(&actual_destination, Some("")) | ||||
|  | @ -232,6 +221,42 @@ where | |||
|     } | ||||
| } | ||||
| 
 | ||||
| /// Returns: actual_destination, host header
 | ||||
| async fn find_actual_destination( | ||||
|     globals: &crate::database::globals::Globals, | ||||
|     destination: &Box<ServerName>, | ||||
| ) -> (String, Option<String>) { | ||||
|     let mut host = None; | ||||
| 
 | ||||
|     let actual_destination = "https://".to_owned() | ||||
|         + &if let Some(mut delegated_hostname) = | ||||
|             request_well_known(globals, destination.as_str()).await | ||||
|         { | ||||
|             if let Ok(Some(srv)) = globals | ||||
|                 .dns_resolver() | ||||
|                 .srv_lookup(format!("_matrix._tcp.{}", delegated_hostname)) | ||||
|                 .await | ||||
|                 .map(|srv| srv.iter().next().map(|result| result.target().to_string())) | ||||
|             { | ||||
|                 host = Some(delegated_hostname); | ||||
|                 srv.trim_end_matches('.').to_owned() | ||||
|             } else { | ||||
|                 if delegated_hostname.find(':').is_none() { | ||||
|                     delegated_hostname += ":8448"; | ||||
|                 } | ||||
|                 delegated_hostname | ||||
|             } | ||||
|         } else { | ||||
|             let mut destination = destination.as_str().to_owned(); | ||||
|             if destination.find(':').is_none() { | ||||
|                 destination += ":8448"; | ||||
|             } | ||||
|             destination | ||||
|         }; | ||||
| 
 | ||||
|     (actual_destination, host) | ||||
| } | ||||
| 
 | ||||
| #[cfg_attr(feature = "conduit_bin", get("/_matrix/federation/v1/version"))] | ||||
| pub fn get_server_version_route( | ||||
|     db: State<'_, Database>, | ||||
|  |  | |||
		Loading…
	
		Reference in a new issue