async_client: Make our sync method threadsafe across yield points.

master
Damir Jelić 2019-11-26 20:34:11 +01:00
parent 19b9927de6
commit afcb68cc0e
1 changed files with 37 additions and 22 deletions

View File

@ -1,7 +1,7 @@
use futures::future::{BoxFuture, Future, FutureExt};
use std::convert::{TryFrom, TryInto};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, RwLock};
use std::sync::{Arc, Mutex, RwLock};
use http::Method as HttpMethod;
use http::Response as HttpResponse;
@ -12,6 +12,7 @@ use url::Url;
use ruma_api::Endpoint;
use ruma_events::collections::all::RoomEvent;
use ruma_events::room::message::MessageEventContent;
use ruma_events::EventResult;
pub use ruma_events::EventType;
use ruma_identifiers::RoomId;
@ -23,7 +24,7 @@ use crate::session::Session;
use crate::VERSION;
type RoomEventCallbackF =
Box<dyn FnMut(Arc<RwLock<Room>>, Arc<RoomEvent>) -> BoxFuture<'static, ()>>;
Box<dyn FnMut(Arc<RwLock<Room>>, Arc<RoomEvent>) -> BoxFuture<'static, ()> + Send>;
#[derive(Clone)]
pub struct AsyncClient {
@ -36,7 +37,7 @@ pub struct AsyncClient {
/// The transaction id.
transaction_id: Arc<AtomicU64>,
/// Event futures
event_futures: Arc<RwLock<Vec<RoomEventCallbackF>>>,
event_futures: Arc<Mutex<Vec<RoomEventCallbackF>>>,
}
#[derive(Default, Debug)]
@ -152,8 +153,7 @@ impl AsyncClient {
http_client,
base_client: Arc::new(RwLock::new(BaseClient::new(session))),
transaction_id: Arc::new(AtomicU64::new(0)),
// event_callbacks: Vec::new(),
event_futures: Arc::new(RwLock::new(Vec::new())),
event_futures: Arc::new(Mutex::new(Vec::new())),
})
}
@ -163,11 +163,11 @@ impl AsyncClient {
pub fn add_event_future<C: 'static>(
&mut self,
mut callback: impl FnMut(Arc<RwLock<Room>>, Arc<RoomEvent>) -> C + 'static,
mut callback: impl FnMut(Arc<RwLock<Room>>, Arc<RoomEvent>) -> C + 'static + Send,
) where
C: Future<Output = ()> + Send,
{
let mut futures = self.event_futures.write().unwrap();
let mut futures = self.event_futures.lock().unwrap();
let future = move |room, event| callback(room, event).boxed();
@ -210,19 +210,20 @@ impl AsyncClient {
let response = self.send(request).await?;
let mut client = self.base_client.write().unwrap();
for (room_id, room) in &response.rooms.join {
let room_id = room_id.to_string();
for event in &room.state.events {
let event = match event.clone().into_result() {
Ok(e) => e,
Err(e) => continue,
};
let matrix_room = {
let mut client = self.base_client.write().unwrap();
client.receive_joined_state_event(&room_id, &event);
for event in &room.state.events {
if let EventResult::Ok(e) = event {
client.receive_joined_state_event(&room_id, &e);
}
}
client.joined_rooms.get(&room_id).unwrap().clone()
};
for event in &room.timeline.events {
let event = match event.clone().into_result() {
@ -230,24 +231,38 @@ impl AsyncClient {
Err(e) => continue,
};
{
let mut client = self.base_client.write().unwrap();
client.receive_joined_timeline_event(&room_id, &event);
let room = client.joined_rooms.get(&room_id).unwrap();
let mut cb_futures = self.event_futures.write().unwrap();
}
let event = Arc::new(event.clone());
let callbacks = {
let mut cb_futures = self.event_futures.lock().unwrap();
let mut callbacks = Vec::new();
for cb in &mut cb_futures.iter_mut() {
cb(room.clone(), event.clone()).await;
callbacks.push(cb(matrix_room.clone(), event.clone()));
}
callbacks
};
for cb in callbacks {
cb.await;
}
}
let mut client = self.base_client.write().unwrap();
client.receive_sync_response(&response);
}
Ok(response)
}
async fn sync_forever() {}
async fn send<Request: Endpoint>(&self, request: Request) -> Result<Request::Response, Error> {
let request: http::Request<Vec<u8>> = request.try_into()?;
let url = request.uri();