add support for arbitrary proxies
This commit is contained in:
		
							parent
							
								
									cc9111059d
								
							
						
					
					
						commit
						b2d5516058
					
				
					 5 changed files with 168 additions and 5 deletions
				
			
		
							
								
								
									
										13
									
								
								Cargo.lock
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										13
									
								
								Cargo.lock
									
									
									
										generated
									
									
									
								
							|  | @ -1761,6 +1761,7 @@ dependencies = [ | ||||||
|  "serde_urlencoded", |  "serde_urlencoded", | ||||||
|  "tokio", |  "tokio", | ||||||
|  "tokio-rustls", |  "tokio-rustls", | ||||||
|  |  "tokio-socks", | ||||||
|  "url", |  "url", | ||||||
|  "wasm-bindgen", |  "wasm-bindgen", | ||||||
|  "wasm-bindgen-futures", |  "wasm-bindgen-futures", | ||||||
|  | @ -2732,6 +2733,18 @@ dependencies = [ | ||||||
|  "webpki", |  "webpki", | ||||||
| ] | ] | ||||||
| 
 | 
 | ||||||
|  | [[package]] | ||||||
|  | name = "tokio-socks" | ||||||
|  | version = "0.5.1" | ||||||
|  | source = "registry+https://github.com/rust-lang/crates.io-index" | ||||||
|  | checksum = "51165dfa029d2a65969413a6cc96f354b86b464498702f174a4efa13608fd8c0" | ||||||
|  | dependencies = [ | ||||||
|  |  "either", | ||||||
|  |  "futures-util", | ||||||
|  |  "thiserror", | ||||||
|  |  "tokio", | ||||||
|  | ] | ||||||
|  | 
 | ||||||
