diff --git a/Cargo.lock b/Cargo.lock index c3d7408..c31894a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1761,6 +1761,7 @@ dependencies = [ "serde_urlencoded", "tokio", "tokio-rustls", + "tokio-socks", "url", "wasm-bindgen", "wasm-bindgen-futures", @@ -2732,6 +2733,18 @@ dependencies = [ "webpki", ] +[[package]] +name = "tokio-socks" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51165dfa029d2a65969413a6cc96f354b86b464498702f174a4efa13608fd8c0" +dependencies = [ + "either", + "futures-util", + "thiserror", + "tokio", +] + [[package]] name = "tokio-util" version = "0.6.6" diff --git a/Cargo.toml b/Cargo.toml index 96260ec..4f7095d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -47,7 +47,7 @@ rand = "0.8.3" # Used to hash passwords rust-argon2 = "0.8.3" # Used to send requests -reqwest = { version = "0.11.3", default-features = false, features = ["rustls-tls-native-roots"] } +reqwest = { version = "0.11.3", default-features = false, features = ["rustls-tls-native-roots", "socks"] } # Custom TLS verifier rustls = { version = "0.19", features = ["dangerous_configuration"] } rustls-native-certs = "0.5.0" diff --git a/conduit-example.toml b/conduit-example.toml index 66c105b..db0bbb7 100644 --- a/conduit-example.toml +++ b/conduit-example.toml @@ -41,3 +41,5 @@ trusted_servers = ["matrix.org"] #workers = 4 # default: cpu core count * 2 address = "127.0.0.1" # This makes sure Conduit can only be reached using the reverse proxy + +proxy = "none" # more examples can be found at src/database/proxy.rs:6 diff --git a/src/database.rs b/src/database.rs index 2846928..0ea4d78 100644 --- a/src/database.rs +++ b/src/database.rs @@ -6,6 +6,7 @@ pub mod appservice; pub mod globals; pub mod key_backups; pub mod media; +pub mod proxy; pub mod pusher; pub mod rooms; pub mod sending; @@ -28,6 +29,8 @@ use std::{ }; use tokio::sync::Semaphore; +use self::proxy::ProxyConfig; + #[derive(Clone, Debug, Deserialize)] pub struct Config { server_name: Box, @@ -46,6 +49,8 @@ pub struct Config { allow_federation: bool, #[serde(default = "false_fn")] pub allow_jaeger: bool, + #[serde(default)] + proxy: ProxyConfig, jwt_secret: Option, #[serde(default = "Vec::new")] trusted_servers: Vec>, diff --git a/src/database/globals.rs b/src/database/globals.rs index 1ce87bd..db166e9 100644 --- a/src/database/globals.rs +++ b/src/database/globals.rs @@ -125,13 +125,15 @@ impl Globals { tlsconfig.root_store = rustls_native_certs::load_native_certs().expect("Error loading system certificates"); - let reqwest_client = reqwest::Client::builder() + let mut reqwest_client_builder = reqwest::Client::builder() .connect_timeout(Duration::from_secs(30)) .timeout(Duration::from_secs(60 * 3)) .pool_max_idle_per_host(1) - .use_preconfigured_tls(tlsconfig) - .build() - .unwrap(); + .use_preconfigured_tls(tlsconfig); + if let Some(proxy) = config.proxy.to_proxy()? { + reqwest_client_builder = reqwest_client_builder.proxy(proxy); + } + let reqwest_client = reqwest_client_builder.build().unwrap(); let jwt_decoding_key = config .jwt_secret diff --git a/src/database/proxy.rs b/src/database/proxy.rs new file mode 100644 index 0000000..78e9d2b --- /dev/null +++ b/src/database/proxy.rs @@ -0,0 +1,146 @@ +use reqwest::{Proxy, Url}; +use serde::Deserialize; + +use crate::Result; + +/// ## Examples: +/// - No proxy (default): +/// ```toml +/// proxy ="none" +/// ``` +/// - Global proxy +/// ```toml +/// [proxy] +/// global = { url = "socks5h://localhost:9050" } +/// ``` +/// - Proxy some domains +/// ```toml +/// [proxy] +/// [[proxy.by_domain]] +/// url = "socks5h://localhost:9050" +/// include = ["*.onion", "matrix.myspecial.onion"] +/// exclude = ["*.myspecial.onion"] +/// ``` +/// ## Include vs. Exclude +/// If include is an empty list, it is assumed to be `["*"]`. +/// +/// If a domain matches both the exclude and include list, the proxy will only be used if it was +/// included because of a more specific rule than it was excluded. In the above example, the proxy +/// would be used for `ordinary.onion`, `matrix.myspecial.onion`, but not `hello.myspecial.onion`. +#[derive(Clone, Debug, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum ProxyConfig { + None, + Global { + #[serde(deserialize_with = "crate::utils::deserialize_from_str")] + url: Url, + }, + ByDomain(Vec), +} +impl ProxyConfig { + pub fn to_proxy(&self) -> Result> { + Ok(match self.clone() { + ProxyConfig::None => None, + ProxyConfig::Global { url } => Some(Proxy::all(url)?), + ProxyConfig::ByDomain(proxies) => Some(Proxy::custom(move |url| { + proxies.iter().find_map(|proxy| proxy.for_url(url)).cloned() // first matching proxy + })), + }) + } +} +impl Default for ProxyConfig { + fn default() -> Self { + ProxyConfig::None + } +} + +#[derive(Clone, Debug, Deserialize)] +pub struct PartialProxyConfig { + #[serde(deserialize_with = "crate::utils::deserialize_from_str")] + url: Url, + #[serde(default)] + include: Vec, + #[serde(default)] + exclude: Vec, +} +impl PartialProxyConfig { + pub fn for_url(&self, url: &Url) -> Option<&Url> { + let domain = url.domain()?; + let mut included_because = None; // most specific reason it was included + let mut excluded_because = None; // most specific reason it was excluded + if self.include.is_empty() { + // treat empty include list as `*` + included_because = Some(&WildCardedDomain::WildCard) + } + for wc_domain in &self.include { + if wc_domain.matches(domain) { + match included_because { + Some(prev) if !wc_domain.more_specific_than(prev) => (), + _ => included_because = Some(wc_domain), + } + } + } + for wc_domain in &self.exclude { + if wc_domain.matches(domain) { + match excluded_because { + Some(prev) if !wc_domain.more_specific_than(prev) => (), + _ => excluded_because = Some(wc_domain), + } + } + } + match (included_because, excluded_because) { + (Some(a), Some(b)) if a.more_specific_than(b) => Some(&self.url), // included for a more specific reason than excluded + (Some(_), None) => Some(&self.url), + _ => None, + } + } +} + +/// A domain name, that optionally allows a * as its first subdomain. +#[derive(Clone, Debug)] +pub enum WildCardedDomain { + WildCard, + WildCarded(String), + Exact(String), +} +impl WildCardedDomain { + pub fn matches(&self, domain: &str) -> bool { + match self { + WildCardedDomain::WildCard => true, + WildCardedDomain::WildCarded(d) => domain.ends_with(d), + WildCardedDomain::Exact(d) => domain == d, + } + } + pub fn more_specific_than(&self, other: &Self) -> bool { + match (self, other) { + (WildCardedDomain::WildCard, WildCardedDomain::WildCard) => false, + (_, WildCardedDomain::WildCard) => true, + (WildCardedDomain::Exact(a), WildCardedDomain::WildCarded(_)) => other.matches(a), + (WildCardedDomain::WildCarded(a), WildCardedDomain::WildCarded(b)) => { + a != b && a.ends_with(b) + } + _ => false, + } + } +} +impl std::str::FromStr for WildCardedDomain { + type Err = std::convert::Infallible; + fn from_str(s: &str) -> std::result::Result { + // maybe do some domain validation? + Ok(if s.starts_with("*.") { + WildCardedDomain::WildCarded(s[1..].to_owned()) + } else if s == "*" { + WildCardedDomain::WildCarded("".to_owned()) + } else { + WildCardedDomain::Exact(s.to_owned()) + }) + } +} +impl<'de> serde::de::Deserialize<'de> for WildCardedDomain { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::de::Deserializer<'de>, + { + crate::utils::deserialize_from_str(deserializer) + } +} diff --git a/src/utils.rs b/src/utils.rs index 2b5336c..b8ce303 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -5,6 +5,7 @@ use ruma::serde::{try_from_json_map, CanonicalJsonError, CanonicalJsonObject}; use std::{ cmp, convert::TryInto, + str::FromStr, time::{SystemTime, UNIX_EPOCH}, }; @@ -115,3 +116,29 @@ pub fn to_canonical_object( ))), } } + +pub fn deserialize_from_str< + 'de, + D: serde::de::Deserializer<'de>, + T: FromStr, + E: std::fmt::Display, +>( + deserializer: D, +) -> std::result::Result { + struct Visitor, E>(std::marker::PhantomData); + impl<'de, T: FromStr, Err: std::fmt::Display> serde::de::Visitor<'de> + for Visitor + { + type Value = T; + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "a parsable string") + } + fn visit_str(self, v: &str) -> Result + where + E: serde::de::Error, + { + v.parse().map_err(|e| serde::de::Error::custom(e)) + } + } + deserializer.deserialize_str(Visitor(std::marker::PhantomData)) +}