From afcb68cc0ed004aa2ece5a9d92e5b3e137fa4e5e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Damir=20Jeli=C4=87?= Date: Tue, 26 Nov 2019 20:34:11 +0100 Subject: [PATCH] async_client: Make our sync method threadsafe across yield points. --- src/async_client.rs | 59 ++++++++++++++++++++++++++++----------------- 1 file changed, 37 insertions(+), 22 deletions(-) diff --git a/src/async_client.rs b/src/async_client.rs index b6534f3e..ff1808b6 100644 --- a/src/async_client.rs +++ b/src/async_client.rs @@ -1,7 +1,7 @@ use futures::future::{BoxFuture, Future, FutureExt}; use std::convert::{TryFrom, TryInto}; use std::sync::atomic::{AtomicU64, Ordering}; -use std::sync::{Arc, RwLock}; +use std::sync::{Arc, Mutex, RwLock}; use http::Method as HttpMethod; use http::Response as HttpResponse; @@ -12,6 +12,7 @@ use url::Url; use ruma_api::Endpoint; use ruma_events::collections::all::RoomEvent; use ruma_events::room::message::MessageEventContent; +use ruma_events::EventResult; pub use ruma_events::EventType; use ruma_identifiers::RoomId; @@ -23,7 +24,7 @@ use crate::session::Session; use crate::VERSION; type RoomEventCallbackF = - Box>, Arc) -> BoxFuture<'static, ()>>; + Box>, Arc) -> BoxFuture<'static, ()> + Send>; #[derive(Clone)] pub struct AsyncClient { @@ -36,7 +37,7 @@ pub struct AsyncClient { /// The transaction id. transaction_id: Arc, /// Event futures - event_futures: Arc>>, + event_futures: Arc>>, } #[derive(Default, Debug)] @@ -152,8 +153,7 @@ impl AsyncClient { http_client, base_client: Arc::new(RwLock::new(BaseClient::new(session))), transaction_id: Arc::new(AtomicU64::new(0)), - // event_callbacks: Vec::new(), - event_futures: Arc::new(RwLock::new(Vec::new())), + event_futures: Arc::new(Mutex::new(Vec::new())), }) } @@ -163,11 +163,11 @@ impl AsyncClient { pub fn add_event_future( &mut self, - mut callback: impl FnMut(Arc>, Arc) -> C + 'static, + mut callback: impl FnMut(Arc>, Arc) -> C + 'static + Send, ) where C: Future + Send, { - let mut futures = self.event_futures.write().unwrap(); + let mut futures = self.event_futures.lock().unwrap(); let future = move |room, event| callback(room, event).boxed(); @@ -210,19 +210,20 @@ impl AsyncClient { let response = self.send(request).await?; - let mut client = self.base_client.write().unwrap(); - for (room_id, room) in &response.rooms.join { let room_id = room_id.to_string(); - for event in &room.state.events { - let event = match event.clone().into_result() { - Ok(e) => e, - Err(e) => continue, - }; + let matrix_room = { + let mut client = self.base_client.write().unwrap(); - client.receive_joined_state_event(&room_id, &event); - } + for event in &room.state.events { + if let EventResult::Ok(e) = event { + client.receive_joined_state_event(&room_id, &e); + } + } + + client.joined_rooms.get(&room_id).unwrap().clone() + }; for event in &room.timeline.events { let event = match event.clone().into_result() { @@ -230,24 +231,38 @@ impl AsyncClient { Err(e) => continue, }; - client.receive_joined_timeline_event(&room_id, &event); - - let room = client.joined_rooms.get(&room_id).unwrap(); - let mut cb_futures = self.event_futures.write().unwrap(); + { + let mut client = self.base_client.write().unwrap(); + client.receive_joined_timeline_event(&room_id, &event); + } let event = Arc::new(event.clone()); - for cb in &mut cb_futures.iter_mut() { - cb(room.clone(), event.clone()).await; + let callbacks = { + let mut cb_futures = self.event_futures.lock().unwrap(); + let mut callbacks = Vec::new(); + + for cb in &mut cb_futures.iter_mut() { + callbacks.push(cb(matrix_room.clone(), event.clone())); + } + + callbacks + }; + + for cb in callbacks { + cb.await; } } + let mut client = self.base_client.write().unwrap(); client.receive_sync_response(&response); } Ok(response) } + async fn sync_forever() {} + async fn send(&self, request: Request) -> Result { let request: http::Request> = request.try_into()?; let url = request.uri();