| [[package]] | [[package]] | ||||||
| name = "tokio-util" | name = "tokio-util" | ||||||
| version = "0.6.6" | version = "0.6.6" | ||||||
|  |  | ||||||
|  | @ -47,7 +47,7 @@ rand = "0.8.3" | ||||||
| # Used to hash passwords | # Used to hash passwords | ||||||
| rust-argon2 = "0.8.3" | rust-argon2 = "0.8.3" | ||||||
| # Used to send requests | # Used to send requests | ||||||
| reqwest = { version = "0.11.3", default-features = false, features = ["rustls-tls-native-roots"] } | reqwest = { version = "0.11.3", default-features = false, features = ["rustls-tls-native-roots", "socks"] } | ||||||
| # Custom TLS verifier | # Custom TLS verifier | ||||||
| rustls = { version = "0.19", features = ["dangerous_configuration"] } | rustls = { version = "0.19", features = ["dangerous_configuration"] } | ||||||
| rustls-native-certs = "0.5.0" | rustls-native-certs = "0.5.0" | ||||||
|  |  | ||||||
							
								
								
									
										121
									
								
								src/database.rs
									
									
									
									
									
								
							
							
						
						
									
										121
									
								
								src/database.rs
									
									
									
									
									
								
							|  | @ -46,6 +46,8 @@ pub struct Config { | ||||||
|     allow_federation: bool, |     allow_federation: bool, | ||||||
|     #[serde(default = "false_fn")] |     #[serde(default = "false_fn")] | ||||||
|     pub allow_jaeger: bool, |     pub allow_jaeger: bool, | ||||||
|  |     #[serde(default)] | ||||||
|  |     proxy: ProxyConfig, | ||||||
|     jwt_secret: Option<String>, |     jwt_secret: Option<String>, | ||||||
|     #[serde(default = "Vec::new")] |     #[serde(default = "Vec::new")] | ||||||
|     trusted_servers: Vec<Box<ServerName>>, |     trusted_servers: Vec<Box<ServerName>>, | ||||||
|  | @ -83,6 +85,125 @@ pub type Engine = abstraction::SledEngine; | ||||||
| #[cfg(feature = "rocksdb")] | #[cfg(feature = "rocksdb")] | ||||||
| pub type Engine = abstraction::RocksDbEngine; | pub type Engine = abstraction::RocksDbEngine; | ||||||
| 
 | 
 | ||||||
|  | #[derive(Clone, Debug, Deserialize)] | ||||||
|  | #[serde(rename_all = "snake_case")] | ||||||
|  | pub enum ProxyConfig { | ||||||
|  |     None, | ||||||
|  |     Global { | ||||||
|  |         #[serde(deserialize_with = "crate::utils::deserialize_from_str")] | ||||||
|  |         url: reqwest::Url, | ||||||
|  |     }, | ||||||
|  |     ByDomain(Vec<PartialProxyConfig>), | ||||||
|  | } | ||||||
|  | impl ProxyConfig { | ||||||
|  |     pub fn to_proxy(&self) -> Result<Option<reqwest::Proxy>> { | ||||||
|  |         Ok(match self.clone() { | ||||||
|  |             ProxyConfig::None => None, | ||||||
|  |             ProxyConfig::Global { url } => Some(reqwest::Proxy::all(url)?), | ||||||
|  |             ProxyConfig::ByDomain(proxies) => Some(reqwest::Proxy::custom(move |url| { | ||||||
|  |                 proxies.iter().find_map(|proxy| proxy.for_url(url)).cloned() // first matching proxy
 | ||||||
|  |             })), | ||||||
|  |         }) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | impl Default for ProxyConfig { | ||||||
|  |     fn default() -> Self { | ||||||
|  |         ProxyConfig::None | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | #[derive(Clone, Debug, Deserialize)] | ||||||
|  | pub struct PartialProxyConfig { | ||||||
|  |     #[serde(deserialize_with = "crate::utils::deserialize_from_str")] | ||||||
|  |     url: reqwest::Url, | ||||||
|  |     #[serde(default)] | ||||||
|  |     include: Vec<WildCardedDomain>, | ||||||
|  |     #[serde(default)] | ||||||
|  |     exclude: Vec<WildCardedDomain>, | ||||||
|  | } | ||||||
|  | impl PartialProxyConfig { | ||||||
|  |     pub fn for_url(&self, url: &reqwest::Url) -> Option<&reqwest::Url> { | ||||||
|  |         let domain = url.domain()?; | ||||||
|  |         let mut included_because = None; // most specific reason it was included
 | ||||||
|  |         let mut excluded_because = None; // most specific reason it was excluded
 | ||||||
|  |         if self.include.is_empty() { | ||||||
|  |             // treat empty include list as `*`
 | ||||||
|  |             included_because = Some(&WildCardedDomain::WildCard) | ||||||
|  |         } | ||||||
|  |         for wc_domain in &self.include { | ||||||
|  |             if wc_domain.matches(domain) { | ||||||
|  |                 match included_because { | ||||||
|  |                     Some(prev) if !wc_domain.more_specific_than(prev) => (), | ||||||
|  |                     _ => included_because = Some(wc_domain), | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |         for wc_domain in &self.exclude { | ||||||
|  |             if wc_domain.matches(domain) { | ||||||
|  |                 match excluded_because { | ||||||
|  |                     Some(prev) if !wc_domain.more_specific_than(prev) => (), | ||||||
|  |                     _ => excluded_because = Some(wc_domain), | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |         match (included_because, excluded_because) { | ||||||
|  |             (Some(a), Some(b)) if a.more_specific_than(b) => Some(&self.url), // included for a more specific reason than excluded
 | ||||||
|  |             (Some(_), None) => Some(&self.url), | ||||||
|  |             _ => None, | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | /// A domain name, that optionally allows a * as its first subdomain.
 | ||||||
|  | #[derive(Clone, Debug)] | ||||||
|  | pub enum WildCardedDomain { | ||||||
|  |     WildCard, | ||||||
|  |     WildCarded(String), | ||||||
|  |     Exact(String), | ||||||
|  | } | ||||||
|  | impl WildCardedDomain { | ||||||
|  |     pub fn matches(&self, domain: &str) -> bool { | ||||||
|  |         match self { | ||||||
|  |             WildCardedDomain::WildCard => true, | ||||||
|  |             WildCardedDomain::WildCarded(d) => domain.ends_with(d), | ||||||
|  |             WildCardedDomain::Exact(d) => domain == d, | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |     pub fn more_specific_than(&self, other: &Self) -> bool { | ||||||
|  |         match (self, other) { | ||||||
|  |             (WildCardedDomain::WildCard, WildCardedDomain::WildCard) => false, | ||||||
|  |             (_, WildCardedDomain::WildCard) => true, | ||||||
|  |             (WildCardedDomain::Exact(a), WildCardedDomain::WildCarded(_)) => other.matches(a), | ||||||
|  |             (WildCardedDomain::WildCarded(a), WildCardedDomain::WildCarded(b)) => { | ||||||
|  |                 a != b && a.ends_with(b) | ||||||
|  |             } | ||||||
|  |             _ => false, | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | impl std::str::FromStr for WildCardedDomain { | ||||||
|  |     type Err = std::convert::Infallible; | ||||||
|  |     fn from_str(s: &str) -> std::result::Result<Self, Self::Err> { | ||||||
|  |         // maybe do some domain validation?
 | ||||||
|  |         Ok(if s.starts_with("*.") { | ||||||
|  |             WildCardedDomain::WildCarded(s[1..].to_owned()) | ||||||
|  |         } else if s == "*" { | ||||||
|  |             WildCardedDomain::WildCarded("".to_owned()) | ||||||
|  |         } else { | ||||||
|  |             WildCardedDomain::Exact(s.to_owned()) | ||||||
|  |         }) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | impl<'de> serde::de::Deserialize<'de> for WildCardedDomain { | ||||||
|  |     fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error> | ||||||
|  |     where | ||||||
|  |         D: serde::de::Deserializer<'de>, | ||||||
|  |     { | ||||||
|  |         crate::utils::deserialize_from_str(deserializer) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | #[derive(Clone)] | ||||||
| pub struct Database { | pub struct Database { | ||||||
|     pub globals: globals::Globals, |     pub globals: globals::Globals, | ||||||
|     pub users: users::Users, |     pub users: users::Users, | ||||||
|  |  | ||||||
|  | @ -125,13 +125,15 @@ impl Globals { | ||||||
|         tlsconfig.root_store = |         tlsconfig.root_store = | ||||||
|             rustls_native_certs::load_native_certs().expect("Error loading system certificates"); |             rustls_native_certs::load_native_certs().expect("Error loading system certificates"); | ||||||
| 
 | 
 | ||||||
|         let reqwest_client = reqwest::Client::builder() |         let mut reqwest_client_builder = reqwest::Client::builder() | ||||||
|             .connect_timeout(Duration::from_secs(30)) |             .connect_timeout(Duration::from_secs(30)) | ||||||
|             .timeout(Duration::from_secs(60 * 3)) |             .timeout(Duration::from_secs(60 * 3)) | ||||||
|             .pool_max_idle_per_host(1) |             .pool_max_idle_per_host(1) | ||||||
|             .use_preconfigured_tls(tlsconfig) |             .use_preconfigured_tls(tlsconfig); | ||||||
|             .build() |         if let Some(proxy) = config.proxy.to_proxy()? { | ||||||
|             .unwrap(); |             reqwest_client_builder = reqwest_client_builder.proxy(proxy); | ||||||
|  |         } | ||||||
|  |         let reqwest_client = reqwest_client_builder.build().unwrap(); | ||||||
| 
 | 
 | ||||||
|         let jwt_decoding_key = config |         let jwt_decoding_key = config | ||||||
|             .jwt_secret |             .jwt_secret | ||||||
|  |  | ||||||
							
								
								
									
										27
									
								
								src/utils.rs
									
									
									
									
									
								
							
							
						
						
									
										27
									
								
								src/utils.rs
									
									
									
									
									
								
							|  | @ -5,6 +5,7 @@ use ruma::serde::{try_from_json_map, CanonicalJsonError, CanonicalJsonObject}; | ||||||
| use std::{ | use std::{ | ||||||
|     cmp, |     cmp, | ||||||
|     convert::TryInto, |     convert::TryInto, | ||||||
|  |     str::FromStr, | ||||||
|     time::{SystemTime, UNIX_EPOCH}, |     time::{SystemTime, UNIX_EPOCH}, | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
|  | @ -115,3 +116,29 @@ pub fn to_canonical_object<T: serde::Serialize>( | ||||||
|         ))), |         ))), | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | pub fn deserialize_from_str< | ||||||
|  |     'de, | ||||||
|  |     D: serde::de::Deserializer<'de>, | ||||||
|  |     T: FromStr<Err = E>, | ||||||
|  |     E: std::fmt::Display, | ||||||
|  | >( | ||||||
|  |     deserializer: D, | ||||||
|  | ) -> std::result::Result<T, D::Error> { | ||||||
|  |     struct Visitor<T: FromStr<Err = E>, E>(std::marker::PhantomData<T>); | ||||||
|  |     impl<'de, T: FromStr<Err = Err>, Err: std::fmt::Display> serde::de::Visitor<'de> | ||||||
|  |         for Visitor<T, Err> | ||||||
|  |     { | ||||||
|  |         type Value = T; | ||||||
|  |         fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | ||||||
|  |             write!(formatter, "a parsable string") | ||||||
|  |         } | ||||||
|  |         fn visit_str<E>(self, v: &str) -> Result<Self::Value, E> | ||||||
|  |         where | ||||||
|  |             E: serde::de::Error, | ||||||
|  |         { | ||||||
|  |             v.parse().map_err(|e| serde::de::Error::custom(e)) | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |     deserializer.deserialize_str(Visitor(std::marker::PhantomData)) | ||||||
|  | } | ||||||
|  |  | ||||||
		Loading…
	
		Reference in a new issue