async_client: Make our sync method threadsafe across yield points.

master
Damir Jelić 2019-11-26 20:34:11 +01:00
parent 19b9927de6
commit afcb68cc0e
1 changed files with 37 additions and 22 deletions

View File

@ -1,7 +1,7 @@
use futures::future::{BoxFuture, Future, FutureExt}; use futures::future::{BoxFuture, Future, FutureExt};
use std::convert::{TryFrom, TryInto}; use std::convert::{TryFrom, TryInto};
use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, RwLock}; use std::sync::{Arc, Mutex, RwLock};
use http::Method as HttpMethod; use http::Method as HttpMethod;
use http::Response as HttpResponse; use http::Response as HttpResponse;
@ -12,6 +12,7 @@ use url::Url;
use ruma_api::Endpoint; use ruma_api::Endpoint;
use ruma_events::collections::all::RoomEvent; use ruma_events::collections::all::RoomEvent;
use ruma_events::room::message::MessageEventContent; use ruma_events::room::message::MessageEventContent;
use ruma_events::EventResult;
pub use ruma_events::EventType; pub use ruma_events::EventType;
use ruma_identifiers::RoomId; use ruma_identifiers::RoomId;
@ -23,7 +24,7 @@ use crate::session::Session;
use crate::VERSION; use crate::VERSION;
type RoomEventCallbackF = type RoomEventCallbackF =
Box<dyn FnMut(Arc<RwLock<Room>>, Arc<RoomEvent>) -> BoxFuture<'static, ()>>; Box<dyn FnMut(Arc<RwLock<Room>>, Arc<RoomEvent>) -> BoxFuture<'static, ()> + Send>;
#[derive(Clone)] #[derive(Clone)]
pub struct AsyncClient { pub struct AsyncClient {
@ -36,7 +37,7 @@ pub struct AsyncClient {
/// The transaction id. /// The transaction id.
transaction_id: Arc<AtomicU64>, transaction_id: Arc<AtomicU64>,
/// Event futures /// Event futures
event_futures: Arc<RwLock<Vec<RoomEventCallbackF>>>, event_futures: Arc<Mutex<Vec<RoomEventCallbackF>>>,
} }
#[derive(Default, Debug)] #[derive(Default, Debug)]
@ -152,8 +153,7 @@ impl AsyncClient {
http_client, http_client,
base_client: Arc::new(RwLock::new(BaseClient::new(session))), base_client: Arc::new(RwLock::new(BaseClient::new(session))),
transaction_id: Arc::new(AtomicU64::new(0)), transaction_id: Arc::new(AtomicU64::new(0)),
// event_callbacks: Vec::new(), event_futures: Arc::new(Mutex::new(Vec::new())),
event_futures: Arc::new(RwLock::new(Vec::new())),
}) })
} }
@ -163,11 +163,11 @@ impl AsyncClient {
pub fn add_event_future<C: 'static>( pub fn add_event_future<C: 'static>(
&mut self, &mut self,
mut callback: impl FnMut(Arc<RwLock<Room>>, Arc<RoomEvent>) -> C + 'static, mut callback: impl FnMut(Arc<RwLock<Room>>, Arc<RoomEvent>) -> C + 'static + Send,
) where ) where
C: Future<Output = ()> + Send, C: Future<Output = ()> + Send,
{ {
let mut futures = self.event_futures.write().unwrap(); let mut futures = self.event_futures.lock().unwrap();
let future = move |room, event| callback(room, event).boxed(); let future = move |room, event| callback(room, event).boxed();
@ -210,19 +210,20 @@ impl AsyncClient {
let response = self.send(request).await?; let response = self.send(request).await?;
let mut client = self.base_client.write().unwrap();
for (room_id, room) in &response.rooms.join { for (room_id, room) in &response.rooms.join {
let room_id = room_id.to_string(); let room_id = room_id.to_string();
for event in &room.state.events { let matrix_room = {
let event = match event.clone().into_result() { let mut client = self.base_client.write().unwrap();
Ok(e) => e,
Err(e) => continue,
};
client.receive_joined_state_event(&room_id, &event); for event in &room.state.events {
} if let EventResult::Ok(e) = event {
client.receive_joined_state_event(&room_id, &e);
}
}
client.joined_rooms.get(&room_id).unwrap().clone()
};
for event in &room.timeline.events { for event in &room.timeline.events {
let event = match event.clone().into_result() { let event = match event.clone().into_result() {
@ -230,24 +231,38 @@ impl AsyncClient {
Err(e) => continue, Err(e) => continue,
}; };
client.receive_joined_timeline_event(&room_id, &event); {
let mut client = self.base_client.write().unwrap();
let room = client.joined_rooms.get(&room_id).unwrap(); client.receive_joined_timeline_event(&room_id, &event);
let mut cb_futures = self.event_futures.write().unwrap(); }
let event = Arc::new(event.clone()); let event = Arc::new(event.clone());
for cb in &mut cb_futures.iter_mut() { let callbacks = {
cb(room.clone(), event.clone()).await; let mut cb_futures = self.event_futures.lock().unwrap();
let mut callbacks = Vec::new();
for cb in &mut cb_futures.iter_mut() {
callbacks.push(cb(matrix_room.clone(), event.clone()));
}
callbacks
};
for cb in callbacks {
cb.await;
} }
} }
let mut client = self.base_client.write().unwrap();
client.receive_sync_response(&response); client.receive_sync_response(&response);
} }
Ok(response) Ok(response)
} }
async fn sync_forever() {}
async fn send<Request: Endpoint>(&self, request: Request) -> Result<Request::Response, Error> { async fn send<Request: Endpoint>(&self, request: Request) -> Result<Request::Response, Error> {
let request: http::Request<Vec<u8>> = request.try_into()?; let request: http::Request<Vec<u8>> = request.try_into()?;
let url = request.uri(); let url = request.uri();