fix: server resolution
This commit is contained in:
		
							parent
							
								
									0330d3e270
								
							
						
					
					
						commit
						19b89ab91f
					
				
					 6 changed files with 156 additions and 111 deletions
				
			
		|  | @ -46,7 +46,11 @@ where | |||
|     *reqwest_request.timeout_mut() = Some(Duration::from_secs(30)); | ||||
| 
 | ||||
|     let url = reqwest_request.url().clone(); | ||||
|     let mut response = globals.reqwest_client().execute(reqwest_request).await?; | ||||
|     let mut response = globals | ||||
|         .reqwest_client()? | ||||
|         .build()? | ||||
|         .execute(reqwest_request) | ||||
|         .await?; | ||||
| 
 | ||||
|     // reqwest::Response -> http::Response conversion
 | ||||
|     let status = response.status(); | ||||
|  |  | |||
|  | @ -19,10 +19,7 @@ use ruma::{ | |||
|     DeviceId, DeviceKeyAlgorithm, UserId, | ||||
| }; | ||||
| use serde_json::json; | ||||
| use std::{ | ||||
|     collections::{BTreeMap, HashMap, HashSet}, | ||||
|     time::{Duration, Instant}, | ||||
| }; | ||||
| use std::collections::{BTreeMap, HashMap, HashSet}; | ||||
| 
 | ||||
