Merge pull request 'refactor: better error handling' (#116) from error-handling into master
This commit is contained in:
		
						commit
						4c98079c4c
					
				
					 16 changed files with 2155 additions and 1932 deletions
				
			
		
							
								
								
									
										2767
									
								
								src/client_server.rs
									
									
									
									
									
								
							
							
						
						
									
										2767
									
								
								src/client_server.rs
									
									
									
									
									
								
							
										
											
												File diff suppressed because it is too large
												Load diff
											
										
									
								
							|  | @ -6,6 +6,7 @@ pub(self) mod rooms; | |||
| pub(self) mod uiaa; | ||||
| pub(self) mod users; | ||||
| 
 | ||||
| use crate::{Error, Result}; | ||||
| use directories::ProjectDirs; | ||||
| use log::info; | ||||
| use std::fs::remove_dir_all; | ||||
|  | @ -25,84 +26,92 @@ pub struct Database { | |||
| 
 | ||||
| impl Database { | ||||
|     /// Tries to remove the old database but ignores all errors.
 | ||||
|     pub fn try_remove(server_name: &str) { | ||||
|     pub fn try_remove(server_name: &str) -> Result<()> { | ||||
|         let mut path = ProjectDirs::from("xyz", "koesters", "conduit") | ||||
|             .unwrap() | ||||
|             .ok_or(Error::BadConfig( | ||||
|                 "The OS didn't return a valid home directory path.", | ||||
|             ))? | ||||
|             .data_dir() | ||||
|             .to_path_buf(); | ||||
|         path.push(server_name); | ||||
|         let _ = remove_dir_all(path); | ||||
| 
 | ||||
|         Ok(()) | ||||
|     } | ||||
| 
 | ||||
|     /// Load an existing database or create a new one.
 | ||||
|     pub fn load_or_create(config: &Config) -> Self { | ||||
|     pub fn load_or_create(config: &Config) -> Result<Self> { | ||||
|         let server_name = config.get_str("server_name").unwrap_or("localhost"); | ||||
| 
 | ||||
|         let path = config | ||||
|             .get_str("database_path") | ||||
|             .map(|x| x.to_owned()) | ||||
|             .map(|x| Ok::<_, Error>(x.to_owned())) | ||||
|             .unwrap_or_else(|_| { | ||||
|                 let path = ProjectDirs::from("xyz", "koesters", "conduit") | ||||
|                     .unwrap() | ||||
|                     .ok_or(Error::BadConfig( | ||||
|                         "The OS didn't return a valid home directory path.", | ||||
|                     ))? | ||||
|                     .data_dir() | ||||
|                     .join(server_name); | ||||
|                 path.to_str().unwrap().to_owned() | ||||
|             }); | ||||
| 
 | ||||
|         let db = sled::open(&path).unwrap(); | ||||
|                 Ok(path | ||||
|                     .to_str() | ||||
|                     .ok_or(Error::BadConfig("Database path contains invalid unicode."))? | ||||
|                     .to_owned()) | ||||
|             })?; | ||||
| 
 | ||||
|         let db = sled::open(&path)?; | ||||
|         info!("Opened sled database at {}", path); | ||||
| 
 | ||||
|         Self { | ||||
|             globals: globals::Globals::load(db.open_tree("global").unwrap(), config), | ||||
|         Ok(Self { | ||||
|             globals: globals::Globals::load(db.open_tree("global")?, config)?, | ||||
|             users: users::Users { | ||||
|                 userid_password: db.open_tree("userid_password").unwrap(), | ||||
|                 userid_displayname: db.open_tree("userid_displayname").unwrap(), | ||||
|                 userid_avatarurl: db.open_tree("userid_avatarurl").unwrap(), | ||||
|                 userdeviceid_token: db.open_tree("userdeviceid_token").unwrap(), | ||||
|                 userdeviceid_metadata: db.open_tree("userdeviceid_metadata").unwrap(), | ||||
|                 token_userdeviceid: db.open_tree("token_userdeviceid").unwrap(), | ||||
|                 onetimekeyid_onetimekeys: db.open_tree("onetimekeyid_onetimekeys").unwrap(), | ||||
|                 userdeviceid_devicekeys: db.open_tree("userdeviceid_devicekeys").unwrap(), | ||||
|                 devicekeychangeid_userid: db.open_tree("devicekeychangeid_userid").unwrap(), | ||||
|                 todeviceid_events: db.open_tree("todeviceid_events").unwrap(), | ||||
|                 userid_password: db.open_tree("userid_password")?, | ||||
|                 userid_displayname: db.open_tree("userid_displayname")?, | ||||
|                 userid_avatarurl: db.open_tree("userid_avatarurl")?, | ||||
|                 userdeviceid_token: db.open_tree("userdeviceid_token")?, | ||||
|                 userdeviceid_metadata: db.open_tree("userdeviceid_metadata")?, | ||||
|                 token_userdeviceid: db.open_tree("token_userdeviceid")?, | ||||
|                 onetimekeyid_onetimekeys: db.open_tree("onetimekeyid_onetimekeys")?, | ||||
|                 userdeviceid_devicekeys: db.open_tree("userdeviceid_devicekeys")?, | ||||
|                 devicekeychangeid_userid: db.open_tree("devicekeychangeid_userid")?, | ||||
|                 todeviceid_events: db.open_tree("todeviceid_events")?, | ||||
|             }, | ||||
|             uiaa: uiaa::Uiaa { | ||||
|                 userdeviceid_uiaainfo: db.open_tree("userdeviceid_uiaainfo").unwrap(), | ||||
|                 userdeviceid_uiaainfo: db.open_tree("userdeviceid_uiaainfo")?, | ||||
|             }, | ||||
|             rooms: rooms::Rooms { | ||||
|                 edus: rooms::RoomEdus { | ||||
|                     roomuserid_lastread: db.open_tree("roomuserid_lastread").unwrap(), // "Private" read receipt
 | ||||
|                     roomlatestid_roomlatest: db.open_tree("roomlatestid_roomlatest").unwrap(), // Read receipts
 | ||||
|                     roomactiveid_userid: db.open_tree("roomactiveid_userid").unwrap(), // Typing notifs
 | ||||
|                     roomid_lastroomactiveupdate: db | ||||
|                         .open_tree("roomid_lastroomactiveupdate") | ||||
|                         .unwrap(), | ||||
|                     roomuserid_lastread: db.open_tree("roomuserid_lastread")?, // "Private" read receipt
 | ||||
|                     roomlatestid_roomlatest: db.open_tree("roomlatestid_roomlatest")?, // Read receipts
 | ||||
|                     roomactiveid_userid: db.open_tree("roomactiveid_userid")?, // Typing notifs
 | ||||
|                     roomid_lastroomactiveupdate: db.open_tree("roomid_lastroomactiveupdate")?, | ||||
|                 }, | ||||
|                 pduid_pdu: db.open_tree("pduid_pdu").unwrap(), | ||||
|                 eventid_pduid: db.open_tree("eventid_pduid").unwrap(), | ||||
|                 roomid_pduleaves: db.open_tree("roomid_pduleaves").unwrap(), | ||||
|                 roomstateid_pdu: db.open_tree("roomstateid_pdu").unwrap(), | ||||
|                 pduid_pdu: db.open_tree("pduid_pdu")?, | ||||
|                 eventid_pduid: db.open_tree("eventid_pduid")?, | ||||
|                 roomid_pduleaves: db.open_tree("roomid_pduleaves")?, | ||||
|                 roomstateid_pdu: db.open_tree("roomstateid_pdu")?, | ||||
| 
 | ||||
|                 alias_roomid: db.open_tree("alias_roomid").unwrap(), | ||||
|                 aliasid_alias: db.open_tree("alias_roomid").unwrap(), | ||||
|                 publicroomids: db.open_tree("publicroomids").unwrap(), | ||||
|                 alias_roomid: db.open_tree("alias_roomid")?, | ||||
|                 aliasid_alias: db.open_tree("alias_roomid")?, | ||||
|                 publicroomids: db.open_tree("publicroomids")?, | ||||
| 
 | ||||
|                 userroomid_joined: db.open_tree("userroomid_joined").unwrap(), | ||||
|                 roomuserid_joined: db.open_tree("roomuserid_joined").unwrap(), | ||||
|                 userroomid_invited: db.open_tree("userroomid_invited").unwrap(), | ||||
|                 roomuserid_invited: db.open_tree("roomuserid_invited").unwrap(), | ||||
|                 userroomid_left: db.open_tree("userroomid_left").unwrap(), | ||||
|                 userroomid_joined: db.open_tree("userroomid_joined")?, | ||||
|                 roomuserid_joined: db.open_tree("roomuserid_joined")?, | ||||
|                 userroomid_invited: db.open_tree("userroomid_invited")?, | ||||
|                 roomuserid_invited: db.open_tree("roomuserid_invited")?, | ||||
|                 userroomid_left: db.open_tree("userroomid_left")?, | ||||
|             }, | ||||
|             account_data: account_data::AccountData { | ||||
|                 roomuserdataid_accountdata: db.open_tree("roomuserdataid_accountdata").unwrap(), | ||||
|                 roomuserdataid_accountdata: db.open_tree("roomuserdataid_accountdata")?, | ||||
|             }, | ||||
|             global_edus: global_edus::GlobalEdus { | ||||
|                 presenceid_presence: db.open_tree("presenceid_presence").unwrap(), // Presence
 | ||||
|                 presenceid_presence: db.open_tree("presenceid_presence")?, // Presence
 | ||||
|             }, | ||||
|             media: media::Media { | ||||
|                 mediaid_file: db.open_tree("mediaid_file").unwrap(), | ||||
|                 mediaid_file: db.open_tree("mediaid_file")?, | ||||
|             }, | ||||
|             _db: db, | ||||
|         } | ||||
|         }) | ||||
|     } | ||||
| } | ||||
|  |  | |||
|  | @ -1,5 +1,6 @@ | |||
| use crate::{utils, Error, Result}; | ||||
| use ruma::{ | ||||
|     api::client::error::ErrorKind, | ||||
|     events::{collections::only::Event as EduEvent, EventJson, EventType}, | ||||
|     identifiers::{RoomId, UserId}, | ||||
| }; | ||||
|  | @ -20,7 +21,10 @@ impl AccountData { | |||
|         globals: &super::globals::Globals, | ||||
|     ) -> Result<()> { | ||||
|         if json.get("content").is_none() { | ||||
|             return Err(Error::BadRequest("json needs to have a content field")); | ||||
|             return Err(Error::BadRequest( | ||||
|                 ErrorKind::BadJson, | ||||
|                 "Json needs to have a content field.", | ||||
|             )); | ||||
|         } | ||||
|         json.insert("type".to_owned(), kind.to_string().into()); | ||||
| 
 | ||||
|  | @ -62,9 +66,10 @@ impl AccountData { | |||
|         key.push(0xff); | ||||
|         key.extend_from_slice(kind.to_string().as_bytes()); | ||||
| 
 | ||||
|         self.roomuserdataid_accountdata | ||||
|             .insert(key, &*serde_json::to_string(&json)?) | ||||
|             .unwrap(); | ||||
|         self.roomuserdataid_accountdata.insert( | ||||
|             key, | ||||
|             &*serde_json::to_string(&json).expect("Map::to_string always works"), | ||||
|         )?; | ||||
| 
 | ||||
|         Ok(()) | ||||
|     } | ||||
|  | @ -109,17 +114,20 @@ impl AccountData { | |||
|             .take_while(move |(k, _)| k.starts_with(&prefix)) | ||||
|             .map(|(k, v)| { | ||||
|                 Ok::<_, Error>(( | ||||
|                     EventType::try_from(utils::string_from_bytes( | ||||
|                         k.rsplit(|&b| b == 0xff) | ||||
|                             .next() | ||||
|                             .ok_or(Error::BadDatabase("roomuserdataid is invalid"))?, | ||||
|                     )?) | ||||
|                     .map_err(|_| Error::BadDatabase("roomuserdataid is invalid"))?, | ||||
|                     serde_json::from_slice::<EventJson<EduEvent>>(&v).unwrap(), | ||||
|                     EventType::try_from( | ||||
|                         utils::string_from_bytes(k.rsplit(|&b| b == 0xff).next().ok_or_else( | ||||
|                             || Error::bad_database("RoomUserData ID in db is invalid."), | ||||
|                         )?) | ||||
|                         .map_err(|_| Error::bad_database("RoomUserData ID in db is invalid."))?, | ||||
|                     ) | ||||
|                     .map_err(|_| Error::bad_database("RoomUserData ID in db is invalid."))?, | ||||
|                     serde_json::from_slice::<EventJson<EduEvent>>(&v).map_err(|_| { | ||||
|                         Error::bad_database("Database contains invalid account data.") | ||||
|                     })?, | ||||
|                 )) | ||||
|             }) | ||||
|         { | ||||
|             let (kind, data) = r.unwrap(); | ||||
|             let (kind, data) = r?; | ||||
|             userdata.insert(kind, data); | ||||
|         } | ||||
| 
 | ||||
|  |  | |||
|  | @ -1,4 +1,4 @@ | |||
| use crate::Result; | ||||
| use crate::{Error, Result}; | ||||
| use ruma::events::EventJson; | ||||
| 
 | ||||
| pub struct GlobalEdus { | ||||
|  | @ -21,7 +21,10 @@ impl GlobalEdus { | |||
|             .rev() | ||||
|             .filter_map(|r| r.ok()) | ||||
|             .find(|key| { | ||||
|                 key.rsplit(|&b| b == 0xff).next().unwrap() == presence.sender.to_string().as_bytes() | ||||
|                 key.rsplit(|&b| b == 0xff) | ||||
|                     .next() | ||||
|                     .expect("rsplit always returns an element") | ||||
|                     == presence.sender.to_string().as_bytes() | ||||
|             }) | ||||
|         { | ||||
|             // This is the old global_latest
 | ||||
|  | @ -32,8 +35,10 @@ impl GlobalEdus { | |||
|         presence_id.push(0xff); | ||||
|         presence_id.extend_from_slice(&presence.sender.to_string().as_bytes()); | ||||
| 
 | ||||
|         self.presenceid_presence | ||||
|             .insert(presence_id, &*serde_json::to_string(&presence)?)?; | ||||
|         self.presenceid_presence.insert( | ||||
|             presence_id, | ||||
|             &*serde_json::to_string(&presence).expect("PresenceEvent can be serialized"), | ||||
|         )?; | ||||
| 
 | ||||
|         Ok(()) | ||||
|     } | ||||
|  | @ -50,6 +55,9 @@ impl GlobalEdus { | |||
|             .presenceid_presence | ||||
|             .range(&*first_possible_edu..) | ||||
|             .filter_map(|r| r.ok()) | ||||
|             .map(|(_, v)| Ok(serde_json::from_slice(&v)?))) | ||||
|             .map(|(_, v)| { | ||||
|                 Ok(serde_json::from_slice(&v) | ||||
|                     .map_err(|_| Error::bad_database("Invalid presence event in db."))?) | ||||
|             })) | ||||
|     } | ||||
| } | ||||
|  |  | |||
|  | @ -1,4 +1,4 @@ | |||
| use crate::{utils, Result}; | ||||
| use crate::{utils, Error, Result}; | ||||
| 
 | ||||
| pub const COUNTER: &str = "c"; | ||||
| 
 | ||||
|  | @ -11,17 +11,16 @@ pub struct Globals { | |||
| } | ||||
| 
 | ||||
| impl Globals { | ||||
|     pub fn load(globals: sled::Tree, config: &rocket::Config) -> Self { | ||||
|     pub fn load(globals: sled::Tree, config: &rocket::Config) -> Result<Self> { | ||||
|         let keypair = ruma::signatures::Ed25519KeyPair::new( | ||||
|             &*globals | ||||
|                 .update_and_fetch("keypair", utils::generate_keypair) | ||||
|                 .unwrap() | ||||
|                 .unwrap(), | ||||
|                 .update_and_fetch("keypair", utils::generate_keypair)? | ||||
|                 .expect("utils::generate_keypair always returns Some"), | ||||
|             "key1".to_owned(), | ||||
|         ) | ||||
|         .unwrap(); | ||||
|         .map_err(|_| Error::bad_database("Private or public keys are invalid."))?; | ||||
| 
 | ||||
|         Self { | ||||
|         Ok(Self { | ||||
|             globals, | ||||
|             keypair, | ||||
|             reqwest_client: reqwest::Client::new(), | ||||
|  | @ -30,7 +29,7 @@ impl Globals { | |||
|                 .unwrap_or("localhost") | ||||
|                 .to_owned(), | ||||
|             registration_disabled: config.get_bool("registration_disabled").unwrap_or(false), | ||||
|         } | ||||
|         }) | ||||
|     } | ||||
| 
 | ||||
|     /// Returns this server's keypair.
 | ||||
|  | @ -49,14 +48,15 @@ impl Globals { | |||
|                 .globals | ||||
|                 .update_and_fetch(COUNTER, utils::increment)? | ||||
|                 .expect("utils::increment will always put in a value"), | ||||
|         )) | ||||
|         ) | ||||
|         .map_err(|_| Error::bad_database("Count has invalid bytes."))?) | ||||
|     } | ||||
| 
 | ||||
|     pub fn current_count(&self) -> Result<u64> { | ||||
|         Ok(self | ||||
|             .globals | ||||
|             .get(COUNTER)? | ||||
|             .map_or(0_u64, |bytes| utils::u64_from_bytes(&bytes))) | ||||
|         self.globals.get(COUNTER)?.map_or(Ok(0_u64), |bytes| { | ||||
|             Ok(utils::u64_from_bytes(&bytes) | ||||
|                 .map_err(|_| Error::bad_database("Count has invalid bytes."))?) | ||||
|         }) | ||||
|     } | ||||
| 
 | ||||
|     pub fn server_name(&self) -> &str { | ||||
|  |  | |||
|  | @ -43,16 +43,20 @@ impl Media { | |||
|             let content_type = utils::string_from_bytes( | ||||
|                 parts | ||||
|                     .next() | ||||
|                     .ok_or(Error::BadDatabase("mediaid is invalid"))?, | ||||
|             )?; | ||||
|                     .ok_or_else(|| Error::bad_database("Media ID in db is invalid."))?, | ||||
|             ) | ||||
|             .map_err(|_| Error::bad_database("Content type in mediaid_file is invalid unicode."))?; | ||||
| 
 | ||||
|             let filename_bytes = parts | ||||
|                 .next() | ||||
|                 .ok_or(Error::BadDatabase("mediaid is invalid"))?; | ||||
|                 .ok_or_else(|| Error::bad_database("Media ID in db is invalid."))?; | ||||
| 
 | ||||
|             let filename = if filename_bytes.is_empty() { | ||||
|                 None | ||||
|             } else { | ||||
|                 Some(utils::string_from_bytes(filename_bytes)?) | ||||
|                 Some(utils::string_from_bytes(filename_bytes).map_err(|_| { | ||||
|                     Error::bad_database("Filename in mediaid_file is invalid unicode.") | ||||
|                 })?) | ||||
|             }; | ||||
| 
 | ||||
|             Ok(Some((filename, content_type, file.to_vec()))) | ||||
|  | @ -89,16 +93,21 @@ impl Media { | |||
|             let content_type = utils::string_from_bytes( | ||||
|                 parts | ||||
|                     .next() | ||||
|                     .ok_or(Error::BadDatabase("mediaid is invalid"))?, | ||||
|             )?; | ||||
|                     .ok_or_else(|| Error::bad_database("Invalid Media ID in db"))?, | ||||
|             ) | ||||
|             .map_err(|_| Error::bad_database("Content type in mediaid_file is invalid unicode."))?; | ||||
| 
 | ||||
|             let filename_bytes = parts | ||||
|                 .next() | ||||
|                 .ok_or(Error::BadDatabase("mediaid is invalid"))?; | ||||
|                 .ok_or_else(|| Error::bad_database("Media ID in db is invalid."))?; | ||||
| 
 | ||||
|             let filename = if filename_bytes.is_empty() { | ||||
|                 None | ||||
|             } else { | ||||
|                 Some(utils::string_from_bytes(filename_bytes)?) | ||||
|                 Some( | ||||
|                     utils::string_from_bytes(filename_bytes) | ||||
|                         .map_err(|_| Error::bad_database("Filename in db is invalid."))?, | ||||
|                 ) | ||||
|             }; | ||||
| 
 | ||||
|             Ok(Some((filename, content_type, file.to_vec()))) | ||||
|  | @ -110,16 +119,20 @@ impl Media { | |||
|             let content_type = utils::string_from_bytes( | ||||
|                 parts | ||||
|                     .next() | ||||
|                     .ok_or(Error::BadDatabase("mediaid is invalid"))?, | ||||
|             )?; | ||||
|                     .ok_or_else(|| Error::bad_database("Media ID in db is invalid."))?, | ||||
|             ) | ||||
|             .map_err(|_| Error::bad_database("Content type in mediaid_file is invalid unicode."))?; | ||||
| 
 | ||||
|             let filename_bytes = parts | ||||
|                 .next() | ||||
|                 .ok_or(Error::BadDatabase("mediaid is invalid"))?; | ||||
|                 .ok_or_else(|| Error::bad_database("Media ID in db is invalid."))?; | ||||
| 
 | ||||
|             let filename = if filename_bytes.is_empty() { | ||||
|                 None | ||||
|             } else { | ||||
|                 Some(utils::string_from_bytes(filename_bytes)?) | ||||
|                 Some(utils::string_from_bytes(filename_bytes).map_err(|_| { | ||||
|                     Error::bad_database("Filename in mediaid_file is invalid unicode.") | ||||
|                 })?) | ||||
|             }; | ||||
| 
 | ||||
|             if let Ok(image) = image::load_from_memory(&file) { | ||||
|  | @ -132,7 +145,7 @@ impl Media { | |||
|                 let width_index = thumbnail_key | ||||
|                     .iter() | ||||
|                     .position(|&b| b == 0xff) | ||||
|                     .ok_or(Error::BadDatabase("mediaid is invalid"))? | ||||
|                     .ok_or_else(|| Error::bad_database("Media in db is invalid."))? | ||||
|                     + 1; | ||||
|                 let mut widthheight = width.to_be_bytes().to_vec(); | ||||
|                 widthheight.extend_from_slice(&height.to_be_bytes()); | ||||
|  |  | |||
|  | @ -5,6 +5,7 @@ pub use edus::RoomEdus; | |||
| use crate::{utils, Error, PduEvent, Result}; | ||||
| use log::error; | ||||
| use ruma::{ | ||||
|     api::client::error::ErrorKind, | ||||
|     events::{ | ||||
|         room::{ | ||||
|             join_rules, member, | ||||
|  | @ -61,30 +62,34 @@ impl Rooms { | |||
|             .roomstateid_pdu | ||||
|             .scan_prefix(&room_id.to_string().as_bytes()) | ||||
|             .values() | ||||
|             .map(|value| Ok::<_, Error>(serde_json::from_slice::<PduEvent>(&value?)?)) | ||||
|             .map(|value| { | ||||
|                 Ok::<_, Error>( | ||||
|                     serde_json::from_slice::<PduEvent>(&value?) | ||||
|                         .map_err(|_| Error::bad_database("Invalid PDU in db."))?, | ||||
|                 ) | ||||
|             }) | ||||
|         { | ||||
|             let pdu = pdu?; | ||||
|             hashmap.insert( | ||||
|                 ( | ||||
|                     pdu.kind.clone(), | ||||
|                     pdu.state_key | ||||
|                         .clone() | ||||
|                         .expect("state events have a state key"), | ||||
|                 ), | ||||
|                 pdu, | ||||
|             ); | ||||
|             let state_key = pdu.state_key.clone().ok_or_else(|| { | ||||
|                 Error::bad_database("Room state contains event without state_key.") | ||||
|             })?; | ||||
|             hashmap.insert((pdu.kind.clone(), state_key), pdu); | ||||
|         } | ||||
|         Ok(hashmap) | ||||
|     } | ||||
| 
 | ||||
|     /// Returns the `count` of this pdu's id.
 | ||||
|     pub fn get_pdu_count(&self, event_id: &EventId) -> Result<Option<u64>> { | ||||
|         Ok(self | ||||
|             .eventid_pduid | ||||
|         self.eventid_pduid | ||||
|             .get(event_id.to_string().as_bytes())? | ||||
|             .map(|pdu_id| { | ||||
|                 utils::u64_from_bytes(&pdu_id[pdu_id.len() - mem::size_of::<u64>()..pdu_id.len()]) | ||||
|             })) | ||||
|             .map_or(Ok(None), |pdu_id| { | ||||
|                 Ok(Some( | ||||
|                     utils::u64_from_bytes( | ||||
|                         &pdu_id[pdu_id.len() - mem::size_of::<u64>()..pdu_id.len()], | ||||
|                     ) | ||||
|                     .map_err(|_| Error::bad_database("PDU has invalid count bytes."))?, | ||||
|                 )) | ||||
|             }) | ||||
|     } | ||||
| 
 | ||||
|     /// Returns the json of a pdu.
 | ||||
|  | @ -92,11 +97,12 @@ impl Rooms { | |||
|         self.eventid_pduid | ||||
|             .get(event_id.to_string().as_bytes())? | ||||
|             .map_or(Ok(None), |pdu_id| { | ||||
|                 Ok(Some(serde_json::from_slice( | ||||
|                     &self.pduid_pdu.get(pdu_id)?.ok_or(Error::BadDatabase( | ||||
|                         "eventid_pduid points to nonexistent pdu", | ||||
|                     ))?, | ||||
|                 )?)) | ||||
|                 Ok(Some( | ||||
|                     serde_json::from_slice(&self.pduid_pdu.get(pdu_id)?.ok_or_else(|| { | ||||
|                         Error::bad_database("eventid_pduid points to nonexistent pdu.") | ||||
|                     })?) | ||||
|                     .map_err(|_| Error::bad_database("Invalid PDU in db."))?, | ||||
|                 )) | ||||
|             }) | ||||
|     } | ||||
| 
 | ||||
|  | @ -112,28 +118,37 @@ impl Rooms { | |||
|         self.eventid_pduid | ||||
|             .get(event_id.to_string().as_bytes())? | ||||
|             .map_or(Ok(None), |pdu_id| { | ||||
|                 Ok(Some(serde_json::from_slice( | ||||
|                     &self.pduid_pdu.get(pdu_id)?.ok_or(Error::BadDatabase( | ||||
|                         "eventid_pduid points to nonexistent pdu", | ||||
|                     ))?, | ||||
|                 )?)) | ||||
|                 Ok(Some( | ||||
|                     serde_json::from_slice(&self.pduid_pdu.get(pdu_id)?.ok_or_else(|| { | ||||
|                         Error::bad_database("eventid_pduid points to nonexistent pdu.") | ||||
|                     })?) | ||||
|                     .map_err(|_| Error::bad_database("Invalid PDU in db."))?, | ||||
|                 )) | ||||
|             }) | ||||
|     } | ||||
|     /// Returns the pdu.
 | ||||
|     pub fn get_pdu_from_id(&self, pdu_id: &IVec) -> Result<Option<PduEvent>> { | ||||
|         self.pduid_pdu | ||||
|             .get(pdu_id)? | ||||
|             .map_or(Ok(None), |pdu| Ok(Some(serde_json::from_slice(&pdu)?))) | ||||
|         self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| { | ||||
|             Ok(Some( | ||||
|                 serde_json::from_slice(&pdu) | ||||
|                     .map_err(|_| Error::bad_database("Invalid PDU in db."))?, | ||||
|             )) | ||||
|         }) | ||||
|     } | ||||
| 
 | ||||
|     /// Returns the pdu.
 | ||||
|     pub fn replace_pdu(&self, pdu_id: &IVec, pdu: &PduEvent) -> Result<()> { | ||||
|     /// Removes a pdu and creates a new one with the same id.
 | ||||
|     fn replace_pdu(&self, pdu_id: &IVec, pdu: &PduEvent) -> Result<()> { | ||||
|         if self.pduid_pdu.get(&pdu_id)?.is_some() { | ||||
|             self.pduid_pdu | ||||
|                 .insert(&pdu_id, &*serde_json::to_string(pdu)?)?; | ||||
|             self.pduid_pdu.insert( | ||||
|                 &pdu_id, | ||||
|                 &*serde_json::to_string(pdu).expect("PduEvent::to_string always works"), | ||||
|             )?; | ||||
|             Ok(()) | ||||
|         } else { | ||||
|             Err(Error::BadRequest("pdu does not exist")) | ||||
|             Err(Error::BadRequest( | ||||
|                 ErrorKind::NotFound, | ||||
|                 "PDU does not exist.", | ||||
|             )) | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|  | @ -148,7 +163,14 @@ impl Rooms { | |||
|             .roomid_pduleaves | ||||
|             .scan_prefix(prefix) | ||||
|             .values() | ||||
|             .map(|bytes| Ok::<_, Error>(EventId::try_from(&*utils::string_from_bytes(&bytes?)?)?)) | ||||
|             .map(|bytes| { | ||||
|                 Ok::<_, Error>( | ||||
|                     EventId::try_from(utils::string_from_bytes(&bytes?).map_err(|_| { | ||||
|                         Error::bad_database("EventID in roomid_pduleaves is invalid unicode.") | ||||
|                     })?) | ||||
|                     .map_err(|_| Error::bad_database("EventId in roomid_pduleaves is invalid."))?, | ||||
|                 ) | ||||
|             }) | ||||
|         { | ||||
|             events.push(event?); | ||||
|         } | ||||
|  | @ -214,174 +236,205 @@ impl Rooms { | |||
|                         Ok( | ||||
|                             serde_json::from_value::<EventJson<PowerLevelsEventContent>>( | ||||
|                                 power_levels.content.clone(), | ||||
|                             )? | ||||
|                             .deserialize()?, | ||||
|                             ) | ||||
|                             .expect("EventJson::from_value always works.") | ||||
|                             .deserialize() | ||||
|                             .map_err(|_| Error::bad_database("Invalid PowerLevels event in db."))?, | ||||
|                         ) | ||||
|                     }, | ||||
|                 )?; | ||||
|             { | ||||
|                 let sender_membership = self | ||||
|                     .room_state(&room_id)? | ||||
|                     .get(&(EventType::RoomMember, sender.to_string())) | ||||
|                     .map_or(Ok::<_, Error>(member::MembershipState::Leave), |pdu| { | ||||
|                         Ok( | ||||
|                             serde_json::from_value::<EventJson<member::MemberEventContent>>( | ||||
|                                 pdu.content.clone(), | ||||
|                             )? | ||||
|                             .deserialize()? | ||||
|                             .membership, | ||||
|             let sender_membership = self | ||||
|                 .room_state(&room_id)? | ||||
|                 .get(&(EventType::RoomMember, sender.to_string())) | ||||
|                 .map_or(Ok::<_, Error>(member::MembershipState::Leave), |pdu| { | ||||
|                     Ok( | ||||
|                         serde_json::from_value::<EventJson<member::MemberEventContent>>( | ||||
|                             pdu.content.clone(), | ||||
|                         ) | ||||
|                         .expect("EventJson::from_value always works.") | ||||
|                         .deserialize() | ||||
|                         .map_err(|_| Error::bad_database("Invalid Member event in db."))? | ||||
|                         .membership, | ||||
|                     ) | ||||
|                 })?; | ||||
| 
 | ||||
|             let sender_power = power_levels.users.get(&sender).map_or_else( | ||||
|                 || { | ||||
|                     if sender_membership != member::MembershipState::Join { | ||||
|                         None | ||||
|                     } else { | ||||
|                         Some(&power_levels.users_default) | ||||
|                     } | ||||
|                 }, | ||||
|                 // If it's okay, wrap with Some(_)
 | ||||
|                 Some, | ||||
|             ); | ||||
| 
 | ||||
|             if !match event_type { | ||||
|                 EventType::RoomMember => { | ||||
|                     let target_user_id = UserId::try_from(&**state_key).map_err(|_| { | ||||
|                         Error::BadRequest( | ||||
|                             ErrorKind::InvalidParam, | ||||
|                             "State key of member event does not contain user id.", | ||||
|                         ) | ||||
|                     })?; | ||||
| 
 | ||||
|                 let sender_power = power_levels.users.get(&sender).map_or_else( | ||||
|                     || { | ||||
|                         if sender_membership != member::MembershipState::Join { | ||||
|                             None | ||||
|                         } else { | ||||
|                             Some(&power_levels.users_default) | ||||
|                         } | ||||
|                     }, | ||||
|                     // If it's okay, wrap with Some(_)
 | ||||
|                     Some, | ||||
|                 ); | ||||
|                     let current_membership = self | ||||
|                         .room_state(&room_id)? | ||||
|                         .get(&(EventType::RoomMember, target_user_id.to_string())) | ||||
|                         .map_or(Ok::<_, Error>(member::MembershipState::Leave), |pdu| { | ||||
|                             Ok( | ||||
|                                 serde_json::from_value::<EventJson<member::MemberEventContent>>( | ||||
|                                     pdu.content.clone(), | ||||
|                                 ) | ||||
|                                 .expect("EventJson::from_value always works.") | ||||
|                                 .deserialize() | ||||
|                                 .map_err(|_| Error::bad_database("Invalid Member event in db."))? | ||||
|                                 .membership, | ||||
|                             ) | ||||
|                         })?; | ||||
| 
 | ||||
|                 if !match event_type { | ||||
|                     EventType::RoomMember => { | ||||
|                         let target_user_id = UserId::try_from(&**state_key)?; | ||||
|                     let target_membership = serde_json::from_value::< | ||||
|                         EventJson<member::MemberEventContent>, | ||||
|                     >(content.clone()) | ||||
|                     .expect("EventJson::from_value always works.") | ||||
|                     .deserialize() | ||||
|                     .map_err(|_| Error::bad_database("Invalid Member event in db."))? | ||||
|                     .membership; | ||||
| 
 | ||||
|                         let current_membership = self | ||||
|                             .room_state(&room_id)? | ||||
|                             .get(&(EventType::RoomMember, target_user_id.to_string())) | ||||
|                             .map_or(Ok::<_, Error>(member::MembershipState::Leave), |pdu| { | ||||
|                                 Ok(serde_json::from_value::< | ||||
|                                         EventJson<member::MemberEventContent>, | ||||
|                                     >(pdu.content.clone())? | ||||
|                                     .deserialize()? | ||||
|                                     .membership) | ||||
|                             })?; | ||||
|                     let target_power = power_levels.users.get(&target_user_id).map_or_else( | ||||
|                         || { | ||||
|                             if target_membership != member::MembershipState::Join { | ||||
|                                 None | ||||
|                             } else { | ||||
|                                 Some(&power_levels.users_default) | ||||
|                             } | ||||
|                         }, | ||||
|                         // If it's okay, wrap with Some(_)
 | ||||
|                         Some, | ||||
|                     ); | ||||
| 
 | ||||
|                         let target_membership = serde_json::from_value::< | ||||
|                             EventJson<member::MemberEventContent>, | ||||
|                         >(content.clone())? | ||||
|                         .deserialize()? | ||||
|                         .membership; | ||||
| 
 | ||||
|                         let target_power = power_levels.users.get(&target_user_id).map_or_else( | ||||
|                             || { | ||||
|                                 if target_membership != member::MembershipState::Join { | ||||
|                                     None | ||||
|                                 } else { | ||||
|                                     Some(&power_levels.users_default) | ||||
|                                 } | ||||
|                             }, | ||||
|                             // If it's okay, wrap with Some(_)
 | ||||
|                             Some, | ||||
|                         ); | ||||
| 
 | ||||
|                         let join_rules = self | ||||
|                             .room_state(&room_id)? | ||||
|                     let join_rules = | ||||
|                         self.room_state(&room_id)? | ||||
|                             .get(&(EventType::RoomJoinRules, "".to_owned())) | ||||
|                             .map_or(join_rules::JoinRule::Public, |pdu| { | ||||
|                                 serde_json::from_value::< | ||||
|                             .map_or(Ok::<_, Error>(join_rules::JoinRule::Public), |pdu| { | ||||
|                                 Ok(serde_json::from_value::< | ||||
|                                     EventJson<join_rules::JoinRulesEventContent>, | ||||
|                                 >(pdu.content.clone()) | ||||
|                                 .unwrap() | ||||
|                                 .expect("EventJson::from_value always works.") | ||||
|                                 .deserialize() | ||||
|                                 .unwrap() | ||||
|                                 .join_rule | ||||
|                             }); | ||||
|                                 .map_err(|_| { | ||||
|                                     Error::bad_database("Database contains invalid JoinRules event") | ||||
|                                 })? | ||||
|                                 .join_rule) | ||||
|                             })?; | ||||
| 
 | ||||
|                         let authorized = if target_membership == member::MembershipState::Join { | ||||
|                             let mut prev_events = prev_events.iter(); | ||||
|                             let prev_event = self | ||||
|                                 .get_pdu(prev_events.next().ok_or(Error::BadRequest( | ||||
|                                     "membership can't be the first event", | ||||
|                                 ))?)? | ||||
|                                 .ok_or(Error::BadDatabase("pdu leave points to valid event"))?; | ||||
|                             if prev_event.kind == EventType::RoomCreate | ||||
|                                 && prev_event.prev_events.is_empty() | ||||
|                             { | ||||
|                                 true | ||||
|                             } else if sender != target_user_id { | ||||
|                                 false | ||||
|                             } else if let member::MembershipState::Ban = current_membership { | ||||
|                                 false | ||||
|                             } else { | ||||
|                                 join_rules == join_rules::JoinRule::Invite | ||||
|                                     && (current_membership == member::MembershipState::Join | ||||
|                                         || current_membership == member::MembershipState::Invite) | ||||
|                                     || join_rules == join_rules::JoinRule::Public | ||||
|                             } | ||||
|                         } else if target_membership == member::MembershipState::Invite { | ||||
|                             if let Some(third_party_invite_json) = content.get("third_party_invite") | ||||
|                             { | ||||
|                                 if current_membership == member::MembershipState::Ban { | ||||
|                                     false | ||||
|                                 } else { | ||||
|                                     let _third_party_invite = | ||||
|                                         serde_json::from_value::<member::ThirdPartyInvite>( | ||||
|                                             third_party_invite_json.clone(), | ||||
|                                         )?; | ||||
|                                     todo!("handle third party invites"); | ||||
|                                 } | ||||
|                             } else if sender_membership != member::MembershipState::Join | ||||
|                                 || current_membership == member::MembershipState::Join | ||||
|                                 || current_membership == member::MembershipState::Ban | ||||
|                             { | ||||
|                                 false | ||||
|                             } else { | ||||
|                                 sender_power | ||||
|                                     .filter(|&p| p >= &power_levels.invite) | ||||
|                                     .is_some() | ||||
|                             } | ||||
|                         } else if target_membership == member::MembershipState::Leave { | ||||
|                             if sender == target_user_id { | ||||
|                                 current_membership == member::MembershipState::Join | ||||
|                                     || current_membership == member::MembershipState::Invite | ||||
|                             } else if sender_membership != member::MembershipState::Join | ||||
|                                 || current_membership == member::MembershipState::Ban | ||||
|                                     && sender_power.filter(|&p| p < &power_levels.ban).is_some() | ||||
|                             { | ||||
|                                 false | ||||
|                             } else { | ||||
|                                 sender_power.filter(|&p| p >= &power_levels.kick).is_some() | ||||
|                                     && target_power < sender_power | ||||
|                             } | ||||
|                         } else if target_membership == member::MembershipState::Ban { | ||||
|                             if sender_membership != member::MembershipState::Join { | ||||
|                                 false | ||||
|                             } else { | ||||
|                                 sender_power.filter(|&p| p >= &power_levels.ban).is_some() | ||||
|                                     && target_power < sender_power | ||||
|                             } | ||||
|                         } else { | ||||
|                     let authorized = if target_membership == member::MembershipState::Join { | ||||
|                         let mut prev_events = prev_events.iter(); | ||||
|                         let prev_event = self | ||||
|                             .get_pdu(prev_events.next().ok_or(Error::BadRequest( | ||||
|                                 ErrorKind::Unknown, | ||||
|                                 "Membership can't be the first event", | ||||
|                             ))?)? | ||||
|                             .ok_or_else(|| { | ||||
|                                 Error::bad_database("PDU leaf points to invalid event!") | ||||
|                             })?; | ||||
|                         if prev_event.kind == EventType::RoomCreate | ||||
|                             && prev_event.prev_events.is_empty() | ||||
|                         { | ||||
|                             true | ||||
|                         } else if sender != target_user_id { | ||||
|                             false | ||||
|                         }; | ||||
| 
 | ||||
|                         if authorized { | ||||
|                             // Update our membership info
 | ||||
|                             self.update_membership(&room_id, &target_user_id, &target_membership)?; | ||||
|                         } else if let member::MembershipState::Ban = current_membership { | ||||
|                             false | ||||
|                         } else { | ||||
|                             join_rules == join_rules::JoinRule::Invite | ||||
|                                 && (current_membership == member::MembershipState::Join | ||||
|                                     || current_membership == member::MembershipState::Invite) | ||||
|                                 || join_rules == join_rules::JoinRule::Public | ||||
|                         } | ||||
|                     } else if target_membership == member::MembershipState::Invite { | ||||
|                         if let Some(third_party_invite_json) = content.get("third_party_invite") { | ||||
|                             if current_membership == member::MembershipState::Ban { | ||||
|                                 false | ||||
|                             } else { | ||||
|                                 let _third_party_invite = | ||||
|                                     serde_json::from_value::<member::ThirdPartyInvite>( | ||||
|                                         third_party_invite_json.clone(), | ||||
|                                     ) | ||||
|                                     .map_err(|_| { | ||||
|                                         Error::BadRequest( | ||||
|                                             ErrorKind::InvalidParam, | ||||
|                                             "ThirdPartyInvite is invalid", | ||||
|                                         ) | ||||
|                                     })?; | ||||
|                                 todo!("handle third party invites"); | ||||
|                             } | ||||
|                         } else if sender_membership != member::MembershipState::Join | ||||
|                             || current_membership == member::MembershipState::Join | ||||
|                             || current_membership == member::MembershipState::Ban | ||||
|                         { | ||||
|                             false | ||||
|                         } else { | ||||
|                             sender_power | ||||
|                                 .filter(|&p| p >= &power_levels.invite) | ||||
|                                 .is_some() | ||||
|                         } | ||||
|                     } else if target_membership == member::MembershipState::Leave { | ||||
|                         if sender == target_user_id { | ||||
|                             current_membership == member::MembershipState::Join | ||||
|                                 || current_membership == member::MembershipState::Invite | ||||
|                         } else if sender_membership != member::MembershipState::Join | ||||
|                             || current_membership == member::MembershipState::Ban | ||||
|                                 && sender_power.filter(|&p| p < &power_levels.ban).is_some() | ||||
|                         { | ||||
|                             false | ||||
|                         } else { | ||||
|                             sender_power.filter(|&p| p >= &power_levels.kick).is_some() | ||||
|                                 && target_power < sender_power | ||||
|                         } | ||||
|                     } else if target_membership == member::MembershipState::Ban { | ||||
|                         if sender_membership != member::MembershipState::Join { | ||||
|                             false | ||||
|                         } else { | ||||
|                             sender_power.filter(|&p| p >= &power_levels.ban).is_some() | ||||
|                                 && target_power < sender_power | ||||
|                         } | ||||
|                     } else { | ||||
|                         false | ||||
|                     }; | ||||
| 
 | ||||
|                         authorized | ||||
|                     if authorized { | ||||
|                         // Update our membership info
 | ||||
|                         self.update_membership(&room_id, &target_user_id, &target_membership)?; | ||||
|                     } | ||||
|                     EventType::RoomCreate => prev_events.is_empty(), | ||||
|                     // Not allow any of the following events if the sender is not joined.
 | ||||
|                     _ if sender_membership != member::MembershipState::Join => false, | ||||
| 
 | ||||
|                     _ => { | ||||
|                         // TODO
 | ||||
|                         sender_power.unwrap_or(&power_levels.users_default) | ||||
|                             >= &power_levels.state_default | ||||
|                     } | ||||
|                 } { | ||||
|                     error!("Unauthorized"); | ||||
|                     // Not authorized
 | ||||
|                     return Err(Error::BadRequest("event not authorized")); | ||||
|                     authorized | ||||
|                 } | ||||
|                 EventType::RoomCreate => prev_events.is_empty(), | ||||
|                 // Not allow any of the following events if the sender is not joined.
 | ||||
|                 _ if sender_membership != member::MembershipState::Join => false, | ||||
| 
 | ||||
|                 _ => { | ||||
|                     // TODO
 | ||||
|                     sender_power.unwrap_or(&power_levels.users_default) | ||||
|                         >= &power_levels.state_default | ||||
|                 } | ||||
|             } { | ||||
|                 error!("Unauthorized"); | ||||
|                 // Not authorized
 | ||||
|                 return Err(Error::BadRequest( | ||||
|                     ErrorKind::Forbidden, | ||||
|                     "Event is not authorized", | ||||
|                 )); | ||||
|             } | ||||
|         } else if !self.is_joined(&sender, &room_id)? { | ||||
|             return Err(Error::BadRequest("event not authorized")); | ||||
|             // TODO: auth rules apply to all events, not only those with a state key
 | ||||
|             error!("Unauthorized"); | ||||
|             return Err(Error::BadRequest( | ||||
|                 ErrorKind::Forbidden, | ||||
|                 "Event is not authorized", | ||||
|             )); | ||||
|         } | ||||
| 
 | ||||
|         // Our depth is the maximum depth of prev_events + 1
 | ||||
|  | @ -410,14 +463,14 @@ impl Rooms { | |||
|             origin: globals.server_name().to_owned(), | ||||
|             origin_server_ts: utils::millis_since_unix_epoch() | ||||
|                 .try_into() | ||||
|                 .expect("this only fails many years in the future"), | ||||
|                 .expect("time is valid"), | ||||
|             kind: event_type.clone(), | ||||
|             content: content.clone(), | ||||
|             state_key, | ||||
|             prev_events, | ||||
|             depth: depth | ||||
|                 .try_into() | ||||
|                 .expect("depth can overflow and should be deprecated..."), | ||||
|                 .map_err(|_| Error::bad_database("Depth is invalid"))?, | ||||
|             auth_events: Vec::new(), | ||||
|             redacts: redacts.clone(), | ||||
|             unsigned, | ||||
|  | @ -430,18 +483,20 @@ impl Rooms { | |||
|         // Generate event id
 | ||||
|         pdu.event_id = EventId::try_from(&*format!( | ||||
|             "${}", | ||||
|             ruma::signatures::reference_hash(&serde_json::to_value(&pdu)?) | ||||
|                 .expect("ruma can calculate reference hashes") | ||||
|             ruma::signatures::reference_hash( | ||||
|                 &serde_json::to_value(&pdu).expect("event is valid, we just created it") | ||||
|             ) | ||||
|             .expect("ruma can calculate reference hashes") | ||||
|         )) | ||||
|         .expect("ruma's reference hashes are correct"); | ||||
|         .expect("ruma's reference hashes are valid event ids"); | ||||
| 
 | ||||
|         let mut pdu_json = serde_json::to_value(&pdu)?; | ||||
|         let mut pdu_json = serde_json::to_value(&pdu).expect("event is valid, we just created it"); | ||||
|         ruma::signatures::hash_and_sign_event( | ||||
|             globals.server_name(), | ||||
|             globals.keypair(), | ||||
|             &mut pdu_json, | ||||
|         ) | ||||
|         .expect("our new event can be hashed and signed"); | ||||
|         .expect("event is valid, we just created it"); | ||||
| 
 | ||||
|         self.replace_pdu_leaves(&room_id, &pdu.event_id)?; | ||||
| 
 | ||||
|  | @ -473,8 +528,15 @@ impl Rooms { | |||
|                     // TODO: Reason
 | ||||
|                     let _reason = serde_json::from_value::< | ||||
|                         EventJson<redaction::RedactionEventContent>, | ||||
|                     >(content)? | ||||
|                     .deserialize()? | ||||
|                     >(content) | ||||
|                     .expect("EventJson::from_value always works.") | ||||
|                     .deserialize() | ||||
|                     .map_err(|_| { | ||||
|                         Error::BadRequest( | ||||
|                             ErrorKind::InvalidParam, | ||||
|                             "Invalid redaction event content.", | ||||
|                         ) | ||||
|                     })? | ||||
|                     .reason; | ||||
| 
 | ||||
|                     self.redact_pdu(&redact_id)?; | ||||
|  | @ -528,7 +590,10 @@ impl Rooms { | |||
|             }) | ||||
|             .filter_map(|r| r.ok()) | ||||
|             .take_while(move |(k, _)| k.starts_with(&prefix)) | ||||
|             .map(|(_, v)| Ok(serde_json::from_slice(&v)?))) | ||||
|             .map(|(_, v)| { | ||||
|                 Ok(serde_json::from_slice(&v) | ||||
|                     .map_err(|_| Error::bad_database("PDU in db is invalid."))?) | ||||
|             })) | ||||
|     } | ||||
| 
 | ||||
|     /// Returns an iterator over all events in a room that happened before the event with id
 | ||||
|  | @ -552,7 +617,10 @@ impl Rooms { | |||
|             .rev() | ||||
|             .filter_map(|r| r.ok()) | ||||
|             .take_while(move |(k, _)| k.starts_with(&prefix)) | ||||
|             .map(|(_, v)| Ok(serde_json::from_slice(&v)?)) | ||||
|             .map(|(_, v)| { | ||||
|                 Ok(serde_json::from_slice(&v) | ||||
|                     .map_err(|_| Error::bad_database("PDU in db is invalid."))?) | ||||
|             }) | ||||
|     } | ||||
| 
 | ||||
|     /// Returns an iterator over all events in a room that happened after the event with id
 | ||||
|  | @ -575,7 +643,10 @@ impl Rooms { | |||
|             .range(current..) | ||||
|             .filter_map(|r| r.ok()) | ||||
|             .take_while(move |(k, _)| k.starts_with(&prefix)) | ||||
|             .map(|(_, v)| Ok(serde_json::from_slice(&v)?)) | ||||
|             .map(|(_, v)| { | ||||
|                 Ok(serde_json::from_slice(&v) | ||||
|                     .map_err(|_| Error::bad_database("PDU in db is invalid."))?) | ||||
|             }) | ||||
|     } | ||||
| 
 | ||||
|     /// Replace a PDU with the redacted form.
 | ||||
|  | @ -583,12 +654,15 @@ impl Rooms { | |||
|         if let Some(pdu_id) = self.get_pdu_id(event_id)? { | ||||
|             let mut pdu = self | ||||
|                 .get_pdu_from_id(&pdu_id)? | ||||
|                 .ok_or(Error::BadDatabase("pduid points to invalid pdu"))?; | ||||
|             pdu.redact(); | ||||
|                 .ok_or_else(|| Error::bad_database("PDU ID points to invalid PDU."))?; | ||||
|             pdu.redact()?; | ||||
|             self.replace_pdu(&pdu_id, &pdu)?; | ||||
|             Ok(()) | ||||
|         } else { | ||||
|             Err(Error::BadRequest("eventid does not exist")) | ||||
|             Err(Error::BadRequest( | ||||
|                 ErrorKind::NotFound, | ||||
|                 "Event ID does not exist.", | ||||
|             )) | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|  | @ -664,7 +738,10 @@ impl Rooms { | |||
|             let room_id = self | ||||
|                 .alias_roomid | ||||
|                 .remove(alias.alias())? | ||||
|                 .ok_or(Error::BadRequest("Alias does not exist"))?; | ||||
|                 .ok_or(Error::BadRequest( | ||||
|                     ErrorKind::NotFound, | ||||
|                     "Alias does not exist.", | ||||
|                 ))?; | ||||
| 
 | ||||
|             for key in self.aliasid_alias.scan_prefix(room_id).keys() { | ||||
|                 self.aliasid_alias.remove(key?)?; | ||||
|  | @ -678,7 +755,12 @@ impl Rooms { | |||
|         self.alias_roomid | ||||
|             .get(alias.alias())? | ||||
|             .map_or(Ok(None), |bytes| { | ||||
|                 Ok(Some(RoomId::try_from(utils::string_from_bytes(&bytes)?)?)) | ||||
|                 Ok(Some( | ||||
|                     RoomId::try_from(utils::string_from_bytes(&bytes).map_err(|_| { | ||||
|                         Error::bad_database("Room ID in alias_roomid is invalid unicode.") | ||||
|                     })?) | ||||
|                     .map_err(|_| Error::bad_database("Room ID in alias_roomid is invalid."))?, | ||||
|                 )) | ||||
|             }) | ||||
|     } | ||||
| 
 | ||||
|  | @ -689,7 +771,10 @@ impl Rooms { | |||
|         self.aliasid_alias | ||||
|             .scan_prefix(prefix) | ||||
|             .values() | ||||
|             .map(|bytes| Ok(RoomAliasId::try_from(utils::string_from_bytes(&bytes?)?)?)) | ||||
|             .map(|bytes| { | ||||
|                 Ok(serde_json::from_slice(&bytes?) | ||||
|                     .map_err(|_| Error::bad_database("Alias in aliasid_alias is invalid."))?) | ||||
|             }) | ||||
|     } | ||||
| 
 | ||||
|     pub fn set_public(&self, room_id: &RoomId, public: bool) -> Result<()> { | ||||
|  | @ -707,54 +792,76 @@ impl Rooms { | |||
|     } | ||||
| 
 | ||||
|     pub fn public_rooms(&self) -> impl Iterator<Item = Result<RoomId>> { | ||||
|         self.publicroomids | ||||
|             .iter() | ||||
|             .keys() | ||||
|             .map(|bytes| Ok(RoomId::try_from(utils::string_from_bytes(&bytes?)?)?)) | ||||
|         self.publicroomids.iter().keys().map(|bytes| { | ||||
|             Ok( | ||||
|                 RoomId::try_from(utils::string_from_bytes(&bytes?).map_err(|_| { | ||||
|                     Error::bad_database("Room ID in publicroomids is invalid unicode.") | ||||
|                 })?) | ||||
|                 .map_err(|_| Error::bad_database("Room ID in publicroomids is invalid."))?, | ||||
|             ) | ||||
|         }) | ||||
|     } | ||||
| 
 | ||||
|     /// Returns an iterator over all rooms a user joined.
 | ||||
|     /// Returns an iterator over all joined members of a room.
 | ||||
|     pub fn room_members(&self, room_id: &RoomId) -> impl Iterator<Item = Result<UserId>> { | ||||
|         self.roomuserid_joined | ||||
|             .scan_prefix(room_id.to_string()) | ||||
|             .values() | ||||
|             .keys() | ||||
|             .map(|key| { | ||||
|                 Ok(UserId::try_from(&*utils::string_from_bytes( | ||||
|                     &key? | ||||
|                         .rsplit(|&b| b == 0xff) | ||||
|                         .next() | ||||
|                         .ok_or(Error::BadDatabase("userroomid is invalid"))?, | ||||
|                 )?)?) | ||||
|                 Ok(UserId::try_from( | ||||
|                     utils::string_from_bytes( | ||||
|                         &key? | ||||
|                             .rsplit(|&b| b == 0xff) | ||||
|                             .next() | ||||
|                             .expect("rsplit always returns an element"), | ||||
|                     ) | ||||
|                     .map_err(|_| { | ||||
|                         Error::bad_database("User ID in roomuserid_joined is invalid unicode.") | ||||
|                     })?, | ||||
|                 ) | ||||
|                 .map_err(|_| Error::bad_database("User ID in roomuserid_joined is invalid."))?) | ||||
|             }) | ||||
|     } | ||||
| 
 | ||||
|     /// Returns an iterator over all rooms a user joined.
 | ||||
|     /// Returns an iterator over all invited members of a room.
 | ||||
|     pub fn room_members_invited(&self, room_id: &RoomId) -> impl Iterator<Item = Result<UserId>> { | ||||
|         self.roomuserid_invited | ||||
|             .scan_prefix(room_id.to_string()) | ||||
|             .keys() | ||||
|             .map(|key| { | ||||
|                 Ok(UserId::try_from(&*utils::string_from_bytes( | ||||
|                     &key? | ||||
|                         .rsplit(|&b| b == 0xff) | ||||
|                         .next() | ||||
|                         .ok_or(Error::BadDatabase("userroomid is invalid"))?, | ||||
|                 )?)?) | ||||
|                 Ok(UserId::try_from( | ||||
|                     utils::string_from_bytes( | ||||
|                         &key? | ||||
|                             .rsplit(|&b| b == 0xff) | ||||
|                             .next() | ||||
|                             .expect("rsplit always returns an element"), | ||||
|                     ) | ||||
|                     .map_err(|_| { | ||||
|                         Error::bad_database("User ID in roomuserid_invited is invalid unicode.") | ||||
|                     })?, | ||||
|                 ) | ||||
|                 .map_err(|_| Error::bad_database("User ID in roomuserid_invited is invalid."))?) | ||||
|             }) | ||||
|     } | ||||
| 
 | ||||
|     /// Returns an iterator over all rooms a user joined.
 | ||||
|     /// Returns an iterator over all left members of a room.
 | ||||
|     pub fn rooms_joined(&self, user_id: &UserId) -> impl Iterator<Item = Result<RoomId>> { | ||||
|         self.userroomid_joined | ||||
|             .scan_prefix(user_id.to_string()) | ||||
|             .keys() | ||||
|             .map(|key| { | ||||
|                 Ok(RoomId::try_from(&*utils::string_from_bytes( | ||||
|                     &key? | ||||
|                         .rsplit(|&b| b == 0xff) | ||||
|                         .next() | ||||
|                         .ok_or(Error::BadDatabase("userroomid is invalid"))?, | ||||
|                 )?)?) | ||||
|                 Ok(RoomId::try_from( | ||||
|                     utils::string_from_bytes( | ||||
|                         &key? | ||||
|                             .rsplit(|&b| b == 0xff) | ||||
|                             .next() | ||||
|                             .expect("rsplit always returns an element"), | ||||
|                     ) | ||||
|                     .map_err(|_| { | ||||
|                         Error::bad_database("Room ID in userroomid_joined is invalid unicode.") | ||||
|                     })?, | ||||
|                 ) | ||||
|                 .map_err(|_| Error::bad_database("Room ID in userroomid_joined is invalid."))?) | ||||
|             }) | ||||
|     } | ||||
| 
 | ||||
|  | @ -764,12 +871,18 @@ impl Rooms { | |||
|             .scan_prefix(&user_id.to_string()) | ||||
|             .keys() | ||||
|             .map(|key| { | ||||
|                 Ok(RoomId::try_from(&*utils::string_from_bytes( | ||||
|                     &key? | ||||
|                         .rsplit(|&b| b == 0xff) | ||||
|                         .next() | ||||
|                         .ok_or(Error::BadDatabase("userroomid is invalid"))?, | ||||
|                 )?)?) | ||||
|                 Ok(RoomId::try_from( | ||||
|                     utils::string_from_bytes( | ||||
|                         &key? | ||||
|                             .rsplit(|&b| b == 0xff) | ||||
|                             .next() | ||||
|                             .expect("rsplit always returns an element"), | ||||
|                     ) | ||||
|                     .map_err(|_| { | ||||
|                         Error::bad_database("Room ID in userroomid_invited is invalid unicode.") | ||||
|                     })?, | ||||
|                 ) | ||||
|                 .map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid."))?) | ||||
|             }) | ||||
|     } | ||||
| 
 | ||||
|  | @ -779,12 +892,18 @@ impl Rooms { | |||
|             .scan_prefix(&user_id.to_string()) | ||||
|             .keys() | ||||
|             .map(|key| { | ||||
|                 Ok(RoomId::try_from(&*utils::string_from_bytes( | ||||
|                     &key? | ||||
|                         .rsplit(|&b| b == 0xff) | ||||
|                         .next() | ||||
|                         .ok_or(Error::BadDatabase("userroomid is invalid"))?, | ||||
|                 )?)?) | ||||
|                 Ok(RoomId::try_from( | ||||
|                     utils::string_from_bytes( | ||||
|                         &key? | ||||
|                             .rsplit(|&b| b == 0xff) | ||||
|                             .next() | ||||
|                             .expect("rsplit always returns an element"), | ||||
|                     ) | ||||
|                     .map_err(|_| { | ||||
|                         Error::bad_database("Room ID in userroomid_left is invalid unicode.") | ||||
|                     })?, | ||||
|                 ) | ||||
|                 .map_err(|_| Error::bad_database("Room ID in userroomid_left is invalid."))?) | ||||
|             }) | ||||
|     } | ||||
| 
 | ||||
|  |  | |||
|  | @ -33,7 +33,10 @@ impl RoomEdus { | |||
|             .filter_map(|r| r.ok()) | ||||
|             .take_while(|key| key.starts_with(&prefix)) | ||||
|             .find(|key| { | ||||
|                 key.rsplit(|&b| b == 0xff).next().unwrap() == user_id.to_string().as_bytes() | ||||
|                 key.rsplit(|&b| b == 0xff) | ||||
|                     .next() | ||||
|                     .expect("rsplit always returns an element") | ||||
|                     == user_id.to_string().as_bytes() | ||||
|             }) | ||||
|         { | ||||
|             // This is the old room_latest
 | ||||
|  | @ -45,8 +48,10 @@ impl RoomEdus { | |||
|         room_latest_id.push(0xff); | ||||
|         room_latest_id.extend_from_slice(&user_id.to_string().as_bytes()); | ||||
| 
 | ||||
|         self.roomlatestid_roomlatest | ||||
|             .insert(room_latest_id, &*serde_json::to_string(&event)?)?; | ||||
|         self.roomlatestid_roomlatest.insert( | ||||
|             room_latest_id, | ||||
|             &*serde_json::to_string(&event).expect("EduEvent::to_string always works"), | ||||
|         )?; | ||||
| 
 | ||||
|         Ok(()) | ||||
|     } | ||||
|  | @ -68,7 +73,11 @@ impl RoomEdus { | |||
|             .range(&*first_possible_edu..) | ||||
|             .filter_map(|r| r.ok()) | ||||
|             .take_while(move |(k, _)| k.starts_with(&prefix)) | ||||
|             .map(|(_, v)| Ok(serde_json::from_slice(&v)?))) | ||||
|             .map(|(_, v)| { | ||||
|                 Ok(serde_json::from_slice(&v).map_err(|_| { | ||||
|                     Error::bad_database("Read receipt in roomlatestid_roomlatest is invalid.") | ||||
|                 })?) | ||||
|             })) | ||||
|     } | ||||
| 
 | ||||
|     /// Sets a user as typing until the timeout timestamp is reached or roomactive_remove is
 | ||||
|  | @ -152,17 +161,21 @@ impl RoomEdus { | |||
|             .roomactiveid_userid | ||||
|             .scan_prefix(&prefix) | ||||
|             .keys() | ||||
|             .filter_map(|r| r.ok()) | ||||
|             .take_while(|k| { | ||||
|                 utils::u64_from_bytes( | ||||
|                     k.split(|&c| c == 0xff) | ||||
|                         .nth(1) | ||||
|                         .expect("roomactive has valid timestamp and delimiters"), | ||||
|                 ) < current_timestamp | ||||
|             .map(|key| { | ||||
|                 let key = key?; | ||||
|                 Ok::<_, Error>(( | ||||
|                     key.clone(), | ||||
|                     utils::u64_from_bytes(key.split(|&b| b == 0xff).nth(1).ok_or_else(|| { | ||||
|                         Error::bad_database("RoomActive has invalid timestamp or delimiters.") | ||||
|                     })?) | ||||
|                     .map_err(|_| Error::bad_database("RoomActive has invalid timestamp bytes."))?, | ||||
|                 )) | ||||
|             }) | ||||
|             .filter_map(|r| r.ok()) | ||||
|             .take_while(|&(_, timestamp)| timestamp < current_timestamp) | ||||
|         { | ||||
|             // This is an outdated edu (time > timestamp)
 | ||||
|             self.roomactiveid_userid.remove(outdated_edu)?; | ||||
|             self.roomactiveid_userid.remove(outdated_edu.0)?; | ||||
|             found_outdated = true; | ||||
|         } | ||||
| 
 | ||||
|  | @ -187,7 +200,11 @@ impl RoomEdus { | |||
|         Ok(self | ||||
|             .roomid_lastroomactiveupdate | ||||
|             .get(&room_id.to_string().as_bytes())? | ||||
|             .map(|bytes| utils::u64_from_bytes(&bytes)) | ||||
|             .map_or(Ok::<_, Error>(None), |bytes| { | ||||
|                 Ok(Some(utils::u64_from_bytes(&bytes).map_err(|_| { | ||||
|                     Error::bad_database("Count in roomid_lastroomactiveupdate is invalid.") | ||||
|                 })?)) | ||||
|             })? | ||||
|             .unwrap_or(0)) | ||||
|     } | ||||
| 
 | ||||
|  | @ -202,7 +219,16 @@ impl RoomEdus { | |||
|             .roomactiveid_userid | ||||
|             .scan_prefix(prefix) | ||||
|             .values() | ||||
|             .map(|user_id| Ok::<_, Error>(UserId::try_from(utils::string_from_bytes(&user_id?)?)?)) | ||||
|             .map(|user_id| { | ||||
|                 Ok::<_, Error>( | ||||
|                     UserId::try_from(utils::string_from_bytes(&user_id?).map_err(|_| { | ||||
|                         Error::bad_database("User ID in roomactiveid_userid is invalid unicode.") | ||||
|                     })?) | ||||
|                     .map_err(|_| { | ||||
|                         Error::bad_database("User ID in roomactiveid_userid is invalid.") | ||||
|                     })?, | ||||
|                 ) | ||||
|             }) | ||||
|         { | ||||
|             user_ids.push(user_id?); | ||||
|         } | ||||
|  | @ -230,9 +256,10 @@ impl RoomEdus { | |||
|         key.push(0xff); | ||||
|         key.extend_from_slice(&user_id.to_string().as_bytes()); | ||||
| 
 | ||||
|         Ok(self | ||||
|             .roomuserid_lastread | ||||
|             .get(key)? | ||||
|             .map(|v| utils::u64_from_bytes(&v))) | ||||
|         self.roomuserid_lastread.get(key)?.map_or(Ok(None), |v| { | ||||
|             Ok(Some(utils::u64_from_bytes(&v).map_err(|_| { | ||||
|                 Error::bad_database("Invalid private read marker bytes") | ||||
|             })?)) | ||||
|         }) | ||||
|     } | ||||
| } | ||||
|  |  | |||
|  | @ -43,15 +43,51 @@ impl Uiaa { | |||
|             // Find out what the user completed
 | ||||
|             match &**kind { | ||||
|                 "m.login.password" => { | ||||
|                     if auth_parameters["identifier"]["type"] != "m.id.user" { | ||||
|                         panic!("identifier not supported"); | ||||
|                     let identifier = auth_parameters.get("identifier").ok_or(Error::BadRequest( | ||||
|                         ErrorKind::MissingParam, | ||||
|                         "m.login.password needs identifier.", | ||||
|                     ))?; | ||||
| 
 | ||||
|                     let identifier_type = identifier.get("type").ok_or(Error::BadRequest( | ||||
|                         ErrorKind::MissingParam, | ||||
|                         "Identifier needs a type.", | ||||
|                     ))?; | ||||
| 
 | ||||
|                     if identifier_type != "m.id.user" { | ||||
|                         return Err(Error::BadRequest( | ||||
|                             ErrorKind::Unrecognized, | ||||
|                             "Identifier type not recognized.", | ||||
|                         )); | ||||
|                     } | ||||
| 
 | ||||
|                     let user_id = UserId::parse_with_server_name( | ||||
|                         auth_parameters["identifier"]["user"].as_str().unwrap(), | ||||
|                         globals.server_name(), | ||||
|                     )?; | ||||
|                     let password = auth_parameters["password"].as_str().unwrap(); | ||||
|                     let username = identifier | ||||
|                         .get("user") | ||||
|                         .ok_or(Error::BadRequest( | ||||
|                             ErrorKind::MissingParam, | ||||
|                             "Identifier needs user field.", | ||||
|                         ))? | ||||
|                         .as_str() | ||||
|                         .ok_or(Error::BadRequest( | ||||
|                             ErrorKind::BadJson, | ||||
|                             "User is not a string.", | ||||
|                         ))?; | ||||
| 
 | ||||
|                     let user_id = UserId::parse_with_server_name(username, globals.server_name()) | ||||
|                         .map_err(|_| { | ||||
|                         Error::BadRequest(ErrorKind::InvalidParam, "User ID is invalid.") | ||||
|                     })?; | ||||
| 
 | ||||
|                     let password = auth_parameters | ||||
|                         .get("password") | ||||
|                         .ok_or(Error::BadRequest( | ||||
|                             ErrorKind::MissingParam, | ||||
|                             "Password is missing.", | ||||
|                         ))? | ||||
|                         .as_str() | ||||
|                         .ok_or(Error::BadRequest( | ||||
|                             ErrorKind::BadJson, | ||||
|                             "Password is not a string.", | ||||
|                         ))?; | ||||
| 
 | ||||
|                     // Check if password is correct
 | ||||
|                     if let Some(hash) = users.password_hash(&user_id)? { | ||||
|  | @ -59,7 +95,6 @@ impl Uiaa { | |||
|                             argon2::verify_encoded(&hash, password.as_bytes()).unwrap_or(false); | ||||
| 
 | ||||
|                         if !hash_matches { | ||||
|                             debug!("Invalid password."); | ||||
|                             uiaainfo.auth_error = Some(ruma::api::client::error::ErrorBody { | ||||
|                                 kind: ErrorKind::Forbidden, | ||||
|                                 message: "Invalid username or password.".to_owned(), | ||||
|  | @ -113,8 +148,10 @@ impl Uiaa { | |||
|         userdeviceid.extend_from_slice(device_id.as_bytes()); | ||||
| 
 | ||||
|         if let Some(uiaainfo) = uiaainfo { | ||||
|             self.userdeviceid_uiaainfo | ||||
|                 .insert(&userdeviceid, &*serde_json::to_string(&uiaainfo)?)?; | ||||
|             self.userdeviceid_uiaainfo.insert( | ||||
|                 &userdeviceid, | ||||
|                 &*serde_json::to_string(&uiaainfo).expect("UiaaInfo::to_string always works"), | ||||
|             )?; | ||||
|         } else { | ||||
|             self.userdeviceid_uiaainfo.remove(&userdeviceid)?; | ||||
|         } | ||||
|  | @ -136,8 +173,12 @@ impl Uiaa { | |||
|             &self | ||||
|                 .userdeviceid_uiaainfo | ||||
|                 .get(&userdeviceid)? | ||||
|                 .ok_or(Error::BadRequest("session does not exist"))?, | ||||
|         )?; | ||||
|                 .ok_or(Error::BadRequest( | ||||
|                     ErrorKind::Forbidden, | ||||
|                     "UIAA session does not exist.", | ||||
|                 ))?, | ||||
|         ) | ||||
|         .map_err(|_| Error::bad_database("UiaaInfo in userdeviceid_uiaainfo is invalid."))?; | ||||
| 
 | ||||
|         if uiaainfo | ||||
|             .session | ||||
|  | @ -145,7 +186,10 @@ impl Uiaa { | |||
|             .filter(|&s| s == session) | ||||
|             .is_none() | ||||
|         { | ||||
|             return Err(Error::BadRequest("wrong session token")); | ||||
|             return Err(Error::BadRequest( | ||||
|                 ErrorKind::Forbidden, | ||||
|                 "UIAA session token invalid.", | ||||
|             )); | ||||
|         } | ||||
| 
 | ||||
|         Ok(uiaainfo) | ||||
|  |  | |||
|  | @ -43,24 +43,36 @@ impl Users { | |||
|             .get(token)? | ||||
|             .map_or(Ok(None), |bytes| { | ||||
|                 let mut parts = bytes.split(|&b| b == 0xff); | ||||
|                 let user_bytes = parts | ||||
|                     .next() | ||||
|                     .ok_or(Error::BadDatabase("token_userdeviceid value invalid"))?; | ||||
|                 let device_bytes = parts | ||||
|                     .next() | ||||
|                     .ok_or(Error::BadDatabase("token_userdeviceid value invalid"))?; | ||||
|                 let user_bytes = parts.next().ok_or_else(|| { | ||||
|                     Error::bad_database("User ID in token_userdeviceid is invalid.") | ||||
|                 })?; | ||||
|                 let device_bytes = parts.next().ok_or_else(|| { | ||||
|                     Error::bad_database("Device ID in token_userdeviceid is invalid.") | ||||
|                 })?; | ||||
| 
 | ||||
|                 Ok(Some(( | ||||
|                     UserId::try_from(utils::string_from_bytes(&user_bytes)?)?, | ||||
|                     utils::string_from_bytes(&device_bytes)?, | ||||
|                     UserId::try_from(utils::string_from_bytes(&user_bytes).map_err(|_| { | ||||
|                         Error::bad_database("User ID in token_userdeviceid is invalid unicode.") | ||||
|                     })?) | ||||
|                     .map_err(|_| { | ||||
|                         Error::bad_database("User ID in token_userdeviceid is invalid.") | ||||
|                     })?, | ||||
|                     utils::string_from_bytes(&device_bytes).map_err(|_| { | ||||
|                         Error::bad_database("Device ID in token_userdeviceid is invalid.") | ||||
|                     })?, | ||||
|                 ))) | ||||
|             }) | ||||
|     } | ||||
| 
 | ||||
|     /// Returns an iterator over all users on this homeserver.
 | ||||
|     pub fn iter(&self) -> impl Iterator<Item = Result<UserId>> { | ||||
|         self.userid_password.iter().keys().map(|r| { | ||||
|             utils::string_from_bytes(&r?).and_then(|string| Ok(UserId::try_from(&*string)?)) | ||||
|         self.userid_password.iter().keys().map(|bytes| { | ||||
|             Ok( | ||||
|                 UserId::try_from(utils::string_from_bytes(&bytes?).map_err(|_| { | ||||
|                     Error::bad_database("User ID in userid_password is invalid unicode.") | ||||
|                 })?) | ||||
|                 .map_err(|_| Error::bad_database("User ID in userid_password is invalid."))?, | ||||
|             ) | ||||
|         }) | ||||
|     } | ||||
| 
 | ||||
|  | @ -68,14 +80,22 @@ impl Users { | |||
|     pub fn password_hash(&self, user_id: &UserId) -> Result<Option<String>> { | ||||
|         self.userid_password | ||||
|             .get(user_id.to_string())? | ||||
|             .map_or(Ok(None), |bytes| utils::string_from_bytes(&bytes).map(Some)) | ||||
|             .map_or(Ok(None), |bytes| { | ||||
|                 Ok(Some(utils::string_from_bytes(&bytes).map_err(|_| { | ||||
|                     Error::bad_database("Password hash in db is not valid string.") | ||||
|                 })?)) | ||||
|             }) | ||||
|     } | ||||
| 
 | ||||
|     /// Returns the displayname of a user on this homeserver.
 | ||||
|     pub fn displayname(&self, user_id: &UserId) -> Result<Option<String>> { | ||||
|         self.userid_displayname | ||||
|             .get(user_id.to_string())? | ||||
|             .map_or(Ok(None), |bytes| utils::string_from_bytes(&bytes).map(Some)) | ||||
|             .map_or(Ok(None), |bytes| { | ||||
|                 Ok(Some(utils::string_from_bytes(&bytes).map_err(|_| { | ||||
|                     Error::bad_database("Displayname in db is invalid.") | ||||
|                 })?)) | ||||
|             }) | ||||
|     } | ||||
| 
 | ||||
|     /// Sets a new displayname or removes it if displayname is None. You still need to nofify all rooms of this change.
 | ||||
|  | @ -94,7 +114,11 @@ impl Users { | |||
|     pub fn avatar_url(&self, user_id: &UserId) -> Result<Option<String>> { | ||||
|         self.userid_avatarurl | ||||
|             .get(user_id.to_string())? | ||||
|             .map_or(Ok(None), |bytes| utils::string_from_bytes(&bytes).map(Some)) | ||||
|             .map_or(Ok(None), |bytes| { | ||||
|                 Ok(Some(utils::string_from_bytes(&bytes).map_err(|_| { | ||||
|                     Error::bad_database("Avatar URL in db is invalid.") | ||||
|                 })?)) | ||||
|             }) | ||||
|     } | ||||
| 
 | ||||
|     /// Sets a new avatar_url or removes it if avatar_url is None.
 | ||||
|  | @ -117,11 +141,8 @@ impl Users { | |||
|         token: &str, | ||||
|         initial_device_display_name: Option<String>, | ||||
|     ) -> Result<()> { | ||||
|         if !self.exists(user_id)? { | ||||
|             return Err(Error::BadRequest( | ||||
|                 "tried to create device for nonexistent user", | ||||
|             )); | ||||
|         } | ||||
|         // This method should never be called for nonexistent users.
 | ||||
|         assert!(self.exists(user_id)?); | ||||
| 
 | ||||
|         let mut userdeviceid = user_id.to_string().as_bytes().to_vec(); | ||||
|         userdeviceid.push(0xff); | ||||
|  | @ -134,7 +155,8 @@ impl Users { | |||
|                 display_name: initial_device_display_name, | ||||
|                 last_seen_ip: None, // TODO
 | ||||
|                 last_seen_ts: Some(SystemTime::now()), | ||||
|             })? | ||||
|             }) | ||||
|             .expect("Device::to_string never fails.") | ||||
|             .as_bytes(), | ||||
|         )?; | ||||
| 
 | ||||
|  | @ -185,23 +207,22 @@ impl Users { | |||
|                     &*bytes? | ||||
|                         .rsplit(|&b| b == 0xff) | ||||
|                         .next() | ||||
|                         .ok_or(Error::BadDatabase("userdeviceid is invalid"))?, | ||||
|                 )?) | ||||
|                         .ok_or_else(|| Error::bad_database("UserDevice ID in db is invalid."))?, | ||||
|                 ) | ||||
|                 .map_err(|_| { | ||||
|                     Error::bad_database("Device ID in userdeviceid_metadata is invalid.") | ||||
|                 })?) | ||||
|             }) | ||||
|     } | ||||
| 
 | ||||
|     /// Replaces the access token of one device.
 | ||||
|     pub fn set_token(&self, user_id: &UserId, device_id: &str, token: &str) -> Result<()> { | ||||
|     fn set_token(&self, user_id: &UserId, device_id: &str, token: &str) -> Result<()> { | ||||
|         let mut userdeviceid = user_id.to_string().as_bytes().to_vec(); | ||||
|         userdeviceid.push(0xff); | ||||
|         userdeviceid.extend_from_slice(device_id.as_bytes()); | ||||
| 
 | ||||
|         // All devices have metadata
 | ||||
|         if self.userdeviceid_metadata.get(&userdeviceid)?.is_none() { | ||||
|             return Err(Error::BadRequest( | ||||
|                 "Tried to set token for nonexistent device", | ||||
|             )); | ||||
|         } | ||||
|         assert!(self.userdeviceid_metadata.get(&userdeviceid)?.is_some()); | ||||
| 
 | ||||
|         // Remove old token
 | ||||
|         if let Some(old_token) = self.userdeviceid_token.get(&userdeviceid)? { | ||||
|  | @ -228,19 +249,23 @@ impl Users { | |||
|         key.extend_from_slice(device_id.as_bytes()); | ||||
| 
 | ||||
|         // All devices have metadata
 | ||||
|         if self.userdeviceid_metadata.get(&key)?.is_none() { | ||||
|             return Err(Error::BadRequest( | ||||
|                 "Tried to set token for nonexistent device", | ||||
|             )); | ||||
|         } | ||||
|         // Only existing devices should be able to call this.
 | ||||
|         assert!(self.userdeviceid_metadata.get(&key)?.is_some()); | ||||
| 
 | ||||
|         key.push(0xff); | ||||
|         // TODO: Use AlgorithmAndDeviceId::to_string when it's available (and update everything,
 | ||||
|         // because there are no wrapping quotation marks anymore)
 | ||||
|         key.extend_from_slice(&serde_json::to_string(one_time_key_key)?.as_bytes()); | ||||
|         key.extend_from_slice( | ||||
|             &serde_json::to_string(one_time_key_key) | ||||
|                 .expect("AlgorithmAndDeviceId::to_string always works") | ||||
|                 .as_bytes(), | ||||
|         ); | ||||
| 
 | ||||
|         self.onetimekeyid_onetimekeys | ||||
|             .insert(&key, &*serde_json::to_string(&one_time_key_value)?)?; | ||||
|         self.onetimekeyid_onetimekeys.insert( | ||||
|             &key, | ||||
|             &*serde_json::to_string(&one_time_key_value) | ||||
|                 .expect("OneTimeKey::to_string always works"), | ||||
|         )?; | ||||
| 
 | ||||
|         Ok(()) | ||||
|     } | ||||
|  | @ -271,9 +296,11 @@ impl Users { | |||
|                         &*key | ||||
|                             .rsplit(|&b| b == 0xff) | ||||
|                             .next() | ||||
|                             .ok_or(Error::BadDatabase("onetimekeyid is invalid"))?, | ||||
|                     )?, | ||||
|                     serde_json::from_slice(&*value)?, | ||||
|                             .ok_or_else(|| Error::bad_database("OneTimeKeyId in db is invalid."))?, | ||||
|                     ) | ||||
|                     .map_err(|_| Error::bad_database("OneTimeKeyId in db is invalid."))?, | ||||
|                     serde_json::from_slice(&*value) | ||||
|                         .map_err(|_| Error::bad_database("OneTimeKeys in db are invalid."))?, | ||||
|                 )) | ||||
|             }) | ||||
|             .transpose() | ||||
|  | @ -297,11 +324,11 @@ impl Users { | |||
|             .map(|bytes| { | ||||
|                 Ok::<_, Error>( | ||||
|                     serde_json::from_slice::<AlgorithmAndDeviceId>( | ||||
|                         &*bytes? | ||||
|                             .rsplit(|&b| b == 0xff) | ||||
|                             .next() | ||||
|                             .ok_or(Error::BadDatabase("onetimekeyid is invalid"))?, | ||||
|                     )? | ||||
|                         &*bytes?.rsplit(|&b| b == 0xff).next().ok_or_else(|| { | ||||
|                             Error::bad_database("OneTimeKey ID in db is invalid.") | ||||
|                         })?, | ||||
|                     ) | ||||
|                     .map_err(|_| Error::bad_database("AlgorithmAndDeviceID in db is invalid."))? | ||||
|                     .0, | ||||
|                 ) | ||||
|             }) | ||||
|  | @ -323,8 +350,10 @@ impl Users { | |||
|         userdeviceid.push(0xff); | ||||
|         userdeviceid.extend_from_slice(device_id.as_bytes()); | ||||
| 
 | ||||
|         self.userdeviceid_devicekeys | ||||
|             .insert(&userdeviceid, &*serde_json::to_string(&device_keys)?)?; | ||||
|         self.userdeviceid_devicekeys.insert( | ||||
|             &userdeviceid, | ||||
|             &*serde_json::to_string(&device_keys).expect("DeviceKeys::to_string always works"), | ||||
|         )?; | ||||
| 
 | ||||
|         self.devicekeychangeid_userid | ||||
|             .insert(globals.next_count()?.to_be_bytes(), &*user_id.to_string())?; | ||||
|  | @ -344,14 +373,28 @@ impl Users { | |||
|         self.userdeviceid_devicekeys | ||||
|             .scan_prefix(key) | ||||
|             .values() | ||||
|             .map(|bytes| Ok(serde_json::from_slice(&bytes?)?)) | ||||
|             .map(|bytes| { | ||||
|                 Ok(serde_json::from_slice(&bytes?) | ||||
|                     .map_err(|_| Error::bad_database("DeviceKeys in db are invalid."))?) | ||||
|             }) | ||||
|     } | ||||
| 
 | ||||
|     pub fn device_keys_changed(&self, since: u64) -> impl Iterator<Item = Result<UserId>> { | ||||
|         self.devicekeychangeid_userid | ||||
|             .range(since.to_be_bytes()..) | ||||
|             .values() | ||||
|             .map(|bytes| Ok(UserId::try_from(utils::string_from_bytes(&bytes?)?)?)) | ||||
|             .map(|bytes| { | ||||
|                 Ok( | ||||
|                     UserId::try_from(utils::string_from_bytes(&bytes?).map_err(|_| { | ||||
|                         Error::bad_database( | ||||
|                             "User ID in devicekeychangeid_userid is invalid unicode.", | ||||
|                         ) | ||||
|                     })?) | ||||
|                     .map_err(|_| { | ||||
|                         Error::bad_database("User ID in devicekeychangeid_userid is invalid.") | ||||
|                     })?, | ||||
|                 ) | ||||
|             }) | ||||
|     } | ||||
| 
 | ||||
|     pub fn all_device_keys( | ||||
|  | @ -366,9 +409,14 @@ impl Users { | |||
|             let userdeviceid = utils::string_from_bytes( | ||||
|                 key.rsplit(|&b| b == 0xff) | ||||
|                     .next() | ||||
|                     .ok_or(Error::BadDatabase("userdeviceid is invalid"))?, | ||||
|             )?; | ||||
|             Ok((userdeviceid, serde_json::from_slice(&*value)?)) | ||||
|                     .ok_or_else(|| Error::bad_database("UserDeviceID in db is invalid."))?, | ||||
|             ) | ||||
|             .map_err(|_| Error::bad_database("UserDeviceId in db is invalid."))?; | ||||
|             Ok(( | ||||
|                 userdeviceid, | ||||
|                 serde_json::from_slice(&*value) | ||||
|                     .map_err(|_| Error::bad_database("DeviceKeys in db are invalid."))?, | ||||
|             )) | ||||
|         }) | ||||
|     } | ||||
| 
 | ||||
|  | @ -392,8 +440,10 @@ impl Users { | |||
|         json.insert("sender".to_owned(), sender.to_string().into()); | ||||
|         json.insert("content".to_owned(), content); | ||||
| 
 | ||||
|         self.todeviceid_events | ||||
|             .insert(&key, &*serde_json::to_string(&json)?)?; | ||||
|         self.todeviceid_events.insert( | ||||
|             &key, | ||||
|             &*serde_json::to_string(&json).expect("Map::to_string always works"), | ||||
|         )?; | ||||
| 
 | ||||
|         Ok(()) | ||||
|     } | ||||
|  | @ -413,7 +463,10 @@ impl Users { | |||
| 
 | ||||
|         for result in self.todeviceid_events.scan_prefix(&prefix).take(max) { | ||||
|             let (key, value) = result?; | ||||
|             events.push(serde_json::from_slice(&*value)?); | ||||
|             events.push( | ||||
|                 serde_json::from_slice(&*value) | ||||
|                     .map_err(|_| Error::bad_database("Event in todeviceid_events is invalid."))?, | ||||
|             ); | ||||
|             self.todeviceid_events.remove(key)?; | ||||
|         } | ||||
| 
 | ||||
|  | @ -430,12 +483,15 @@ impl Users { | |||
|         userdeviceid.push(0xff); | ||||
|         userdeviceid.extend_from_slice(device_id.as_bytes()); | ||||
| 
 | ||||
|         if self.userdeviceid_metadata.get(&userdeviceid)?.is_none() { | ||||
|             return Err(Error::BadRequest("device does not exist")); | ||||
|         } | ||||
|         // Only existing devices should be able to call this.
 | ||||
|         assert!(self.userdeviceid_metadata.get(&userdeviceid)?.is_some()); | ||||
| 
 | ||||
|         self.userdeviceid_metadata | ||||
|             .insert(userdeviceid, serde_json::to_string(device)?.as_bytes())?; | ||||
|         self.userdeviceid_metadata.insert( | ||||
|             userdeviceid, | ||||
|             serde_json::to_string(device) | ||||
|                 .expect("Device::to_string always works") | ||||
|                 .as_bytes(), | ||||
|         )?; | ||||
| 
 | ||||
|         Ok(()) | ||||
|     } | ||||
|  | @ -448,7 +504,11 @@ impl Users { | |||
| 
 | ||||
|         self.userdeviceid_metadata | ||||
|             .get(&userdeviceid)? | ||||
|             .map_or(Ok(None), |bytes| Ok(Some(serde_json::from_slice(&bytes)?))) | ||||
|             .map_or(Ok(None), |bytes| { | ||||
|                 Ok(Some(serde_json::from_slice(&bytes).map_err(|_| { | ||||
|                     Error::bad_database("Metadata in userdeviceid_metadata is invalid.") | ||||
|                 })?)) | ||||
|             }) | ||||
|     } | ||||
| 
 | ||||
|     pub fn all_devices_metadata(&self, user_id: &UserId) -> impl Iterator<Item = Result<Device>> { | ||||
|  | @ -458,6 +518,10 @@ impl Users { | |||
|         self.userdeviceid_metadata | ||||
|             .scan_prefix(key) | ||||
|             .values() | ||||
|             .map(|bytes| Ok(serde_json::from_slice::<Device>(&bytes?)?)) | ||||
|             .map(|bytes| { | ||||
|                 Ok(serde_json::from_slice::<Device>(&bytes?).map_err(|_| { | ||||
|                     Error::bad_database("Device in userdeviceid_metadata is invalid.") | ||||
|                 })?) | ||||
|             }) | ||||
|     } | ||||
| } | ||||
|  |  | |||
							
								
								
									
										97
									
								
								src/error.rs
									
									
									
									
									
								
							
							
						
						
									
										97
									
								
								src/error.rs
									
									
									
									
									
								
							|  | @ -1,41 +1,88 @@ | |||
| use crate::RumaResponse; | ||||
| use http::StatusCode; | ||||
| use log::error; | ||||
| use rocket::{ | ||||
|     response::{self, Responder}, | ||||
|     Request, | ||||
| }; | ||||
| use ruma::api::client::{ | ||||
|     error::{Error as RumaError, ErrorKind}, | ||||
|     r0::uiaa::{UiaaInfo, UiaaResponse}, | ||||
| }; | ||||
| use thiserror::Error; | ||||
| 
 | ||||
| pub type Result<T> = std::result::Result<T, Error>; | ||||
| 
 | ||||
| #[derive(Error, Debug)] | ||||
| pub enum Error { | ||||
|     #[error("problem with the database")] | ||||
|     #[error("There was a problem with the connection to the database.")] | ||||
|     SledError { | ||||
|         #[from] | ||||
|         source: sled::Error, | ||||
|     }, | ||||
|     #[error("tried to parse invalid string")] | ||||
|     StringFromBytesError { | ||||
|         #[from] | ||||
|         source: std::string::FromUtf8Error, | ||||
|     }, | ||||
|     #[error("tried to parse invalid identifier")] | ||||
|     SerdeJsonError { | ||||
|         #[from] | ||||
|         source: serde_json::Error, | ||||
|     }, | ||||
|     #[error("tried to parse invalid identifier")] | ||||
|     RumaIdentifierError { | ||||
|         #[from] | ||||
|         source: ruma::identifiers::Error, | ||||
|     }, | ||||
|     #[error("tried to parse invalid event")] | ||||
|     RumaEventError { | ||||
|         #[from] | ||||
|         source: ruma::events::InvalidEvent, | ||||
|     }, | ||||
|     #[error("could not generate image")] | ||||
|     #[error("Could not generate an image.")] | ||||
|     ImageError { | ||||
|         #[from] | ||||
|         source: image::error::ImageError, | ||||
|     }, | ||||
|     #[error("bad request")] | ||||
|     BadRequest(&'static str), | ||||
|     #[error("problem in that database")] | ||||
|     #[error("{0}")] | ||||
|     BadConfig(&'static str), | ||||
|     #[error("{0}")] | ||||
|     /// Don't create this directly. Use Error::bad_database instead.
 | ||||
|     BadDatabase(&'static str), | ||||
|     #[error("uiaa")] | ||||
|     Uiaa(UiaaInfo), | ||||
| 
 | ||||
|     #[error("{0}: {1}")] | ||||
|     BadRequest(ErrorKind, &'static str), | ||||
|     #[error("{0}")] | ||||
|     Conflict(&'static str), // This is only needed for when a room alias already exists
 | ||||
| } | ||||
| 
 | ||||
| impl Error { | ||||
|     pub fn bad_database(message: &'static str) -> Self { | ||||
|         error!("BadDatabase: {}", message); | ||||
|         Self::BadDatabase(message) | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| #[rocket::async_trait] | ||||
| impl<'r> Responder<'r> for Error { | ||||
|     async fn respond_to(self, r: &'r Request<'_>) -> response::Result<'r> { | ||||
|         if let Self::Uiaa(uiaainfo) = &self { | ||||
|             return RumaResponse::from(UiaaResponse::AuthResponse(uiaainfo.clone())) | ||||
|                 .respond_to(r) | ||||
|                 .await; | ||||
|         } | ||||
| 
 | ||||
|         let message = format!("{}", self); | ||||
| 
 | ||||
|         use ErrorKind::*; | ||||
|         let (kind, status_code) = match self { | ||||
|             Self::BadRequest(kind, _) => ( | ||||
|                 kind, | ||||
|                 match kind { | ||||
|                     Forbidden | GuestAccessForbidden | ThreepidAuthFailed | ThreepidDenied => { | ||||
|                         StatusCode::FORBIDDEN | ||||
|                     } | ||||
|                     Unauthorized | UnknownToken | MissingToken => StatusCode::UNAUTHORIZED, | ||||
|                     NotFound => StatusCode::NOT_FOUND, | ||||
|                     LimitExceeded => StatusCode::TOO_MANY_REQUESTS, | ||||
|                     UserDeactivated => StatusCode::FORBIDDEN, | ||||
|                     TooLarge => StatusCode::PAYLOAD_TOO_LARGE, | ||||
|                     _ => StatusCode::BAD_REQUEST, | ||||
|                 }, | ||||
|             ), | ||||
|             Self::Conflict(_) => (Unknown, StatusCode::CONFLICT), | ||||
|             _ => (Unknown, StatusCode::INTERNAL_SERVER_ERROR), | ||||
|         }; | ||||
| 
 | ||||
|         RumaResponse::from(RumaError { | ||||
|             kind, | ||||
|             message, | ||||
|             status_code, | ||||
|         }) | ||||
|         .respond_to(r) | ||||
|         .await | ||||
|     } | ||||
| } | ||||
|  |  | |||
|  | @ -12,7 +12,7 @@ mod utils; | |||
| pub use database::Database; | ||||
| pub use error::{Error, Result}; | ||||
| pub use pdu::PduEvent; | ||||
| pub use ruma_wrapper::{MatrixResult, Ruma}; | ||||
| pub use ruma_wrapper::{ConduitResult, Ruma, RumaResponse}; | ||||
| 
 | ||||
| use rocket::{fairing::AdHoc, routes}; | ||||
| 
 | ||||
|  | @ -95,7 +95,7 @@ fn setup_rocket() -> rocket::Rocket { | |||
|             ], | ||||
|         ) | ||||
|         .attach(AdHoc::on_attach("Config", |rocket| { | ||||
|             let data = Database::load_or_create(&rocket.config()); | ||||
|             let data = Database::load_or_create(&rocket.config()).expect("valid config"); | ||||
| 
 | ||||
|             Ok(rocket.manage(data)) | ||||
|         })) | ||||
|  |  | |||
							
								
								
									
										29
									
								
								src/pdu.rs
									
									
									
									
									
								
							
							
						
						
									
										29
									
								
								src/pdu.rs
									
									
									
									
									
								
							|  | @ -1,3 +1,4 @@ | |||
| use crate::{Error, Result}; | ||||
| use js_int::UInt; | ||||
| use ruma::{ | ||||
|     api::federation::pdu::EventHash, | ||||
|  | @ -36,7 +37,7 @@ pub struct PduEvent { | |||
| } | ||||
| 
 | ||||
| impl PduEvent { | ||||
|     pub fn redact(&mut self) { | ||||
|     pub fn redact(&mut self) -> Result<()> { | ||||
|         self.unsigned.clear(); | ||||
|         let allowed = match self.kind { | ||||
|             EventType::RoomMember => vec!["membership"], | ||||
|  | @ -56,7 +57,11 @@ impl PduEvent { | |||
|             _ => vec![], | ||||
|         }; | ||||
| 
 | ||||
|         let old_content = self.content.as_object_mut().unwrap(); // TODO error
 | ||||
|         let old_content = self | ||||
|             .content | ||||
|             .as_object_mut() | ||||
|             .ok_or_else(|| Error::bad_database("PDU in db has invalid content."))?; | ||||
| 
 | ||||
|         let mut new_content = serde_json::Map::new(); | ||||
| 
 | ||||
|         for key in allowed { | ||||
|  | @ -71,21 +76,23 @@ impl PduEvent { | |||
|         ); | ||||
| 
 | ||||
|         self.content = new_content.into(); | ||||
| 
 | ||||
|         Ok(()) | ||||
|     } | ||||
| 
 | ||||
|     pub fn to_room_event(&self) -> EventJson<RoomEvent> { | ||||
|         // Can only fail in rare circumstances that won't ever happen here, see
 | ||||
|         // https://docs.rs/serde_json/1.0.50/serde_json/fn.to_string.html
 | ||||
|         let json = serde_json::to_string(&self).unwrap(); | ||||
|         // EventJson's deserialize implementation always returns `Ok(...)`
 | ||||
|         serde_json::from_str::<EventJson<RoomEvent>>(&json).unwrap() | ||||
|         let json = serde_json::to_string(&self).expect("PDUs are always valid"); | ||||
|         serde_json::from_str::<EventJson<RoomEvent>>(&json) | ||||
|             .expect("EventJson::from_str always works") | ||||
|     } | ||||
|     pub fn to_state_event(&self) -> EventJson<StateEvent> { | ||||
|         let json = serde_json::to_string(&self).unwrap(); | ||||
|         serde_json::from_str::<EventJson<StateEvent>>(&json).unwrap() | ||||
|         let json = serde_json::to_string(&self).expect("PDUs are always valid"); | ||||
|         serde_json::from_str::<EventJson<StateEvent>>(&json) | ||||
|             .expect("EventJson::from_str always works") | ||||
|     } | ||||
|     pub fn to_stripped_state_event(&self) -> EventJson<AnyStrippedStateEvent> { | ||||
|         let json = serde_json::to_string(&self).unwrap(); | ||||
|         serde_json::from_str::<EventJson<AnyStrippedStateEvent>>(&json).unwrap() | ||||
|         let json = serde_json::to_string(&self).expect("PDUs are always valid"); | ||||
|         serde_json::from_str::<EventJson<AnyStrippedStateEvent>>(&json) | ||||
|             .expect("EventJson::from_str always works") | ||||
|     } | ||||
| } | ||||
|  |  | |||
|  | @ -1,4 +1,4 @@ | |||
| use crate::utils; | ||||
| use crate::{utils, Error}; | ||||
| use log::warn; | ||||
| use rocket::{ | ||||
|     data::{Data, FromData, FromDataFuture, Transform, TransformFuture, Transformed}, | ||||
|  | @ -42,7 +42,10 @@ impl<'a, T: Endpoint> FromData<'a> for Ruma<T> { | |||
|             let data = rocket::try_outcome!(outcome.owned()); | ||||
| 
 | ||||
|             let (user_id, device_id) = if T::METADATA.requires_authentication { | ||||
|                 let db = request.guard::<State<'_, crate::Database>>().await.unwrap(); | ||||
|                 let db = request | ||||
|                     .guard::<State<'_, crate::Database>>() | ||||
|                     .await | ||||
|                     .expect("database was loaded"); | ||||
| 
 | ||||
|                 // Get token from header or query value
 | ||||
|                 let token = match request | ||||
|  | @ -108,32 +111,24 @@ impl<T> Deref for Ruma<T> { | |||
| } | ||||
| 
 | ||||
| /// This struct converts ruma responses into rocket http responses.
 | ||||
| pub struct MatrixResult<T, E = ruma::api::client::Error>(pub std::result::Result<T, E>); | ||||
| pub type ConduitResult<T> = std::result::Result<RumaResponse<T>, Error>; | ||||
| 
 | ||||
| impl<T, E> TryInto<http::Response<Vec<u8>>> for MatrixResult<T, E> | ||||
| where | ||||
|     T: TryInto<http::Response<Vec<u8>>>, | ||||
|     E: Into<http::Response<Vec<u8>>>, | ||||
| { | ||||
|     type Error = T::Error; | ||||
| pub struct RumaResponse<T: TryInto<http::Response<Vec<u8>>>>(pub T); | ||||
| 
 | ||||
|     fn try_into(self) -> Result<http::Response<Vec<u8>>, T::Error> { | ||||
|         match self.0 { | ||||
|             Ok(t) => t.try_into(), | ||||
|             Err(e) => Ok(e.into()), | ||||
|         } | ||||
| impl<T: TryInto<http::Response<Vec<u8>>>> From<T> for RumaResponse<T> { | ||||
|     fn from(t: T) -> Self { | ||||
|         Self(t) | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| #[rocket::async_trait] | ||||
| impl<'r, T, E> Responder<'r> for MatrixResult<T, E> | ||||
| impl<'r, T> Responder<'r> for RumaResponse<T> | ||||
| where | ||||
|     T: Send + TryInto<http::Response<Vec<u8>>>, | ||||
|     T::Error: Send, | ||||
|     E: Into<http::Response<Vec<u8>>> + Send, | ||||
| { | ||||
|     async fn respond_to(self, _: &'r Request<'_>) -> response::Result<'r> { | ||||
|         let http_response: Result<http::Response<_>, _> = self.try_into(); | ||||
|         let http_response: Result<http::Response<_>, _> = self.0.try_into(); | ||||
|         match http_response { | ||||
|             Ok(http_response) => { | ||||
|                 let mut response = rocket::response::Response::build(); | ||||
|  | @ -165,11 +160,3 @@ where | |||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl<T, E> Deref for MatrixResult<T, E> { | ||||
|     type Target = Result<T, E>; | ||||
| 
 | ||||
|     fn deref(&self) -> &Self::Target { | ||||
|         &self.0 | ||||
|     } | ||||
| } | ||||
|  |  | |||
							
								
								
									
										32
									
								
								src/utils.rs
									
									
									
									
									
								
							
							
						
						
									
										32
									
								
								src/utils.rs
									
									
									
									
									
								
							|  | @ -1,4 +1,3 @@ | |||
| use crate::Result; | ||||
| use argon2::{Config, Variant}; | ||||
| use rand::prelude::*; | ||||
| use std::{ | ||||
|  | @ -9,39 +8,38 @@ use std::{ | |||
| pub fn millis_since_unix_epoch() -> u64 { | ||||
|     SystemTime::now() | ||||
|         .duration_since(UNIX_EPOCH) | ||||
|         .unwrap() | ||||
|         .expect("time is valid") | ||||
|         .as_millis() as u64 | ||||
| } | ||||
| 
 | ||||
| pub fn increment(old: Option<&[u8]>) -> Option<Vec<u8>> { | ||||
|     let number = match old { | ||||
|         Some(bytes) => { | ||||
|             let array: [u8; 8] = bytes.try_into().unwrap(); | ||||
|             let number = u64::from_be_bytes(array); | ||||
|     let number = match old.map(|bytes| bytes.try_into()) { | ||||
|         Some(Ok(bytes)) => { | ||||
|             let number = u64::from_be_bytes(bytes); | ||||
|             number + 1 | ||||
|         } | ||||
|         None => 1, // Start at one. since 0 should return the first event in the db
 | ||||
|         _ => 1, // Start at one. since 0 should return the first event in the db
 | ||||
|     }; | ||||
| 
 | ||||
|     Some(number.to_be_bytes().to_vec()) | ||||
| } | ||||
| 
 | ||||
| pub fn generate_keypair(old: Option<&[u8]>) -> Option<Vec<u8>> { | ||||
|     Some( | ||||
|         old.map(|s| s.to_vec()) | ||||
|             .unwrap_or_else(|| ruma::signatures::Ed25519KeyPair::generate().unwrap()), | ||||
|     ) | ||||
|     Some(old.map(|s| s.to_vec()).unwrap_or_else(|| { | ||||
|         ruma::signatures::Ed25519KeyPair::generate() | ||||
|             .expect("Ed25519KeyPair generation always works (?)") | ||||
|     })) | ||||
| } | ||||
| 
 | ||||
| /// Parses the bytes into an u64.
 | ||||
| pub fn u64_from_bytes(bytes: &[u8]) -> u64 { | ||||
|     let array: [u8; 8] = bytes.try_into().expect("bytes are valid u64"); | ||||
|     u64::from_be_bytes(array) | ||||
| pub fn u64_from_bytes(bytes: &[u8]) -> Result<u64, std::array::TryFromSliceError> { | ||||
|     let array: [u8; 8] = bytes.try_into()?; | ||||
|     Ok(u64::from_be_bytes(array)) | ||||
| } | ||||
| 
 | ||||
| /// Parses the bytes into a string.
 | ||||
| pub fn string_from_bytes(bytes: &[u8]) -> Result<String> { | ||||
|     Ok(String::from_utf8(bytes.to_vec())?) | ||||
| pub fn string_from_bytes(bytes: &[u8]) -> Result<String, std::string::FromUtf8Error> { | ||||
|     String::from_utf8(bytes.to_vec()) | ||||
| } | ||||
| 
 | ||||
| pub fn random_string(length: usize) -> String { | ||||
|  | @ -52,7 +50,7 @@ pub fn random_string(length: usize) -> String { | |||
| } | ||||
| 
 | ||||
| /// Calculate a new hash for the given password
 | ||||
| pub fn calculate_hash(password: &str) -> std::result::Result<String, argon2::Error> { | ||||
| pub fn calculate_hash(password: &str) -> Result<String, argon2::Error> { | ||||
|     let hashing_config = Config { | ||||
|         variant: Variant::Argon2id, | ||||
|         ..Default::default() | ||||
|  |  | |||
|  | @ -35,8 +35,6 @@ POST /rooms/:room_id/invite can send an invite | |||
| PUT /rooms/:room_id/state/m.room.power_levels can set levels | ||||
| PUT power_levels should not explode if the old power levels were empty | ||||
| Both GET and PUT work | ||||
| Room creation reports m.room.create to myself | ||||
| Room creation reports m.room.member to myself | ||||
| Version responds 200 OK with valid structure | ||||
| PUT /profile/:user_id/displayname sets my name | ||||
| GET /profile/:user_id/displayname publicly accessible | ||||
|  | @ -78,3 +76,4 @@ User directory correctly update on display name change | |||
| User in shared private room does appear in user directory | ||||
| User in dir while user still shares private rooms | ||||
| POST /rooms/:room_id/ban can ban a user | ||||
| Alternative server names do not cause a routing loop | ||||
|  |  | |||
		Loading…
	
		Reference in a new issue