improvement: uiaa works like in synapse
This commit is contained in:
		
							parent
							
								
									e1c4e5c73e
								
							
						
					
					
						commit
						cf94b8e712
					
				
					 10 changed files with 326 additions and 177 deletions
				
			
		|  | @ -1,4 +1,7 @@ | ||||||
| use std::{collections::BTreeMap, convert::TryInto}; | use std::{ | ||||||
|  |     collections::BTreeMap, | ||||||
|  |     convert::{TryFrom, TryInto}, | ||||||
|  | }; | ||||||
| 
 | 
 | ||||||
| use super::{State, DEVICE_ID_LENGTH, SESSION_ID_LENGTH, TOKEN_LENGTH}; | use super::{State, DEVICE_ID_LENGTH, SESSION_ID_LENGTH, TOKEN_LENGTH}; | ||||||
| use crate::{pdu::PduBuilder, utils, ConduitResult, Database, Error, Ruma}; | use crate::{pdu::PduBuilder, utils, ConduitResult, Database, Error, Ruma}; | ||||||
|  | @ -143,16 +146,28 @@ pub async fn register_route( | ||||||
| 
 | 
 | ||||||
|     if !body.from_appservice { |     if !body.from_appservice { | ||||||
|         if let Some(auth) = &body.auth { |         if let Some(auth) = &body.auth { | ||||||
|             let (worked, uiaainfo) = |             let (worked, uiaainfo) = db.uiaa.try_auth( | ||||||
|                 db.uiaa |                 &UserId::parse_with_server_name("", db.globals.server_name()) | ||||||
|                     .try_auth(&user_id, "".into(), auth, &uiaainfo, &db.users, &db.globals)?; |                     .expect("we know this is valid"), | ||||||
|  |                 "".into(), | ||||||
|  |                 auth, | ||||||
|  |                 &uiaainfo, | ||||||
|  |                 &db.users, | ||||||
|  |                 &db.globals, | ||||||
|  |             )?; | ||||||
|             if !worked { |             if !worked { | ||||||
|                 return Err(Error::Uiaa(uiaainfo)); |                 return Err(Error::Uiaa(uiaainfo)); | ||||||
|             } |             } | ||||||
|         // Success!
 |         // Success!
 | ||||||
|         } else { |         } else { | ||||||
|             uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); |             uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); | ||||||
|             db.uiaa.create(&user_id, "".into(), &uiaainfo)?; |             db.uiaa.create( | ||||||
|  |                 &UserId::parse_with_server_name("", db.globals.server_name()) | ||||||
|  |                     .expect("we know this is valid"), | ||||||
|  |                 "".into(), | ||||||
|  |                 &uiaainfo, | ||||||
|  |                 &body.json_body.expect("body is json"), | ||||||
|  |             )?; | ||||||
|             return Err(Error::Uiaa(uiaainfo)); |             return Err(Error::Uiaa(uiaainfo)); | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  | @ -526,7 +541,12 @@ pub async fn change_password_route( | ||||||
|     // Success!
 |     // Success!
 | ||||||
|     } else { |     } else { | ||||||
|         uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); |         uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); | ||||||
|         db.uiaa.create(&sender_user, &sender_device, &uiaainfo)?; |         db.uiaa.create( | ||||||
|  |             &sender_user, | ||||||
|  |             &sender_device, | ||||||
|  |             &uiaainfo, | ||||||
|  |             &body.json_body.expect("body is json"), | ||||||
|  |         )?; | ||||||
|         return Err(Error::Uiaa(uiaainfo)); |         return Err(Error::Uiaa(uiaainfo)); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  | @ -612,7 +632,12 @@ pub async fn deactivate_route( | ||||||
|     // Success!
 |     // Success!
 | ||||||
|     } else { |     } else { | ||||||
|         uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); |         uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); | ||||||
|         db.uiaa.create(&sender_user, &sender_device, &uiaainfo)?; |         db.uiaa.create( | ||||||
|  |             &sender_user, | ||||||
|  |             &sender_device, | ||||||
|  |             &uiaainfo, | ||||||
|  |             &body.json_body.expect("body is json"), | ||||||
|  |         )?; | ||||||
|         return Err(Error::Uiaa(uiaainfo)); |         return Err(Error::Uiaa(uiaainfo)); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -115,7 +115,12 @@ pub async fn delete_device_route( | ||||||
|     // Success!
 |     // Success!
 | ||||||
|     } else { |     } else { | ||||||
|         uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); |         uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); | ||||||
|         db.uiaa.create(&sender_user, &sender_device, &uiaainfo)?; |         db.uiaa.create( | ||||||
|  |             &sender_user, | ||||||
|  |             &sender_device, | ||||||
|  |             &uiaainfo, | ||||||
|  |             &body.json_body.expect("body is json"), | ||||||
|  |         )?; | ||||||
|         return Err(Error::Uiaa(uiaainfo)); |         return Err(Error::Uiaa(uiaainfo)); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  | @ -164,7 +169,12 @@ pub async fn delete_devices_route( | ||||||
|     // Success!
 |     // Success!
 | ||||||
|     } else { |     } else { | ||||||
|         uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); |         uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); | ||||||
|         db.uiaa.create(&sender_user, &sender_device, &uiaainfo)?; |         db.uiaa.create( | ||||||
|  |             &sender_user, | ||||||
|  |             &sender_device, | ||||||
|  |             &uiaainfo, | ||||||
|  |             &body.json_body.expect("body is json"), | ||||||
|  |         )?; | ||||||
|         return Err(Error::Uiaa(uiaainfo)); |         return Err(Error::Uiaa(uiaainfo)); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -203,19 +203,20 @@ pub async fn get_public_rooms_filtered_helper( | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     let mut all_rooms = |     let mut all_rooms = db | ||||||
|         db.rooms |         .rooms | ||||||
|             .public_rooms() |         .public_rooms() | ||||||
|             .map(|room_id| { |         .map(|room_id| { | ||||||
|                 let room_id = room_id?; |             let room_id = room_id?; | ||||||
| 
 | 
 | ||||||
|                 let chunk = PublicRoomsChunk { |             let chunk = PublicRoomsChunk { | ||||||
|                     aliases: Vec::new(), |                 aliases: Vec::new(), | ||||||
|                     canonical_alias: db |                 canonical_alias: db | ||||||
|                         .rooms |                     .rooms | ||||||
|                         .room_state_get(&room_id, &EventType::RoomCanonicalAlias, "")? |                     .room_state_get(&room_id, &EventType::RoomCanonicalAlias, "")? | ||||||
|                         .map_or(Ok::<_, Error>(None), |s| { |                     .map_or(Ok::<_, Error>(None), |s| { | ||||||
|                             Ok(serde_json::from_value::< |                         Ok( | ||||||
|  |                             serde_json::from_value::< | ||||||
|                                 Raw<canonical_alias::CanonicalAliasEventContent>, |                                 Raw<canonical_alias::CanonicalAliasEventContent>, | ||||||
|                             >(s.content) |                             >(s.content) | ||||||
|                             .expect("from_value::<Raw<..>> can never fail") |                             .expect("from_value::<Raw<..>> can never fail") | ||||||
|  | @ -223,62 +224,61 @@ pub async fn get_public_rooms_filtered_helper( | ||||||
|                             .map_err(|_| { |                             .map_err(|_| { | ||||||
|                                 Error::bad_database("Invalid canonical alias event in database.") |                                 Error::bad_database("Invalid canonical alias event in database.") | ||||||
|                             })? |                             })? | ||||||
|                             .alias) |                             .alias, | ||||||
|                         })?, |                         ) | ||||||
|                     name: db |                     })?, | ||||||
|                         .rooms |                 name: db | ||||||
|                         .room_state_get(&room_id, &EventType::RoomName, "")? |                     .rooms | ||||||
|                         .map_or(Ok::<_, Error>(None), |s| { |                     .room_state_get(&room_id, &EventType::RoomName, "")? | ||||||
|                             Ok(serde_json::from_value::<Raw<name::NameEventContent>>( |                     .map_or(Ok::<_, Error>(None), |s| { | ||||||
|                                 s.content, |                         Ok( | ||||||
|                             ) |                             serde_json::from_value::<Raw<name::NameEventContent>>(s.content) | ||||||
|                             .expect("from_value::<Raw<..>> can never fail") |                                 .expect("from_value::<Raw<..>> can never fail") | ||||||
|                             .deserialize() |                                 .deserialize() | ||||||
|                             .map_err(|_| { |                                 .map_err(|_| { | ||||||
|                                 Error::bad_database("Invalid room name event in database.") |                                     Error::bad_database("Invalid room name event in database.") | ||||||
|                             })? |                                 })? | ||||||
|                             .name() |                                 .name() | ||||||
|                             .map(|n| n.to_owned())) |                                 .map(|n| n.to_owned()), | ||||||
|                         })?, |                         ) | ||||||
|                     num_joined_members: (db.rooms.room_members(&room_id).count() as u32).into(), |                     })?, | ||||||
|                     topic: db |                 num_joined_members: (db.rooms.room_members(&room_id).count() as u32).into(), | ||||||
|                         .rooms |                 topic: db | ||||||
|                         .room_state_get(&room_id, &EventType::RoomTopic, "")? |                     .rooms | ||||||
|                         .map_or(Ok::<_, Error>(None), |s| { |                     .room_state_get(&room_id, &EventType::RoomTopic, "")? | ||||||
|                             Ok(Some( |                     .map_or(Ok::<_, Error>(None), |s| { | ||||||
|                                 serde_json::from_value::<Raw<topic::TopicEventContent>>( |                         Ok(Some( | ||||||
|                                     s.content, |                             serde_json::from_value::<Raw<topic::TopicEventContent>>(s.content) | ||||||
|                                 ) |  | ||||||
|                                 .expect("from_value::<Raw<..>> can never fail") |                                 .expect("from_value::<Raw<..>> can never fail") | ||||||
|                                 .deserialize() |                                 .deserialize() | ||||||
|                                 .map_err(|_| { |                                 .map_err(|_| { | ||||||
|                                     Error::bad_database("Invalid room topic event in database.") |                                     Error::bad_database("Invalid room topic event in database.") | ||||||
|                                 })? |                                 })? | ||||||
|                                 .topic, |                                 .topic, | ||||||
|                             )) |                         )) | ||||||
|                         })?, |                     })?, | ||||||
|                     world_readable: db |                 world_readable: db | ||||||
|                         .rooms |                     .rooms | ||||||
|                         .room_state_get(&room_id, &EventType::RoomHistoryVisibility, "")? |                     .room_state_get(&room_id, &EventType::RoomHistoryVisibility, "")? | ||||||
|                         .map_or(Ok::<_, Error>(false), |s| { |                     .map_or(Ok::<_, Error>(false), |s| { | ||||||
|                             Ok(serde_json::from_value::< |                         Ok(serde_json::from_value::< | ||||||
|                                 Raw<history_visibility::HistoryVisibilityEventContent>, |                             Raw<history_visibility::HistoryVisibilityEventContent>, | ||||||
|                             >(s.content) |                         >(s.content) | ||||||
|                             .expect("from_value::<Raw<..>> can never fail") |                         .expect("from_value::<Raw<..>> can never fail") | ||||||
|                             .deserialize() |                         .deserialize() | ||||||
|                             .map_err(|_| { |                         .map_err(|_| { | ||||||
|                                 Error::bad_database( |                             Error::bad_database( | ||||||
|                                     "Invalid room history visibility event in database.", |                                 "Invalid room history visibility event in database.", | ||||||
|                                 ) |                             ) | ||||||
|                             })? |                         })? | ||||||
|                             .history_visibility |                         .history_visibility | ||||||
|                                 == history_visibility::HistoryVisibility::WorldReadable) |                             == history_visibility::HistoryVisibility::WorldReadable) | ||||||
|                         })?, |                     })?, | ||||||
|                     guest_can_join: db |                 guest_can_join: db | ||||||
|                         .rooms |                     .rooms | ||||||
|                         .room_state_get(&room_id, &EventType::RoomGuestAccess, "")? |                     .room_state_get(&room_id, &EventType::RoomGuestAccess, "")? | ||||||
|                         .map_or(Ok::<_, Error>(false), |s| { |                     .map_or(Ok::<_, Error>(false), |s| { | ||||||
|                             Ok( |                         Ok( | ||||||
|                             serde_json::from_value::<Raw<guest_access::GuestAccessEventContent>>( |                             serde_json::from_value::<Raw<guest_access::GuestAccessEventContent>>( | ||||||
|                                 s.content, |                                 s.content, | ||||||
|                             ) |                             ) | ||||||
|  | @ -290,33 +290,31 @@ pub async fn get_public_rooms_filtered_helper( | ||||||
|                             .guest_access |                             .guest_access | ||||||
|                                 == guest_access::GuestAccess::CanJoin, |                                 == guest_access::GuestAccess::CanJoin, | ||||||
|                         ) |                         ) | ||||||
|                         })?, |                     })?, | ||||||
|                     avatar_url: db |                 avatar_url: db | ||||||
|                         .rooms |                     .rooms | ||||||
|                         .room_state_get(&room_id, &EventType::RoomAvatar, "")? |                     .room_state_get(&room_id, &EventType::RoomAvatar, "")? | ||||||
|                         .map(|s| { |                     .map(|s| { | ||||||
|                             Ok::<_, Error>( |                         Ok::<_, Error>( | ||||||
|                                 serde_json::from_value::<Raw<avatar::AvatarEventContent>>( |                             serde_json::from_value::<Raw<avatar::AvatarEventContent>>(s.content) | ||||||
|                                     s.content, |  | ||||||
|                                 ) |  | ||||||
|                                 .expect("from_value::<Raw<..>> can never fail") |                                 .expect("from_value::<Raw<..>> can never fail") | ||||||
|                                 .deserialize() |                                 .deserialize() | ||||||
|                                 .map_err(|_| { |                                 .map_err(|_| { | ||||||
|                                     Error::bad_database("Invalid room avatar event in database.") |                                     Error::bad_database("Invalid room avatar event in database.") | ||||||
|                                 })? |                                 })? | ||||||
|                                 .url, |                                 .url, | ||||||
|                             ) |                         ) | ||||||
|                         }) |                     }) | ||||||
|                         .transpose()? |                     .transpose()? | ||||||
|                         // url is now an Option<String> so we must flatten
 |                     // url is now an Option<String> so we must flatten
 | ||||||
|                         .flatten(), |                     .flatten(), | ||||||
|                     room_id, |                 room_id, | ||||||
|                 }; |             }; | ||||||
|                 Ok(chunk) |             Ok(chunk) | ||||||
|             }) |         }) | ||||||
|             .filter_map(|r: Result<_>| r.ok()) // Filter out buggy rooms
 |         .filter_map(|r: Result<_>| r.ok()) // Filter out buggy rooms
 | ||||||
|             // We need to collect all, so we can sort by member count
 |         // We need to collect all, so we can sort by member count
 | ||||||
|             .collect::<Vec<_>>(); |         .collect::<Vec<_>>(); | ||||||
| 
 | 
 | ||||||
|     all_rooms.sort_by(|l, r| r.num_joined_members.cmp(&l.num_joined_members)); |     all_rooms.sort_by(|l, r| r.num_joined_members.cmp(&l.num_joined_members)); | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -220,7 +220,12 @@ pub async fn upload_signing_keys_route( | ||||||
|     // Success!
 |     // Success!
 | ||||||
|     } else { |     } else { | ||||||
|         uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); |         uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); | ||||||
|         db.uiaa.create(&sender_user, &sender_device, &uiaainfo)?; |         db.uiaa.create( | ||||||
|  |             &sender_user, | ||||||
|  |             &sender_device, | ||||||
|  |             &uiaainfo, | ||||||
|  |             &body.json_body.expect("body is json"), | ||||||
|  |         )?; | ||||||
|         return Err(Error::Uiaa(uiaainfo)); |         return Err(Error::Uiaa(uiaainfo)); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -56,13 +56,8 @@ pub async fn send_message_event_route( | ||||||
|     let event_id = db.rooms.build_and_append_pdu( |     let event_id = db.rooms.build_and_append_pdu( | ||||||
|         PduBuilder { |         PduBuilder { | ||||||
|             event_type: EventType::from(&body.event_type), |             event_type: EventType::from(&body.event_type), | ||||||
|             content: serde_json::from_str( |             content: serde_json::from_str(body.body.body.json().get()) | ||||||
|                 body.json_body |                 .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Invalid JSON body."))?, | ||||||
|                     .as_ref() |  | ||||||
|                     .ok_or(Error::BadRequest(ErrorKind::BadJson, "Invalid JSON body."))? |  | ||||||
|                     .get(), |  | ||||||
|             ) |  | ||||||
|             .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Invalid JSON body."))?, |  | ||||||
|             unsigned: Some(unsigned), |             unsigned: Some(unsigned), | ||||||
|             state_key: None, |             state_key: None, | ||||||
|             redacts: None, |             redacts: None, | ||||||
|  |  | ||||||
|  | @ -69,9 +69,9 @@ use { | ||||||
|     ruma::api::client::r0::to_device::send_event_to_device, |     ruma::api::client::r0::to_device::send_event_to_device, | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| const DEVICE_ID_LENGTH: usize = 10; | pub const DEVICE_ID_LENGTH: usize = 10; | ||||||
| const TOKEN_LENGTH: usize = 256; | pub const TOKEN_LENGTH: usize = 256; | ||||||
| const SESSION_ID_LENGTH: usize = 256; | pub const SESSION_ID_LENGTH: usize = 256; | ||||||
| 
 | 
 | ||||||
| #[cfg(feature = "conduit_bin")] | #[cfg(feature = "conduit_bin")] | ||||||
| #[options("/<_..>")] | #[options("/<_..>")] | ||||||
|  |  | ||||||
|  | @ -135,7 +135,8 @@ impl Database { | ||||||
|                 todeviceid_events: db.open_tree("todeviceid_events")?, |                 todeviceid_events: db.open_tree("todeviceid_events")?, | ||||||
|             }, |             }, | ||||||
|             uiaa: uiaa::Uiaa { |             uiaa: uiaa::Uiaa { | ||||||
|                 userdeviceid_uiaainfo: db.open_tree("userdeviceid_uiaainfo")?, |                 userdevicesessionid_uiaainfo: db.open_tree("userdevicesessionid_uiaainfo")?, | ||||||
|  |                 userdevicesessionid_uiaarequest: db.open_tree("userdevicesessionid_uiaarequest")?, | ||||||
|             }, |             }, | ||||||
|             rooms: rooms::Rooms { |             rooms: rooms::Rooms { | ||||||
|                 edus: rooms::RoomEdus { |                 edus: rooms::RoomEdus { | ||||||
|  |  | ||||||
|  | @ -1,15 +1,17 @@ | ||||||
| use crate::{Error, Result}; | use crate::{client_server::SESSION_ID_LENGTH, utils, Error, Result}; | ||||||
| use ruma::{ | use ruma::{ | ||||||
|     api::client::{ |     api::client::{ | ||||||
|         error::ErrorKind, |         error::ErrorKind, | ||||||
|         r0::uiaa::{IncomingAuthData, UiaaInfo}, |         r0::uiaa::{IncomingAuthData, UiaaInfo}, | ||||||
|     }, |     }, | ||||||
|  |     signatures::CanonicalJsonValue, | ||||||
|     DeviceId, UserId, |     DeviceId, UserId, | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| #[derive(Clone)] | #[derive(Clone)] | ||||||
| pub struct Uiaa { | pub struct Uiaa { | ||||||
|     pub(super) userdeviceid_uiaainfo: sled::Tree, // User-interactive authentication
 |     pub(super) userdevicesessionid_uiaainfo: sled::Tree, // User-interactive authentication
 | ||||||
|  |     pub(super) userdevicesessionid_uiaarequest: sled::Tree, // UiaaRequest = canonical json value
 | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| impl Uiaa { | impl Uiaa { | ||||||
|  | @ -19,8 +21,20 @@ impl Uiaa { | ||||||
|         user_id: &UserId, |         user_id: &UserId, | ||||||
|         device_id: &DeviceId, |         device_id: &DeviceId, | ||||||
|         uiaainfo: &UiaaInfo, |         uiaainfo: &UiaaInfo, | ||||||
|  |         json_body: &CanonicalJsonValue, | ||||||
|     ) -> Result<()> { |     ) -> Result<()> { | ||||||
|         self.update_uiaa_session(user_id, device_id, Some(uiaainfo)) |         self.set_uiaa_request( | ||||||
|  |             user_id, | ||||||
|  |             device_id, | ||||||
|  |             uiaainfo.session.as_ref().expect("session should be set"), // TODO: better session error handling (why is it optional in ruma?)
 | ||||||
|  |             json_body, | ||||||
|  |         )?; | ||||||
|  |         self.update_uiaa_session( | ||||||
|  |             user_id, | ||||||
|  |             device_id, | ||||||
|  |             uiaainfo.session.as_ref().expect("session should be set"), | ||||||
|  |             Some(uiaainfo), | ||||||
|  |         ) | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     pub fn try_auth( |     pub fn try_auth( | ||||||
|  | @ -45,6 +59,10 @@ impl Uiaa { | ||||||
|                 }) |                 }) | ||||||
|                 .unwrap_or_else(|| Ok(uiaainfo.clone()))?; |                 .unwrap_or_else(|| Ok(uiaainfo.clone()))?; | ||||||
| 
 | 
 | ||||||
|  |             if uiaainfo.session.is_none() { | ||||||
|  |                 uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|             // Find out what the user completed
 |             // Find out what the user completed
 | ||||||
|             match &**kind { |             match &**kind { | ||||||
|                 "m.login.password" => { |                 "m.login.password" => { | ||||||
|  | @ -130,35 +148,96 @@ impl Uiaa { | ||||||
|             } |             } | ||||||
| 
 | 
 | ||||||
|             if !completed { |             if !completed { | ||||||
|                 self.update_uiaa_session(user_id, device_id, Some(&uiaainfo))?; |                 self.update_uiaa_session( | ||||||
|  |                     user_id, | ||||||
|  |                     device_id, | ||||||
|  |                     uiaainfo.session.as_ref().expect("session is always set"), | ||||||
|  |                     Some(&uiaainfo), | ||||||
|  |                 )?; | ||||||
|                 return Ok((false, uiaainfo)); |                 return Ok((false, uiaainfo)); | ||||||
|             } |             } | ||||||
| 
 | 
 | ||||||
|             // UIAA was successful! Remove this session and return true
 |             // UIAA was successful! Remove this session and return true
 | ||||||
|             self.update_uiaa_session(user_id, device_id, None)?; |             self.update_uiaa_session( | ||||||
|  |                 user_id, | ||||||
|  |                 device_id, | ||||||
|  |                 uiaainfo.session.as_ref().expect("session is always set"), | ||||||
|  |                 None, | ||||||
|  |             )?; | ||||||
|             Ok((true, uiaainfo)) |             Ok((true, uiaainfo)) | ||||||
|         } else { |         } else { | ||||||
|             panic!("FallbackAcknowledgement is not supported yet"); |             panic!("FallbackAcknowledgement is not supported yet"); | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |     fn set_uiaa_request( | ||||||
|  |         &self, | ||||||
|  |         user_id: &UserId, | ||||||
|  |         device_id: &DeviceId, | ||||||
|  |         session: &str, | ||||||
|  |         request: &CanonicalJsonValue, | ||||||
|  |     ) -> Result<()> { | ||||||
|  |         let mut userdevicesessionid = user_id.as_bytes().to_vec(); | ||||||
|  |         userdevicesessionid.push(0xff); | ||||||
|  |         userdevicesessionid.extend_from_slice(device_id.as_bytes()); | ||||||
|  |         userdevicesessionid.push(0xff); | ||||||
|  |         userdevicesessionid.extend_from_slice(session.as_bytes()); | ||||||
|  | 
 | ||||||
|  |         self.userdevicesessionid_uiaarequest.insert( | ||||||
|  |             &userdevicesessionid, | ||||||
|  |             &*serde_json::to_string(request).expect("json value to string always works"), | ||||||
|  |         )?; | ||||||
|  | 
 | ||||||
|  |         Ok(()) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub fn get_uiaa_request( | ||||||
|  |         &self, | ||||||
|  |         user_id: &UserId, | ||||||
|  |         device_id: &DeviceId, | ||||||
|  |         session: &str, | ||||||
|  |     ) -> Result<Option<CanonicalJsonValue>> { | ||||||
|  |         let mut userdevicesessionid = user_id.as_bytes().to_vec(); | ||||||
|  |         userdevicesessionid.push(0xff); | ||||||
|  |         userdevicesessionid.extend_from_slice(device_id.as_bytes()); | ||||||
|  |         userdevicesessionid.push(0xff); | ||||||
|  |         userdevicesessionid.extend_from_slice(session.as_bytes()); | ||||||
|  | 
 | ||||||
|  |         self.userdevicesessionid_uiaarequest | ||||||
|  |             .get(&userdevicesessionid)? | ||||||
|  |             .map_or(Ok(None), |bytes| { | ||||||
|  |                 Ok::<_, Error>(Some( | ||||||
|  |                     serde_json::from_str::<CanonicalJsonValue>( | ||||||
|  |                         &utils::string_from_bytes(&bytes).map_err(|_| { | ||||||
|  |                             Error::bad_database("Invalid uiaa request bytes in db.") | ||||||
|  |                         })?, | ||||||
|  |                     ) | ||||||
|  |                     .map_err(|_| Error::bad_database("Invalid uiaa request in db."))?, | ||||||
|  |                 )) | ||||||
|  |             }) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     fn update_uiaa_session( |     fn update_uiaa_session( | ||||||
|         &self, |         &self, | ||||||
|         user_id: &UserId, |         user_id: &UserId, | ||||||
|         device_id: &DeviceId, |         device_id: &DeviceId, | ||||||
|  |         session: &str, | ||||||
|         uiaainfo: Option<&UiaaInfo>, |         uiaainfo: Option<&UiaaInfo>, | ||||||
|     ) -> Result<()> { |     ) -> Result<()> { | ||||||
|         let mut userdeviceid = user_id.as_bytes().to_vec(); |         let mut userdevicesessionid = user_id.as_bytes().to_vec(); | ||||||
|         userdeviceid.push(0xff); |         userdevicesessionid.push(0xff); | ||||||
|         userdeviceid.extend_from_slice(device_id.as_bytes()); |         userdevicesessionid.extend_from_slice(device_id.as_bytes()); | ||||||
|  |         userdevicesessionid.push(0xff); | ||||||
|  |         userdevicesessionid.extend_from_slice(session.as_bytes()); | ||||||
| 
 | 
 | ||||||
|         if let Some(uiaainfo) = uiaainfo { |         if let Some(uiaainfo) = uiaainfo { | ||||||
|             self.userdeviceid_uiaainfo.insert( |             self.userdevicesessionid_uiaainfo.insert( | ||||||
|                 &userdeviceid, |                 &userdevicesessionid, | ||||||
|                 &*serde_json::to_string(&uiaainfo).expect("UiaaInfo::to_string always works"), |                 &*serde_json::to_string(&uiaainfo).expect("UiaaInfo::to_string always works"), | ||||||
|             )?; |             )?; | ||||||
|         } else { |         } else { | ||||||
|             self.userdeviceid_uiaainfo.remove(&userdeviceid)?; |             self.userdevicesessionid_uiaainfo | ||||||
|  |                 .remove(&userdevicesessionid)?; | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         Ok(()) |         Ok(()) | ||||||
|  | @ -170,14 +249,16 @@ impl Uiaa { | ||||||
|         device_id: &DeviceId, |         device_id: &DeviceId, | ||||||
|         session: &str, |         session: &str, | ||||||
|     ) -> Result<UiaaInfo> { |     ) -> Result<UiaaInfo> { | ||||||
|         let mut userdeviceid = user_id.as_bytes().to_vec(); |         let mut userdevicesessionid = user_id.as_bytes().to_vec(); | ||||||
|         userdeviceid.push(0xff); |         userdevicesessionid.push(0xff); | ||||||
|         userdeviceid.extend_from_slice(device_id.as_bytes()); |         userdevicesessionid.extend_from_slice(device_id.as_bytes()); | ||||||
|  |         userdevicesessionid.push(0xff); | ||||||
|  |         userdevicesessionid.extend_from_slice(session.as_bytes()); | ||||||
| 
 | 
 | ||||||
|         let uiaainfo = serde_json::from_slice::<UiaaInfo>( |         let uiaainfo = serde_json::from_slice::<UiaaInfo>( | ||||||
|             &self |             &self | ||||||
|                 .userdeviceid_uiaainfo |                 .userdevicesessionid_uiaainfo | ||||||
|                 .get(&userdeviceid)? |                 .get(&userdevicesessionid)? | ||||||
|                 .ok_or(Error::BadRequest( |                 .ok_or(Error::BadRequest( | ||||||
|                     ErrorKind::Forbidden, |                     ErrorKind::Forbidden, | ||||||
|                     "UIAA session does not exist.", |                     "UIAA session does not exist.", | ||||||
|  | @ -185,18 +266,6 @@ impl Uiaa { | ||||||
|         ) |         ) | ||||||
|         .map_err(|_| Error::bad_database("UiaaInfo in userdeviceid_uiaainfo is invalid."))?; |         .map_err(|_| Error::bad_database("UiaaInfo in userdeviceid_uiaainfo is invalid."))?; | ||||||
| 
 | 
 | ||||||
|         if uiaainfo |  | ||||||
|             .session |  | ||||||
|             .as_ref() |  | ||||||
|             .filter(|&s| s == session) |  | ||||||
|             .is_none() |  | ||||||
|         { |  | ||||||
|             return Err(Error::BadRequest( |  | ||||||
|                 ErrorKind::Forbidden, |  | ||||||
|                 "UIAA session token invalid.", |  | ||||||
|             )); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         Ok(uiaainfo) |         Ok(uiaainfo) | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -8,7 +8,7 @@ use std::ops::Deref; | ||||||
| 
 | 
 | ||||||
| #[cfg(feature = "conduit_bin")] | #[cfg(feature = "conduit_bin")] | ||||||
| use { | use { | ||||||
|     crate::{server_server, utils}, |     crate::server_server, | ||||||
|     log::{debug, warn}, |     log::{debug, warn}, | ||||||
|     rocket::{ |     rocket::{ | ||||||
|         data::{self, ByteUnit, Data, FromData}, |         data::{self, ByteUnit, Data, FromData}, | ||||||
|  | @ -35,7 +35,7 @@ pub struct Ruma<T: Outgoing> { | ||||||
|     pub sender_user: Option<UserId>, |     pub sender_user: Option<UserId>, | ||||||
|     pub sender_device: Option<Box<DeviceId>>, |     pub sender_device: Option<Box<DeviceId>>, | ||||||
|     // This is None when body is not a valid string
 |     // This is None when body is not a valid string
 | ||||||
|     pub json_body: Option<Box<serde_json::value::RawValue>>, |     pub json_body: Option<CanonicalJsonValue>, | ||||||
|     pub from_appservice: bool, |     pub from_appservice: bool, | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -66,6 +66,8 @@ where | ||||||
|         let mut body = Vec::new(); |         let mut body = Vec::new(); | ||||||
|         handle.read_to_end(&mut body).await.unwrap(); |         handle.read_to_end(&mut body).await.unwrap(); | ||||||
| 
 | 
 | ||||||
|  |         let mut json_body = serde_json::from_slice::<CanonicalJsonValue>(&body).ok(); | ||||||
|  | 
 | ||||||
|         let (sender_user, sender_device, from_appservice) = if let Some((_id, registration)) = db |         let (sender_user, sender_device, from_appservice) = if let Some((_id, registration)) = db | ||||||
|             .appservice |             .appservice | ||||||
|             .iter_all() |             .iter_all() | ||||||
|  | @ -115,7 +117,7 @@ where | ||||||
|                             // Unknown Token
 |                             // Unknown Token
 | ||||||
|                             None => return Failure((Status::raw(581), ())), |                             None => return Failure((Status::raw(581), ())), | ||||||
|                             Some((user_id, device_id)) => { |                             Some((user_id, device_id)) => { | ||||||
|                                 (Some(user_id), Some(device_id.into()), false) |                                 (Some(user_id), Some(Box::<DeviceId>::from(device_id)), false) | ||||||
|                             } |                             } | ||||||
|                         } |                         } | ||||||
|                     } else { |                     } else { | ||||||
|  | @ -187,12 +189,10 @@ where | ||||||
|                         } |                         } | ||||||
|                     }; |                     }; | ||||||
| 
 | 
 | ||||||
|                     let json_body = serde_json::from_slice::<CanonicalJsonValue>(&body); |  | ||||||
| 
 |  | ||||||
|                     let mut request_map = BTreeMap::<String, CanonicalJsonValue>::new(); |                     let mut request_map = BTreeMap::<String, CanonicalJsonValue>::new(); | ||||||
| 
 | 
 | ||||||
|                     if let Ok(json_body) = json_body { |                     if let Some(json_body) = &json_body { | ||||||
|                         request_map.insert("content".to_owned(), json_body); |                         request_map.insert("content".to_owned(), json_body.clone()); | ||||||
|                     }; |                     }; | ||||||
| 
 | 
 | ||||||
|                     request_map.insert( |                     request_map.insert( | ||||||
|  | @ -271,6 +271,43 @@ where | ||||||
|             http_request = http_request.header(header.name.as_str(), &*header.value); |             http_request = http_request.header(header.name.as_str(), &*header.value); | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|  |         match &mut json_body { | ||||||
|  |             Some(CanonicalJsonValue::Object(json_body)) => { | ||||||
|  |                 let user_id = sender_user.clone().unwrap_or_else(|| { | ||||||
|  |                     UserId::parse_with_server_name("", db.globals.server_name()) | ||||||
|  |                         .expect("we know this is valid") | ||||||
|  |                 }); | ||||||
|  | 
 | ||||||
|  |                 if let Some(initial_request) = json_body | ||||||
|  |                     .get("auth") | ||||||
|  |                     .and_then(|auth| auth.as_object()) | ||||||
|  |                     .and_then(|auth| auth.get("session")) | ||||||
|  |                     .and_then(|session| session.as_str()) | ||||||
|  |                     .and_then(|session| { | ||||||
|  |                         db.uiaa | ||||||
|  |                             .get_uiaa_request( | ||||||
|  |                                 &user_id, | ||||||
|  |                                 &sender_device.clone().unwrap_or_else(|| "".into()), | ||||||
|  |                                 session, | ||||||
|  |                             ) | ||||||
|  |                             .ok() | ||||||
|  |                             .flatten() | ||||||
|  |                     }) | ||||||
|  |                 { | ||||||
|  |                     match initial_request { | ||||||
|  |                         CanonicalJsonValue::Object(initial_request) => { | ||||||
|  |                             for (key, value) in initial_request.into_iter() { | ||||||
|  |                                 json_body.entry(key).or_insert(value); | ||||||
|  |                             } | ||||||
|  |                         } | ||||||
|  |                         _ => {} | ||||||
|  |                     } | ||||||
|  |                 } | ||||||
|  |                 body = serde_json::to_vec(json_body).expect("value to bytes can't fail"); | ||||||
|  |             } | ||||||
|  |             _ => {} | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|         let http_request = http_request.body(&*body).unwrap(); |         let http_request = http_request.body(&*body).unwrap(); | ||||||
|         debug!("{:?}", http_request); |         debug!("{:?}", http_request); | ||||||
|         match <T::Incoming as IncomingRequest>::try_from_http_request(http_request) { |         match <T::Incoming as IncomingRequest>::try_from_http_request(http_request) { | ||||||
|  | @ -278,11 +315,8 @@ where | ||||||
|                 body: t, |                 body: t, | ||||||
|                 sender_user, |                 sender_user, | ||||||
|                 sender_device, |                 sender_device, | ||||||
|                 // TODO: Can we avoid parsing it again? (We only need this for append_pdu)
 |  | ||||||
|                 json_body: utils::string_from_bytes(&body) |  | ||||||
|                     .ok() |  | ||||||
|                     .and_then(|s| serde_json::value::RawValue::from_string(s).ok()), |  | ||||||
|                 from_appservice, |                 from_appservice, | ||||||
|  |                 json_body, | ||||||
|             }), |             }), | ||||||
|             Err(e) => { |             Err(e) => { | ||||||
|                 warn!("{:?}", e); |                 warn!("{:?}", e); | ||||||
|  |  | ||||||
|  | @ -1018,29 +1018,6 @@ pub fn handle_incoming_pdu<'a>( | ||||||
|         } |         } | ||||||
|         debug!("Auth check succeeded."); |         debug!("Auth check succeeded."); | ||||||
| 
 | 
 | ||||||
|         // 13. Check if the event passes auth based on the "current state" of the room, if not "soft fail" it
 |  | ||||||
|         let current_state = db |  | ||||||
|             .rooms |  | ||||||
|             .room_state_full(&room_id) |  | ||||||
|             .map_err(|_| "Failed to load room state.".to_owned())? |  | ||||||
|             .into_iter() |  | ||||||
|             .map(|(k, v)| (k, Arc::new(v))) |  | ||||||
|             .collect(); |  | ||||||
| 
 |  | ||||||
|         if !state_res::event_auth::auth_check( |  | ||||||
|             &room_version, |  | ||||||
|             &incoming_pdu, |  | ||||||
|             previous_create, |  | ||||||
|             ¤t_state, |  | ||||||
|             None, |  | ||||||
|         ) |  | ||||||
|         .map_err(|_e| "Auth check failed.".to_owned())? |  | ||||||
|         { |  | ||||||
|             // Soft fail, we leave the event as an outlier but don't add it to the timeline
 |  | ||||||
|             return Err("Event has been soft failed".into()); |  | ||||||
|         }; |  | ||||||
|         debug!("Auth check with current state succeeded."); |  | ||||||
| 
 |  | ||||||
|         // Now we calculate the set of extremities this room has after the incoming event has been
 |         // Now we calculate the set of extremities this room has after the incoming event has been
 | ||||||
|         // applied. We start with the previous extremities (aka leaves)
 |         // applied. We start with the previous extremities (aka leaves)
 | ||||||
|         let mut extremities = db |         let mut extremities = db | ||||||
|  | @ -1103,6 +1080,14 @@ pub fn handle_incoming_pdu<'a>( | ||||||
|         //     don't just trust a set of state we got from a remote).
 |         //     don't just trust a set of state we got from a remote).
 | ||||||
| 
 | 
 | ||||||
|         // We do this by adding the current state to the list of fork states
 |         // We do this by adding the current state to the list of fork states
 | ||||||
|  |         let current_state = db | ||||||
|  |             .rooms | ||||||
|  |             .room_state_full(&room_id) | ||||||
|  |             .map_err(|_| "Failed to load room state.".to_owned())? | ||||||
|  |             .into_iter() | ||||||
|  |             .map(|(k, v)| (k, Arc::new(v))) | ||||||
|  |             .collect(); | ||||||
|  | 
 | ||||||
|         fork_states.insert(current_state); |         fork_states.insert(current_state); | ||||||
| 
 | 
 | ||||||
|         // We also add state after incoming event to the fork states
 |         // We also add state after incoming event to the fork states
 | ||||||
|  | @ -1199,18 +1184,40 @@ pub fn handle_incoming_pdu<'a>( | ||||||
|             } |             } | ||||||
|         }; |         }; | ||||||
| 
 | 
 | ||||||
|         // Now that the event has passed all auth it is added into the timeline.
 |         // 13. Check if the event passes auth based on the "current state" of the room, if not "soft fail" it
 | ||||||
|         // We use the `state_at_event` instead of `state_after` so we accurately
 |         let soft_fail = !state_res::event_auth::auth_check( | ||||||
|         // represent the state for this event.
 |             &room_version, | ||||||
|         let pdu_id = append_incoming_pdu( |  | ||||||
|             &db, |  | ||||||
|             &incoming_pdu, |             &incoming_pdu, | ||||||
|             val, |             previous_create, | ||||||
|             extremities, |             &new_room_state | ||||||
|             &state_at_incoming_event, |                 .iter() | ||||||
|  |                 .filter_map(|(k, v)| { | ||||||
|  |                     Some((k.clone(), Arc::new(db.rooms.get_pdu(&v).ok().flatten()?))) | ||||||
|  |                 }) | ||||||
|  |                 .collect(), | ||||||
|  |             None, | ||||||
|         ) |         ) | ||||||
|         .map_err(|_| "Failed to add pdu to db.".to_owned())?; |         .map_err(|_e| "Auth check failed.".to_owned())?; | ||||||
|         debug!("Appended incoming pdu."); | 
 | ||||||
|  |         let mut pdu_id = None; | ||||||
|  |         if !soft_fail { | ||||||
|  |             // Now that the event has passed all auth it is added into the timeline.
 | ||||||
|  |             // We use the `state_at_event` instead of `state_after` so we accurately
 | ||||||
|  |             // represent the state for this event.
 | ||||||
|  |             pdu_id = Some( | ||||||
|  |                 append_incoming_pdu( | ||||||
|  |                     &db, | ||||||
|  |                     &incoming_pdu, | ||||||
|  |                     val, | ||||||
|  |                     extremities, | ||||||
|  |                     &state_at_incoming_event, | ||||||
|  |                 ) | ||||||
|  |                 .map_err(|_| "Failed to add pdu to db.".to_owned())?, | ||||||
|  |             ); | ||||||
|  |             debug!("Appended incoming pdu."); | ||||||
|  |         } else { | ||||||
|  |             warn!("Event was soft failed: {:?}", incoming_pdu); | ||||||
|  |         } | ||||||
| 
 | 
 | ||||||
|         // Set the new room state to the resolved state
 |         // Set the new room state to the resolved state
 | ||||||
|         if update_state { |         if update_state { | ||||||
|  | @ -1220,8 +1227,13 @@ pub fn handle_incoming_pdu<'a>( | ||||||
|         } |         } | ||||||
|         debug!("Updated resolved state"); |         debug!("Updated resolved state"); | ||||||
| 
 | 
 | ||||||
|  |         if soft_fail { | ||||||
|  |             // Soft fail, we leave the event as an outlier but don't add it to the timeline
 | ||||||
|  |             return Err("Event has been soft failed".into()); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|         // Event has passed all auth/stateres checks
 |         // Event has passed all auth/stateres checks
 | ||||||
|         Ok(Some(pdu_id)) |         Ok(pdu_id) | ||||||
|     }) |     }) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in a new issue