diff --git a/src/async_client.rs b/src/async_client.rs index 96b50123..98a0ce42 100644 --- a/src/async_client.rs +++ b/src/async_client.rs @@ -2,7 +2,8 @@ use std::convert::{TryFrom, TryInto}; use std::future::Future; use std::pin::Pin; use std::rc::Rc; -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, RwLock}; +use std::sync::atomic::{AtomicU64, Ordering}; use http::Method as HttpMethod; use http::Response as HttpResponse; @@ -12,7 +13,10 @@ use url::Url; use ruma_api::Endpoint; use ruma_events::collections::all::RoomEvent; +use ruma_events::room::message::MessageEvent; +use ruma_events::room::message::MessageEventContent; use ruma_events::Event; +use ruma_identifiers::RoomId; pub use ruma_events::EventType; use crate::api; @@ -22,19 +26,22 @@ use crate::error::{Error, InnerError}; use crate::session::Session; type RoomEventCallback = Box::; -type RoomEventCallbackF = Box::>, Arc) -> Pin + Send + Sync>> + Send + Sync>; +type RoomEventCallbackF = Box::>, Arc) -> Pin + Send + Sync>> + Send + Sync>; +#[derive(Clone)] pub struct AsyncClient { /// The URL of the homeserver to connect to. homeserver: Url, /// The underlying HTTP client. http_client: reqwest::Client, /// User session data. - base_client: BaseClient, + base_client: Arc>, + /// The transaction id. + transaction_id: Arc, // /// Event callbacks // event_callbacks: Vec, - /// Event futures - event_futures: Vec, + // /// Event futures + // event_futures: Vec, } #[derive(Default, Debug)] @@ -115,20 +122,25 @@ impl SyncSettings { use api::r0::session::login; use api::r0::sync::sync_events; +use api::r0::send::send_message_event; impl AsyncClient { /// Creates a new client for making HTTP requests to the given homeserver. - pub fn new(homeserver_url: &str, session: Option) -> Result { + pub fn new>(homeserver_url: U, session: Option) -> Result { let config = AsyncClientConfig::new(); AsyncClient::new_with_config(homeserver_url, session, config) } - pub fn new_with_config( - homeserver_url: &str, + pub fn new_with_config>( + homeserver_url: U, session: Option, config: AsyncClientConfig, ) -> Result { - let homeserver = Url::parse(homeserver_url)?; + let homeserver: Url = match homeserver_url.try_into() { + Ok(u) => u, + Err(e) => panic!("Error parsing homeserver url") + }; + let http_client = reqwest::Client::builder(); let http_client = if config.disable_ssl_verification { @@ -165,9 +177,10 @@ impl AsyncClient { Ok(Self { homeserver, http_client, - base_client: BaseClient::new(session), + base_client: Arc::new(RwLock::new(BaseClient::new(session))), + transaction_id: Arc::new(AtomicU64::new(0)), // event_callbacks: Vec::new(), - event_futures: Vec::new(), + // event_futures: Vec::new(), }) } @@ -179,13 +192,13 @@ impl AsyncClient { // self.event_callbacks.push(callback); // } - pub fn add_event_future( - &mut self, - event_type: EventType, - callback: RoomEventCallbackF, - ) { - self.event_futures.push(callback); - } + // pub fn add_event_future( + // &mut self, + // event_type: EventType, + // callback: RoomEventCallbackF, + // ) { + // self.event_futures.push(callback); + // } pub async fn login>( &mut self, @@ -203,7 +216,8 @@ impl AsyncClient { }; let response = self.send(request).await?; - self.base_client.receive_login_response(&response); + let mut client = self.base_client.write().unwrap(); + client.receive_login_response(&response); Ok(response) } @@ -222,6 +236,8 @@ impl AsyncClient { let response = self.send(request).await?; + let mut client = self.base_client.write().unwrap(); + for (room_id, room) in &response.rooms.join { let room_id = room_id.to_string(); @@ -231,7 +247,7 @@ impl AsyncClient { Err(e) => continue }; - self.base_client.receive_joined_state_event(&room_id, &event); + client.receive_joined_state_event(&room_id, &event); } for event in &room.timeline.events { @@ -240,18 +256,17 @@ impl AsyncClient { Err(e) => continue }; - self.base_client - .receive_joined_timeline_event(&room_id, &event); + client.receive_joined_timeline_event(&room_id, &event); - let room = self.base_client.joined_rooms.get(&room_id).unwrap(); + let room = client.joined_rooms.get(&room_id).unwrap(); // for mut cb in &mut self.event_callbacks { // cb(&room.lock().unwrap(), &event); // } - for mut cb in &mut self.event_futures { - cb(room.clone(), Arc::new(event.clone())).await; - } + // for mut cb in &mut self.event_futures { + // cb(room.clone(), Arc::new(event.clone())).await; + // } } } @@ -261,7 +276,7 @@ impl AsyncClient { async fn send(&self, request: Request) -> Result { let request: http::Request> = request.try_into()?; let url = request.uri(); - let url = self.homeserver.join(url.path()).unwrap(); + let url = self.homeserver.join(url.path_and_query().unwrap().as_str()).unwrap(); let request_builder = match Request::METADATA.method { HttpMethod::GET => self.http_client.get(url), @@ -269,13 +284,24 @@ impl AsyncClient { let body = request.body().clone(); self.http_client.post(url).body(body) } - HttpMethod::PUT => unimplemented!(), + HttpMethod::PUT => { + let body = request.body().clone(); + self.http_client.put(url).body(body) + } HttpMethod::DELETE => unimplemented!(), _ => panic!("Unsuported method"), }; + // let request_builder = if let Some(query) = request.uri().query() { + // request_builder.query(query) + // } else { + // request_builder + // }; + let request_builder = if Request::METADATA.requires_authentication { - if let Some(ref session) = self.base_client.session { + let client = self.base_client.read().unwrap(); + + if let Some(ref session) = client.session { request_builder.bearer_auth(&session.access_token) } else { return Err(Error(InnerError::AuthenticationRequired)); diff --git a/src/base_client.rs b/src/base_client.rs index 9b7be64d..6b77e0b8 100644 --- a/src/base_client.rs +++ b/src/base_client.rs @@ -4,7 +4,7 @@ use crate::api::r0 as api; use crate::events::collections::all::{RoomEvent, StateEvent}; use crate::events::room::member::{MemberEvent, MembershipState}; use crate::session::Session; -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, RwLock}; pub type Token = String; pub type RoomId = String; @@ -99,7 +99,7 @@ impl Room { false } - fn handle_membership(&mut self, event: &MemberEvent) -> bool { + pub fn handle_membership(&mut self, event: &MemberEvent) -> bool { match event.content.membership { MembershipState::Join => self.handle_join(event), MembershipState::Leave => self.handle_leave(event), @@ -149,7 +149,7 @@ pub struct Client { /// The current sync token that should be used for the next sync call. pub sync_token: Option, /// A map of the rooms our user is joined in. - pub joined_rooms: HashMap>>, + pub joined_rooms: HashMap>>, } impl Client { @@ -185,11 +185,11 @@ impl Client { self.session = Some(session); } - fn get_or_create_room(&mut self, room_id: &RoomId) -> &mut Arc> { + fn get_or_create_room(&mut self, room_id: &RoomId) -> &mut Arc> { self.joined_rooms .entry(room_id.to_string()) .or_insert( - Arc::new(Mutex::new(Room::new( + Arc::new(RwLock::new(Room::new( room_id, &self .session @@ -209,7 +209,7 @@ impl Client { /// Returns true if the membership list of the room changed, false /// otherwise. pub fn receive_joined_timeline_event(&mut self, room_id: &RoomId, event: &RoomEvent) -> bool { - let mut room = self.get_or_create_room(room_id).lock().unwrap(); + let mut room = self.get_or_create_room(room_id).write().unwrap(); room.receive_timeline_event(event) } @@ -222,7 +222,7 @@ impl Client { /// Returns true if the membership list of the room changed, false /// otherwise. pub fn receive_joined_state_event(&mut self, room_id: &RoomId, event: &StateEvent) -> bool { - let mut room = self.get_or_create_room(room_id).lock().unwrap(); + let mut room = self.get_or_create_room(room_id).write().unwrap(); room.receive_state_event(event) } }