Merge branch 'feature/proxy' into 'master'
add support for arbitrary proxies See merge request famedly/conduit!54
This commit is contained in:
		
						commit
						5f6b0c673c
					
				
					 7 changed files with 200 additions and 5 deletions
				
			
		
							
								
								
									
										13
									
								
								Cargo.lock
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										13
									
								
								Cargo.lock
									
									
									
										generated
									
									
									
								
							|  | @ -1761,6 +1761,7 @@ dependencies = [ | |||
|  "serde_urlencoded", | ||||
|  "tokio", | ||||
|  "tokio-rustls", | ||||
|  "tokio-socks", | ||||
|  "url", | ||||
|  "wasm-bindgen", | ||||
|  "wasm-bindgen-futures", | ||||
|  | @ -2732,6 +2733,18 @@ dependencies = [ | |||
|  "webpki", | ||||
| ] | ||||
| 
 | ||||
| [[package]] | ||||
| name = "tokio-socks" | ||||
| version = "0.5.1" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "51165dfa029d2a65969413a6cc96f354b86b464498702f174a4efa13608fd8c0" | ||||
| dependencies = [ | ||||
|  "either", | ||||
|  "futures-util", | ||||
|  "thiserror", | ||||
|  "tokio", | ||||
| ] | ||||
| 
 | ||||
| [[package]] | ||||
| name = "tokio-util" | ||||
| version = "0.6.6" | ||||
|  |  | |||
|  | @ -47,7 +47,7 @@ rand = "0.8.3" | |||
| # Used to hash passwords | ||||
| rust-argon2 = "0.8.3" | ||||
| # Used to send requests | ||||
| reqwest = { version = "0.11.3", default-features = false, features = ["rustls-tls-native-roots"] } | ||||
| reqwest = { version = "0.11.3", default-features = false, features = ["rustls-tls-native-roots", "socks"] } | ||||
| # Custom TLS verifier | ||||
| rustls = { version = "0.19", features = ["dangerous_configuration"] } | ||||
| rustls-native-certs = "0.5.0" | ||||
|  |  | |||
|  | @ -41,3 +41,5 @@ trusted_servers = ["matrix.org"] | |||
| #workers = 4 # default: cpu core count * 2 | ||||
| 
 | ||||
| address = "127.0.0.1" # This makes sure Conduit can only be reached using the reverse proxy | ||||
| 
 | ||||
| proxy = "none" # more examples can be found at src/database/proxy.rs:6 | ||||
|  |  | |||
|  | @ -6,6 +6,7 @@ pub mod appservice; | |||
| pub mod globals; | ||||
| pub mod key_backups; | ||||
| pub mod media; | ||||
| pub mod proxy; | ||||
| pub mod pusher; | ||||
| pub mod rooms; | ||||
| pub mod sending; | ||||
|  | @ -28,6 +29,8 @@ use std::{ | |||
| }; | ||||
| use tokio::sync::Semaphore; | ||||
| 
 | ||||
| use self::proxy::ProxyConfig; | ||||
| 
 | ||||
| #[derive(Clone, Debug, Deserialize)] | ||||
| pub struct Config { | ||||
|     server_name: Box<ServerName>, | ||||
|  | @ -46,6 +49,8 @@ pub struct Config { | |||
|     allow_federation: bool, | ||||
|     #[serde(default = "false_fn")] | ||||
|     pub allow_jaeger: bool, | ||||
|     #[serde(default)] | ||||
|     proxy: ProxyConfig, | ||||
|     jwt_secret: Option<String>, | ||||
|     #[serde(default = "Vec::new")] | ||||
|     trusted_servers: Vec<Box<ServerName>>, | ||||
|  |  | |||
|  | @ -125,13 +125,15 @@ impl Globals { | |||
|         tlsconfig.root_store = | ||||
|             rustls_native_certs::load_native_certs().expect("Error loading system certificates"); | ||||
| 
 | ||||
|         let reqwest_client = reqwest::Client::builder() | ||||
|         let mut reqwest_client_builder = reqwest::Client::builder() | ||||
|             .connect_timeout(Duration::from_secs(30)) | ||||
|             .timeout(Duration::from_secs(60 * 3)) | ||||
|             .pool_max_idle_per_host(1) | ||||
|             .use_preconfigured_tls(tlsconfig) | ||||
|             .build() | ||||
|             .unwrap(); | ||||
|             .use_preconfigured_tls(tlsconfig); | ||||
|         if let Some(proxy) = config.proxy.to_proxy()? { | ||||
|             reqwest_client_builder = reqwest_client_builder.proxy(proxy); | ||||
|         } | ||||
|         let reqwest_client = reqwest_client_builder.build().unwrap(); | ||||
| 
 | ||||
|         let jwt_decoding_key = config | ||||
|             .jwt_secret | ||||
|  |  | |||
							
								
								
									
										146
									
								
								src/database/proxy.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										146
									
								
								src/database/proxy.rs
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,146 @@ | |||
| use reqwest::{Proxy, Url}; | ||||
| use serde::Deserialize; | ||||
| 
 | ||||
| use crate::Result; | ||||
| 
 | ||||
| /// ## Examples:
 | ||||
| /// - No proxy (default):
 | ||||
| /// ```toml
 | ||||
| /// proxy ="none"
 | ||||
| /// ```
 | ||||
| /// - Global proxy
 | ||||
| /// ```toml
 | ||||
| /// [proxy]
 | ||||
| /// global = { url = "socks5h://localhost:9050" }
 | ||||
| /// ```
 | ||||
| /// - Proxy some domains
 | ||||
| /// ```toml
 | ||||
| /// [proxy]
 | ||||
| /// [[proxy.by_domain]]
 | ||||
| /// url = "socks5h://localhost:9050"
 | ||||
| /// include = ["*.onion", "matrix.myspecial.onion"]
 | ||||
| /// exclude = ["*.myspecial.onion"]
 | ||||
| /// ```
 | ||||
| /// ## Include vs. Exclude
 | ||||
| /// If include is an empty list, it is assumed to be `["*"]`.
 | ||||
| ///
 | ||||
| /// If a domain matches both the exclude and include list, the proxy will only be used if it was
 | ||||
| /// included because of a more specific rule than it was excluded. In the above example, the proxy
 | ||||
| /// would be used for `ordinary.onion`, `matrix.myspecial.onion`, but not `hello.myspecial.onion`.
 | ||||
| #[derive(Clone, Debug, Deserialize)] | ||||
| #[serde(rename_all = "snake_case")] | ||||
| pub enum ProxyConfig { | ||||
|     None, | ||||
|     Global { | ||||
|         #[serde(deserialize_with = "crate::utils::deserialize_from_str")] | ||||
|         url: Url, | ||||
|     }, | ||||
|     ByDomain(Vec<PartialProxyConfig>), | ||||
| } | ||||
| impl ProxyConfig { | ||||
|     pub fn to_proxy(&self) -> Result<Option<Proxy>> { | ||||
|         Ok(match self.clone() { | ||||
|             ProxyConfig::None => None, | ||||
|             ProxyConfig::Global { url } => Some(Proxy::all(url)?), | ||||
|             ProxyConfig::ByDomain(proxies) => Some(Proxy::custom(move |url| { | ||||
|                 proxies.iter().find_map(|proxy| proxy.for_url(url)).cloned() // first matching proxy
 | ||||
|             })), | ||||
|         }) | ||||
|     } | ||||
| } | ||||
| impl Default for ProxyConfig { | ||||
|     fn default() -> Self { | ||||
|         ProxyConfig::None | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| #[derive(Clone, Debug, Deserialize)] | ||||
| pub struct PartialProxyConfig { | ||||
|     #[serde(deserialize_with = "crate::utils::deserialize_from_str")] | ||||
|     url: Url, | ||||
|     #[serde(default)] | ||||
|     include: Vec<WildCardedDomain>, | ||||
|     #[serde(default)] | ||||
|     exclude: Vec<WildCardedDomain>, | ||||
| } | ||||
| impl PartialProxyConfig { | ||||
|     pub fn for_url(&self, url: &Url) -> Option<&Url> { | ||||
|         let domain = url.domain()?; | ||||
|         let mut included_because = None; // most specific reason it was included
 | ||||
|         let mut excluded_because = None; // most specific reason it was excluded
 | ||||
|         if self.include.is_empty() { | ||||
|             // treat empty include list as `*`
 | ||||
|             included_because = Some(&WildCardedDomain::WildCard) | ||||
|         } | ||||
|         for wc_domain in &self.include { | ||||
|             if wc_domain.matches(domain) { | ||||
|                 match included_because { | ||||
|                     Some(prev) if !wc_domain.more_specific_than(prev) => (), | ||||
|                     _ => included_because = Some(wc_domain), | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|         for wc_domain in &self.exclude { | ||||
|             if wc_domain.matches(domain) { | ||||
|                 match excluded_because { | ||||
|                     Some(prev) if !wc_domain.more_specific_than(prev) => (), | ||||
|                     _ => excluded_because = Some(wc_domain), | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|         match (included_because, excluded_because) { | ||||
|             (Some(a), Some(b)) if a.more_specific_than(b) => Some(&self.url), // included for a more specific reason than excluded
 | ||||
|             (Some(_), None) => Some(&self.url), | ||||
|             _ => None, | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| /// A domain name, that optionally allows a * as its first subdomain.
 | ||||
| #[derive(Clone, Debug)] | ||||
| pub enum WildCardedDomain { | ||||
|     WildCard, | ||||
|     WildCarded(String), | ||||
|     Exact(String), | ||||
| } | ||||
| impl WildCardedDomain { | ||||
|     pub fn matches(&self, domain: &str) -> bool { | ||||
|         match self { | ||||
|             WildCardedDomain::WildCard => true, | ||||
|             WildCardedDomain::WildCarded(d) => domain.ends_with(d), | ||||
|             WildCardedDomain::Exact(d) => domain == d, | ||||
|         } | ||||
|     } | ||||
|     pub fn more_specific_than(&self, other: &Self) -> bool { | ||||
|         match (self, other) { | ||||
|             (WildCardedDomain::WildCard, WildCardedDomain::WildCard) => false, | ||||
|             (_, WildCardedDomain::WildCard) => true, | ||||
|             (WildCardedDomain::Exact(a), WildCardedDomain::WildCarded(_)) => other.matches(a), | ||||
|             (WildCardedDomain::WildCarded(a), WildCardedDomain::WildCarded(b)) => { | ||||
|                 a != b && a.ends_with(b) | ||||
|             } | ||||
|             _ => false, | ||||
|         } | ||||
|     } | ||||
| } | ||||
| impl std::str::FromStr for WildCardedDomain { | ||||
|     type Err = std::convert::Infallible; | ||||
|     fn from_str(s: &str) -> std::result::Result<Self, Self::Err> { | ||||
|         // maybe do some domain validation?
 | ||||
|         Ok(if s.starts_with("*.") { | ||||
|             WildCardedDomain::WildCarded(s[1..].to_owned()) | ||||
|         } else if s == "*" { | ||||
|             WildCardedDomain::WildCarded("".to_owned()) | ||||
|         } else { | ||||
|             WildCardedDomain::Exact(s.to_owned()) | ||||
|         }) | ||||
|     } | ||||
| } | ||||
| impl<'de> serde::de::Deserialize<'de> for WildCardedDomain { | ||||
|     fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error> | ||||
|     where | ||||
|         D: serde::de::Deserializer<'de>, | ||||
|     { | ||||
|         crate::utils::deserialize_from_str(deserializer) | ||||
|     } | ||||
| } | ||||
							
								
								
									
										27
									
								
								src/utils.rs
									
									
									
									
									
								
							
							
						
						
									
										27
									
								
								src/utils.rs
									
									
									
									
									
								
							|  | @ -5,6 +5,7 @@ use ruma::serde::{try_from_json_map, CanonicalJsonError, CanonicalJsonObject}; | |||
| use std::{ | ||||
|     cmp, | ||||
|     convert::TryInto, | ||||
|     str::FromStr, | ||||
|     time::{SystemTime, UNIX_EPOCH}, | ||||
| }; | ||||
| 
 | ||||
|  | @ -115,3 +116,29 @@ pub fn to_canonical_object<T: serde::Serialize>( | |||
|         ))), | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| pub fn deserialize_from_str< | ||||
|     'de, | ||||
|     D: serde::de::Deserializer<'de>, | ||||
|     T: FromStr<Err = E>, | ||||
|     E: std::fmt::Display, | ||||
| >( | ||||
|     deserializer: D, | ||||
| ) -> std::result::Result<T, D::Error> { | ||||
|     struct Visitor<T: FromStr<Err = E>, E>(std::marker::PhantomData<T>); | ||||
|     impl<'de, T: FromStr<Err = Err>, Err: std::fmt::Display> serde::de::Visitor<'de> | ||||
|         for Visitor<T, Err> | ||||
|     { | ||||
|         type Value = T; | ||||
|         fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | ||||
|             write!(formatter, "a parsable string") | ||||
|         } | ||||
|         fn visit_str<E>(self, v: &str) -> Result<Self::Value, E> | ||||
|         where | ||||
|             E: serde::de::Error, | ||||
|         { | ||||
|             v.parse().map_err(|e| serde::de::Error::custom(e)) | ||||
|         } | ||||
|     } | ||||
|     deserializer.deserialize_str(Visitor(std::marker::PhantomData)) | ||||
| } | ||||
|  |  | |||
		Loading…
	
		Reference in a new issue