diff --git a/src/client_server.rs b/src/client_server.rs index 0103208..27db78c 100644 --- a/src/client_server.rs +++ b/src/client_server.rs @@ -13,9 +13,10 @@ use ruma_client_api::{ filter::{self, create_filter, get_filter}, keys::{get_keys, upload_keys}, membership::{ - get_member_events, invite_user, join_room_by_id, join_room_by_id_or_alias, forget_room, leave_room, + forget_room, get_member_events, invite_user, join_room_by_id, join_room_by_id_or_alias, + leave_room, }, - message::{get_message_events, create_message_event}, + message::{create_message_event, get_message_events}, presence::set_presence, profile::{ get_avatar_url, get_display_name, get_profile, set_avatar_url, set_display_name, @@ -673,16 +674,18 @@ pub fn join_room_by_id_or_alias_route( ) -> MatrixResult { let room_id = match RoomId::try_from(body.room_id_or_alias.clone()) { Ok(room_id) => room_id, - Err(room_alias) => if room_alias.server_name() == data.hostname() { - return MatrixResult(Err(Error { - kind: ErrorKind::NotFound, - message: "Room alias not found.".to_owned(), - status_code: http::StatusCode::NOT_FOUND, - })); - } else { - // Ask creator server of the room to join TODO ask someone else when not available - //server_server::send_request(data, destination, request) - todo!(); + Err(room_alias) => { + if room_alias.server_name() == data.hostname() { + return MatrixResult(Err(Error { + kind: ErrorKind::NotFound, + message: "Room alias not found.".to_owned(), + status_code: http::StatusCode::NOT_FOUND, + })); + } else { + // Ask creator server of the room to join TODO ask someone else when not available + //server_server::send_request(data, destination, request) + todo!(); + } } }; @@ -762,7 +765,7 @@ pub async fn get_public_rooms_filtered_route( .and_then(|s| s.content.get("name")) .and_then(|n| n.as_str()) .map(|n| n.to_owned()), - num_joined_members: data.room_users(&room_id).into(), + num_joined_members: data.room_users_joined(&room_id).into(), room_id, topic: None, world_readable: false, @@ -917,19 +920,37 @@ pub fn sync_route( body: Ruma, ) -> MatrixResult { std::thread::sleep(Duration::from_millis(300)); + let user_id = body.user_id.clone().expect("user is authenticated"); let next_batch = data.last_pdu_index().to_string(); let mut joined_rooms = BTreeMap::new(); - let joined_roomids = data.rooms_joined(body.user_id.as_ref().expect("user is authenticated")); + let joined_roomids = data.rooms_joined(&user_id); let since = body .since .clone() .and_then(|string| string.parse().ok()) .unwrap_or(0); + for room_id in joined_roomids { let pdus = data.pdus_since(&room_id, since); - let room_events = pdus.into_iter().map(|pdu| pdu.to_room_event()).collect::>(); - let is_first_pdu = data.room_pdu_first(&room_id, since); + + let mut send_member_count = false; + let mut send_full_state = false; + for pdu in &pdus { + if pdu.kind == EventType::RoomMember { + if pdu.state_key == Some(user_id.to_string()) && pdu.content["membership"] == "join" + { + send_full_state = true; + } + send_member_count = true; + } + } + + let room_events = pdus + .into_iter() + .map(|pdu| pdu.to_room_event()) + .collect::>(); + let mut edus = data.roomlatests_since(&room_id, since); edus.extend_from_slice(&data.roomactives_in(&room_id)); @@ -939,8 +960,16 @@ pub fn sync_route( account_data: sync_events::AccountData { events: Vec::new() }, summary: sync_events::RoomSummary { heroes: Vec::new(), - joined_member_count: None, - invited_member_count: None, + joined_member_count: if send_member_count { + Some(data.room_users_joined(&room_id).into()) + } else { + None + }, + invited_member_count: if send_member_count { + Some(data.room_users_invited(&room_id).into()) + } else { + None + }, }, unread_notifications: sync_events::UnreadNotificationsCount { highlight_count: None, @@ -951,14 +980,24 @@ pub fn sync_route( prev_batch: Some(since.to_string()), events: room_events, }, - state: sync_events::State { events: Vec::new() }, + // TODO: state before timeline + state: sync_events::State { + events: if send_full_state { + data.room_state(&room_id) + .into_iter() + .map(|(_, pdu)| pdu.to_state_event()) + .collect() + } else { + Vec::new() + }, + }, ephemeral: sync_events::Ephemeral { events: edus }, }, ); } let mut left_rooms = BTreeMap::new(); - let left_roomids = data.rooms_left(body.user_id.as_ref().expect("user is authenticated")); + let left_roomids = data.rooms_left(&user_id); for room_id in left_roomids { let pdus = data.pdus_since(&room_id, since); let room_events = pdus.into_iter().map(|pdu| pdu.to_room_event()).collect(); @@ -980,7 +1019,7 @@ pub fn sync_route( } let mut invited_rooms = BTreeMap::new(); - for room_id in data.rooms_invited(body.user_id.as_ref().expect("user is authenticated")) { + for room_id in data.rooms_invited(&user_id) { let events = data .pdus_since(&room_id, since) .into_iter() @@ -1013,15 +1052,18 @@ pub fn sync_route( pub fn get_message_events_route( data: State, body: Ruma, - _room_id: String) -> MatrixResult { - if let get_message_events::Direction::Forward = body.dir {todo!();} + _room_id: String, +) -> MatrixResult { + if let get_message_events::Direction::Forward = body.dir { + todo!(); + } - if let Ok(from) = body - .from - .clone() - .parse() { + if let Ok(from) = body.from.clone().parse() { let pdus = data.pdus_until(&body.room_id, from); - let room_events = pdus.into_iter().map(|pdu| pdu.to_room_event()).collect::>(); + let room_events = pdus + .into_iter() + .map(|pdu| pdu.to_room_event()) + .collect::>(); MatrixResult(Ok(get_message_events::Response { start: Some(body.from.clone()), end: None, @@ -1058,7 +1100,9 @@ pub fn publicised_groups_route() -> MatrixResult } #[options("/<_segments..>")] -pub fn options_route(_segments: rocket::http::uri::Segments) -> MatrixResult { +pub fn options_route( + _segments: rocket::http::uri::Segments, +) -> MatrixResult { MatrixResult(Err(Error { kind: ErrorKind::NotFound, message: "This is the options route.".to_owned(), diff --git a/src/data.rs b/src/data.rs index 6e09d07..3b652ba 100644 --- a/src/data.rs +++ b/src/data.rs @@ -193,11 +193,11 @@ impl Data { return false; } - self.db.userid_roomids.add( + self.db.userid_joinroomids.add( user_id.to_string().as_bytes(), room_id.to_string().as_bytes().into(), ); - self.db.roomid_userids.add( + self.db.roomid_joinuserids.add( room_id.to_string().as_bytes(), user_id.to_string().as_bytes().into(), ); @@ -205,6 +205,10 @@ impl Data { user_id.to_string().as_bytes(), room_id.to_string().as_bytes(), ); + self.db.roomid_inviteuserids.remove_value( + user_id.to_string().as_bytes(), + room_id.to_string().as_bytes(), + ); self.db.userid_leftroomids.remove_value( user_id.to_string().as_bytes(), room_id.to_string().as_bytes().into(), @@ -232,7 +236,7 @@ impl Data { pub fn rooms_joined(&self, user_id: &UserId) -> Vec { self.db - .userid_roomids + .userid_joinroomids .get_iter(user_id.to_string().as_bytes()) .values() .map(|room_id| { @@ -282,9 +286,16 @@ impl Data { room_ids } - pub fn room_users(&self, room_id: &RoomId) -> u32 { + pub fn room_users_joined(&self, room_id: &RoomId) -> u32 { self.db - .roomid_userids + .roomid_joinuserids + .get_iter(room_id.to_string().as_bytes()) + .count() as u32 + } + + pub fn room_users_invited(&self, room_id: &RoomId) -> u32 { + self.db + .roomid_inviteuserids .get_iter(room_id.to_string().as_bytes()) .count() as u32 } @@ -324,11 +335,15 @@ impl Data { user_id.to_string().as_bytes(), room_id.to_string().as_bytes().into(), ); - self.db.userid_roomids.remove_value( + self.db.roomid_inviteuserids.remove_value( user_id.to_string().as_bytes(), room_id.to_string().as_bytes().into(), ); - self.db.roomid_userids.remove_value( + self.db.userid_joinroomids.remove_value( + user_id.to_string().as_bytes(), + room_id.to_string().as_bytes().into(), + ); + self.db.roomid_joinuserids.remove_value( room_id.to_string().as_bytes(), user_id.to_string().as_bytes().into(), ); @@ -358,7 +373,7 @@ impl Data { user_id.to_string().as_bytes(), room_id.to_string().as_bytes().into(), ); - self.db.roomid_userids.add( + self.db.roomid_inviteuserids.add( room_id.to_string().as_bytes(), user_id.to_string().as_bytes().into(), ); diff --git a/src/database.rs b/src/database.rs index 73406c1..3dd7564 100644 --- a/src/database.rs +++ b/src/database.rs @@ -71,8 +71,9 @@ pub struct Database { pub eventid_pduid: sled::Tree, pub roomid_pduleaves: MultiValue, pub roomstateid_pdu: sled::Tree, // Room + StateType + StateKey - pub roomid_userids: MultiValue, - pub userid_roomids: MultiValue, + pub roomid_joinuserids: MultiValue, + pub roomid_inviteuserids: MultiValue, + pub userid_joinroomids: MultiValue, pub userid_inviteroomids: MultiValue, pub userid_leftroomids: MultiValue, // EDUs: @@ -115,8 +116,9 @@ impl Database { eventid_pduid: db.open_tree("eventid_pduid").unwrap(), roomid_pduleaves: MultiValue(db.open_tree("roomid_pduleaves").unwrap()), roomstateid_pdu: db.open_tree("roomstateid_pdu").unwrap(), - roomid_userids: MultiValue(db.open_tree("roomid_userids").unwrap()), - userid_roomids: MultiValue(db.open_tree("userid_roomids").unwrap()), + roomid_joinuserids: MultiValue(db.open_tree("roomid_joinuserids").unwrap()), + roomid_inviteuserids: MultiValue(db.open_tree("roomid_inviteuserids").unwrap()), + userid_joinroomids: MultiValue(db.open_tree("userid_joinroomids").unwrap()), userid_inviteroomids: MultiValue(db.open_tree("userid_inviteroomids").unwrap()), userid_leftroomids: MultiValue(db.open_tree("userid_leftroomids").unwrap()), roomlatestid_roomlatest: db.open_tree("roomlatestid_roomlatest").unwrap(), @@ -200,7 +202,7 @@ impl Database { ); } println!("\n# RoomId -> UserIds:"); - for (k, v) in self.roomid_userids.iter_all().map(|r| r.unwrap()) { + for (k, v) in self.roomid_joinuserids.iter_all().map(|r| r.unwrap()) { println!( "{:?} -> {:?}", String::from_utf8_lossy(&k), @@ -208,7 +210,7 @@ impl Database { ); } println!("\n# UserId -> RoomIds:"); - for (k, v) in self.userid_roomids.iter_all().map(|r| r.unwrap()) { + for (k, v) in self.userid_joinroomids.iter_all().map(|r| r.unwrap()) { println!( "{:?} -> {:?}", String::from_utf8_lossy(&k), diff --git a/src/pdu.rs b/src/pdu.rs index b6aa45d..0e1b3de 100644 --- a/src/pdu.rs +++ b/src/pdu.rs @@ -1,6 +1,8 @@ use js_int::UInt; use ruma_events::{ - collections::all::RoomEvent, stripped::AnyStrippedStateEvent, EventJson, EventType, + collections::all::{RoomEvent, StateEvent}, + stripped::AnyStrippedStateEvent, + EventJson, EventType, }; use ruma_federation_api::EventHash; use ruma_identifiers::{EventId, RoomId, UserId}; @@ -39,12 +41,12 @@ impl PduEvent { serde_json::from_str::>(&json).unwrap() } - pub fn to_stripped_state_event(&self) -> EventJson { - // Can only fail in rare circumstances that won't ever happen here, see - // https://docs.rs/serde_json/1.0.50/serde_json/fn.to_string.html + pub fn to_state_event(&self) -> EventJson { + let json = serde_json::to_string(&self).unwrap(); + serde_json::from_str::>(&json).unwrap() + } + pub fn to_stripped_state_event(&self) -> EventJson { let json = serde_json::to_string(&self).unwrap(); - - // EventJson's deserialize implementation always returns `Ok(...)` serde_json::from_str::>(&json).unwrap() } } diff --git a/src/server_server.rs b/src/server_server.rs index da0b8a0..394757a 100644 --- a/src/server_server.rs +++ b/src/server_server.rs @@ -13,18 +13,21 @@ use std::{ }; pub async fn request_well_known(data: &crate::Data, destination: &str) -> Option { - let body: serde_json::Value = serde_json::from_str(&data - .reqwest_client() - .get(&format!( - "https://{}/.well-known/matrix/server", - destination - )) - .send() - .await - .ok()? - .text() - .await - .ok()?).ok()?; + let body: serde_json::Value = serde_json::from_str( + &data + .reqwest_client() + .get(&format!( + "https://{}/.well-known/matrix/server", + destination + )) + .send() + .await + .ok()? + .text() + .await + .ok()?, + ) + .ok()?; Some(body.get("m.server")?.as_str()?.to_owned()) } @@ -35,10 +38,11 @@ pub async fn send_request( ) -> Option { let mut http_request: http::Request<_> = request.try_into().unwrap(); - let actual_destination = "https://".to_owned() + &request_well_known(data, &destination).await.unwrap_or(destination.clone() + ":8448"); - *http_request.uri_mut() = (actual_destination + T::METADATA.path) - .parse() - .unwrap(); + let actual_destination = "https://".to_owned() + + &request_well_known(data, &destination) + .await + .unwrap_or(destination.clone() + ":8448"); + *http_request.uri_mut() = (actual_destination + T::METADATA.path).parse().unwrap(); let mut request_map = serde_json::Map::new(); diff --git a/src/test.rs b/src/test.rs index 8756436..9f56214 100644 --- a/src/test.rs +++ b/src/test.rs @@ -5,10 +5,8 @@ use serde_json::{json, Value}; use std::time::Duration; fn setup_client() -> Client { - Database::try_remove("temp"); - let data = Data::load_or_create("temp"); - - let rocket = setup_rocket(data); + Database::try_remove("localhost"); + let rocket = setup_rocket(); Client::new(rocket).expect("valid rocket instance") }