Implement max_request_size config option
This commit is contained in:
		
							parent
							
								
									3451b10a4b
								
							
						
					
					
						commit
						fa2da9e048
					
				
					 3 changed files with 24 additions and 16 deletions
				
			
		|  | @ -2977,11 +2977,11 @@ pub fn send_event_to_device_route( | |||
| } | ||||
| 
 | ||||
| #[get("/_matrix/media/r0/config")] | ||||
| pub fn get_media_config_route() -> ConduitResult<get_media_config::Response> { | ||||
|     Ok(get_media_config::Response { | ||||
|         upload_size: (20_u32 * 1024 * 1024).into(), // 20 MB
 | ||||
|     } | ||||
|     .into()) | ||||
| pub fn get_media_config_route( | ||||
|     db: State<'_, Database>, | ||||
| ) -> ConduitResult<get_media_config::Response> { | ||||
|     let upload_size = db.globals.max_request_size().into(); | ||||
|     Ok(get_media_config::Response { upload_size }.into()) | ||||
| } | ||||
| 
 | ||||
| #[post("/_matrix/media/r0/upload", data = "<body>")] | ||||
|  |  | |||
|  | @ -1,7 +1,7 @@ | |||
| use std::convert::TryInto; | ||||
| 
 | ||||
| use crate::{utils, Error, Result}; | ||||
| use ruma::ServerName; | ||||
| use std::convert::TryInto; | ||||
| 
 | ||||
| pub const COUNTER: &str = "c"; | ||||
| 
 | ||||
| pub struct Globals { | ||||
|  | @ -9,6 +9,7 @@ pub struct Globals { | |||
|     keypair: ruma::signatures::Ed25519KeyPair, | ||||
|     reqwest_client: reqwest::Client, | ||||
|     server_name: Box<ServerName>, | ||||
|     max_request_size: u32, | ||||
|     registration_disabled: bool, | ||||
|     encryption_disabled: bool, | ||||
| } | ||||
|  | @ -32,7 +33,12 @@ impl Globals { | |||
|                 .unwrap_or("localhost") | ||||
|                 .to_string() | ||||
|                 .try_into() | ||||
|                 .map_err(|_| Error::BadConfig("Invalid server name found."))?, | ||||
|                 .map_err(|_| Error::BadConfig("Invalid server_name."))?, | ||||
|             max_request_size: config | ||||
|                 .get_int("max_request_size") | ||||
|                 .unwrap_or(20 * 1024 * 1024) // Default to 20 MB
 | ||||
|                 .try_into() | ||||
|                 .map_err(|_| Error::BadConfig("Invalid max_request_size."))?, | ||||
|             registration_disabled: config.get_bool("registration_disabled").unwrap_or(false), | ||||
|             encryption_disabled: config.get_bool("encryption_disabled").unwrap_or(false), | ||||
|         }) | ||||
|  | @ -69,6 +75,10 @@ impl Globals { | |||
|         self.server_name.as_ref() | ||||
|     } | ||||
| 
 | ||||
|     pub fn max_request_size(&self) -> u32 { | ||||
|         self.max_request_size | ||||
|     } | ||||
| 
 | ||||
|     pub fn registration_disabled(&self) -> bool { | ||||
|         self.registration_disabled | ||||
|     } | ||||
|  |  | |||
|  | @ -11,8 +11,6 @@ use ruma::{api::Endpoint, DeviceId, UserId}; | |||
| use std::{convert::TryInto, io::Cursor, ops::Deref}; | ||||
| use tokio::io::AsyncReadExt; | ||||
| 
 | ||||
| const MESSAGE_LIMIT: u64 = 20 * 1024 * 1024; // 20 MB
 | ||||
| 
 | ||||
| /// This struct converts rocket requests into ruma structs by converting them into http requests
 | ||||
| /// first.
 | ||||
| pub struct Ruma<T> { | ||||
|  | @ -40,13 +38,12 @@ impl<'a, T: Endpoint> FromTransformedData<'a> for Ruma<T> { | |||
|     ) -> FromDataFuture<'a, Self, Self::Error> { | ||||
|         Box::pin(async move { | ||||
|             let data = rocket::try_outcome!(outcome.owned()); | ||||
|             let db = request | ||||
|                 .guard::<State<'_, crate::Database>>() | ||||
|                 .await | ||||
|                 .expect("database was loaded"); | ||||
| 
 | ||||
|             let (user_id, device_id) = if T::METADATA.requires_authentication { | ||||
|                 let db = request | ||||
|                     .guard::<State<'_, crate::Database>>() | ||||
|                     .await | ||||
|                     .expect("database was loaded"); | ||||
| 
 | ||||
|                 // Get token from header or query value
 | ||||
|                 let token = match request | ||||
|                     .headers() | ||||
|  | @ -76,7 +73,8 @@ impl<'a, T: Endpoint> FromTransformedData<'a> for Ruma<T> { | |||
|                 http_request = http_request.header(header.name.as_str(), &*header.value); | ||||
|             } | ||||
| 
 | ||||
|             let mut handle = data.open().take(MESSAGE_LIMIT); | ||||
|             let limit = db.globals.max_request_size(); | ||||
|             let mut handle = data.open().take(limit.into()); | ||||
|             let mut body = Vec::new(); | ||||
|             handle.read_to_end(&mut body).await.unwrap(); | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in a new issue
	
	 CapsizeGlimmer
						CapsizeGlimmer