From 533260edd84e18c1c773d43b6fd71d2e99032ac0 Mon Sep 17 00:00:00 2001 From: timokoesters Date: Sun, 29 Mar 2020 21:05:20 +0200 Subject: [PATCH] Add auth --- Cargo.lock | 26 ++++---- Cargo.toml | 7 +- src/data.rs | 110 ++++++++++++++++++++++++++++--- src/main.rs | 154 +++++++++++++++++++++++++++++++------------- src/ruma_wrapper.rs | 53 ++++++++------- src/utils.rs | 18 ++++++ 6 files changed, 274 insertions(+), 94 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b194092..1a1da0c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -497,6 +497,7 @@ dependencies = [ "ruma-client-api", "ruma-events", "ruma-identifiers", + "serde_json", "sled", ] @@ -807,9 +808,9 @@ dependencies = [ [[package]] name = "ruma-api" -version = "0.15.0-dev.1" +version = "0.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44987d5fefcf801a6fb5c5843c17f876a53852fa07e5e4d99e0dca3670f1441a" +checksum = "120f0cd8625b842423ef3a63cabb8c309ca35a02de87cc4b377fb2cdd43f1fe5" dependencies = [ "http", "percent-encoding 2.1.0", @@ -824,9 +825,9 @@ dependencies = [ [[package]] name = "ruma-api-macros" -version = "0.12.0-dev.1" +version = "0.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "36931db94874129f9202f650d91447d8317b099bae1e12cdd5769ba4eced07d2" +checksum = "bfc523efc9c1ba7033ff17888551c1d378e12eae087cfbe4fcee938ff516759e" dependencies = [ "proc-macro2 1.0.9", "quote 1.0.3", @@ -835,8 +836,9 @@ dependencies = [ [[package]] name = "ruma-client-api" -version = "0.6.0" -source = "git+https://github.com/ruma/ruma-client-api#57f5e8d66168a54128426c8e34b26fa78f739c3e" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a64241cdc0cff76038484451d7a5d2689f8ea4e59b6695cd3c8448af7bcc016" dependencies = [ "http", "js_int", @@ -851,9 +853,9 @@ dependencies = [ [[package]] name = "ruma-events" -version = "0.17.0" +version = "0.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "11951235b25c72a82eb988aabf5af23cae883562665e0cb73954ffe4ae81f11c" +checksum = "80e34bfc20462f18d7f0beb6f1863db62d29438f2dcf390b625e9b20696cb2b3" dependencies = [ "js_int", "ruma-events-macros", @@ -864,9 +866,9 @@ dependencies = [ [[package]] name = "ruma-events-macros" -version = "0.2.0" +version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "962d93056619ed61826a9d8872c863560e4892ff6a69b70f593baa5ae8b19dc8" +checksum = "ff95b6b4480c570db471b490b35ad70add5470651654e75faf0b97052b4f29e1" dependencies = [ "proc-macro2 1.0.9", "quote 1.0.3", @@ -994,9 +996,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.49" +version = "1.0.50" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02044a6a92866fd61624b3db4d2c9dccc2feabbc6be490b87611bf285edbac55" +checksum = "78a7a12c167809363ec3bd7329fc0a3369056996de43c4b37ef3cd54a6ce4867" dependencies = [ "itoa", "ryu", diff --git a/Cargo.toml b/Cargo.toml index b693378..3adca54 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,12 +9,13 @@ edition = "2018" [dependencies] rocket = { git = "https://github.com/SergioBenitez/Rocket.git", features = ["tls"] } http = "0.2.1" -ruma-client-api = { git = "https://github.com/ruma/ruma-client-api" } +ruma-client-api = "0.7.0" pretty_env_logger = "0.4.0" log = "0.4.8" sled = "0.31.0" directories = "2.0.2" ruma-identifiers = "0.14.1" -ruma-api = "0.15.0-dev.1" -ruma-events = "0.17.0" +ruma-api = "0.15.0" +ruma-events = "0.18.0" js_int = "0.1.3" +serde_json = "1.0.50" diff --git a/src/data.rs b/src/data.rs index de063ab..9a0a9c2 100644 --- a/src/data.rs +++ b/src/data.rs @@ -1,16 +1,18 @@ +use crate::utils; use directories::ProjectDirs; use ruma_events::collections::all::RoomEvent; use ruma_identifiers::UserId; +use std::convert::TryInto; + +const USERID_PASSWORD: &str = "userid_password"; +const USERID_DEVICEIDS: &str = "userid_deviceids"; +const DEVICEID_TOKEN: &str = "deviceid_token"; +const TOKEN_USERID: &str = "token_userid"; pub struct Data(sled::Db); impl Data { - pub fn set_hostname(&self, hostname: &str) { - self.0.insert("hostname", hostname).unwrap(); - } - pub fn hostname(&self) -> String { - String::from_utf8(self.0.get("hostname").unwrap().unwrap().to_vec()).unwrap() - } + /// Load an existing database or create a new one. pub fn load_or_create() -> Self { Data( sled::open( @@ -22,21 +24,109 @@ impl Data { ) } + /// Set the hostname of the server. Warning: Hostname changes will likely break things. + pub fn set_hostname(&self, hostname: &str) { + self.0.insert("hostname", hostname).unwrap(); + } + + /// Get the hostname of the server. + pub fn hostname(&self) -> String { + utils::bytes_to_string(&self.0.get("hostname").unwrap().unwrap()) + } + + /// Check if a user has an account by looking for an assigned password. pub fn user_exists(&self, user_id: &UserId) -> bool { self.0 - .open_tree("username_password") + .open_tree(USERID_PASSWORD) .unwrap() .contains_key(user_id.to_string()) .unwrap() } - pub fn user_add(&self, user_id: UserId, password: Option) { + /// Create a new user account by assigning them a password. + pub fn user_add(&self, user_id: &UserId, password: Option) { self.0 - .open_tree("username_password") + .open_tree(USERID_PASSWORD) .unwrap() .insert(user_id.to_string(), &*password.unwrap_or_default()) .unwrap(); } - pub fn room_event_add(&self, room_event: &RoomEvent) {} + /// Find out which user an access token belongs to. + pub fn user_from_token(&self, token: &str) -> Option { + self.0 + .open_tree(TOKEN_USERID) + .unwrap() + .get(token) + .unwrap() + .and_then(|bytes| (*utils::bytes_to_string(&bytes)).try_into().ok()) + } + + /// Checks if the given password is equal to the one in the database. + pub fn password_get(&self, user_id: &UserId) -> Option { + self.0 + .open_tree(USERID_PASSWORD) + .unwrap() + .get(user_id.to_string()) + .unwrap() + .map(|bytes| utils::bytes_to_string(&bytes)) + } + + /// Add a new device to a user. + pub fn device_add(&self, user_id: &UserId, device_id: &str) { + self.0 + .open_tree(USERID_DEVICEIDS) + .unwrap() + .insert(user_id.to_string(), device_id) + .unwrap(); + } + + /// Replace the access token of one device. + pub fn token_replace(&self, user_id: &UserId, device_id: &String, token: String) { + // Make sure the device id belongs to the user + debug_assert!(self + .0 + .open_tree(USERID_DEVICEIDS) + .unwrap() + .get(&user_id.to_string()) // Does the user exist? + .unwrap() + .map(|bytes| utils::bytes_to_vec(&bytes)) + .filter(|devices| devices.contains(device_id)) // Does the user have that device? + .is_some()); + + // Remove old token + if let Some(old_token) = self + .0 + .open_tree(DEVICEID_TOKEN) + .unwrap() + .get(device_id) + .unwrap() + { + self.0 + .open_tree(TOKEN_USERID) + .unwrap() + .remove(old_token) + .unwrap(); + // It will be removed from DEVICEID_TOKEN by the insert later + } + + // Assign token to device_id + self.0 + .open_tree(DEVICEID_TOKEN) + .unwrap() + .insert(device_id, &*token) + .unwrap(); + + // Assign token to user + self.0 + .open_tree(TOKEN_USERID) + .unwrap() + .insert(token, &*user_id.to_string()) + .unwrap(); + } + + /// Create a new room event. + pub fn room_event_add(&self, _room_event: &RoomEvent) { + todo!(); + } } diff --git a/src/main.rs b/src/main.rs index 0097109..7cb7c67 100644 --- a/src/main.rs +++ b/src/main.rs @@ -14,9 +14,10 @@ use ruma_client_api::{ }, unversioned::get_supported_versions, }; -use ruma_events::room::message::MessageEvent; +use ruma_events::{room::message::MessageEvent, EventResult}; use ruma_identifiers::{EventId, UserId}; use ruma_wrapper::{MatrixResult, Ruma}; +use serde_json::map::Map; use std::convert::TryFrom; use std::{collections::HashMap, convert::TryInto}; @@ -41,6 +42,7 @@ fn register_route( data: State, body: Ruma, ) -> MatrixResult { + // Validate user id let user_id: UserId = match (*format!( "@{}:{}", body.username.clone().unwrap_or("randomname".to_owned()), @@ -59,6 +61,7 @@ fn register_route( Ok(user_id) => user_id, }; + // Check if username is creative enough if data.user_exists(&user_id) { debug!("ID already taken"); return MatrixResult(Err(Error { @@ -68,68 +71,115 @@ fn register_route( })); } - data.user_add(user_id.clone(), body.password.clone()); + // Create user + data.user_add(&user_id, body.password.clone()); + + // Generate new device id if the user didn't specify one + let device_id = body + .device_id + .clone() + .unwrap_or_else(|| "TODO:randomdeviceid".to_owned()); + + // Add device + data.device_add(&user_id, &device_id); + + // Generate new token for the device + let token = "TODO:randomtoken".to_owned(); + data.token_replace(&user_id, &device_id, token.clone()); MatrixResult(Ok(register::Response { - access_token: "randomtoken".to_owned(), + access_token: token, home_server: data.hostname(), user_id, - device_id: body.device_id.clone().unwrap_or("randomid".to_owned()), + device_id, })) } #[post("/_matrix/client/r0/login", data = "")] fn login_route(data: State, body: Ruma) -> MatrixResult { - let username = if let login::UserInfo::MatrixId(mut username) = body.user.clone() { - if !username.contains(':') { - username = format!("@{}:{}", username, data.hostname()); - } - if let Ok(user_id) = (*username).try_into() { - if !data.user_exists(&user_id) { - debug!("Userid does not exist. Can't log in."); + // Validate login method + let user_id = + if let (login::UserInfo::MatrixId(mut username), login::LoginInfo::Password { password }) = + (body.user.clone(), body.login_info.clone()) + { + if !username.contains(':') { + username = format!("@{}:{}", username, data.hostname()); + } + if let Ok(user_id) = (*username).try_into() { + if !data.user_exists(&user_id) {} + + // Check password + if let Some(correct_password) = data.password_get(&user_id) { + if password == correct_password { + // Success! + user_id + } else { + debug!("Invalid password."); + return MatrixResult(Err(Error { + kind: ErrorKind::Unknown, + message: "".to_owned(), + status_code: http::StatusCode::FORBIDDEN, + })); + } + } else { + debug!("UserId does not exist (has no assigned password). Can't log in."); + return MatrixResult(Err(Error { + kind: ErrorKind::Forbidden, + message: "".to_owned(), + status_code: http::StatusCode::FORBIDDEN, + })); + } + } else { + debug!("Invalid UserId."); return MatrixResult(Err(Error { - kind: ErrorKind::Forbidden, - message: "UserId not found.".to_owned(), + kind: ErrorKind::Unknown, + message: "Bad login type.".to_owned(), status_code: http::StatusCode::BAD_REQUEST, })); } - user_id } else { - debug!("Invalid UserId."); + debug!("Bad login type"); return MatrixResult(Err(Error { kind: ErrorKind::Unknown, message: "Bad login type.".to_owned(), status_code: http::StatusCode::BAD_REQUEST, })); - } - } else { - debug!("Bad login type"); - return MatrixResult(Err(Error { - kind: ErrorKind::Unknown, - message: "Bad login type.".to_owned(), - status_code: http::StatusCode::BAD_REQUEST, - })); - }; + }; + + // Generate new device id if the user didn't specify one + let device_id = body + .device_id + .clone() + .unwrap_or("TODO:randomdeviceid".to_owned()); + + // Add device (TODO: We might not want to call it when using an existing device) + data.device_add(&user_id, &device_id); + + // Generate a new token for the device + let token = "TODO:randomtoken".to_owned(); + data.token_replace(&user_id, &device_id, token.clone()); return MatrixResult(Ok(login::Response { - user_id: username.try_into().unwrap(), // Unwrap is okay because the user is already registered - access_token: "randomtoken".to_owned(), - home_server: Some("localhost".to_owned()), - device_id: body.device_id.clone().unwrap_or("randomid".to_owned()), + user_id, + access_token: token, + home_server: Some(data.hostname()), + device_id, well_known: None, })); } #[get("/_matrix/client/r0/directory/room/")] fn get_alias_route(room_alias: String) -> MatrixResult { + // TODO let room_id = match &*room_alias { "#room:localhost" => "!xclkjvdlfj:localhost", _ => { + debug!("Room not found."); return MatrixResult(Err(Error { kind: ErrorKind::NotFound, message: "Room not found.".to_owned(), status_code: http::StatusCode::NOT_FOUND, - })) + })); } } .try_into() @@ -146,6 +196,7 @@ fn join_room_by_id_route( _room_id: String, body: Ruma, ) -> MatrixResult { + // TODO MatrixResult(Ok(join_room_by_id::Response { room_id: body.room_id.clone(), })) @@ -162,23 +213,34 @@ fn create_message_event_route( _txn_id: String, body: Ruma, ) -> MatrixResult { - dbg!(&body); - if let Ok(content) = body.data.clone().into_result() { - data.room_event_add( - &MessageEvent { - content, - event_id: EventId::try_from("$randomeventid:localhost").unwrap(), - origin_server_ts: utils::millis_since_unix_epoch(), - room_id: Some(body.room_id.clone()), - sender: UserId::try_from("@TODO:localhost").unwrap(), - unsigned: None, - } - .into(), - ); - } - MatrixResult(Ok(create_message_event::Response { - event_id: "$randomeventid:localhost".try_into().unwrap(), - })) + // Check if content is valid + let content = match body.data.clone() { + EventResult::Ok(content) => content, + EventResult::Err(_) => { + debug!("No content."); + return MatrixResult(Err(Error { + kind: ErrorKind::NotFound, + message: "No content.".to_owned(), + status_code: http::StatusCode::BAD_REQUEST, + })); + } + }; + + let event_id = EventId::try_from("$TODOrandomeventid:localhost").unwrap(); + + data.room_event_add( + &MessageEvent { + content, + event_id: event_id.clone(), + origin_server_ts: utils::millis_since_unix_epoch(), + room_id: Some(body.room_id.clone()), + sender: body.user_id.expect("user is authenticated"), + unsigned: Map::default(), + } + .into(), + ); + + MatrixResult(Ok(create_message_event::Response { event_id })) } fn main() { diff --git a/src/ruma_wrapper.rs b/src/ruma_wrapper.rs index 5b71925..0b42ceb 100644 --- a/src/ruma_wrapper.rs +++ b/src/ruma_wrapper.rs @@ -10,10 +10,10 @@ use { Endpoint, Outgoing, }, ruma_client_api::error::Error, + ruma_identifiers::UserId, std::ops::Deref, std::{ convert::{TryFrom, TryInto}, - fmt, io::{Cursor, Read}, }, }; @@ -22,9 +22,10 @@ const MESSAGE_LIMIT: u64 = 65535; /// This struct converts rocket requests into ruma structs by converting them into http requests /// first. +#[derive(Debug)] pub struct Ruma { body: T::Incoming, - headers: http::HeaderMap, + pub user_id: Option, } impl FromDataSimple for Ruma @@ -37,9 +38,34 @@ where Error = FromHttpResponseError<::ResponseError>, >, { - type Error = (); + type Error = (); // TODO: Better error handling fn from_data(request: &Request, data: rocket::Data) -> Outcome { + let user_id = if T::METADATA.requires_authentication { + let data = request.guard::>().unwrap(); + + // Get token from header or query value + let token = match request + .headers() + .get_one("Authorization") + .map(|s| s.to_owned()) + .or_else(|| request.get_query_value("access_token").and_then(|r| r.ok())) + { + // TODO: M_MISSING_TOKEN + None => return Failure((Status::Unauthorized, ())), + Some(token) => token, + }; + + // Check if token is valid + match data.user_from_token(&token) { + // TODO: M_UNKNOWN_TOKEN + None => return Failure((Status::Unauthorized, ())), + Some(user_id) => Some(user_id), + } + } else { + None + }; + let mut http_request = http::Request::builder() .uri(request.uri().to_string()) .method(&*request.method().to_string()); @@ -52,17 +78,10 @@ where handle.read_to_end(&mut body).unwrap(); let http_request = http_request.body(body).unwrap(); - let headers = http_request.headers().clone(); log::info!("{:?}", http_request); match T::Incoming::try_from(http_request) { - Ok(t) => { - if T::METADATA.requires_authentication { - let data = request.guard::>(); - // TODO: auth - } - Success(Ruma { body: t, headers }) - } + Ok(t) => Success(Ruma { body: t, user_id }), Err(e) => { log::error!("{:?}", e); Failure((Status::InternalServerError, ())) @@ -79,18 +98,6 @@ impl Deref for Ruma { } } -impl fmt::Debug for Ruma -where - T::Incoming: fmt::Debug, -{ - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Ruma") - .field("body", &self.body) - .field("headers", &self.headers) - .finish() - } -} - /// This struct converts ruma responses into rocket http responses. pub struct MatrixResult(pub std::result::Result); impl>>> TryInto>> for MatrixResult { diff --git a/src/utils.rs b/src/utils.rs index 2905088..fd7b4cb 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -7,3 +7,21 @@ pub fn millis_since_unix_epoch() -> js_int::UInt { .as_millis() as u32) .into() } + +pub fn bytes_to_string(bytes: &[u8]) -> String { + String::from_utf8(bytes.to_vec()).expect("convert bytes to string") +} + +pub fn vec_to_bytes(vec: Vec) -> Vec { + vec.into_iter() + .map(|string| string.into_bytes()) + .collect::>>() + .join(&0) +} + +pub fn bytes_to_vec(bytes: &[u8]) -> Vec { + bytes + .split(|&b| b == 0) + .map(|bytes_string| bytes_to_string(bytes_string)) + .collect::>() +}