add support for arbitrary proxies
parent
cc9111059d
commit
b2d5516058
|
@ -1761,6 +1761,7 @@ dependencies = [
|
||||||
"serde_urlencoded",
|
"serde_urlencoded",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tokio-rustls",
|
"tokio-rustls",
|
||||||
|
"tokio-socks",
|
||||||
"url",
|
"url",
|
||||||
"wasm-bindgen",
|
"wasm-bindgen",
|
||||||
"wasm-bindgen-futures",
|
"wasm-bindgen-futures",
|
||||||
|
@ -2732,6 +2733,18 @@ dependencies = [
|
||||||
"webpki",
|
"webpki",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tokio-socks"
|
||||||
|
version = "0.5.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "51165dfa029d2a65969413a6cc96f354b86b464498702f174a4efa13608fd8c0"
|
||||||
|
dependencies = [
|
||||||
|
"either",
|
||||||
|
"futures-util",
|
||||||
|
"thiserror",
|
||||||
|
"tokio",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tokio-util"
|
name = "tokio-util"
|
||||||
version = "0.6.6"
|
version = "0.6.6"
|
||||||
|
|
|
@ -47,7 +47,7 @@ rand = "0.8.3"
|
||||||
# Used to hash passwords
|
# Used to hash passwords
|
||||||
rust-argon2 = "0.8.3"
|
rust-argon2 = "0.8.3"
|
||||||
# Used to send requests
|
# Used to send requests
|
||||||
reqwest = { version = "0.11.3", default-features = false, features = ["rustls-tls-native-roots"] }
|
reqwest = { version = "0.11.3", default-features = false, features = ["rustls-tls-native-roots", "socks"] }
|
||||||
# Custom TLS verifier
|
# Custom TLS verifier
|
||||||
rustls = { version = "0.19", features = ["dangerous_configuration"] }
|
rustls = { version = "0.19", features = ["dangerous_configuration"] }
|
||||||
rustls-native-certs = "0.5.0"
|
rustls-native-certs = "0.5.0"
|
||||||
|
|
121
src/database.rs
121
src/database.rs
|
@ -46,6 +46,8 @@ pub struct Config {
|
||||||
allow_federation: bool,
|
allow_federation: bool,
|
||||||
#[serde(default = "false_fn")]
|
#[serde(default = "false_fn")]
|
||||||
pub allow_jaeger: bool,
|
pub allow_jaeger: bool,
|
||||||
|
#[serde(default)]
|
||||||
|
proxy: ProxyConfig,
|
||||||
jwt_secret: Option<String>,
|
jwt_secret: Option<String>,
|
||||||
#[serde(default = "Vec::new")]
|
#[serde(default = "Vec::new")]
|
||||||
trusted_servers: Vec<Box<ServerName>>,
|
trusted_servers: Vec<Box<ServerName>>,
|
||||||
|
@ -83,6 +85,125 @@ pub type Engine = abstraction::SledEngine;
|
||||||
#[cfg(feature = "rocksdb")]
|
#[cfg(feature = "rocksdb")]
|
||||||
pub type Engine = abstraction::RocksDbEngine;
|
pub type Engine = abstraction::RocksDbEngine;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize)]
|
||||||
|
#[serde(rename_all = "snake_case")]
|
||||||
|
pub enum ProxyConfig {
|
||||||
|
None,
|
||||||
|
Global {
|
||||||
|
#[serde(deserialize_with = "crate::utils::deserialize_from_str")]
|
||||||
|
url: reqwest::Url,
|
||||||
|
},
|
||||||
|
ByDomain(Vec<PartialProxyConfig>),
|
||||||
|
}
|
||||||
|
impl ProxyConfig {
|
||||||
|
pub fn to_proxy(&self) -> Result<Option<reqwest::Proxy>> {
|
||||||
|
Ok(match self.clone() {
|
||||||
|
ProxyConfig::None => None,
|
||||||
|
ProxyConfig::Global { url } => Some(reqwest::Proxy::all(url)?),
|
||||||
|
ProxyConfig::ByDomain(proxies) => Some(reqwest::Proxy::custom(move |url| {
|
||||||
|
proxies.iter().find_map(|proxy| proxy.for_url(url)).cloned() // first matching proxy
|
||||||
|
})),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl Default for ProxyConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
ProxyConfig::None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize)]
|
||||||
|
pub struct PartialProxyConfig {
|
||||||
|
#[serde(deserialize_with = "crate::utils::deserialize_from_str")]
|
||||||
|
url: reqwest::Url,
|
||||||
|
#[serde(default)]
|
||||||
|
include: Vec<WildCardedDomain>,
|
||||||
|
#[serde(default)]
|
||||||
|
exclude: Vec<WildCardedDomain>,
|
||||||
|
}
|
||||||
|
impl PartialProxyConfig {
|
||||||
|
pub fn for_url(&self, url: &reqwest::Url) -> Option<&reqwest::Url> {
|
||||||
|
let domain = url.domain()?;
|
||||||
|
let mut included_because = None; // most specific reason it was included
|
||||||
|
let mut excluded_because = None; // most specific reason it was excluded
|
||||||
|
if self.include.is_empty() {
|
||||||
|
// treat empty include list as `*`
|
||||||
|
included_because = Some(&WildCardedDomain::WildCard)
|
||||||
|
}
|
||||||
|
for wc_domain in &self.include {
|
||||||
|
if wc_domain.matches(domain) {
|
||||||
|
match included_because {
|
||||||
|
Some(prev) if !wc_domain.more_specific_than(prev) => (),
|
||||||
|
_ => included_because = Some(wc_domain),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for wc_domain in &self.exclude {
|
||||||
|
if wc_domain.matches(domain) {
|
||||||
|
match excluded_because {
|
||||||
|
Some(prev) if !wc_domain.more_specific_than(prev) => (),
|
||||||
|
_ => excluded_because = Some(wc_domain),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
match (included_because, excluded_because) {
|
||||||
|
(Some(a), Some(b)) if a.more_specific_than(b) => Some(&self.url), // included for a more specific reason than excluded
|
||||||
|
(Some(_), None) => Some(&self.url),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A domain name, that optionally allows a * as its first subdomain.
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub enum WildCardedDomain {
|
||||||
|
WildCard,
|
||||||
|
WildCarded(String),
|
||||||
|
Exact(String),
|
||||||
|
}
|
||||||
|
impl WildCardedDomain {
|
||||||
|
pub fn matches(&self, domain: &str) -> bool {
|
||||||
|
match self {
|
||||||
|
WildCardedDomain::WildCard => true,
|
||||||
|
WildCardedDomain::WildCarded(d) => domain.ends_with(d),
|
||||||
|
WildCardedDomain::Exact(d) => domain == d,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pub fn more_specific_than(&self, other: &Self) -> bool {
|
||||||
|
match (self, other) {
|
||||||
|
(WildCardedDomain::WildCard, WildCardedDomain::WildCard) => false,
|
||||||
|
(_, WildCardedDomain::WildCard) => true,
|
||||||
|
(WildCardedDomain::Exact(a), WildCardedDomain::WildCarded(_)) => other.matches(a),
|
||||||
|
(WildCardedDomain::WildCarded(a), WildCardedDomain::WildCarded(b)) => {
|
||||||
|
a != b && a.ends_with(b)
|
||||||
|
}
|
||||||
|
_ => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl std::str::FromStr for WildCardedDomain {
|
||||||
|
type Err = std::convert::Infallible;
|
||||||
|
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
|
||||||
|
// maybe do some domain validation?
|
||||||
|
Ok(if s.starts_with("*.") {
|
||||||
|
WildCardedDomain::WildCarded(s[1..].to_owned())
|
||||||
|
} else if s == "*" {
|
||||||
|
WildCardedDomain::WildCarded("".to_owned())
|
||||||
|
} else {
|
||||||
|
WildCardedDomain::Exact(s.to_owned())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl<'de> serde::de::Deserialize<'de> for WildCardedDomain {
|
||||||
|
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
|
||||||
|
where
|
||||||
|
D: serde::de::Deserializer<'de>,
|
||||||
|
{
|
||||||
|
crate::utils::deserialize_from_str(deserializer)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
pub struct Database {
|
pub struct Database {
|
||||||
pub globals: globals::Globals,
|
pub globals: globals::Globals,
|
||||||
pub users: users::Users,
|
pub users: users::Users,
|
||||||
|
|
|
@ -125,13 +125,15 @@ impl Globals {
|
||||||
tlsconfig.root_store =
|
tlsconfig.root_store =
|
||||||
rustls_native_certs::load_native_certs().expect("Error loading system certificates");
|
rustls_native_certs::load_native_certs().expect("Error loading system certificates");
|
||||||
|
|
||||||
let reqwest_client = reqwest::Client::builder()
|
let mut reqwest_client_builder = reqwest::Client::builder()
|
||||||
.connect_timeout(Duration::from_secs(30))
|
.connect_timeout(Duration::from_secs(30))
|
||||||
.timeout(Duration::from_secs(60 * 3))
|
.timeout(Duration::from_secs(60 * 3))
|
||||||
.pool_max_idle_per_host(1)
|
.pool_max_idle_per_host(1)
|
||||||
.use_preconfigured_tls(tlsconfig)
|
.use_preconfigured_tls(tlsconfig);
|
||||||
.build()
|
if let Some(proxy) = config.proxy.to_proxy()? {
|
||||||
.unwrap();
|
reqwest_client_builder = reqwest_client_builder.proxy(proxy);
|
||||||
|
}
|
||||||
|
let reqwest_client = reqwest_client_builder.build().unwrap();
|
||||||
|
|
||||||
let jwt_decoding_key = config
|
let jwt_decoding_key = config
|
||||||
.jwt_secret
|
.jwt_secret
|
||||||
|
|
27
src/utils.rs
27
src/utils.rs
|
@ -5,6 +5,7 @@ use ruma::serde::{try_from_json_map, CanonicalJsonError, CanonicalJsonObject};
|
||||||
use std::{
|
use std::{
|
||||||
cmp,
|
cmp,
|
||||||
convert::TryInto,
|
convert::TryInto,
|
||||||
|
str::FromStr,
|
||||||
time::{SystemTime, UNIX_EPOCH},
|
time::{SystemTime, UNIX_EPOCH},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -115,3 +116,29 @@ pub fn to_canonical_object<T: serde::Serialize>(
|
||||||
))),
|
))),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn deserialize_from_str<
|
||||||
|
'de,
|
||||||
|
D: serde::de::Deserializer<'de>,
|
||||||
|
T: FromStr<Err = E>,
|
||||||
|
E: std::fmt::Display,
|
||||||
|
>(
|
||||||
|
deserializer: D,
|
||||||
|
) -> std::result::Result<T, D::Error> {
|
||||||
|
struct Visitor<T: FromStr<Err = E>, E>(std::marker::PhantomData<T>);
|
||||||
|
impl<'de, T: FromStr<Err = Err>, Err: std::fmt::Display> serde::de::Visitor<'de>
|
||||||
|
for Visitor<T, Err>
|
||||||
|
{
|
||||||
|
type Value = T;
|
||||||
|
fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
write!(formatter, "a parsable string")
|
||||||
|
}
|
||||||
|
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
|
||||||
|
where
|
||||||
|
E: serde::de::Error,
|
||||||
|
{
|
||||||
|
v.parse().map_err(|e| serde::de::Error::custom(e))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
deserializer.deserialize_str(Visitor(std::marker::PhantomData))
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue