nio: Allow the AsyncClient to be clonable.

master
Damir Jelić 2019-11-10 11:44:03 +01:00
parent 066d76cc8e
commit 2aca52c9f0
2 changed files with 62 additions and 36 deletions

View File

@ -2,7 +2,8 @@ use std::convert::{TryFrom, TryInto};
use std::future::Future;
use std::pin::Pin;
use std::rc::Rc;
use std::sync::{Arc, Mutex};
use std::sync::{Arc, RwLock};
use std::sync::atomic::{AtomicU64, Ordering};
use http::Method as HttpMethod;
use http::Response as HttpResponse;
@ -12,7 +13,10 @@ use url::Url;
use ruma_api::Endpoint;
use ruma_events::collections::all::RoomEvent;
use ruma_events::room::message::MessageEvent;
use ruma_events::room::message::MessageEventContent;
use ruma_events::Event;
use ruma_identifiers::RoomId;
pub use ruma_events::EventType;
use crate::api;
@ -22,19 +26,22 @@ use crate::error::{Error, InnerError};
use crate::session::Session;
type RoomEventCallback = Box::<dyn FnMut(&Room, &RoomEvent)>;
type RoomEventCallbackF = Box::<dyn FnMut(Arc<Mutex<Room>>, Arc<RoomEvent>) -> Pin<Box<dyn Future<Output = ()> + Send + Sync>> + Send + Sync>;
type RoomEventCallbackF = Box::<dyn FnMut(Arc<RwLock<Room>>, Arc<RoomEvent>) -> Pin<Box<dyn Future<Output = ()> + Send + Sync>> + Send + Sync>;
#[derive(Clone)]
pub struct AsyncClient {
/// The URL of the homeserver to connect to.
homeserver: Url,
/// The underlying HTTP client.
http_client: reqwest::Client,
/// User session data.
base_client: BaseClient,
base_client: Arc<RwLock<BaseClient>>,
/// The transaction id.
transaction_id: Arc<AtomicU64>,
// /// Event callbacks
// event_callbacks: Vec<RoomEventCallback>,
/// Event futures
event_futures: Vec<RoomEventCallbackF>,
// /// Event futures
// event_futures: Vec<RoomEventCallbackF>,
}
#[derive(Default, Debug)]
@ -115,20 +122,25 @@ impl SyncSettings {
use api::r0::session::login;
use api::r0::sync::sync_events;
use api::r0::send::send_message_event;
impl AsyncClient {
/// Creates a new client for making HTTP requests to the given homeserver.
pub fn new(homeserver_url: &str, session: Option<Session>) -> Result<Self, Error> {
pub fn new<U: TryInto<Url>>(homeserver_url: U, session: Option<Session>) -> Result<Self, Error> {
let config = AsyncClientConfig::new();
AsyncClient::new_with_config(homeserver_url, session, config)
}
pub fn new_with_config(
homeserver_url: &str,
pub fn new_with_config<U: TryInto<Url>>(
homeserver_url: U,
session: Option<Session>,
config: AsyncClientConfig,
) -> Result<Self, Error> {
let homeserver = Url::parse(homeserver_url)?;
let homeserver: Url = match homeserver_url.try_into() {
Ok(u) => u,
Err(e) => panic!("Error parsing homeserver url")
};
let http_client = reqwest::Client::builder();
let http_client = if config.disable_ssl_verification {
@ -165,9 +177,10 @@ impl AsyncClient {
Ok(Self {
homeserver,
http_client,
base_client: BaseClient::new(session),
base_client: Arc::new(RwLock::new(BaseClient::new(session))),
transaction_id: Arc::new(AtomicU64::new(0)),
// event_callbacks: Vec::new(),
event_futures: Vec::new(),
// event_futures: Vec::new(),
})
}
@ -179,13 +192,13 @@ impl AsyncClient {
// self.event_callbacks.push(callback);
// }
pub fn add_event_future(
&mut self,
event_type: EventType,
callback: RoomEventCallbackF,
) {
self.event_futures.push(callback);
}
// pub fn add_event_future(
// &mut self,
// event_type: EventType,
// callback: RoomEventCallbackF,
// ) {
// self.event_futures.push(callback);
// }
pub async fn login<S: Into<String>>(
&mut self,
@ -203,7 +216,8 @@ impl AsyncClient {
};
let response = self.send(request).await?;
self.base_client.receive_login_response(&response);
let mut client = self.base_client.write().unwrap();
client.receive_login_response(&response);
Ok(response)
}
@ -222,6 +236,8 @@ impl AsyncClient {
let response = self.send(request).await?;
let mut client = self.base_client.write().unwrap();
for (room_id, room) in &response.rooms.join {
let room_id = room_id.to_string();
@ -231,7 +247,7 @@ impl AsyncClient {
Err(e) => continue
};
self.base_client.receive_joined_state_event(&room_id, &event);
client.receive_joined_state_event(&room_id, &event);
}
for event in &room.timeline.events {
@ -240,18 +256,17 @@ impl AsyncClient {
Err(e) => continue
};
self.base_client
.receive_joined_timeline_event(&room_id, &event);
client.receive_joined_timeline_event(&room_id, &event);
let room = self.base_client.joined_rooms.get(&room_id).unwrap();
let room = client.joined_rooms.get(&room_id).unwrap();
// for mut cb in &mut self.event_callbacks {
// cb(&room.lock().unwrap(), &event);
// }
for mut cb in &mut self.event_futures {
cb(room.clone(), Arc::new(event.clone())).await;
}
// for mut cb in &mut self.event_futures {
// cb(room.clone(), Arc::new(event.clone())).await;
// }
}
}
@ -261,7 +276,7 @@ impl AsyncClient {
async fn send<Request: Endpoint>(&self, request: Request) -> Result<Request::Response, Error> {
let request: http::Request<Vec<u8>> = request.try_into()?;
let url = request.uri();
let url = self.homeserver.join(url.path()).unwrap();
let url = self.homeserver.join(url.path_and_query().unwrap().as_str()).unwrap();
let request_builder = match Request::METADATA.method {
HttpMethod::GET => self.http_client.get(url),
@ -269,13 +284,24 @@ impl AsyncClient {
let body = request.body().clone();
self.http_client.post(url).body(body)
}
HttpMethod::PUT => unimplemented!(),
HttpMethod::PUT => {
let body = request.body().clone();
self.http_client.put(url).body(body)
}
HttpMethod::DELETE => unimplemented!(),
_ => panic!("Unsuported method"),
};
// let request_builder = if let Some(query) = request.uri().query() {
// request_builder.query(query)
// } else {
// request_builder
// };
let request_builder = if Request::METADATA.requires_authentication {
if let Some(ref session) = self.base_client.session {
let client = self.base_client.read().unwrap();
if let Some(ref session) = client.session {
request_builder.bearer_auth(&session.access_token)
} else {
return Err(Error(InnerError::AuthenticationRequired));

View File

@ -4,7 +4,7 @@ use crate::api::r0 as api;
use crate::events::collections::all::{RoomEvent, StateEvent};
use crate::events::room::member::{MemberEvent, MembershipState};
use crate::session::Session;
use std::sync::{Arc, Mutex};
use std::sync::{Arc, RwLock};
pub type Token = String;
pub type RoomId = String;
@ -99,7 +99,7 @@ impl Room {
false
}
fn handle_membership(&mut self, event: &MemberEvent) -> bool {
pub fn handle_membership(&mut self, event: &MemberEvent) -> bool {
match event.content.membership {
MembershipState::Join => self.handle_join(event),
MembershipState::Leave => self.handle_leave(event),
@ -149,7 +149,7 @@ pub struct Client {
/// The current sync token that should be used for the next sync call.
pub sync_token: Option<Token>,
/// A map of the rooms our user is joined in.
pub joined_rooms: HashMap<RoomId, Arc<Mutex<Room>>>,
pub joined_rooms: HashMap<RoomId, Arc<RwLock<Room>>>,
}
impl Client {
@ -185,11 +185,11 @@ impl Client {
self.session = Some(session);
}
fn get_or_create_room(&mut self, room_id: &RoomId) -> &mut Arc<Mutex<Room>> {
fn get_or_create_room(&mut self, room_id: &RoomId) -> &mut Arc<RwLock<Room>> {
self.joined_rooms
.entry(room_id.to_string())
.or_insert(
Arc::new(Mutex::new(Room::new(
Arc::new(RwLock::new(Room::new(
room_id,
&self
.session
@ -209,7 +209,7 @@ impl Client {
/// Returns true if the membership list of the room changed, false
/// otherwise.
pub fn receive_joined_timeline_event(&mut self, room_id: &RoomId, event: &RoomEvent) -> bool {
let mut room = self.get_or_create_room(room_id).lock().unwrap();
let mut room = self.get_or_create_room(room_id).write().unwrap();
room.receive_timeline_event(event)
}
@ -222,7 +222,7 @@ impl Client {
/// Returns true if the membership list of the room changed, false
/// otherwise.
pub fn receive_joined_state_event(&mut self, room_id: &RoomId, event: &StateEvent) -> bool {
let mut room = self.get_or_create_room(room_id).lock().unwrap();
let mut room = self.get_or_create_room(room_id).write().unwrap();
room.receive_state_event(event)
}
}