diff --git a/src/client_server/alias.rs b/src/client_server/alias.rs index 12bb8df..669f558 100644 --- a/src/client_server/alias.rs +++ b/src/client_server/alias.rs @@ -53,11 +53,11 @@ pub async fn get_alias_route( db: State<'_, Database>, body: Ruma, ) -> ConduitResult { - get_alias_helper(db, &body.room_alias).await + get_alias_helper(&db, &body.room_alias).await } pub async fn get_alias_helper( - db: State<'_, Database>, + db: &Database, room_alias: &RoomAliasId, ) -> ConduitResult { if room_alias.server_name() != db.globals.server_name() { diff --git a/src/client_server/membership.rs b/src/client_server/membership.rs index 0075861..996d3c4 100644 --- a/src/client_server/membership.rs +++ b/src/client_server/membership.rs @@ -49,13 +49,12 @@ pub async fn join_room_by_id_route( )] pub async fn join_room_by_id_or_alias_route( db: State<'_, Database>, - db2: State<'_, Database>, body: Ruma, ) -> ConduitResult { let room_id = match RoomId::try_from(body.room_id_or_alias.clone()) { Ok(room_id) => room_id, Err(room_alias) => { - client_server::get_alias_helper(db, &room_alias) + client_server::get_alias_helper(&db, &room_alias) .await? .0 .room_id @@ -64,7 +63,7 @@ pub async fn join_room_by_id_or_alias_route( Ok(join_room_by_id_or_alias::Response { room_id: join_room_by_id_helper( - &db2, + &db, body.sender_id.as_ref(), &room_id, body.third_party_signed.as_ref(), @@ -507,14 +506,13 @@ async fn join_room_by_id_helper( .collect::>(); // TODO make StateResolution's methods free functions ? or no self param ? - let sorted_events_ids = state_res::StateResolution::default() - .reverse_topological_power_sort( - &room_id, - &event_map.keys().cloned().collect::>(), - &mut event_map, - &db.rooms, - &[], // TODO auth_diff: is this none since we have a set of resolved events we only want to sort - ); + let sorted_events_ids = state_res::StateResolution::reverse_topological_power_sort( + &room_id, + &event_map.keys().cloned().collect::>(), + &mut event_map, + &db.rooms, + &[], // TODO auth_diff: is this none since we have a set of resolved events we only want to sort + ); for ev_id in &sorted_events_ids { // this is a `state_res::StateEvent` that holds a `ruma::Pdu` diff --git a/src/client_server/state.rs b/src/client_server/state.rs index 2920de2..867b051 100644 --- a/src/client_server/state.rs +++ b/src/client_server/state.rs @@ -9,8 +9,8 @@ use ruma::{ }, }, events::{AnyStateEventContent, EventContent}, + RoomId, UserId, }; -use std::convert::TryFrom; #[cfg(feature = "conduit_bin")] use rocket::{get, put}; @@ -33,45 +33,14 @@ pub fn send_state_event_for_key_route( ) .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Invalid JSON body."))?; - if let AnyStateEventContent::RoomCanonicalAlias(canonical_alias) = &body.content { - let mut aliases = canonical_alias.alt_aliases.clone(); - - if let Some(alias) = canonical_alias.alias.clone() { - aliases.push(alias); - } - - for alias in aliases { - if alias.server_name() != db.globals.server_name() - || db - .rooms - .id_from_alias(&alias)? - .filter(|room| room == &body.room_id) // Make sure it's the right room - .is_none() - { - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "You are only allowed to send canonical_alias \ - events when it's aliases already exists", - )); - } - } - } - - let event_id = db.rooms.build_and_append_pdu( - PduBuilder { - room_id: body.room_id.clone(), - sender: sender_id.clone(), - event_type: body.content.event_type().into(), - content, - unsigned: None, - state_key: Some(body.state_key.clone()), - redacts: None, - }, - &db.globals, - &db.account_data, - )?; - - Ok(send_state_event_for_key::Response::new(event_id).into()) + send_state_event_for_key_helper( + &db, + sender_id, + &body.content, + content, + &body.room_id, + Some(body.state_key.clone()), + ) } #[cfg_attr( @@ -84,34 +53,30 @@ pub fn send_state_event_for_empty_key_route( ) -> ConduitResult { // This just calls send_state_event_for_key_route let Ruma { - body: - send_state_event_for_empty_key::IncomingRequest { - room_id, content, .. - }, + body, sender_id, - device_id, + device_id: _, json_body, } = body; + let json = serde_json::from_str::( + json_body + .as_ref() + .ok_or(Error::BadRequest(ErrorKind::BadJson, "Invalid JSON body."))? + .get(), + ) + .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Invalid JSON body."))?; + Ok(send_state_event_for_empty_key::Response::new( - send_state_event_for_key_route( - db, - Ruma { - body: send_state_event_for_key::IncomingRequest::try_from(http::Request::new( - serde_json::json!({ - "room_id": room_id, - "state_key": "", - "content": content, - }) - .to_string() - .as_bytes() - .to_vec(), - )) - .unwrap(), - sender_id, - device_id, - json_body, - }, + send_state_event_for_key_helper( + &db, + sender_id + .as_ref() + .expect("no user for send state empty key rout"), + &body.content, + json, + &body.room_id, + None, )? .0 .event_id, @@ -210,3 +175,54 @@ pub fn get_state_events_for_empty_key_route( } .into()) } + +pub fn send_state_event_for_key_helper( + db: &Database, + sender: &UserId, + content: &AnyStateEventContent, + json: serde_json::Value, + room_id: &RoomId, + state_key: Option, +) -> ConduitResult { + let sender_id = sender; + + if let AnyStateEventContent::RoomCanonicalAlias(canonical_alias) = content { + let mut aliases = canonical_alias.alt_aliases.clone(); + + if let Some(alias) = canonical_alias.alias.clone() { + aliases.push(alias); + } + + for alias in aliases { + if alias.server_name() != db.globals.server_name() + || db + .rooms + .id_from_alias(&alias)? + .filter(|room| room == room_id) // Make sure it's the right room + .is_none() + { + return Err(Error::BadRequest( + ErrorKind::Forbidden, + "You are only allowed to send canonical_alias \ + events when it's aliases already exists", + )); + } + } + } + + let event_id = db.rooms.build_and_append_pdu( + PduBuilder { + room_id: room_id.clone(), + sender: sender_id.clone(), + event_type: content.event_type().into(), + content: json, + unsigned: None, + state_key, + redacts: None, + }, + &db.globals, + &db.account_data, + )?; + + Ok(send_state_event_for_key::Response::new(event_id).into()) +}