fix: server resolution
This commit is contained in:
		
							parent
							
								
									0330d3e270
								
							
						
					
					
						commit
						19b89ab91f
					
				
					 6 changed files with 156 additions and 111 deletions
				
			
		|  | @ -46,7 +46,11 @@ where | ||||||
|     *reqwest_request.timeout_mut() = Some(Duration::from_secs(30)); |     *reqwest_request.timeout_mut() = Some(Duration::from_secs(30)); | ||||||
| 
 | 
 | ||||||
|     let url = reqwest_request.url().clone(); |     let url = reqwest_request.url().clone(); | ||||||
|     let mut response = globals.reqwest_client().execute(reqwest_request).await?; |     let mut response = globals | ||||||
|  |         .reqwest_client()? | ||||||
|  |         .build()? | ||||||
|  |         .execute(reqwest_request) | ||||||
|  |         .await?; | ||||||
| 
 | 
 | ||||||
|     // reqwest::Response -> http::Response conversion
 |     // reqwest::Response -> http::Response conversion
 | ||||||
|     let status = response.status(); |     let status = response.status(); | ||||||
|  |  | ||||||
|  | @ -19,10 +19,7 @@ use ruma::{ | ||||||
|     DeviceId, DeviceKeyAlgorithm, UserId, |     DeviceId, DeviceKeyAlgorithm, UserId, | ||||||
| }; | }; | ||||||
| use serde_json::json; | use serde_json::json; | ||||||
| use std::{ | use std::collections::{BTreeMap, HashMap, HashSet}; | ||||||
|     collections::{BTreeMap, HashMap, HashSet}, |  | ||||||
|     time::{Duration, Instant}, |  | ||||||
| }; |  | ||||||
| 
 | 
 | ||||||
| #[cfg(feature = "conduit_bin")] | #[cfg(feature = "conduit_bin")] | ||||||
| use rocket::{get, post}; | use rocket::{get, post}; | ||||||
|  |  | ||||||
|  | @ -1,4 +1,4 @@ | ||||||
| use crate::{database::Config, utils, ConduitResult, Error, Result}; | use crate::{database::Config, server_server::FedDest, utils, ConduitResult, Error, Result}; | ||||||
| use ruma::{ | use ruma::{ | ||||||
|     api::{ |     api::{ | ||||||
|         client::r0::sync::sync_events, |         client::r0::sync::sync_events, | ||||||
|  | @ -6,25 +6,25 @@ use ruma::{ | ||||||
|     }, |     }, | ||||||
|     DeviceId, EventId, MilliSecondsSinceUnixEpoch, RoomId, ServerName, ServerSigningKeyId, UserId, |     DeviceId, EventId, MilliSecondsSinceUnixEpoch, RoomId, ServerName, ServerSigningKeyId, UserId, | ||||||
| }; | }; | ||||||
| use rustls::{ServerCertVerifier, WebPKIVerifier}; |  | ||||||
| use std::{ | use std::{ | ||||||
|     collections::{BTreeMap, HashMap}, |     collections::{BTreeMap, HashMap}, | ||||||
|     fs, |     fs, | ||||||
|     future::Future, |     future::Future, | ||||||
|  |     net::IpAddr, | ||||||
|     path::PathBuf, |     path::PathBuf, | ||||||
|     sync::{Arc, Mutex, RwLock}, |     sync::{Arc, Mutex, RwLock}, | ||||||
|     time::{Duration, Instant}, |     time::{Duration, Instant}, | ||||||
| }; | }; | ||||||
| use tokio::sync::{broadcast, watch::Receiver, Mutex as TokioMutex, Semaphore}; | use tokio::sync::{broadcast, watch::Receiver, Mutex as TokioMutex, Semaphore}; | ||||||
| use tracing::{error, info}; | use tracing::error; | ||||||
| use trust_dns_resolver::TokioAsyncResolver; | use trust_dns_resolver::TokioAsyncResolver; | ||||||
| 
 | 
 | ||||||
| use super::abstraction::Tree; | use super::abstraction::Tree; | ||||||
| 
 | 
 | ||||||
| pub const COUNTER: &[u8] = b"c"; | pub const COUNTER: &[u8] = b"c"; | ||||||
| 
 | 
 | ||||||
| type WellKnownMap = HashMap<Box<ServerName>, (String, String)>; | type WellKnownMap = HashMap<Box<ServerName>, (FedDest, String)>; | ||||||
| type TlsNameMap = HashMap<String, webpki::DNSName>; | type TlsNameMap = HashMap<String, (Vec<IpAddr>, u16)>; | ||||||
| type RateLimitState = (Instant, u32); // Time if last failed try, number of failed tries
 | type RateLimitState = (Instant, u32); // Time if last failed try, number of failed tries
 | ||||||
| type SyncHandle = ( | type SyncHandle = ( | ||||||
|     Option<String>,                                         // since
 |     Option<String>,                                         // since
 | ||||||
|  | @ -37,7 +37,6 @@ pub struct Globals { | ||||||
|     pub(super) globals: Arc<dyn Tree>, |     pub(super) globals: Arc<dyn Tree>, | ||||||
|     config: Config, |     config: Config, | ||||||
|     keypair: Arc<ruma::signatures::Ed25519KeyPair>, |     keypair: Arc<ruma::signatures::Ed25519KeyPair>, | ||||||
|     reqwest_client: reqwest::Client, |  | ||||||
|     dns_resolver: TokioAsyncResolver, |     dns_resolver: TokioAsyncResolver, | ||||||
|     jwt_decoding_key: Option<jsonwebtoken::DecodingKey<'static>>, |     jwt_decoding_key: Option<jsonwebtoken::DecodingKey<'static>>, | ||||||
|     pub(super) server_signingkeys: Arc<dyn Tree>, |     pub(super) server_signingkeys: Arc<dyn Tree>, | ||||||
|  | @ -51,40 +50,6 @@ pub struct Globals { | ||||||
|     pub rotate: RotationHandler, |     pub rotate: RotationHandler, | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| struct MatrixServerVerifier { |  | ||||||
|     inner: WebPKIVerifier, |  | ||||||
|     tls_name_override: Arc<RwLock<TlsNameMap>>, |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| impl ServerCertVerifier for MatrixServerVerifier { |  | ||||||
|     #[tracing::instrument(skip(self, roots, presented_certs, dns_name, ocsp_response))] |  | ||||||
|     fn verify_server_cert( |  | ||||||
|         &self, |  | ||||||
|         roots: &rustls::RootCertStore, |  | ||||||
|         presented_certs: &[rustls::Certificate], |  | ||||||
|         dns_name: webpki::DNSNameRef<'_>, |  | ||||||
|         ocsp_response: &[u8], |  | ||||||
|     ) -> std::result::Result<rustls::ServerCertVerified, rustls::TLSError> { |  | ||||||
|         if let Some(override_name) = self.tls_name_override.read().unwrap().get(dns_name.into()) { |  | ||||||
|             let result = self.inner.verify_server_cert( |  | ||||||
|                 roots, |  | ||||||
|                 presented_certs, |  | ||||||
|                 override_name.as_ref(), |  | ||||||
|                 ocsp_response, |  | ||||||
|             ); |  | ||||||
|             if result.is_ok() { |  | ||||||
|                 return result; |  | ||||||
|             } |  | ||||||
|             info!( |  | ||||||
|                 "Server {:?} is non-compliant, retrying TLS verification with original name", |  | ||||||
|                 dns_name |  | ||||||
|             ); |  | ||||||
|         } |  | ||||||
|         self.inner |  | ||||||
|             .verify_server_cert(roots, presented_certs, dns_name, ocsp_response) |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| /// Handles "rotation" of long-polling requests. "Rotation" in this context is similar to "rotation" of log files and the like.
 | /// Handles "rotation" of long-polling requests. "Rotation" in this context is similar to "rotation" of log files and the like.
 | ||||||
| ///
 | ///
 | ||||||
| /// This is utilized to have sync workers return early and release read locks on the database.
 | /// This is utilized to have sync workers return early and release read locks on the database.
 | ||||||
|  | @ -162,24 +127,6 @@ impl Globals { | ||||||
|         }; |         }; | ||||||
| 
 | 
 | ||||||
|         let tls_name_override = Arc::new(RwLock::new(TlsNameMap::new())); |         let tls_name_override = Arc::new(RwLock::new(TlsNameMap::new())); | ||||||
|         let verifier = Arc::new(MatrixServerVerifier { |  | ||||||
|             inner: WebPKIVerifier::new(), |  | ||||||
|             tls_name_override: tls_name_override.clone(), |  | ||||||
|         }); |  | ||||||
|         let mut tlsconfig = rustls::ClientConfig::new(); |  | ||||||
|         tlsconfig.dangerous().set_certificate_verifier(verifier); |  | ||||||
|         tlsconfig.root_store = |  | ||||||
|             rustls_native_certs::load_native_certs().expect("Error loading system certificates"); |  | ||||||
| 
 |  | ||||||
|         let mut reqwest_client_builder = reqwest::Client::builder() |  | ||||||
|             .connect_timeout(Duration::from_secs(30)) |  | ||||||
|             .timeout(Duration::from_secs(60 * 3)) |  | ||||||
|             .pool_max_idle_per_host(1) |  | ||||||
|             .use_preconfigured_tls(tlsconfig); |  | ||||||
|         if let Some(proxy) = config.proxy.to_proxy()? { |  | ||||||
|             reqwest_client_builder = reqwest_client_builder.proxy(proxy); |  | ||||||
|         } |  | ||||||
|         let reqwest_client = reqwest_client_builder.build().unwrap(); |  | ||||||
| 
 | 
 | ||||||
|         let jwt_decoding_key = config |         let jwt_decoding_key = config | ||||||
|             .jwt_secret |             .jwt_secret | ||||||
|  | @ -190,7 +137,6 @@ impl Globals { | ||||||
|             globals, |             globals, | ||||||
|             config, |             config, | ||||||
|             keypair: Arc::new(keypair), |             keypair: Arc::new(keypair), | ||||||
|             reqwest_client, |  | ||||||
|             dns_resolver: TokioAsyncResolver::tokio_from_system_conf().map_err(|_| { |             dns_resolver: TokioAsyncResolver::tokio_from_system_conf().map_err(|_| { | ||||||
|                 Error::bad_config("Failed to set up trust dns resolver with system config.") |                 Error::bad_config("Failed to set up trust dns resolver with system config.") | ||||||
|             })?, |             })?, | ||||||
|  | @ -219,8 +165,16 @@ impl Globals { | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     /// Returns a reqwest client which can be used to send requests.
 |     /// Returns a reqwest client which can be used to send requests.
 | ||||||
|     pub fn reqwest_client(&self) -> &reqwest::Client { |     pub fn reqwest_client(&self) -> Result<reqwest::ClientBuilder> { | ||||||
|         &self.reqwest_client |         let mut reqwest_client_builder = reqwest::Client::builder() | ||||||
|  |             .connect_timeout(Duration::from_secs(30)) | ||||||
|  |             .timeout(Duration::from_secs(60 * 3)) | ||||||
|  |             .pool_max_idle_per_host(1); | ||||||
|  |         if let Some(proxy) = self.config.proxy.to_proxy()? { | ||||||
|  |             reqwest_client_builder = reqwest_client_builder.proxy(proxy); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         Ok(reqwest_client_builder) | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     #[tracing::instrument(skip(self))] |     #[tracing::instrument(skip(self))] | ||||||
|  |  | ||||||
|  | @ -113,7 +113,11 @@ where | ||||||
|     //*reqwest_request.timeout_mut() = Some(Duration::from_secs(5));
 |     //*reqwest_request.timeout_mut() = Some(Duration::from_secs(5));
 | ||||||
| 
 | 
 | ||||||
|     let url = reqwest_request.url().clone(); |     let url = reqwest_request.url().clone(); | ||||||
|     let response = globals.reqwest_client().execute(reqwest_request).await; |     let response = globals | ||||||
|  |         .reqwest_client()? | ||||||
|  |         .build()? | ||||||
|  |         .execute(reqwest_request) | ||||||
|  |         .await; | ||||||
| 
 | 
 | ||||||
|     match response { |     match response { | ||||||
|         Ok(mut response) => { |         Ok(mut response) => { | ||||||
|  |  | ||||||
|  | @ -3,7 +3,7 @@ mod edus; | ||||||
| pub use edus::RoomEdus; | pub use edus::RoomEdus; | ||||||
| use member::MembershipState; | use member::MembershipState; | ||||||
| 
 | 
 | ||||||
| use crate::{pdu::PduBuilder, utils, Database, Error, PduEvent, Result}; | use crate::{Database, Error, PduEvent, Result, pdu::PduBuilder, server_server, utils}; | ||||||
| use lru_cache::LruCache; | use lru_cache::LruCache; | ||||||
| use regex::Regex; | use regex::Regex; | ||||||
| use ring::digest; | use ring::digest; | ||||||
|  | @ -22,12 +22,7 @@ use ruma::{ | ||||||
|     state_res::{self, RoomVersion, StateMap}, |     state_res::{self, RoomVersion, StateMap}, | ||||||
|     uint, EventId, RoomAliasId, RoomId, RoomVersionId, ServerName, UserId, |     uint, EventId, RoomAliasId, RoomId, RoomVersionId, ServerName, UserId, | ||||||
| }; | }; | ||||||
| use std::{ | use std::{collections::{BTreeMap, HashMap, HashSet}, convert::{TryFrom, TryInto}, mem::size_of, sync::{Arc, Mutex}, time::Instant}; | ||||||
|     collections::{BTreeMap, HashMap, HashSet}, |  | ||||||
|     convert::{TryFrom, TryInto}, |  | ||||||
|     mem::size_of, |  | ||||||
|     sync::{Arc, Mutex}, |  | ||||||
| }; |  | ||||||
| use tokio::sync::MutexGuard; | use tokio::sync::MutexGuard; | ||||||
| use tracing::{error, warn}; | use tracing::{error, warn}; | ||||||
| 
 | 
 | ||||||
|  | @ -1515,6 +1510,23 @@ impl Rooms { | ||||||
|                                 "list_appservices" => { |                                 "list_appservices" => { | ||||||
|                                     db.admin.send(AdminCommand::ListAppservices); |                                     db.admin.send(AdminCommand::ListAppservices); | ||||||
|                                 } |                                 } | ||||||
|  |                                 "get_auth_chain" => { | ||||||
|  |                                     if args.len() == 1 { | ||||||
|  |                                         if let Ok(event_id) = EventId::try_from(args[0]) { | ||||||
|  |                                             let start = Instant::now(); | ||||||
|  |                                             let count = | ||||||
|  |                                                 server_server::get_auth_chain(vec![event_id], db)? | ||||||
|  |                                                     .count(); | ||||||
|  |                                             let elapsed = start.elapsed(); | ||||||
|  |                                             db.admin.send(AdminCommand::SendMessage( | ||||||
|  |                                                 message::MessageEventContent::text_plain(format!( | ||||||
|  |                                                     "Loaded auth chain with length {} in {:?}", | ||||||
|  |                                                     count, elapsed | ||||||
|  |                                                 )), | ||||||
|  |                                             )); | ||||||
|  |                                         } | ||||||
|  |                                     } | ||||||
|  |                                 } | ||||||
|                                 "get_pdu" => { |                                 "get_pdu" => { | ||||||
|                                     if args.len() == 1 { |                                     if args.len() == 1 { | ||||||
|                                         if let Ok(event_id) = EventId::try_from(args[0]) { |                                         if let Ok(event_id) = EventId::try_from(args[0]) { | ||||||
|  |  | ||||||
|  | @ -4,7 +4,7 @@ use crate::{ | ||||||
|     utils, ConduitResult, Database, Error, PduEvent, Result, Ruma, |     utils, ConduitResult, Database, Error, PduEvent, Result, Ruma, | ||||||
| }; | }; | ||||||
| use get_profile_information::v1::ProfileField; | use get_profile_information::v1::ProfileField; | ||||||
| use http::header::{HeaderValue, AUTHORIZATION, HOST}; | use http::header::{HeaderValue, AUTHORIZATION}; | ||||||
| use regex::Regex; | use regex::Regex; | ||||||
| use rocket::response::content::Json; | use rocket::response::content::Json; | ||||||
| use ruma::{ | use ruma::{ | ||||||
|  | @ -83,7 +83,7 @@ use rocket::{get, post, put}; | ||||||
| /// FedDest::Named("198.51.100.5".to_owned(), "".to_owned());
 | /// FedDest::Named("198.51.100.5".to_owned(), "".to_owned());
 | ||||||
| /// ```
 | /// ```
 | ||||||
| #[derive(Clone, Debug, PartialEq)] | #[derive(Clone, Debug, PartialEq)] | ||||||
| enum FedDest { | pub enum FedDest { | ||||||
|     Literal(SocketAddr), |     Literal(SocketAddr), | ||||||
|     Named(String, String), |     Named(String, String), | ||||||
| } | } | ||||||
|  | @ -109,6 +109,13 @@ impl FedDest { | ||||||
|             Self::Named(host, _) => host.clone(), |             Self::Named(host, _) => host.clone(), | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  | 
 | ||||||
|  |     fn port(&self) -> Option<u16> { | ||||||
|  |         match &self { | ||||||
|  |             Self::Literal(addr) => Some(addr.port()), | ||||||
|  |             Self::Named(_, port) => port[1..].parse().ok(), | ||||||
|  |         } | ||||||
|  |     } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| #[tracing::instrument(skip(globals, request))] | #[tracing::instrument(skip(globals, request))] | ||||||
|  | @ -124,41 +131,34 @@ where | ||||||
|         return Err(Error::bad_config("Federation is disabled.")); |         return Err(Error::bad_config("Federation is disabled.")); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     let maybe_result = globals |     let mut write_destination_to_cache = false; | ||||||
|  | 
 | ||||||
|  |     let cached_result = globals | ||||||
|         .actual_destination_cache |         .actual_destination_cache | ||||||
|         .read() |         .read() | ||||||
|         .unwrap() |         .unwrap() | ||||||
|         .get(destination) |         .get(destination) | ||||||
|         .cloned(); |         .cloned(); | ||||||
| 
 | 
 | ||||||
|     let (actual_destination, host) = if let Some(result) = maybe_result { |     let (actual_destination, host) = if let Some(result) = cached_result { | ||||||
|         result |         result | ||||||
|     } else { |     } else { | ||||||
|  |         write_destination_to_cache = true; | ||||||
|  | 
 | ||||||
|         let result = find_actual_destination(globals, &destination).await; |         let result = find_actual_destination(globals, &destination).await; | ||||||
|         let (actual_destination, host) = result.clone(); | 
 | ||||||
|         let result_string = (result.0.into_https_string(), result.1.into_uri_string()); |         (result.0, result.1.clone().into_uri_string()) | ||||||
|         globals |  | ||||||
|             .actual_destination_cache |  | ||||||
|             .write() |  | ||||||
|             .unwrap() |  | ||||||
|             .insert(Box::<ServerName>::from(destination), result_string.clone()); |  | ||||||
|         let dest_hostname = actual_destination.hostname(); |  | ||||||
|         let host_hostname = host.hostname(); |  | ||||||
|         if dest_hostname != host_hostname { |  | ||||||
|             globals.tls_name_override.write().unwrap().insert( |  | ||||||
|                 dest_hostname, |  | ||||||
|                 webpki::DNSNameRef::try_from_ascii_str(&host_hostname) |  | ||||||
|                     .unwrap() |  | ||||||
|                     .to_owned(), |  | ||||||
|             ); |  | ||||||
|         } |  | ||||||
|         result_string |  | ||||||
|     }; |     }; | ||||||
| 
 | 
 | ||||||
|  |     let actual_destination_str = actual_destination.clone().into_https_string(); | ||||||
|  | 
 | ||||||
|     let mut http_request = request |     let mut http_request = request | ||||||
|         .try_into_http_request::<Vec<u8>>(&actual_destination, SendAccessToken::IfRequired("")) |         .try_into_http_request::<Vec<u8>>(&actual_destination_str, SendAccessToken::IfRequired("")) | ||||||
|         .map_err(|e| { |         .map_err(|e| { | ||||||
|             warn!("Failed to find destination {}: {}", actual_destination, e); |             warn!( | ||||||
|  |                 "Failed to find destination {}: {}", | ||||||
|  |                 actual_destination_str, e | ||||||
|  |             ); | ||||||
|             Error::BadServerResponse("Invalid destination") |             Error::BadServerResponse("Invalid destination") | ||||||
|         })?; |         })?; | ||||||
| 
 | 
 | ||||||
|  | @ -224,15 +224,26 @@ where | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     http_request |  | ||||||
|         .headers_mut() |  | ||||||
|         .insert(HOST, HeaderValue::from_str(&host).unwrap()); |  | ||||||
| 
 |  | ||||||
|     let reqwest_request = reqwest::Request::try_from(http_request) |     let reqwest_request = reqwest::Request::try_from(http_request) | ||||||
|         .expect("all http requests are valid reqwest requests"); |         .expect("all http requests are valid reqwest requests"); | ||||||
| 
 | 
 | ||||||
|     let url = reqwest_request.url().clone(); |     let url = reqwest_request.url().clone(); | ||||||
|     let response = globals.reqwest_client().execute(reqwest_request).await; | 
 | ||||||
|  |     let mut client = globals.reqwest_client()?; | ||||||
|  |     if let Some((override_name, port)) = globals | ||||||
|  |         .tls_name_override | ||||||
|  |         .read() | ||||||
|  |         .unwrap() | ||||||
|  |         .get(&actual_destination.hostname()) | ||||||
|  |     { | ||||||
|  |         client = client.resolve( | ||||||
|  |             &actual_destination.hostname(), | ||||||
|  |             SocketAddr::new(override_name[0], *port), | ||||||
|  |         ); | ||||||
|  |         // port will be ignored
 | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     let response = client.build()?.execute(reqwest_request).await; | ||||||
| 
 | 
 | ||||||
|     match response { |     match response { | ||||||
|         Ok(mut response) => { |         Ok(mut response) => { | ||||||
|  | @ -271,6 +282,13 @@ where | ||||||
| 
 | 
 | ||||||
|             if status == 200 { |             if status == 200 { | ||||||
|                 let response = T::IncomingResponse::try_from_http_response(http_response); |                 let response = T::IncomingResponse::try_from_http_response(http_response); | ||||||
|  |                 if response.is_ok() && write_destination_to_cache { | ||||||
|  |                     globals.actual_destination_cache.write().unwrap().insert( | ||||||
|  |                         Box::<ServerName>::from(destination), | ||||||
|  |                         (actual_destination, host), | ||||||
|  |                     ); | ||||||
|  |                 } | ||||||
|  | 
 | ||||||
|                 response.map_err(|e| { |                 response.map_err(|e| { | ||||||
|                     warn!( |                     warn!( | ||||||
|                         "Invalid 200 response from {} on: {} {}", |                         "Invalid 200 response from {} on: {} {}", | ||||||
|  | @ -339,7 +357,7 @@ async fn find_actual_destination( | ||||||
|                 match request_well_known(globals, &destination.as_str()).await { |                 match request_well_known(globals, &destination.as_str()).await { | ||||||
|                     // 3: A .well-known file is available
 |                     // 3: A .well-known file is available
 | ||||||
|                     Some(delegated_hostname) => { |                     Some(delegated_hostname) => { | ||||||
|                         hostname = delegated_hostname.clone(); |                         hostname = add_port_to_hostname(&delegated_hostname).into_uri_string(); | ||||||
|                         match get_ip_with_port(&delegated_hostname) { |                         match get_ip_with_port(&delegated_hostname) { | ||||||
|                             Some(host_and_port) => host_and_port, // 3.1: IP literal in .well-known file
 |                             Some(host_and_port) => host_and_port, // 3.1: IP literal in .well-known file
 | ||||||
|                             None => { |                             None => { | ||||||
|  | @ -348,11 +366,40 @@ async fn find_actual_destination( | ||||||
|                                     let (host, port) = delegated_hostname.split_at(pos); |                                     let (host, port) = delegated_hostname.split_at(pos); | ||||||
|                                     FedDest::Named(host.to_string(), port.to_string()) |                                     FedDest::Named(host.to_string(), port.to_string()) | ||||||
|                                 } else { |                                 } else { | ||||||
|                                     match query_srv_record(globals, &delegated_hostname).await { |                                     // Delegated hostname has no port in this branch
 | ||||||
|  |                                     if let Some(hostname_override) = | ||||||
|  |                                         query_srv_record(globals, &delegated_hostname).await | ||||||
|  |                                     { | ||||||
|                                         // 3.3: SRV lookup successful
 |                                         // 3.3: SRV lookup successful
 | ||||||
|                                         Some(hostname) => hostname, |                                         let force_port = hostname_override.port(); | ||||||
|  | 
 | ||||||
|  |                                         if let Ok(override_ip) = globals | ||||||
|  |                                             .dns_resolver() | ||||||
|  |                                             .lookup_ip(hostname_override.hostname()) | ||||||
|  |                                             .await | ||||||
|  |                                         { | ||||||
|  |                                             globals.tls_name_override.write().unwrap().insert( | ||||||
|  |                                                 delegated_hostname.clone(), | ||||||
|  |                                                 ( | ||||||
|  |                                                     override_ip.iter().collect(), | ||||||
|  |                                                     force_port.unwrap_or(8448), | ||||||
|  |                                                 ), | ||||||
|  |                                             ); | ||||||
|  |                                         } else { | ||||||
|  |                                             warn!("Using SRV record, but could not resolve to IP"); | ||||||
|  |                                         } | ||||||
|  | 
 | ||||||
|  |                                         if let Some(port) = force_port { | ||||||
|  |                                             FedDest::Named( | ||||||
|  |                                                 delegated_hostname, | ||||||
|  |                                                 format!(":{}", port.to_string()), | ||||||
|  |                                             ) | ||||||
|  |                                         } else { | ||||||
|  |                                             add_port_to_hostname(&delegated_hostname) | ||||||
|  |                                         } | ||||||
|  |                                     } else { | ||||||
|                                         // 3.4: No SRV records, just use the hostname from .well-known
 |                                         // 3.4: No SRV records, just use the hostname from .well-known
 | ||||||
|                                         None => add_port_to_hostname(&delegated_hostname), |                                         add_port_to_hostname(&delegated_hostname) | ||||||
|                                     } |                                     } | ||||||
|                                 } |                                 } | ||||||
|                             } |                             } | ||||||
|  | @ -362,7 +409,31 @@ async fn find_actual_destination( | ||||||
|                     None => { |                     None => { | ||||||
|                         match query_srv_record(globals, &destination_str).await { |                         match query_srv_record(globals, &destination_str).await { | ||||||
|                             // 4: SRV record found
 |                             // 4: SRV record found
 | ||||||
|                             Some(hostname) => hostname, |                             Some(hostname_override) => { | ||||||
|  |                                 let force_port = hostname_override.port(); | ||||||
|  | 
 | ||||||
|  |                                 if let Ok(override_ip) = globals | ||||||
|  |                                     .dns_resolver() | ||||||
|  |                                     .lookup_ip(hostname_override.hostname()) | ||||||
|  |                                     .await | ||||||
|  |                                 { | ||||||
|  |                                     globals.tls_name_override.write().unwrap().insert( | ||||||
|  |                                         hostname.clone(), | ||||||
|  |                                         (override_ip.iter().collect(), force_port.unwrap_or(8448)), | ||||||
|  |                                     ); | ||||||
|  |                                 } else { | ||||||
|  |                                     warn!("Using SRV record, but could not resolve to IP"); | ||||||
|  |                                 } | ||||||
|  | 
 | ||||||
|  |                                 if let Some(port) = force_port { | ||||||
|  |                                     FedDest::Named( | ||||||
|  |                                         hostname.clone(), | ||||||
|  |                                         format!(":{}", port.to_string()), | ||||||
|  |                                     ) | ||||||
|  |                                 } else { | ||||||
|  |                                     add_port_to_hostname(&hostname) | ||||||
|  |                                 } | ||||||
|  |                             } | ||||||
|                             // 5: No SRV record found
 |                             // 5: No SRV record found
 | ||||||
|                             None => add_port_to_hostname(&destination_str), |                             None => add_port_to_hostname(&destination_str), | ||||||
|                         } |                         } | ||||||
|  | @ -377,12 +448,12 @@ async fn find_actual_destination( | ||||||
|     let hostname = if let Ok(addr) = hostname.parse::<SocketAddr>() { |     let hostname = if let Ok(addr) = hostname.parse::<SocketAddr>() { | ||||||
|         FedDest::Literal(addr) |         FedDest::Literal(addr) | ||||||
|     } else if let Ok(addr) = hostname.parse::<IpAddr>() { |     } else if let Ok(addr) = hostname.parse::<IpAddr>() { | ||||||
|         FedDest::Named(addr.to_string(), "".to_string()) |         FedDest::Named(addr.to_string(), ":8448".to_string()) | ||||||
|     } else if let Some(pos) = hostname.find(':') { |     } else if let Some(pos) = hostname.find(':') { | ||||||
|         let (host, port) = hostname.split_at(pos); |         let (host, port) = hostname.split_at(pos); | ||||||
|         FedDest::Named(host.to_string(), port.to_string()) |         FedDest::Named(host.to_string(), port.to_string()) | ||||||
|     } else { |     } else { | ||||||
|         FedDest::Named(hostname, "".to_string()) |         FedDest::Named(hostname, ":8448".to_string()) | ||||||
|     }; |     }; | ||||||
|     (actual_destination, hostname) |     (actual_destination, hostname) | ||||||
| } | } | ||||||
|  | @ -423,6 +494,9 @@ pub async fn request_well_known( | ||||||
|     let body: serde_json::Value = serde_json::from_str( |     let body: serde_json::Value = serde_json::from_str( | ||||||
|         &globals |         &globals | ||||||
|             .reqwest_client() |             .reqwest_client() | ||||||
|  |             .ok()? | ||||||
|  |             .build() | ||||||
|  |             .ok()? | ||||||
|             .get(&format!( |             .get(&format!( | ||||||
|                 "https://{}/.well-known/matrix/server", |                 "https://{}/.well-known/matrix/server", | ||||||
|                 destination |                 destination | ||||||
|  | @ -1971,7 +2045,7 @@ fn append_incoming_pdu( | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| #[tracing::instrument(skip(starting_events, db))] | #[tracing::instrument(skip(starting_events, db))] | ||||||
| fn get_auth_chain( | pub fn get_auth_chain( | ||||||
|     starting_events: Vec<EventId>, |     starting_events: Vec<EventId>, | ||||||
|     db: &Database, |     db: &Database, | ||||||
| ) -> Result<impl Iterator<Item = EventId> + '_> { | ) -> Result<impl Iterator<Item = EventId> + '_> { | ||||||
|  |  | ||||||
		Loading…
	
		Reference in a new issue