diff --git a/src/server_server.rs b/src/server_server.rs index 7abce5a..c47afab 100644 --- a/src/server_server.rs +++ b/src/server_server.rs @@ -24,30 +24,12 @@ use std::{ collections::BTreeMap, convert::TryFrom, fmt::Debug, + net::{IpAddr, SocketAddr}, time::{Duration, SystemTime}, }; -pub async fn request_well_known( - globals: &crate::database::globals::Globals, - destination: &str, -) -> Option { - let body: serde_json::Value = serde_json::from_str( - &globals - .reqwest_client() - .get(&format!( - "https://{}/.well-known/matrix/server", - destination - )) - .send() - .await - .ok()? - .text() - .await - .ok()?, - ) - .ok()?; - Some(body.get("m.server")?.as_str()?.to_owned()) -} + + pub async fn send_request( globals: &crate::database::globals::Globals, @@ -215,42 +197,130 @@ where } } +fn get_ip_with_port(destination_str: String) -> Option { + if destination_str.parse::().is_ok() { + Some(destination_str) + } else if let Ok(ip_addr) = destination_str.parse::() { + Some(SocketAddr::new(ip_addr, 8448).to_string()) + } else { + None + } +} + +fn add_port_to_hostname(destination_str: String) -> String { + match destination_str.find(':') { + None => destination_str.to_owned() + ":8448", + Some(_) => destination_str.to_string(), + } +} + /// Returns: actual_destination, host header +/// Implemented according to the specification at https://matrix.org/docs/spec/server_server/r0.1.4#resolving-server-names +/// Numbers in comments below refer to bullet points in linked section of specification async fn find_actual_destination( globals: &crate::database::globals::Globals, destination: &Box, ) -> (String, Option) { let mut host = None; + let destination_str = destination.as_str().to_owned(); let actual_destination = "https://".to_owned() - + &if let Some(mut delegated_hostname) = - request_well_known(globals, destination.as_str()).await - { - if let Ok(Some(srv)) = globals - .dns_resolver() - .srv_lookup(format!("_matrix._tcp.{}", delegated_hostname)) - .await - .map(|srv| srv.iter().next().map(|result| result.target().to_string())) - { - host = Some(delegated_hostname); - srv.trim_end_matches('.').to_owned() - } else { - if delegated_hostname.find(':').is_none() { - delegated_hostname += ":8448"; + + &match get_ip_with_port(destination_str.clone()) { + Some(host_port) => { + // 1: IP literal with provided or default port + host_port + } + None => { + if destination_str.find(':').is_some() { + // 2: Hostname with included port + destination_str + } else { + match request_well_known(globals, &destination.as_str()).await { + // 3: A .well-known file is available + Some(delegated_hostname) => { + match get_ip_with_port(delegated_hostname.clone()) { + Some(host_and_port) => host_and_port, // 3.1: IP literal in .well-known file + None => { + if destination_str.find(':').is_some() { + // 3.2: Hostname with port in .well-known file + destination_str + } else { + match query_srv_record(globals, &delegated_hostname).await { + // 3.3: SRV lookup successful + Some(hostname) => hostname, + // 3.4: No SRV records, just use the hostname from .well-known + None => add_port_to_hostname(delegated_hostname), + } + } + } + } + } + // 4: No .well-known or an error occured + None => { + match query_srv_record(globals, &destination_str).await { + // 4: SRV record found + Some(hostname) => { + host = Some(destination_str.to_owned()); + hostname + } + // 5: No SRV record found + None => add_port_to_hostname(destination_str.to_string()), + } + } + } } - delegated_hostname } - } else { - let mut destination = destination.as_str().to_owned(); - if destination.find(':').is_none() { - destination += ":8448"; - } - destination }; (actual_destination, host) } +async fn query_srv_record<'a>( + globals: &crate::database::globals::Globals, + hostname: &'a str, +) -> Option { + if let Ok(Some(host_port)) = globals + .dns_resolver() + .srv_lookup(format!("_matrix._tcp.{}", hostname)) + .await + .map(|srv| { + srv.iter().next().map(|result| { + format!( + "{}:{}", + result.target().to_string().trim_end_matches('.'), + result.port().to_string() + ) + }) + }) + { + Some(host_port) + } else { + None + } +} + +pub async fn request_well_known( + globals: &crate::database::globals::Globals, + destination: &str, +) -> Option { + let body: serde_json::Value = serde_json::from_str( + &globals + .reqwest_client() + .get(&format!( + "https://{}/.well-known/matrix/server", + destination + )) + .send() + .await + .ok()? + .text() + .await + .ok()?, + ) + .ok()?; + Some(body.get("m.server")?.as_str()?.to_owned()) +} + #[cfg_attr(feature = "conduit_bin", get("/_matrix/federation/v1/version"))] pub fn get_server_version_route( db: State<'_, Database>, @@ -622,3 +692,48 @@ pub fn get_user_devices_route<'a>( .into()) } */ + +#[cfg(test)] +mod tests { + use super::{add_port_to_hostname, get_ip_with_port}; + + #[test] + fn ips_get_default_ports() { + assert_eq!( + get_ip_with_port(String::from("1.1.1.1")), + Some(String::from("1.1.1.1:8448")) + ); + assert_eq!( + get_ip_with_port(String::from("dead:beef::")), + Some(String::from("[dead:beef::]:8448")) + ); + } + + #[test] + fn ips_keep_custom_ports() { + assert_eq!( + get_ip_with_port(String::from("1.1.1.1:1234")), + Some(String::from("1.1.1.1:1234")) + ); + assert_eq!( + get_ip_with_port(String::from("[dead::beef]:8933")), + Some(String::from("[dead::beef]:8933")) + ); + } + + #[test] + fn hostnames_get_default_ports() { + assert_eq!( + add_port_to_hostname(String::from("example.com")), + "example.com:8448" + ) + } + + #[test] + fn hostnames_keep_custom_ports() { + assert_eq!( + add_port_to_hostname(String::from("example.com:1337")), + "example.com:1337" + ) + } +}