| #[cfg(feature = "conduit_bin")] | ||||
| use rocket::{get, post}; | ||||
|  |  | |||
|  | @ -1,4 +1,4 @@ | |||
| use crate::{database::Config, utils, ConduitResult, Error, Result}; | ||||
| use crate::{database::Config, server_server::FedDest, utils, ConduitResult, Error, Result}; | ||||
| use ruma::{ | ||||
|     api::{ | ||||
|         client::r0::sync::sync_events, | ||||
|  | @ -6,25 +6,25 @@ use ruma::{ | |||
|     }, | ||||
|     DeviceId, EventId, MilliSecondsSinceUnixEpoch, RoomId, ServerName, ServerSigningKeyId, UserId, | ||||
| }; | ||||
| use rustls::{ServerCertVerifier, WebPKIVerifier}; | ||||
| use std::{ | ||||
|     collections::{BTreeMap, HashMap}, | ||||
|     fs, | ||||
|     future::Future, | ||||
|     net::IpAddr, | ||||
|     path::PathBuf, | ||||
|     sync::{Arc, Mutex, RwLock}, | ||||
|     time::{Duration, Instant}, | ||||
| }; | ||||
| use tokio::sync::{broadcast, watch::Receiver, Mutex as TokioMutex, Semaphore}; | ||||
| use tracing::{error, info}; | ||||
| use tracing::error; | ||||
| use trust_dns_resolver::TokioAsyncResolver; | ||||
| 
 | ||||
| use super::abstraction::Tree; | ||||
| 
 | ||||
| pub const COUNTER: &[u8] = b"c"; | ||||
| 
 | ||||
| type WellKnownMap = HashMap<Box<ServerName>, (String, String)>; | ||||
| type TlsNameMap = HashMap<String, webpki::DNSName>; | ||||
| type WellKnownMap = HashMap<Box<ServerName>, (FedDest, String)>; | ||||
| type TlsNameMap = HashMap<String, (Vec<IpAddr>, u16)>; | ||||
| type RateLimitState = (Instant, u32); // Time if last failed try, number of failed tries
 | ||||
| type SyncHandle = ( | ||||
|     Option<String>,                                         // since
 | ||||
|  | @ -37,7 +37,6 @@ pub struct Globals { | |||
|     pub(super) globals: Arc<dyn Tree>, | ||||
|     config: Config, | ||||
|     keypair: Arc<ruma::signatures::Ed25519KeyPair>, | ||||
|     reqwest_client: reqwest::Client, | ||||
|     dns_resolver: TokioAsyncResolver, | ||||
|     jwt_decoding_key: Option<jsonwebtoken::DecodingKey<'static>>, | ||||
|     pub(super) server_signingkeys: Arc<dyn Tree>, | ||||
|  | @ -51,40 +50,6 @@ pub struct Globals { | |||
|     pub rotate: RotationHandler, | ||||
| } | ||||
| 
 | ||||
| struct MatrixServerVerifier { | ||||
|     inner: WebPKIVerifier, | ||||
|     tls_name_override: Arc<RwLock<TlsNameMap>>, | ||||
| } | ||||
| 
 | ||||
| impl ServerCertVerifier for MatrixServerVerifier { | ||||
|     #[tracing::instrument(skip(self, roots, presented_certs, dns_name, ocsp_response))] | ||||
|     fn verify_server_cert( | ||||
|         &self, | ||||
|         roots: &rustls::RootCertStore, | ||||
|         presented_certs: &[rustls::Certificate], | ||||
|         dns_name: webpki::DNSNameRef<'_>, | ||||
|         ocsp_response: &[u8], | ||||
|     ) -> std::result::Result<rustls::ServerCertVerified, rustls::TLSError> { | ||||
|         if let Some(override_name) = self.tls_name_override.read().unwrap().get(dns_name.into()) { | ||||
|             let result = self.inner.verify_server_cert( | ||||
|                 roots, | ||||
|                 presented_certs, | ||||
|                 override_name.as_ref(), | ||||
|                 ocsp_response, | ||||
|             ); | ||||
|             if result.is_ok() { | ||||
|                 return result; | ||||
|             } | ||||
|             info!( | ||||
|                 "Server {:?} is non-compliant, retrying TLS verification with original name", | ||||
|                 dns_name | ||||
|             ); | ||||
|         } | ||||
|         self.inner | ||||
|             .verify_server_cert(roots, presented_certs, dns_name, ocsp_response) | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| /// Handles "rotation" of long-polling requests. "Rotation" in this context is similar to "rotation" of log files and the like.
 | ||||
| ///
 | ||||
| /// This is utilized to have sync workers return early and release read locks on the database.
 | ||||
|  | @ -162,24 +127,6 @@ impl Globals { | |||
|         }; | ||||
| 
 | ||||
|         let tls_name_override = Arc::new(RwLock::new(TlsNameMap::new())); | ||||
|         let verifier = Arc::new(MatrixServerVerifier { | ||||
|             inner: WebPKIVerifier::new(), | ||||
|             tls_name_override: tls_name_override.clone(), | ||||
|         }); | ||||
|         let mut tlsconfig = rustls::ClientConfig::new(); | ||||
|         tlsconfig.dangerous().set_certificate_verifier(verifier); | ||||
|         tlsconfig.root_store = | ||||
|             rustls_native_certs::load_native_certs().expect("Error loading system certificates"); | ||||
| 
 | ||||
|         let mut reqwest_client_builder = reqwest::Client::builder() | ||||
|             .connect_timeout(Duration::from_secs(30)) | ||||
|             .timeout(Duration::from_secs(60 * 3)) | ||||
|             .pool_max_idle_per_host(1) | ||||
|             .use_preconfigured_tls(tlsconfig); | ||||
|         if let Some(proxy) = config.proxy.to_proxy()? { | ||||
|             reqwest_client_builder = reqwest_client_builder.proxy(proxy); | ||||
|         } | ||||
|         let reqwest_client = reqwest_client_builder.build().unwrap(); | ||||
| 
 | ||||
|         let jwt_decoding_key = config | ||||
|             .jwt_secret | ||||
|  | @ -190,7 +137,6 @@ impl Globals { | |||
|             globals, | ||||
|             config, | ||||
|             keypair: Arc::new(keypair), | ||||
|             reqwest_client, | ||||
|             dns_resolver: TokioAsyncResolver::tokio_from_system_conf().map_err(|_| { | ||||
|                 Error::bad_config("Failed to set up trust dns resolver with system config.") | ||||
|             })?, | ||||
|  | @ -219,8 +165,16 @@ impl Globals { | |||
|     } | ||||
| 
 | ||||
|     /// Returns a reqwest client which can be used to send requests.
 | ||||
|     pub fn reqwest_client(&self) -> &reqwest::Client { | ||||
|         &self.reqwest_client | ||||
|     pub fn reqwest_client(&self) -> Result<reqwest::ClientBuilder> { | ||||
|         let mut reqwest_client_builder = reqwest::Client::builder() | ||||
|             .connect_timeout(Duration::from_secs(30)) | ||||
|             .timeout(Duration::from_secs(60 * 3)) | ||||
|             .pool_max_idle_per_host(1); | ||||
|         if let Some(proxy) = self.config.proxy.to_proxy()? { | ||||
|             reqwest_client_builder = reqwest_client_builder.proxy(proxy); | ||||
|         } | ||||
| 
 | ||||
|         Ok(reqwest_client_builder) | ||||
|     } | ||||
| 
 | ||||
|     #[tracing::instrument(skip(self))] | ||||
|  |  | |||
|  | @ -113,7 +113,11 @@ where | |||
|     //*reqwest_request.timeout_mut() = Some(Duration::from_secs(5));
 | ||||
| 
 | ||||
|     let url = reqwest_request.url().clone(); | ||||
|     let response = globals.reqwest_client().execute(reqwest_request).await; | ||||
|     let response = globals | ||||
|         .reqwest_client()? | ||||
|         .build()? | ||||
|         .execute(reqwest_request) | ||||
|         .await; | ||||
| 
 | ||||
|     match response { | ||||
|         Ok(mut response) => { | ||||
|  |  | |||
|  | @ -3,7 +3,7 @@ mod edus; | |||
| pub use edus::RoomEdus; | ||||
| use member::MembershipState; | ||||
| 
 | ||||
| use crate::{pdu::PduBuilder, utils, Database, Error, PduEvent, Result}; | ||||
| use crate::{Database, Error, PduEvent, Result, pdu::PduBuilder, server_server, utils}; | ||||
| use lru_cache::LruCache; | ||||
| use regex::Regex; | ||||
| use ring::digest; | ||||
|  | @ -22,12 +22,7 @@ use ruma::{ | |||
|     state_res::{self, RoomVersion, StateMap}, | ||||
|     uint, EventId, RoomAliasId, RoomId, RoomVersionId, ServerName, UserId, | ||||
| }; | ||||
| use std::{ | ||||
|     collections::{BTreeMap, HashMap, HashSet}, | ||||
|     convert::{TryFrom, TryInto}, | ||||
|     mem::size_of, | ||||
|     sync::{Arc, Mutex}, | ||||
| }; | ||||
| use std::{collections::{BTreeMap, HashMap, HashSet}, convert::{TryFrom, TryInto}, mem::size_of, sync::{Arc, Mutex}, time::Instant}; | ||||
| use tokio::sync::MutexGuard; | ||||
| use tracing::{error, warn}; | ||||
| 
 | ||||
|  | @ -1515,6 +1510,23 @@ impl Rooms { | |||
|                                 "list_appservices" => { | ||||
|                                     db.admin.send(AdminCommand::ListAppservices); | ||||
|                                 } | ||||
|                                 "get_auth_chain" => { | ||||
|                                     if args.len() == 1 { | ||||
|                                         if let Ok(event_id) = EventId::try_from(args[0]) { | ||||
|                                             let start = Instant::now(); | ||||
|                                             let count = | ||||
|                                                 server_server::get_auth_chain(vec![event_id], db)? | ||||
|                                                     .count(); | ||||
|                                             let elapsed = start.elapsed(); | ||||
|                                             db.admin.send(AdminCommand::SendMessage( | ||||
|                                                 message::MessageEventContent::text_plain(format!( | ||||
|                                                     "Loaded auth chain with length {} in {:?}", | ||||
|                                                     count, elapsed | ||||
|                                                 )), | ||||
|                                             )); | ||||
|                                         } | ||||
|                                     } | ||||
|                                 } | ||||
|                                 "get_pdu" => { | ||||
|                                     if args.len() == 1 { | ||||
|                                         if let Ok(event_id) = EventId::try_from(args[0]) { | ||||
|  |  | |||
|  | @ -4,7 +4,7 @@ use crate::{ | |||
|     utils, ConduitResult, Database, Error, PduEvent, Result, Ruma, | ||||
| }; | ||||
| use get_profile_information::v1::ProfileField; | ||||
| use http::header::{HeaderValue, AUTHORIZATION, HOST}; | ||||
| use http::header::{HeaderValue, AUTHORIZATION}; | ||||
| use regex::Regex; | ||||
| use rocket::response::content::Json; | ||||
| use ruma::{ | ||||
|  | @ -83,7 +83,7 @@ use rocket::{get, post, put}; | |||
| /// FedDest::Named("198.51.100.5".to_owned(), "".to_owned());
 | ||||
| /// ```
 | ||||
| #[derive(Clone, Debug, PartialEq)] | ||||
| enum FedDest { | ||||
| pub enum FedDest { | ||||
|     Literal(SocketAddr), | ||||
|     Named(String, String), | ||||
| } | ||||
|  | @ -109,6 +109,13 @@ impl FedDest { | |||
|             Self::Named(host, _) => host.clone(), | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     fn port(&self) -> Option<u16> { | ||||
|         match &self { | ||||
|             Self::Literal(addr) => Some(addr.port()), | ||||
|             Self::Named(_, port) => port[1..].parse().ok(), | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| #[tracing::instrument(skip(globals, request))] | ||||
|  | @ -124,41 +131,34 @@ where | |||
|         return Err(Error::bad_config("Federation is disabled.")); | ||||
|     } | ||||
| 
 | ||||
|     let maybe_result = globals | ||||
|     let mut write_destination_to_cache = false; | ||||
| 
 | ||||
|     let cached_result = globals | ||||
|         .actual_destination_cache | ||||
|         .read() | ||||
|         .unwrap() | ||||
|         .get(destination) | ||||
|         .cloned(); | ||||
| 
 | ||||
|     let (actual_destination, host) = if let Some(result) = maybe_result { | ||||
|     let (actual_destination, host) = if let Some(result) = cached_result { | ||||
|         result | ||||
|     } else { | ||||
|         write_destination_to_cache = true; | ||||
| 
 | ||||
|         let result = find_actual_destination(globals, &destination).await; | ||||
|         let (actual_destination, host) = result.clone(); | ||||
|         let result_string = (result.0.into_https_string(), result.1.into_uri_string()); | ||||
|         globals | ||||
|             .actual_destination_cache | ||||
|             .write() | ||||
|             .unwrap() | ||||
|             .insert(Box::<ServerName>::from(destination), result_string.clone()); | ||||
|         let dest_hostname = actual_destination.hostname(); | ||||
|         let host_hostname = host.hostname(); | ||||
|         if dest_hostname != host_hostname { | ||||
|             globals.tls_name_override.write().unwrap().insert( | ||||
|                 dest_hostname, | ||||
|                 webpki::DNSNameRef::try_from_ascii_str(&host_hostname) | ||||
|                     .unwrap() | ||||
|                     .to_owned(), | ||||
|             ); | ||||
|         } | ||||
|         result_string | ||||
| 
 | ||||
|         (result.0, result.1.clone().into_uri_string()) | ||||
|     }; | ||||
| 
 | ||||
|     let actual_destination_str = actual_destination.clone().into_https_string(); | ||||
| 
 | ||||
|     let mut http_request = request | ||||
|         .try_into_http_request::<Vec<u8>>(&actual_destination, SendAccessToken::IfRequired("")) | ||||
|         .try_into_http_request::<Vec<u8>>(&actual_destination_str, SendAccessToken::IfRequired("")) | ||||
|         .map_err(|e| { | ||||
|             warn!("Failed to find destination {}: {}", actual_destination, e); | ||||
|             warn!( | ||||
|                 "Failed to find destination {}: {}", | ||||
|                 actual_destination_str, e | ||||
|             ); | ||||
|             Error::BadServerResponse("Invalid destination") | ||||
|         })?; | ||||
| 
 | ||||
|  | @ -224,15 +224,26 @@ where | |||
|         } | ||||
|     } | ||||
| 
 | ||||
|     http_request | ||||
|         .headers_mut() | ||||
|         .insert(HOST, HeaderValue::from_str(&host).unwrap()); | ||||
| 
 | ||||
|     let reqwest_request = reqwest::Request::try_from(http_request) | ||||
|         .expect("all http requests are valid reqwest requests"); | ||||
| 
 | ||||
|     let url = reqwest_request.url().clone(); | ||||
|     let response = globals.reqwest_client().execute(reqwest_request).await; | ||||
| 
 | ||||
|     let mut client = globals.reqwest_client()?; | ||||
|     if let Some((override_name, port)) = globals | ||||
|         .tls_name_override | ||||
|         .read() | ||||
|         .unwrap() | ||||
|         .get(&actual_destination.hostname()) | ||||
|     { | ||||
|         client = client.resolve( | ||||
|             &actual_destination.hostname(), | ||||
|             SocketAddr::new(override_name[0], *port), | ||||
|         ); | ||||
|         // port will be ignored
 | ||||
|     } | ||||
| 
 | ||||
|     let response = client.build()?.execute(reqwest_request).await; | ||||
| 
 | ||||
|     match response { | ||||
|         Ok(mut response) => { | ||||
|  | @ -271,6 +282,13 @@ where | |||
| 
 | ||||
|             if status == 200 { | ||||
|                 let response = T::IncomingResponse::try_from_http_response(http_response); | ||||
|                 if response.is_ok() && write_destination_to_cache { | ||||
|                     globals.actual_destination_cache.write().unwrap().insert( | ||||
|                         Box::<ServerName>::from(destination), | ||||
|                         (actual_destination, host), | ||||
|                     ); | ||||
|                 } | ||||
| 
 | ||||
|                 response.map_err(|e| { | ||||
|                     warn!( | ||||
|                         "Invalid 200 response from {} on: {} {}", | ||||
|  | @ -339,7 +357,7 @@ async fn find_actual_destination( | |||
|                 match request_well_known(globals, &destination.as_str()).await { | ||||
|                     // 3: A .well-known file is available
 | ||||
|                     Some(delegated_hostname) => { | ||||
|                         hostname = delegated_hostname.clone(); | ||||
|                         hostname = add_port_to_hostname(&delegated_hostname).into_uri_string(); | ||||
|                         match get_ip_with_port(&delegated_hostname) { | ||||
|                             Some(host_and_port) => host_and_port, // 3.1: IP literal in .well-known file
 | ||||
|                             None => { | ||||
|  | @ -348,11 +366,40 @@ async fn find_actual_destination( | |||
|                                     let (host, port) = delegated_hostname.split_at(pos); | ||||
|                                     FedDest::Named(host.to_string(), port.to_string()) | ||||
|                                 } else { | ||||
|                                     match query_srv_record(globals, &delegated_hostname).await { | ||||
|                                     // Delegated hostname has no port in this branch
 | ||||
|                                     if let Some(hostname_override) = | ||||
|                                         query_srv_record(globals, &delegated_hostname).await | ||||
|                                     { | ||||
|                                         // 3.3: SRV lookup successful
 | ||||
|                                         Some(hostname) => hostname, | ||||
|                                         let force_port = hostname_override.port(); | ||||
| 
 | ||||
|                                         if let Ok(override_ip) = globals | ||||
|                                             .dns_resolver() | ||||
|                                             .lookup_ip(hostname_override.hostname()) | ||||
|                                             .await | ||||
|                                         { | ||||
|                                             globals.tls_name_override.write().unwrap().insert( | ||||
|                                                 delegated_hostname.clone(), | ||||
|                                                 ( | ||||
|                                                     override_ip.iter().collect(), | ||||
|                                                     force_port.unwrap_or(8448), | ||||
|                                                 ), | ||||
|                                             ); | ||||
|                                         } else { | ||||
|                                             warn!("Using SRV record, but could not resolve to IP"); | ||||
|                                         } | ||||
| 
 | ||||
|                                         if let Some(port) = force_port { | ||||
|                                             FedDest::Named( | ||||
|                                                 delegated_hostname, | ||||
|                                                 format!(":{}", port.to_string()), | ||||
|                                             ) | ||||
|                                         } else { | ||||
|                                             add_port_to_hostname(&delegated_hostname) | ||||
|                                         } | ||||
|                                     } else { | ||||
|                                         // 3.4: No SRV records, just use the hostname from .well-known
 | ||||
|                                         None => add_port_to_hostname(&delegated_hostname), | ||||
|                                         add_port_to_hostname(&delegated_hostname) | ||||
|                                     } | ||||
|                                 } | ||||
|                             } | ||||
|  | @ -362,7 +409,31 @@ async fn find_actual_destination( | |||
|                     None => { | ||||
|                         match query_srv_record(globals, &destination_str).await { | ||||
|                             // 4: SRV record found
 | ||||
|                             Some(hostname) => hostname, | ||||
|                             Some(hostname_override) => { | ||||
|                                 let force_port = hostname_override.port(); | ||||
| 
 | ||||
|                                 if let Ok(override_ip) = globals | ||||
|                                     .dns_resolver() | ||||
|                                     .lookup_ip(hostname_override.hostname()) | ||||
|                                     .await | ||||
|                                 { | ||||
|                                     globals.tls_name_override.write().unwrap().insert( | ||||
|                                         hostname.clone(), | ||||
|                                         (override_ip.iter().collect(), force_port.unwrap_or(8448)), | ||||
|                                     ); | ||||
|                                 } else { | ||||
|                                     warn!("Using SRV record, but could not resolve to IP"); | ||||
|                                 } | ||||
| 
 | ||||
|                                 if let Some(port) = force_port { | ||||
|                                     FedDest::Named( | ||||
|                                         hostname.clone(), | ||||
|                                         format!(":{}", port.to_string()), | ||||
|                                     ) | ||||
|                                 } else { | ||||
|                                     add_port_to_hostname(&hostname) | ||||
|                                 } | ||||
|                             } | ||||
|                             // 5: No SRV record found
 | ||||
|                             None => add_port_to_hostname(&destination_str), | ||||
|                         } | ||||
|  | @ -377,12 +448,12 @@ async fn find_actual_destination( | |||
|     let hostname = if let Ok(addr) = hostname.parse::<SocketAddr>() { | ||||
|         FedDest::Literal(addr) | ||||
|     } else if let Ok(addr) = hostname.parse::<IpAddr>() { | ||||
|         FedDest::Named(addr.to_string(), "".to_string()) | ||||
|         FedDest::Named(addr.to_string(), ":8448".to_string()) | ||||
|     } else if let Some(pos) = hostname.find(':') { | ||||
|         let (host, port) = hostname.split_at(pos); | ||||
|         FedDest::Named(host.to_string(), port.to_string()) | ||||
|     } else { | ||||
|         FedDest::Named(hostname, "".to_string()) | ||||
|         FedDest::Named(hostname, ":8448".to_string()) | ||||
|     }; | ||||
|     (actual_destination, hostname) | ||||
| } | ||||
|  | @ -423,6 +494,9 @@ pub async fn request_well_known( | |||
|     let body: serde_json::Value = serde_json::from_str( | ||||
|         &globals | ||||
|             .reqwest_client() | ||||
|             .ok()? | ||||
|             .build() | ||||
|             .ok()? | ||||
|             .get(&format!( | ||||
|                 "https://{}/.well-known/matrix/server", | ||||
|                 destination | ||||
|  | @ -1971,7 +2045,7 @@ fn append_incoming_pdu( | |||
| } | ||||
| 
 | ||||
| #[tracing::instrument(skip(starting_events, db))] | ||||
| fn get_auth_chain( | ||||
| pub fn get_auth_chain( | ||||
|     starting_events: Vec<EventId>, | ||||
|     db: &Database, | ||||
| ) -> Result<impl Iterator<Item = EventId> + '_> { | ||||
|  |  | |||
		Loading…
	
		Reference in a new issue