address pr comments
This commit is contained in:
		
							parent
							
								
									f25f61d4a9
								
							
						
					
					
						commit
						c53cc03ff8
					
				
					 3 changed files with 151 additions and 118 deletions
				
			
		|  | @ -41,3 +41,5 @@ trusted_servers = ["matrix.org"] | ||||||
| #workers = 4 # default: cpu core count * 2 | #workers = 4 # default: cpu core count * 2 | ||||||
| 
 | 
 | ||||||
| address = "127.0.0.1" # This makes sure Conduit can only be reached using the reverse proxy | address = "127.0.0.1" # This makes sure Conduit can only be reached using the reverse proxy | ||||||
|  | 
 | ||||||
|  | proxy = "none" # more examples can be found at src/database/proxy.rs:6 | ||||||
|  |  | ||||||
							
								
								
									
										121
									
								
								src/database.rs
									
									
									
									
									
								
							
							
						
						
									
										121
									
								
								src/database.rs
									
									
									
									
									
								
							|  | @ -6,6 +6,7 @@ pub mod appservice; | ||||||
| pub mod globals; | pub mod globals; | ||||||
| pub mod key_backups; | pub mod key_backups; | ||||||
| pub mod media; | pub mod media; | ||||||
|  | pub mod proxy; | ||||||
| pub mod pusher; | pub mod pusher; | ||||||
| pub mod rooms; | pub mod rooms; | ||||||
| pub mod sending; | pub mod sending; | ||||||
|  | @ -28,6 +29,8 @@ use std::{ | ||||||
| }; | }; | ||||||
| use tokio::sync::Semaphore; | use tokio::sync::Semaphore; | ||||||
| 
 | 
 | ||||||
|  | use self::proxy::ProxyConfig; | ||||||
|  | 
 | ||||||
