diff --git a/src/database/globals.rs b/src/database/globals.rs index 8951425..8ce9c01 100644 --- a/src/database/globals.rs +++ b/src/database/globals.rs @@ -1,4 +1,5 @@ use crate::{utils, Error, Result}; +use log::error; use ruma::ServerName; use std::{convert::TryInto, sync::Arc}; @@ -17,19 +18,43 @@ pub struct Globals { impl Globals { pub fn load(globals: sled::Tree, config: &rocket::Config) -> Result { - let keypair = Arc::new( - ruma::signatures::Ed25519KeyPair::new( - &*globals - .update_and_fetch("keypair", utils::generate_keypair)? - .expect("utils::generate_keypair always returns Some"), - "key1".to_owned(), - ) - .map_err(|_| Error::bad_database("Private or public keys are invalid."))?, - ); + let bytes = &*globals + .update_and_fetch("keypair", utils::generate_keypair)? + .expect("utils::generate_keypair always returns Some"); + + let mut parts = bytes.splitn(2, |&b| b == 0xff); + + let keypair = utils::string_from_bytes( + // 1. version + parts + .next() + .expect("splitn always returns at least one element"), + ) + .map_err(|_| Error::bad_database("Invalid version bytes in keypair.")) + .and_then(|version| { + // 2. key + parts + .next() + .ok_or_else(|| Error::bad_database("Invalid keypair format in database.")) + .map(|key| (version, key)) + }) + .and_then(|(version, key)| { + ruma::signatures::Ed25519KeyPair::new(&key, version) + .map_err(|_| Error::bad_database("Private or public keys are invalid.")) + }); + + let keypair = match keypair { + Ok(k) => k, + Err(e) => { + error!("Keypair invalid. Deleting..."); + globals.remove("keypair")?; + return Err(e); + } + }; Ok(Self { globals, - keypair, + keypair: Arc::new(keypair), reqwest_client: reqwest::Client::new(), server_name: config .get_str("server_name") diff --git a/src/server_server.rs b/src/server_server.rs index 106f60e..f334d6b 100644 --- a/src/server_server.rs +++ b/src/server_server.rs @@ -17,7 +17,6 @@ use ruma::{ directory::{IncomingFilter, IncomingRoomNetwork}, EventId, ServerName, }; -use serde_json::json; use std::{ collections::BTreeMap, convert::TryFrom, @@ -58,7 +57,13 @@ where let actual_destination = "https://".to_owned() + &request_well_known(globals, &destination.as_str()) .await - .unwrap_or(destination.as_str().to_owned() + ":8448"); + .unwrap_or_else(|| { + let mut destination = destination.as_str().to_owned(); + if destination.find(':').is_none() { + destination += ":8448"; + } + destination + }); let mut http_request = request .try_into_http_request(&actual_destination, Some("")) diff --git a/src/utils.rs b/src/utils.rs index 8cf1b2c..452b7c5 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -29,8 +29,13 @@ pub fn increment(old: Option<&[u8]>) -> Option> { pub fn generate_keypair(old: Option<&[u8]>) -> Option> { Some(old.map(|s| s.to_vec()).unwrap_or_else(|| { - ruma::signatures::Ed25519KeyPair::generate() - .expect("Ed25519KeyPair generation always works (?)") + let mut value = random_string(8).as_bytes().to_vec(); + value.push(0xff); + value.extend_from_slice( + &ruma::signatures::Ed25519KeyPair::generate() + .expect("Ed25519KeyPair generation always works (?)"), + ); + value })) }