Implement max_request_size config option

next
CapsizeGlimmer 2020-07-23 23:03:24 -04:00 committed by timokoesters
parent 3451b10a4b
commit fa2da9e048
No known key found for this signature in database
GPG Key ID: 24DA7517711A2BA4
3 changed files with 24 additions and 16 deletions

View File

@ -2977,11 +2977,11 @@ pub fn send_event_to_device_route(
} }
#[get("/_matrix/media/r0/config")] #[get("/_matrix/media/r0/config")]
pub fn get_media_config_route() -> ConduitResult<get_media_config::Response> { pub fn get_media_config_route(
Ok(get_media_config::Response { db: State<'_, Database>,
upload_size: (20_u32 * 1024 * 1024).into(), // 20 MB ) -> ConduitResult<get_media_config::Response> {
} let upload_size = db.globals.max_request_size().into();
.into()) Ok(get_media_config::Response { upload_size }.into())
} }
#[post("/_matrix/media/r0/upload", data = "<body>")] #[post("/_matrix/media/r0/upload", data = "<body>")]

View File

@ -1,7 +1,7 @@
use std::convert::TryInto;
use crate::{utils, Error, Result}; use crate::{utils, Error, Result};
use ruma::ServerName; use ruma::ServerName;
use std::convert::TryInto;
pub const COUNTER: &str = "c"; pub const COUNTER: &str = "c";
pub struct Globals { pub struct Globals {
@ -9,6 +9,7 @@ pub struct Globals {
keypair: ruma::signatures::Ed25519KeyPair, keypair: ruma::signatures::Ed25519KeyPair,
reqwest_client: reqwest::Client, reqwest_client: reqwest::Client,
server_name: Box<ServerName>, server_name: Box<ServerName>,
max_request_size: u32,
registration_disabled: bool, registration_disabled: bool,
encryption_disabled: bool, encryption_disabled: bool,
} }
@ -32,7 +33,12 @@ impl Globals {
.unwrap_or("localhost") .unwrap_or("localhost")
.to_string() .to_string()
.try_into() .try_into()
.map_err(|_| Error::BadConfig("Invalid server name found."))?, .map_err(|_| Error::BadConfig("Invalid server_name."))?,
max_request_size: config
.get_int("max_request_size")
.unwrap_or(20 * 1024 * 1024) // Default to 20 MB
.try_into()
.map_err(|_| Error::BadConfig("Invalid max_request_size."))?,
registration_disabled: config.get_bool("registration_disabled").unwrap_or(false), registration_disabled: config.get_bool("registration_disabled").unwrap_or(false),
encryption_disabled: config.get_bool("encryption_disabled").unwrap_or(false), encryption_disabled: config.get_bool("encryption_disabled").unwrap_or(false),
}) })
@ -69,6 +75,10 @@ impl Globals {
self.server_name.as_ref() self.server_name.as_ref()
} }
pub fn max_request_size(&self) -> u32 {
self.max_request_size
}
pub fn registration_disabled(&self) -> bool { pub fn registration_disabled(&self) -> bool {
self.registration_disabled self.registration_disabled
} }

View File

@ -11,8 +11,6 @@ use ruma::{api::Endpoint, DeviceId, UserId};
use std::{convert::TryInto, io::Cursor, ops::Deref}; use std::{convert::TryInto, io::Cursor, ops::Deref};
use tokio::io::AsyncReadExt; use tokio::io::AsyncReadExt;
const MESSAGE_LIMIT: u64 = 20 * 1024 * 1024; // 20 MB
/// This struct converts rocket requests into ruma structs by converting them into http requests /// This struct converts rocket requests into ruma structs by converting them into http requests
/// first. /// first.
pub struct Ruma<T> { pub struct Ruma<T> {
@ -40,13 +38,12 @@ impl<'a, T: Endpoint> FromTransformedData<'a> for Ruma<T> {
) -> FromDataFuture<'a, Self, Self::Error> { ) -> FromDataFuture<'a, Self, Self::Error> {
Box::pin(async move { Box::pin(async move {
let data = rocket::try_outcome!(outcome.owned()); let data = rocket::try_outcome!(outcome.owned());
let db = request
.guard::<State<'_, crate::Database>>()
.await
.expect("database was loaded");
let (user_id, device_id) = if T::METADATA.requires_authentication { let (user_id, device_id) = if T::METADATA.requires_authentication {
let db = request
.guard::<State<'_, crate::Database>>()
.await
.expect("database was loaded");
// Get token from header or query value // Get token from header or query value
let token = match request let token = match request
.headers() .headers()
@ -76,7 +73,8 @@ impl<'a, T: Endpoint> FromTransformedData<'a> for Ruma<T> {
http_request = http_request.header(header.name.as_str(), &*header.value); http_request = http_request.header(header.name.as_str(), &*header.value);
} }
let mut handle = data.open().take(MESSAGE_LIMIT); let limit = db.globals.max_request_size();
let mut handle = data.open().take(limit.into());
let mut body = Vec::new(); let mut body = Vec::new();
handle.read_to_end(&mut body).await.unwrap(); handle.read_to_end(&mut body).await.unwrap();