| #[derive(Clone, Debug, Deserialize)] | #[derive(Clone, Debug, Deserialize)] | ||||||
| pub struct Config { | pub struct Config { | ||||||
|     server_name: Box<ServerName>, |     server_name: Box<ServerName>, | ||||||
|  | @ -85,124 +88,6 @@ pub type Engine = abstraction::SledEngine; | ||||||
| #[cfg(feature = "rocksdb")] | #[cfg(feature = "rocksdb")] | ||||||
| pub type Engine = abstraction::RocksDbEngine; | pub type Engine = abstraction::RocksDbEngine; | ||||||
| 
 | 
 | ||||||
| #[derive(Clone, Debug, Deserialize)] |  | ||||||
| #[serde(rename_all = "snake_case")] |  | ||||||
| pub enum ProxyConfig { |  | ||||||
|     None, |  | ||||||
|     Global { |  | ||||||
|         #[serde(deserialize_with = "crate::utils::deserialize_from_str")] |  | ||||||
|         url: reqwest::Url, |  | ||||||
|     }, |  | ||||||
|     ByDomain(Vec<PartialProxyConfig>), |  | ||||||
| } |  | ||||||
| impl ProxyConfig { |  | ||||||
|     pub fn to_proxy(&self) -> Result<Option<reqwest::Proxy>> { |  | ||||||
|         Ok(match self.clone() { |  | ||||||
|             ProxyConfig::None => None, |  | ||||||
|             ProxyConfig::Global { url } => Some(reqwest::Proxy::all(url)?), |  | ||||||
|             ProxyConfig::ByDomain(proxies) => Some(reqwest::Proxy::custom(move |url| { |  | ||||||
|                 proxies.iter().find_map(|proxy| proxy.for_url(url)).cloned() // first matching proxy
 |  | ||||||
|             })), |  | ||||||
|         }) |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| impl Default for ProxyConfig { |  | ||||||
|     fn default() -> Self { |  | ||||||
|         ProxyConfig::None |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| #[derive(Clone, Debug, Deserialize)] |  | ||||||
| pub struct PartialProxyConfig { |  | ||||||
|     #[serde(deserialize_with = "crate::utils::deserialize_from_str")] |  | ||||||
|     url: reqwest::Url, |  | ||||||
|     #[serde(default)] |  | ||||||
|     include: Vec<WildCardedDomain>, |  | ||||||
|     #[serde(default)] |  | ||||||
|     exclude: Vec<WildCardedDomain>, |  | ||||||
| } |  | ||||||
| impl PartialProxyConfig { |  | ||||||
|     pub fn for_url(&self, url: &reqwest::Url) -> Option<&reqwest::Url> { |  | ||||||
|         let domain = url.domain()?; |  | ||||||
|         let mut included_because = None; // most specific reason it was included
 |  | ||||||
|         let mut excluded_because = None; // most specific reason it was excluded
 |  | ||||||
|         if self.include.is_empty() { |  | ||||||
|             // treat empty include list as `*`
 |  | ||||||
|             included_because = Some(&WildCardedDomain::WildCard) |  | ||||||
|         } |  | ||||||
|         for wc_domain in &self.include { |  | ||||||
|             if wc_domain.matches(domain) { |  | ||||||
|                 match included_because { |  | ||||||
|                     Some(prev) if !wc_domain.more_specific_than(prev) => (), |  | ||||||
|                     _ => included_because = Some(wc_domain), |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|         for wc_domain in &self.exclude { |  | ||||||
|             if wc_domain.matches(domain) { |  | ||||||
|                 match excluded_because { |  | ||||||
|                     Some(prev) if !wc_domain.more_specific_than(prev) => (), |  | ||||||
|                     _ => excluded_because = Some(wc_domain), |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|         match (included_because, excluded_because) { |  | ||||||
|             (Some(a), Some(b)) if a.more_specific_than(b) => Some(&self.url), // included for a more specific reason than excluded
 |  | ||||||
|             (Some(_), None) => Some(&self.url), |  | ||||||
|             _ => None, |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| /// A domain name, that optionally allows a * as its first subdomain.
 |  | ||||||
| #[derive(Clone, Debug)] |  | ||||||
| pub enum WildCardedDomain { |  | ||||||
|     WildCard, |  | ||||||
|     WildCarded(String), |  | ||||||
|     Exact(String), |  | ||||||
| } |  | ||||||
| impl WildCardedDomain { |  | ||||||
|     pub fn matches(&self, domain: &str) -> bool { |  | ||||||
|         match self { |  | ||||||
|             WildCardedDomain::WildCard => true, |  | ||||||
|             WildCardedDomain::WildCarded(d) => domain.ends_with(d), |  | ||||||
|             WildCardedDomain::Exact(d) => domain == d, |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
|     pub fn more_specific_than(&self, other: &Self) -> bool { |  | ||||||
|         match (self, other) { |  | ||||||
|             (WildCardedDomain::WildCard, WildCardedDomain::WildCard) => false, |  | ||||||
|             (_, WildCardedDomain::WildCard) => true, |  | ||||||
|             (WildCardedDomain::Exact(a), WildCardedDomain::WildCarded(_)) => other.matches(a), |  | ||||||
|             (WildCardedDomain::WildCarded(a), WildCardedDomain::WildCarded(b)) => { |  | ||||||
|                 a != b && a.ends_with(b) |  | ||||||
|             } |  | ||||||
|             _ => false, |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| impl std::str::FromStr for WildCardedDomain { |  | ||||||
|     type Err = std::convert::Infallible; |  | ||||||
|     fn from_str(s: &str) -> std::result::Result<Self, Self::Err> { |  | ||||||
|         // maybe do some domain validation?
 |  | ||||||
|         Ok(if s.starts_with("*.") { |  | ||||||
|             WildCardedDomain::WildCarded(s[1..].to_owned()) |  | ||||||
|         } else if s == "*" { |  | ||||||
|             WildCardedDomain::WildCarded("".to_owned()) |  | ||||||
|         } else { |  | ||||||
|             WildCardedDomain::Exact(s.to_owned()) |  | ||||||
|         }) |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| impl<'de> serde::de::Deserialize<'de> for WildCardedDomain { |  | ||||||
|     fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error> |  | ||||||
|     where |  | ||||||
|         D: serde::de::Deserializer<'de>, |  | ||||||
|     { |  | ||||||
|         crate::utils::deserialize_from_str(deserializer) |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| pub struct Database { | pub struct Database { | ||||||
|     pub globals: globals::Globals, |     pub globals: globals::Globals, | ||||||
|     pub users: users::Users, |     pub users: users::Users, | ||||||
|  |  | ||||||
							
								
								
									
										146
									
								
								src/database/proxy.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										146
									
								
								src/database/proxy.rs
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,146 @@ | ||||||
|  | use reqwest::{Proxy, Url}; | ||||||
|  | use serde::Deserialize; | ||||||
|  | 
 | ||||||
|  | use crate::Result; | ||||||
|  | 
 | ||||||
|  | /// ## Examples:
 | ||||||
|  | /// - No proxy (default):
 | ||||||
|  | /// ```toml
 | ||||||
|  | /// proxy ="none"
 | ||||||
|  | /// ```
 | ||||||
|  | /// - Global proxy
 | ||||||
|  | /// ```toml
 | ||||||
|  | /// [proxy]
 | ||||||
|  | /// global = { url = "socks5h://localhost:9050" }
 | ||||||
|  | /// ```
 | ||||||
|  | /// - Proxy some domains
 | ||||||
|  | /// ```toml
 | ||||||
|  | /// [proxy]
 | ||||||
|  | /// [[proxy.by_domain]]
 | ||||||
|  | /// url = "socks5h://localhost:9050"
 | ||||||
|  | /// include = ["*.onion", "matrix.myspecial.onion"]
 | ||||||
|  | /// exclude = ["*.myspecial.onion"]
 | ||||||
|  | /// ```
 | ||||||
|  | /// ## Include vs. Exclude
 | ||||||
|  | /// If include is an empty list, it is assumed to be `["*"]`.
 | ||||||
|  | ///
 | ||||||
|  | /// If a domain matches both the exclude and include list, the proxy will only be used if it was
 | ||||||
|  | /// included because of a more specific rule than it was excluded. In the above example, the proxy
 | ||||||
|  | /// would be used for `ordinary.onion`, `matrix.myspecial.onion`, but not `hello.myspecial.onion`.
 | ||||||
|  | #[derive(Clone, Debug, Deserialize)] | ||||||
|  | #[serde(rename_all = "snake_case")] | ||||||
|  | pub enum ProxyConfig { | ||||||
|  |     None, | ||||||
|  |     Global { | ||||||
|  |         #[serde(deserialize_with = "crate::utils::deserialize_from_str")] | ||||||
|  |         url: Url, | ||||||
|  |     }, | ||||||
|  |     ByDomain(Vec<PartialProxyConfig>), | ||||||
|  | } | ||||||
|  | impl ProxyConfig { | ||||||
|  |     pub fn to_proxy(&self) -> Result<Option<Proxy>> { | ||||||
|  |         Ok(match self.clone() { | ||||||
|  |             ProxyConfig::None => None, | ||||||
|  |             ProxyConfig::Global { url } => Some(Proxy::all(url)?), | ||||||
|  |             ProxyConfig::ByDomain(proxies) => Some(Proxy::custom(move |url| { | ||||||
|  |                 proxies.iter().find_map(|proxy| proxy.for_url(url)).cloned() // first matching proxy
 | ||||||
|  |             })), | ||||||
|  |         }) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | impl Default for ProxyConfig { | ||||||
|  |     fn default() -> Self { | ||||||
|  |         ProxyConfig::None | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | #[derive(Clone, Debug, Deserialize)] | ||||||
|  | pub struct PartialProxyConfig { | ||||||
|  |     #[serde(deserialize_with = "crate::utils::deserialize_from_str")] | ||||||
|  |     url: Url, | ||||||
|  |     #[serde(default)] | ||||||
|  |     include: Vec<WildCardedDomain>, | ||||||
|  |     #[serde(default)] | ||||||
|  |     exclude: Vec<WildCardedDomain>, | ||||||
|  | } | ||||||
|  | impl PartialProxyConfig { | ||||||
|  |     pub fn for_url(&self, url: &Url) -> Option<&Url> { | ||||||
|  |         let domain = url.domain()?; | ||||||
|  |         let mut included_because = None; // most specific reason it was included
 | ||||||
|  |         let mut excluded_because = None; // most specific reason it was excluded
 | ||||||
|  |         if self.include.is_empty() { | ||||||
|  |             // treat empty include list as `*`
 | ||||||
|  |             included_because = Some(&WildCardedDomain::WildCard) | ||||||
|  |         } | ||||||
|  |         for wc_domain in &self.include { | ||||||
|  |             if wc_domain.matches(domain) { | ||||||
|  |                 match included_because { | ||||||
|  |                     Some(prev) if !wc_domain.more_specific_than(prev) => (), | ||||||
|  |                     _ => included_because = Some(wc_domain), | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |         for wc_domain in &self.exclude { | ||||||
|  |             if wc_domain.matches(domain) { | ||||||
|  |                 match excluded_because { | ||||||
|  |                     Some(prev) if !wc_domain.more_specific_than(prev) => (), | ||||||
|  |                     _ => excluded_because = Some(wc_domain), | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |         match (included_because, excluded_because) { | ||||||
|  |             (Some(a), Some(b)) if a.more_specific_than(b) => Some(&self.url), // included for a more specific reason than excluded
 | ||||||
|  |             (Some(_), None) => Some(&self.url), | ||||||
|  |             _ => None, | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | /// A domain name, that optionally allows a * as its first subdomain.
 | ||||||
|  | #[derive(Clone, Debug)] | ||||||
|  | pub enum WildCardedDomain { | ||||||
|  |     WildCard, | ||||||
|  |     WildCarded(String), | ||||||
|  |     Exact(String), | ||||||
|  | } | ||||||
|  | impl WildCardedDomain { | ||||||
|  |     pub fn matches(&self, domain: &str) -> bool { | ||||||
|  |         match self { | ||||||
|  |             WildCardedDomain::WildCard => true, | ||||||
|  |             WildCardedDomain::WildCarded(d) => domain.ends_with(d), | ||||||
|  |             WildCardedDomain::Exact(d) => domain == d, | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |     pub fn more_specific_than(&self, other: &Self) -> bool { | ||||||
|  |         match (self, other) { | ||||||
|  |             (WildCardedDomain::WildCard, WildCardedDomain::WildCard) => false, | ||||||
|  |             (_, WildCardedDomain::WildCard) => true, | ||||||
|  |             (WildCardedDomain::Exact(a), WildCardedDomain::WildCarded(_)) => other.matches(a), | ||||||
|  |             (WildCardedDomain::WildCarded(a), WildCardedDomain::WildCarded(b)) => { | ||||||
|  |                 a != b && a.ends_with(b) | ||||||
|  |             } | ||||||
|  |             _ => false, | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | impl std::str::FromStr for WildCardedDomain { | ||||||
|  |     type Err = std::convert::Infallible; | ||||||
|  |     fn from_str(s: &str) -> std::result::Result<Self, Self::Err> { | ||||||
|  |         // maybe do some domain validation?
 | ||||||
|  |         Ok(if s.starts_with("*.") { | ||||||
|  |             WildCardedDomain::WildCarded(s[1..].to_owned()) | ||||||
|  |         } else if s == "*" { | ||||||
|  |             WildCardedDomain::WildCarded("".to_owned()) | ||||||
|  |         } else { | ||||||
|  |             WildCardedDomain::Exact(s.to_owned()) | ||||||
|  |         }) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | impl<'de> serde::de::Deserialize<'de> for WildCardedDomain { | ||||||
|  |     fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error> | ||||||
|  |     where | ||||||
|  |         D: serde::de::Deserializer<'de>, | ||||||
|  |     { | ||||||
|  |         crate::utils::deserialize_from_str(deserializer) | ||||||
|  |     } | ||||||
|  | } | ||||||
		Loading…
	
		Reference in a new issue