From b2d55160585f15c93695f1dc60f3ca3eb7967911 Mon Sep 17 00:00:00 2001 From: Aiden McClelland Date: Tue, 13 Apr 2021 12:15:58 -0600 Subject: [PATCH] add support for arbitrary proxies --- Cargo.lock | 13 +++++ Cargo.toml | 2 +- src/database.rs | 121 ++++++++++++++++++++++++++++++++++++++++ src/database/globals.rs | 10 ++-- src/utils.rs | 27 +++++++++ 5 files changed, 168 insertions(+), 5 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c3d7408..c31894a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1761,6 +1761,7 @@ dependencies = [ "serde_urlencoded", "tokio", "tokio-rustls", + "tokio-socks", "url", "wasm-bindgen", "wasm-bindgen-futures", @@ -2732,6 +2733,18 @@ dependencies = [ "webpki", ] +[[package]] +name = "tokio-socks" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51165dfa029d2a65969413a6cc96f354b86b464498702f174a4efa13608fd8c0" +dependencies = [ + "either", + "futures-util", + "thiserror", + "tokio", +] + [[package]] name = "tokio-util" version = "0.6.6" diff --git a/Cargo.toml b/Cargo.toml index 96260ec..4f7095d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -47,7 +47,7 @@ rand = "0.8.3" # Used to hash passwords rust-argon2 = "0.8.3" # Used to send requests -reqwest = { version = "0.11.3", default-features = false, features = ["rustls-tls-native-roots"] } +reqwest = { version = "0.11.3", default-features = false, features = ["rustls-tls-native-roots", "socks"] } # Custom TLS verifier rustls = { version = "0.19", features = ["dangerous_configuration"] } rustls-native-certs = "0.5.0" diff --git a/src/database.rs b/src/database.rs index 2846928..52d92a5 100644 --- a/src/database.rs +++ b/src/database.rs @@ -46,6 +46,8 @@ pub struct Config { allow_federation: bool, #[serde(default = "false_fn")] pub allow_jaeger: bool, + #[serde(default)] + proxy: ProxyConfig, jwt_secret: Option, #[serde(default = "Vec::new")] trusted_servers: Vec>, @@ -83,6 +85,125 @@ pub type Engine = abstraction::SledEngine; #[cfg(feature = "rocksdb")] pub type Engine = abstraction::RocksDbEngine; +#[derive(Clone, Debug, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum ProxyConfig { + None, + Global { + #[serde(deserialize_with = "crate::utils::deserialize_from_str")] + url: reqwest::Url, + }, + ByDomain(Vec), +} +impl ProxyConfig { + pub fn to_proxy(&self) -> Result> { + Ok(match self.clone() { + ProxyConfig::None => None, + ProxyConfig::Global { url } => Some(reqwest::Proxy::all(url)?), + ProxyConfig::ByDomain(proxies) => Some(reqwest::Proxy::custom(move |url| { + proxies.iter().find_map(|proxy| proxy.for_url(url)).cloned() // first matching proxy + })), + }) + } +} +impl Default for ProxyConfig { + fn default() -> Self { + ProxyConfig::None + } +} + +#[derive(Clone, Debug, Deserialize)] +pub struct PartialProxyConfig { + #[serde(deserialize_with = "crate::utils::deserialize_from_str")] + url: reqwest::Url, + #[serde(default)] + include: Vec, + #[serde(default)] + exclude: Vec, +} +impl PartialProxyConfig { + pub fn for_url(&self, url: &reqwest::Url) -> Option<&reqwest::Url> { + let domain = url.domain()?; + let mut included_because = None; // most specific reason it was included + let mut excluded_because = None; // most specific reason it was excluded + if self.include.is_empty() { + // treat empty include list as `*` + included_because = Some(&WildCardedDomain::WildCard) + } + for wc_domain in &self.include { + if wc_domain.matches(domain) { + match included_because { + Some(prev) if !wc_domain.more_specific_than(prev) => (), + _ => included_because = Some(wc_domain), + } + } + } + for wc_domain in &self.exclude { + if wc_domain.matches(domain) { + match excluded_because { + Some(prev) if !wc_domain.more_specific_than(prev) => (), + _ => excluded_because = Some(wc_domain), + } + } + } + match (included_because, excluded_because) { + (Some(a), Some(b)) if a.more_specific_than(b) => Some(&self.url), // included for a more specific reason than excluded + (Some(_), None) => Some(&self.url), + _ => None, + } + } +} + +/// A domain name, that optionally allows a * as its first subdomain. +#[derive(Clone, Debug)] +pub enum WildCardedDomain { + WildCard, + WildCarded(String), + Exact(String), +} +impl WildCardedDomain { + pub fn matches(&self, domain: &str) -> bool { + match self { + WildCardedDomain::WildCard => true, + WildCardedDomain::WildCarded(d) => domain.ends_with(d), + WildCardedDomain::Exact(d) => domain == d, + } + } + pub fn more_specific_than(&self, other: &Self) -> bool { + match (self, other) { + (WildCardedDomain::WildCard, WildCardedDomain::WildCard) => false, + (_, WildCardedDomain::WildCard) => true, + (WildCardedDomain::Exact(a), WildCardedDomain::WildCarded(_)) => other.matches(a), + (WildCardedDomain::WildCarded(a), WildCardedDomain::WildCarded(b)) => { + a != b && a.ends_with(b) + } + _ => false, + } + } +} +impl std::str::FromStr for WildCardedDomain { + type Err = std::convert::Infallible; + fn from_str(s: &str) -> std::result::Result { + // maybe do some domain validation? + Ok(if s.starts_with("*.") { + WildCardedDomain::WildCarded(s[1..].to_owned()) + } else if s == "*" { + WildCardedDomain::WildCarded("".to_owned()) + } else { + WildCardedDomain::Exact(s.to_owned()) + }) + } +} +impl<'de> serde::de::Deserialize<'de> for WildCardedDomain { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::de::Deserializer<'de>, + { + crate::utils::deserialize_from_str(deserializer) + } +} + +#[derive(Clone)] pub struct Database { pub globals: globals::Globals, pub users: users::Users, diff --git a/src/database/globals.rs b/src/database/globals.rs index 1ce87bd..db166e9 100644 --- a/src/database/globals.rs +++ b/src/database/globals.rs @@ -125,13 +125,15 @@ impl Globals { tlsconfig.root_store = rustls_native_certs::load_native_certs().expect("Error loading system certificates"); - let reqwest_client = reqwest::Client::builder() + let mut reqwest_client_builder = reqwest::Client::builder() .connect_timeout(Duration::from_secs(30)) .timeout(Duration::from_secs(60 * 3)) .pool_max_idle_per_host(1) - .use_preconfigured_tls(tlsconfig) - .build() - .unwrap(); + .use_preconfigured_tls(tlsconfig); + if let Some(proxy) = config.proxy.to_proxy()? { + reqwest_client_builder = reqwest_client_builder.proxy(proxy); + } + let reqwest_client = reqwest_client_builder.build().unwrap(); let jwt_decoding_key = config .jwt_secret diff --git a/src/utils.rs b/src/utils.rs index 2b5336c..b8ce303 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -5,6 +5,7 @@ use ruma::serde::{try_from_json_map, CanonicalJsonError, CanonicalJsonObject}; use std::{ cmp, convert::TryInto, + str::FromStr, time::{SystemTime, UNIX_EPOCH}, }; @@ -115,3 +116,29 @@ pub fn to_canonical_object( ))), } } + +pub fn deserialize_from_str< + 'de, + D: serde::de::Deserializer<'de>, + T: FromStr, + E: std::fmt::Display, +>( + deserializer: D, +) -> std::result::Result { + struct Visitor, E>(std::marker::PhantomData); + impl<'de, T: FromStr, Err: std::fmt::Display> serde::de::Visitor<'de> + for Visitor + { + type Value = T; + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "a parsable string") + } + fn visit_str(self, v: &str) -> Result + where + E: serde::de::Error, + { + v.parse().map_err(|e| serde::de::Error::custom(e)) + } + } + deserializer.deserialize_str(Visitor(std::marker::PhantomData)) +}