diff --git a/src/client_server.rs b/src/client_server.rs index 03be7bf..c190ef7 100644 --- a/src/client_server.rs +++ b/src/client_server.rs @@ -117,6 +117,14 @@ pub fn register_route( db: State<'_, Database>, body: Ruma, ) -> MatrixResult { + if db.globals.registration_disabled() { + return MatrixResult(Err(UiaaResponse::MatrixError(Error { + kind: ErrorKind::Unknown, + message: "Registration has been disabled.".to_owned(), + status_code: http::StatusCode::FORBIDDEN, + }))); + } + // Validate user id let user_id = match UserId::parse_with_server_name( body.username diff --git a/src/database.rs b/src/database.rs index 492f880..34af8fc 100644 --- a/src/database.rs +++ b/src/database.rs @@ -7,6 +7,7 @@ pub(self) mod uiaa; pub(self) mod users; use directories::ProjectDirs; +use log::info; use std::fs::remove_dir_all; use rocket::Config; @@ -49,13 +50,10 @@ impl Database { }); let db = sled::open(&path).unwrap(); - log::info!("Opened sled database at {}", path); + info!("Opened sled database at {}", path); Self { - globals: globals::Globals::load( - db.open_tree("global").unwrap(), - server_name.to_owned(), - ), + globals: globals::Globals::load(db.open_tree("global").unwrap(), config), users: users::Users { userid_password: db.open_tree("userid_password").unwrap(), userid_displayname: db.open_tree("userid_displayname").unwrap(), diff --git a/src/database/globals.rs b/src/database/globals.rs index 93d5794..08ab411 100644 --- a/src/database/globals.rs +++ b/src/database/globals.rs @@ -4,13 +4,14 @@ pub const COUNTER: &str = "c"; pub struct Globals { pub(super) globals: sled::Tree, - server_name: String, keypair: ruma::signatures::Ed25519KeyPair, reqwest_client: reqwest::Client, + server_name: String, + registration_disabled: bool, } impl Globals { - pub fn load(globals: sled::Tree, server_name: String) -> Self { + pub fn load(globals: sled::Tree, config: &rocket::Config) -> Self { let keypair = ruma::signatures::Ed25519KeyPair::new( &*globals .update_and_fetch("keypair", utils::generate_keypair) @@ -22,17 +23,16 @@ impl Globals { Self { globals, - server_name, keypair, reqwest_client: reqwest::Client::new(), + server_name: config + .get_str("server_name") + .unwrap_or("localhost") + .to_owned(), + registration_disabled: config.get_bool("registration_disabled").unwrap_or(false), } } - /// Returns the server_name of the server. - pub fn server_name(&self) -> &str { - &self.server_name - } - /// Returns this server's keypair. pub fn keypair(&self) -> &ruma::signatures::Ed25519KeyPair { &self.keypair @@ -58,4 +58,12 @@ impl Globals { .get(COUNTER)? .map_or(0_u64, |bytes| utils::u64_from_bytes(&bytes))) } + + pub fn server_name(&self) -> &str { + &self.server_name + } + + pub fn registration_disabled(&self) -> bool { + self.registration_disabled + } }