diff --git a/src/client_server/membership.rs b/src/client_server/membership.rs index 0da0747..ab646a1 100644 --- a/src/client_server/membership.rs +++ b/src/client_server/membership.rs @@ -22,7 +22,11 @@ use ruma::{ serde::{to_canonical_value, CanonicalJsonObject, Raw}, EventId, RoomId, RoomVersionId, ServerName, UserId, }; -use std::{collections::BTreeMap, convert::TryFrom, sync::RwLock}; +use std::{ + collections::{BTreeMap, HashSet}, + convert::TryFrom, + sync::RwLock, +}; #[cfg(feature = "conduit_bin")] use rocket::{get, post}; @@ -36,11 +40,29 @@ pub async fn join_room_by_id_route( db: State<'_, Database>, body: Ruma>, ) -> ConduitResult { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + + let mut servers = db + .rooms + .invite_state(&sender_user, &body.room_id)? + .unwrap_or_default() + .iter() + .filter_map(|event| { + serde_json::from_str::(&event.json().to_string()).ok() + }) + .filter_map(|event| event.get("sender").cloned()) + .filter_map(|sender| sender.as_str().map(|s| s.to_owned())) + .filter_map(|sender| UserId::try_from(sender).ok()) + .map(|user| user.server_name().to_owned()) + .collect::>(); + + servers.insert(body.room_id.server_name().to_owned()); + join_room_by_id_helper( &db, body.sender_user.as_ref(), &body.room_id, - &[body.room_id.server_name().to_owned()], + &servers, body.third_party_signed.as_ref(), ) .await @@ -55,12 +77,31 @@ pub async fn join_room_by_id_or_alias_route( db: State<'_, Database>, body: Ruma>, ) -> ConduitResult { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let (servers, room_id) = match RoomId::try_from(body.room_id_or_alias.clone()) { - Ok(room_id) => (vec![room_id.server_name().to_owned()], room_id), + Ok(room_id) => { + let mut servers = db + .rooms + .invite_state(&sender_user, &room_id)? + .unwrap_or_default() + .iter() + .filter_map(|event| { + serde_json::from_str::(&event.json().to_string()).ok() + }) + .filter_map(|event| event.get("sender").cloned()) + .filter_map(|sender| sender.as_str().map(|s| s.to_owned())) + .filter_map(|sender| UserId::try_from(sender).ok()) + .map(|user| user.server_name().to_owned()) + .collect::>(); + + servers.insert(room_id.server_name().to_owned()); + (servers, room_id) + } Err(room_alias) => { let response = client_server::get_alias_helper(&db, &room_alias).await?; - (response.0.servers, response.0.room_id) + (response.0.servers.into_iter().collect(), response.0.room_id) } }; @@ -406,7 +447,7 @@ async fn join_room_by_id_helper( db: &Database, sender_user: Option<&UserId>, room_id: &RoomId, - servers: &[Box], + servers: &HashSet>, _third_party_signed: Option<&IncomingThirdPartySigned>, ) -> ConduitResult { let sender_user = sender_user.expect("user is authenticated"); diff --git a/src/database/rooms.rs b/src/database/rooms.rs index 955ad83..35c1df7 100644 --- a/src/database/rooms.rs +++ b/src/database/rooms.rs @@ -1795,7 +1795,8 @@ impl Rooms { .filter_map(|event| event.get("sender").cloned()) .filter_map(|sender| sender.as_str().map(|s| s.to_owned())) .filter_map(|sender| UserId::try_from(sender).ok()) - .map(|user| user.server_name().to_owned()); + .map(|user| user.server_name().to_owned()) + .collect::>(); for remote_server in servers { let make_leave_response = db