From b2d55160585f15c93695f1dc60f3ca3eb7967911 Mon Sep 17 00:00:00 2001 From: Aiden McClelland Date: Tue, 13 Apr 2021 12:15:58 -0600 Subject: [PATCH 1/3] add support for arbitrary proxies --- Cargo.lock | 13 +++++ Cargo.toml | 2 +- src/database.rs | 121 ++++++++++++++++++++++++++++++++++++++++ src/database/globals.rs | 10 ++-- src/utils.rs | 27 +++++++++ 5 files changed, 168 insertions(+), 5 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c3d7408..c31894a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1761,6 +1761,7 @@ dependencies = [ "serde_urlencoded", "tokio", "tokio-rustls", + "tokio-socks", "url", "wasm-bindgen", "wasm-bindgen-futures", @@ -2732,6 +2733,18 @@ dependencies = [ "webpki", ] +[[package]] +name = "tokio-socks" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51165dfa029d2a65969413a6cc96f354b86b464498702f174a4efa13608fd8c0" +dependencies = [ + "either", + "futures-util", + "thiserror", + "tokio", +] + [[package]] name = "tokio-util" version = "0.6.6" diff --git a/Cargo.toml b/Cargo.toml index 96260ec..4f7095d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -47,7 +47,7 @@ rand = "0.8.3" # Used to hash passwords rust-argon2 = "0.8.3" # Used to send requests -reqwest = { version = "0.11.3", default-features = false, features = ["rustls-tls-native-roots"] } +reqwest = { version = "0.11.3", default-features = false, features = ["rustls-tls-native-roots", "socks"] } # Custom TLS verifier rustls = { version = "0.19", features = ["dangerous_configuration"] } rustls-native-certs = "0.5.0" diff --git a/src/database.rs b/src/database.rs index 2846928..52d92a5 100644 --- a/src/database.rs +++ b/src/database.rs @@ -46,6 +46,8 @@ pub struct Config { allow_federation: bool, #[serde(default = "false_fn")] pub allow_jaeger: bool, + #[serde(default)] + proxy: ProxyConfig, jwt_secret: Option, #[serde(default = "Vec::new")] trusted_servers: Vec>, @@ -83,6 +85,125 @@ pub type Engine = abstraction::SledEngine; #[cfg(feature = "rocksdb")] pub type Engine = abstraction::RocksDbEngine; +#[derive(Clone, Debug, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum ProxyConfig { + None, + Global { + #[serde(deserialize_with = "crate::utils::deserialize_from_str")] + url: reqwest::Url, + }, + ByDomain(Vec), +} +impl ProxyConfig { + pub fn to_proxy(&self) -> Result> { + Ok(match self.clone() { + ProxyConfig::None => None, + ProxyConfig::Global { url } => Some(reqwest::Proxy::all(url)?), + ProxyConfig::ByDomain(proxies) => Some(reqwest::Proxy::custom(move |url| { + proxies.iter().find_map(|proxy| proxy.for_url(url)).cloned() // first matching proxy + })), + }) + } +} +impl Default for ProxyConfig { + fn default() -> Self { + ProxyConfig::None + } +} + +#[derive(Clone, Debug, Deserialize)] +pub struct PartialProxyConfig { + #[serde(deserialize_with = "crate::utils::deserialize_from_str")] + url: reqwest::Url, + #[serde(default)] + include: Vec, + #[serde(default)] + exclude: Vec, +} +impl PartialProxyConfig { + pub fn for_url(&self, url: &reqwest::Url) -> Option<&reqwest::Url> { + let domain = url.domain()?; + let mut included_because = None; // most specific reason it was included + let mut excluded_because = None; // most specific reason it was excluded + if self.include.is_empty() { + // treat empty include list as `*` + included_because = Some(&WildCardedDomain::WildCard) + } + for wc_domain in &self.include { + if wc_domain.matches(domain) { + match included_because { + Some(prev) if !wc_domain.more_specific_than(prev) => (), + _ => included_because = Some(wc_domain), + } + } + } + for wc_domain in &self.exclude { + if wc_domain.matches(domain) { + match excluded_because { + Some(prev) if !wc_domain.more_specific_than(prev) => (), + _ => excluded_because = Some(wc_domain), + } + } + } + match (included_because, excluded_because) { + (Some(a), Some(b)) if a.more_specific_than(b) => Some(&self.url), // included for a more specific reason than excluded + (Some(_), None) => Some(&self.url), + _ => None, + } + } +} + +/// A domain name, that optionally allows a * as its first subdomain. +#[derive(Clone, Debug)] +pub enum WildCardedDomain { + WildCard, + WildCarded(String), + Exact(String), +} +impl WildCardedDomain { + pub fn matches(&self, domain: &str) -> bool { + match self { + WildCardedDomain::WildCard => true, + WildCardedDomain::WildCarded(d) => domain.ends_with(d), + WildCardedDomain::Exact(d) => domain == d, + } + } + pub fn more_specific_than(&self, other: &Self) -> bool { + match (self, other) { + (WildCardedDomain::WildCard, WildCardedDomain::WildCard) => false, + (_, WildCardedDomain::WildCard) => true, + (WildCardedDomain::Exact(a), WildCardedDomain::WildCarded(_)) => other.matches(a), + (WildCardedDomain::WildCarded(a), WildCardedDomain::WildCarded(b)) => { + a != b && a.ends_with(b) + } + _ => false, + } + } +} +impl std::str::FromStr for WildCardedDomain { + type Err = std::convert::Infallible; + fn from_str(s: &str) -> std::result::Result { + // maybe do some domain validation? + Ok(if s.starts_with("*.") { + WildCardedDomain::WildCarded(s[1..].to_owned()) + } else if s == "*" { + WildCardedDomain::WildCarded("".to_owned()) + } else { + WildCardedDomain::Exact(s.to_owned()) + }) + } +} +impl<'de> serde::de::Deserialize<'de> for WildCardedDomain { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::de::Deserializer<'de>, + { + crate::utils::deserialize_from_str(deserializer) + } +} + +#[derive(Clone)] pub struct Database { pub globals: globals::Globals, pub users: users::Users, diff --git a/src/database/globals.rs b/src/database/globals.rs index 1ce87bd..db166e9 100644 --- a/src/database/globals.rs +++ b/src/database/globals.rs @@ -125,13 +125,15 @@ impl Globals { tlsconfig.root_store = rustls_native_certs::load_native_certs().expect("Error loading system certificates"); - let reqwest_client = reqwest::Client::builder() + let mut reqwest_client_builder = reqwest::Client::builder() .connect_timeout(Duration::from_secs(30)) .timeout(Duration::from_secs(60 * 3)) .pool_max_idle_per_host(1) - .use_preconfigured_tls(tlsconfig) - .build() - .unwrap(); + .use_preconfigured_tls(tlsconfig); + if let Some(proxy) = config.proxy.to_proxy()? { + reqwest_client_builder = reqwest_client_builder.proxy(proxy); + } + let reqwest_client = reqwest_client_builder.build().unwrap(); let jwt_decoding_key = config .jwt_secret diff --git a/src/utils.rs b/src/utils.rs index 2b5336c..b8ce303 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -5,6 +5,7 @@ use ruma::serde::{try_from_json_map, CanonicalJsonError, CanonicalJsonObject}; use std::{ cmp, convert::TryInto, + str::FromStr, time::{SystemTime, UNIX_EPOCH}, }; @@ -115,3 +116,29 @@ pub fn to_canonical_object( ))), } } + +pub fn deserialize_from_str< + 'de, + D: serde::de::Deserializer<'de>, + T: FromStr, + E: std::fmt::Display, +>( + deserializer: D, +) -> std::result::Result { + struct Visitor, E>(std::marker::PhantomData); + impl<'de, T: FromStr, Err: std::fmt::Display> serde::de::Visitor<'de> + for Visitor + { + type Value = T; + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "a parsable string") + } + fn visit_str(self, v: &str) -> Result + where + E: serde::de::Error, + { + v.parse().map_err(|e| serde::de::Error::custom(e)) + } + } + deserializer.deserialize_str(Visitor(std::marker::PhantomData)) +} From f25f61d4a9e42d29704c357868074e45d24bd4df Mon Sep 17 00:00:00 2001 From: Aiden McClelland Date: Thu, 1 Jul 2021 12:48:12 -0600 Subject: [PATCH 2/3] fix errors introduced by rebase --- src/database.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/database.rs b/src/database.rs index 52d92a5..64b5ee3 100644 --- a/src/database.rs +++ b/src/database.rs @@ -203,7 +203,6 @@ impl<'de> serde::de::Deserialize<'de> for WildCardedDomain { } } -#[derive(Clone)] pub struct Database { pub globals: globals::Globals, pub users: users::Users, From c53cc03ff8db65b6b447a852eee85e540ad38cb1 Mon Sep 17 00:00:00 2001 From: Aiden McClelland Date: Thu, 1 Jul 2021 13:38:25 -0600 Subject: [PATCH 3/3] address pr comments --- conduit-example.toml | 2 + src/database.rs | 121 +--------------------------------- src/database/proxy.rs | 146 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 151 insertions(+), 118 deletions(-) create mode 100644 src/database/proxy.rs diff --git a/conduit-example.toml b/conduit-example.toml index 66c105b..db0bbb7 100644 --- a/conduit-example.toml +++ b/conduit-example.toml @@ -41,3 +41,5 @@ trusted_servers = ["matrix.org"] #workers = 4 # default: cpu core count * 2 address = "127.0.0.1" # This makes sure Conduit can only be reached using the reverse proxy + +proxy = "none" # more examples can be found at src/database/proxy.rs:6 diff --git a/src/database.rs b/src/database.rs index 64b5ee3..0ea4d78 100644 --- a/src/database.rs +++ b/src/database.rs @@ -6,6 +6,7 @@ pub mod appservice; pub mod globals; pub mod key_backups; pub mod media; +pub mod proxy; pub mod pusher; pub mod rooms; pub mod sending; @@ -28,6 +29,8 @@ use std::{ }; use tokio::sync::Semaphore; +use self::proxy::ProxyConfig; + #[derive(Clone, Debug, Deserialize)] pub struct Config { server_name: Box, @@ -85,124 +88,6 @@ pub type Engine = abstraction::SledEngine; #[cfg(feature = "rocksdb")] pub type Engine = abstraction::RocksDbEngine; -#[derive(Clone, Debug, Deserialize)] -#[serde(rename_all = "snake_case")] -pub enum ProxyConfig { - None, - Global { - #[serde(deserialize_with = "crate::utils::deserialize_from_str")] - url: reqwest::Url, - }, - ByDomain(Vec), -} -impl ProxyConfig { - pub fn to_proxy(&self) -> Result> { - Ok(match self.clone() { - ProxyConfig::None => None, - ProxyConfig::Global { url } => Some(reqwest::Proxy::all(url)?), - ProxyConfig::ByDomain(proxies) => Some(reqwest::Proxy::custom(move |url| { - proxies.iter().find_map(|proxy| proxy.for_url(url)).cloned() // first matching proxy - })), - }) - } -} -impl Default for ProxyConfig { - fn default() -> Self { - ProxyConfig::None - } -} - -#[derive(Clone, Debug, Deserialize)] -pub struct PartialProxyConfig { - #[serde(deserialize_with = "crate::utils::deserialize_from_str")] - url: reqwest::Url, - #[serde(default)] - include: Vec, - #[serde(default)] - exclude: Vec, -} -impl PartialProxyConfig { - pub fn for_url(&self, url: &reqwest::Url) -> Option<&reqwest::Url> { - let domain = url.domain()?; - let mut included_because = None; // most specific reason it was included - let mut excluded_because = None; // most specific reason it was excluded - if self.include.is_empty() { - // treat empty include list as `*` - included_because = Some(&WildCardedDomain::WildCard) - } - for wc_domain in &self.include { - if wc_domain.matches(domain) { - match included_because { - Some(prev) if !wc_domain.more_specific_than(prev) => (), - _ => included_because = Some(wc_domain), - } - } - } - for wc_domain in &self.exclude { - if wc_domain.matches(domain) { - match excluded_because { - Some(prev) if !wc_domain.more_specific_than(prev) => (), - _ => excluded_because = Some(wc_domain), - } - } - } - match (included_because, excluded_because) { - (Some(a), Some(b)) if a.more_specific_than(b) => Some(&self.url), // included for a more specific reason than excluded - (Some(_), None) => Some(&self.url), - _ => None, - } - } -} - -/// A domain name, that optionally allows a * as its first subdomain. -#[derive(Clone, Debug)] -pub enum WildCardedDomain { - WildCard, - WildCarded(String), - Exact(String), -} -impl WildCardedDomain { - pub fn matches(&self, domain: &str) -> bool { - match self { - WildCardedDomain::WildCard => true, - WildCardedDomain::WildCarded(d) => domain.ends_with(d), - WildCardedDomain::Exact(d) => domain == d, - } - } - pub fn more_specific_than(&self, other: &Self) -> bool { - match (self, other) { - (WildCardedDomain::WildCard, WildCardedDomain::WildCard) => false, - (_, WildCardedDomain::WildCard) => true, - (WildCardedDomain::Exact(a), WildCardedDomain::WildCarded(_)) => other.matches(a), - (WildCardedDomain::WildCarded(a), WildCardedDomain::WildCarded(b)) => { - a != b && a.ends_with(b) - } - _ => false, - } - } -} -impl std::str::FromStr for WildCardedDomain { - type Err = std::convert::Infallible; - fn from_str(s: &str) -> std::result::Result { - // maybe do some domain validation? - Ok(if s.starts_with("*.") { - WildCardedDomain::WildCarded(s[1..].to_owned()) - } else if s == "*" { - WildCardedDomain::WildCarded("".to_owned()) - } else { - WildCardedDomain::Exact(s.to_owned()) - }) - } -} -impl<'de> serde::de::Deserialize<'de> for WildCardedDomain { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::de::Deserializer<'de>, - { - crate::utils::deserialize_from_str(deserializer) - } -} - pub struct Database { pub globals: globals::Globals, pub users: users::Users, diff --git a/src/database/proxy.rs b/src/database/proxy.rs new file mode 100644 index 0000000..78e9d2b --- /dev/null +++ b/src/database/proxy.rs @@ -0,0 +1,146 @@ +use reqwest::{Proxy, Url}; +use serde::Deserialize; + +use crate::Result; + +/// ## Examples: +/// - No proxy (default): +/// ```toml +/// proxy ="none" +/// ``` +/// - Global proxy +/// ```toml +/// [proxy] +/// global = { url = "socks5h://localhost:9050" } +/// ``` +/// - Proxy some domains +/// ```toml +/// [proxy] +/// [[proxy.by_domain]] +/// url = "socks5h://localhost:9050" +/// include = ["*.onion", "matrix.myspecial.onion"] +/// exclude = ["*.myspecial.onion"] +/// ``` +/// ## Include vs. Exclude +/// If include is an empty list, it is assumed to be `["*"]`. +/// +/// If a domain matches both the exclude and include list, the proxy will only be used if it was +/// included because of a more specific rule than it was excluded. In the above example, the proxy +/// would be used for `ordinary.onion`, `matrix.myspecial.onion`, but not `hello.myspecial.onion`. +#[derive(Clone, Debug, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum ProxyConfig { + None, + Global { + #[serde(deserialize_with = "crate::utils::deserialize_from_str")] + url: Url, + }, + ByDomain(Vec), +} +impl ProxyConfig { + pub fn to_proxy(&self) -> Result> { + Ok(match self.clone() { + ProxyConfig::None => None, + ProxyConfig::Global { url } => Some(Proxy::all(url)?), + ProxyConfig::ByDomain(proxies) => Some(Proxy::custom(move |url| { + proxies.iter().find_map(|proxy| proxy.for_url(url)).cloned() // first matching proxy + })), + }) + } +} +impl Default for ProxyConfig { + fn default() -> Self { + ProxyConfig::None + } +} + +#[derive(Clone, Debug, Deserialize)] +pub struct PartialProxyConfig { + #[serde(deserialize_with = "crate::utils::deserialize_from_str")] + url: Url, + #[serde(default)] + include: Vec, + #[serde(default)] + exclude: Vec, +} +impl PartialProxyConfig { + pub fn for_url(&self, url: &Url) -> Option<&Url> { + let domain = url.domain()?; + let mut included_because = None; // most specific reason it was included + let mut excluded_because = None; // most specific reason it was excluded + if self.include.is_empty() { + // treat empty include list as `*` + included_because = Some(&WildCardedDomain::WildCard) + } + for wc_domain in &self.include { + if wc_domain.matches(domain) { + match included_because { + Some(prev) if !wc_domain.more_specific_than(prev) => (), + _ => included_because = Some(wc_domain), + } + } + } + for wc_domain in &self.exclude { + if wc_domain.matches(domain) { + match excluded_because { + Some(prev) if !wc_domain.more_specific_than(prev) => (), + _ => excluded_because = Some(wc_domain), + } + } + } + match (included_because, excluded_because) { + (Some(a), Some(b)) if a.more_specific_than(b) => Some(&self.url), // included for a more specific reason than excluded + (Some(_), None) => Some(&self.url), + _ => None, + } + } +} + +/// A domain name, that optionally allows a * as its first subdomain. +#[derive(Clone, Debug)] +pub enum WildCardedDomain { + WildCard, + WildCarded(String), + Exact(String), +} +impl WildCardedDomain { + pub fn matches(&self, domain: &str) -> bool { + match self { + WildCardedDomain::WildCard => true, + WildCardedDomain::WildCarded(d) => domain.ends_with(d), + WildCardedDomain::Exact(d) => domain == d, + } + } + pub fn more_specific_than(&self, other: &Self) -> bool { + match (self, other) { + (WildCardedDomain::WildCard, WildCardedDomain::WildCard) => false, + (_, WildCardedDomain::WildCard) => true, + (WildCardedDomain::Exact(a), WildCardedDomain::WildCarded(_)) => other.matches(a), + (WildCardedDomain::WildCarded(a), WildCardedDomain::WildCarded(b)) => { + a != b && a.ends_with(b) + } + _ => false, + } + } +} +impl std::str::FromStr for WildCardedDomain { + type Err = std::convert::Infallible; + fn from_str(s: &str) -> std::result::Result { + // maybe do some domain validation? + Ok(if s.starts_with("*.") { + WildCardedDomain::WildCarded(s[1..].to_owned()) + } else if s == "*" { + WildCardedDomain::WildCarded("".to_owned()) + } else { + WildCardedDomain::Exact(s.to_owned()) + }) + } +} +impl<'de> serde::de::Deserialize<'de> for WildCardedDomain { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::de::Deserializer<'de>, + { + crate::utils::deserialize_from_str(deserializer) + } +}