diff --git a/examples/command_bot.rs b/examples/command_bot.rs index a2d58fb0..220b0ea1 100644 --- a/examples/command_bot.rs +++ b/examples/command_bot.rs @@ -12,9 +12,6 @@ use url::Url; struct CommandBot { /// This clone of the `AsyncClient` will send requests to the server, /// while the other keeps us in sync with the server using `sync_forever`. - /// - /// The type parameter is for the `StateStore` trait specifying the `Store` - /// type for state storage, here we don't care. client: AsyncClient, } diff --git a/src/async_client.rs b/src/async_client.rs index 5c3b33c9..aae5f5d8 100644 --- a/src/async_client.rs +++ b/src/async_client.rs @@ -16,7 +16,6 @@ use std::collections::HashMap; use std::convert::{TryFrom, TryInto}; use std::ops::Deref; -use std::path::{Path, PathBuf}; use std::result::Result as StdResult; use std::sync::Arc; use std::time::{Duration, Instant}; @@ -48,7 +47,7 @@ use crate::api; use crate::base_client::Client as BaseClient; use crate::models::Room; use crate::session::Session; -use crate::state::{ClientState, JsonStore, StateStore}; +use crate::state::{ClientState, StateStore}; use crate::VERSION; use crate::{Error, EventEmitter, Result}; @@ -65,8 +64,6 @@ pub struct AsyncClient { http_client: reqwest::Client, /// User session data. pub(crate) base_client: Arc>, - /// The path to the default state store. - state_store_path: Option, } impl std::fmt::Debug for AsyncClient { @@ -93,7 +90,6 @@ pub struct AsyncClientConfig { proxy: Option, user_agent: Option, disable_ssl_verification: bool, - store_path: Option, state_store: Option>, } @@ -103,7 +99,6 @@ impl std::fmt::Debug for AsyncClientConfig { .field("proxy", &self.proxy) .field("user_agent", &self.user_agent) .field("disable_ssl_verification", &self.disable_ssl_verification) - .field("store_path", &self.store_path) .finish() } } @@ -148,15 +143,6 @@ impl AsyncClientConfig { Ok(self) } - /// Set the path for the default `StateStore`. - /// - /// When the path is set `AsyncClient` will set the state store - /// to `JsonStore`. - pub fn state_store_path>(mut self, path: P) -> Self { - self.store_path = Some(path.as_ref().to_owned()); - self - } - /// Set a custom implementation of a `StateStore`. /// /// The state store should be "connected" before being set. @@ -289,11 +275,8 @@ impl AsyncClient { let http_client = http_client.default_headers(headers).build()?; let mut base_client = BaseClient::new(session)?; - if let Some(path) = config.store_path.as_ref() { - let store = JsonStore; - store.open(path)?; - base_client.state_store = Some(Box::new(store)); - } else if let Some(store) = config.state_store { + + if let Some(store) = config.state_store { base_client.state_store = Some(store); }; @@ -301,7 +284,6 @@ impl AsyncClient { homeserver, http_client, base_client: Arc::new(RwLock::new(base_client)), - state_store_path: config.store_path, }) } @@ -382,15 +364,8 @@ impl AsyncClient { let response = self.send(request).await?; let mut client = self.base_client.write().await; - // TODO avoid allocation somehow? - let path = self.state_store_path.as_ref().map(|p| { - let mut path = PathBuf::from(p); - path.push(response.user_id.to_string()); - path - }); - client - .receive_login_response(&response, path.as_ref()) - .await?; + + client.receive_login_response(&response).await?; Ok(response) } @@ -706,11 +681,9 @@ impl AsyncClient { if updated { if let Some(store) = self.base_client.read().await.state_store.as_ref() { - if let Some(path) = self.state_store_path.as_ref() { - store - .store_room_state(&path, matrix_room.read().await.deref()) - .await?; - }; + store + .store_room_state(matrix_room.read().await.deref()) + .await?; } } } @@ -720,10 +693,8 @@ impl AsyncClient { if updated { if let Some(store) = client.state_store.as_ref() { - if let Some(path) = self.state_store_path.as_ref() { - let state = ClientState::from_base_client(&client); - store.store_client_state(&path, state).await?; - }; + let state = ClientState::from_base_client(&client); + store.store_client_state(state).await?; } } Ok(response) diff --git a/src/base_client.rs b/src/base_client.rs index 92c191f6..f1d04759 100644 --- a/src/base_client.rs +++ b/src/base_client.rs @@ -19,7 +19,6 @@ use std::collections::HashSet; use std::fmt; use std::sync::Arc; -use std::path::PathBuf; #[cfg(feature = "encryption")] use std::result::Result as StdResult; @@ -145,7 +144,6 @@ impl Client { pub async fn receive_login_response( &mut self, response: &api::session::login::Response, - store_path: Option<&PathBuf>, ) -> Result<()> { let session = Session { access_token: response.access_token.clone(), @@ -160,25 +158,23 @@ impl Client { *olm = Some(OlmMachine::new(&response.user_id, &response.device_id)?); } - if let Some(path) = store_path { - if let Some(store) = self.state_store.as_ref() { - let ClientState { - session, - sync_token, - ignored_users, - push_ruleset, - } = store.load_client_state(&path).await?; - let mut rooms = store.load_all_rooms(&path).await?; + if let Some(store) = self.state_store.as_ref() { + let ClientState { + session, + sync_token, + ignored_users, + push_ruleset, + } = store.load_client_state().await?; + let mut rooms = store.load_all_rooms().await?; - self.joined_rooms = rooms - .drain() - .map(|(k, room)| (k, Arc::new(RwLock::new(room)))) - .collect(); - self.session = session; - self.sync_token = sync_token; - self.ignored_users = ignored_users; - self.push_ruleset = push_ruleset; - } + self.joined_rooms = rooms + .drain() + .map(|(k, room)| (k, Arc::new(RwLock::new(room)))) + .collect(); + self.session = session; + self.sync_token = sync_token; + self.ignored_users = ignored_users; + self.push_ruleset = push_ruleset; } Ok(()) diff --git a/src/state/mod.rs b/src/state/mod.rs index e6513e98..f643570e 100644 --- a/src/state/mod.rs +++ b/src/state/mod.rs @@ -14,7 +14,6 @@ // limitations under the License. use std::collections::HashMap; -use std::path::Path; pub mod state_store; pub use state_store::JsonStore; @@ -61,20 +60,18 @@ impl ClientState { /// Abstraction around the data store to avoid unnecessary request on client initialization. #[async_trait::async_trait] pub trait StateStore: Send + Sync { - /// Set up connections or check files exist to load/save state. - fn open(&self, path: &Path) -> Result<()>; /// Loads the state of `BaseClient` through `StateStore::Store` type. - async fn load_client_state(&self, path: &Path) -> Result; + async fn load_client_state(&self) -> Result; /// Load the state of a single `Room` by `RoomId`. - async fn load_room_state(&self, path: &Path, room_id: &RoomId) -> Result; + async fn load_room_state(&self, room_id: &RoomId) -> Result; /// Load the state of all `Room`s. /// /// This will be mapped over in the client in order to store `Room`s in an async safe way. - async fn load_all_rooms(&self, path: &Path) -> Result>; + async fn load_all_rooms(&self) -> Result>; /// Save the current state of the `BaseClient` using the `StateStore::Store` type. - async fn store_client_state(&self, path: &Path, _: ClientState) -> Result<()>; + async fn store_client_state(&self, _: ClientState) -> Result<()>; /// Save the state a single `Room`. - async fn store_room_state(&self, path: &Path, _: &Room) -> Result<()>; + async fn store_room_state(&self, _: &Room) -> Result<()>; } #[cfg(test)] diff --git a/src/state/state_store.rs b/src/state/state_store.rs index 83ea4e63..98d2f7c3 100644 --- a/src/state/state_store.rs +++ b/src/state/state_store.rs @@ -1,26 +1,36 @@ use std::collections::HashMap; use std::fs::{self, OpenOptions}; use std::io::{BufReader, BufWriter, Write}; -use std::path::Path; +use std::path::{Path, PathBuf}; use super::{ClientState, StateStore}; use crate::identifiers::RoomId; use crate::{Error, Result, Room}; /// A default `StateStore` implementation that serializes state as json /// and saves it to disk. -pub struct JsonStore; +pub struct JsonStore { + path: PathBuf, +} + +impl JsonStore { + /// Create a `JsonStore` to store the client and room state. + /// + /// Checks if the provided path exists and creates the directories if not. + pub fn open>(path: P) -> Result { + let p = path.as_ref(); + if !p.exists() { + std::fs::create_dir_all(p)?; + } + Ok(Self { + path: p.to_path_buf(), + }) + } +} #[async_trait::async_trait] impl StateStore for JsonStore { - fn open(&self, path: &Path) -> Result<()> { - if !path.exists() { - std::fs::create_dir_all(path)?; - } - Ok(()) - } - - async fn load_client_state(&self, path: &Path) -> Result { - let mut path = path.to_path_buf(); + async fn load_client_state(&self) -> Result { + let mut path = self.path.clone(); path.push("client.json"); let file = OpenOptions::new().read(true).open(path)?; @@ -28,8 +38,8 @@ impl StateStore for JsonStore { serde_json::from_reader(reader).map_err(Error::from) } - async fn load_room_state(&self, path: &Path, room_id: &RoomId) -> Result { - let mut path = path.to_path_buf(); + async fn load_room_state(&self, room_id: &RoomId) -> Result { + let mut path = self.path.clone(); path.push(&format!("rooms/{}.json", room_id)); let file = OpenOptions::new().read(true).open(path)?; @@ -37,8 +47,8 @@ impl StateStore for JsonStore { serde_json::from_reader(reader).map_err(Error::from) } - async fn load_all_rooms(&self, path: &Path) -> Result> { - let mut path = path.to_path_buf(); + async fn load_all_rooms(&self) -> Result> { + let mut path = self.path.clone(); path.push("rooms"); let mut rooms_map = HashMap::new(); @@ -61,8 +71,8 @@ impl StateStore for JsonStore { Ok(rooms_map) } - async fn store_client_state(&self, path: &Path, state: ClientState) -> Result<()> { - let mut path = path.to_path_buf(); + async fn store_client_state(&self, state: ClientState) -> Result<()> { + let mut path = self.path.clone(); path.push("client.json"); if !Path::new(&path).exists() { @@ -84,8 +94,8 @@ impl StateStore for JsonStore { Ok(()) } - async fn store_room_state(&self, path: &Path, room: &Room) -> Result<()> { - let mut path = path.to_path_buf(); + async fn store_room_state(&self, room: &Room) -> Result<()> { + let mut path = self.path.clone(); path.push(&format!("rooms/{}.json", room.room_id)); if !Path::new(&path).exists() { @@ -153,10 +163,11 @@ mod test { } async fn test_store_client_state() { - let store = JsonStore; + let path: &Path = &PATH; + let store = JsonStore::open(path).unwrap(); let state = ClientState::default(); - store.store_client_state(&PATH, state).await.unwrap(); - let loaded = store.load_client_state(&PATH).await.unwrap(); + store.store_client_state(state).await.unwrap(); + let loaded = store.load_client_state().await.unwrap(); assert_eq!(loaded, ClientState::default()); } @@ -166,14 +177,15 @@ mod test { } async fn test_store_room_state() { - let store = JsonStore; + let path: &Path = &PATH; + let store = JsonStore::open(path).unwrap(); let id = RoomId::try_from("!roomid:example.com").unwrap(); let user = UserId::try_from("@example:example.com").unwrap(); let room = Room::new(&id, &user); - store.store_room_state(&PATH, &room).await.unwrap(); - let loaded = store.load_room_state(&PATH, &id).await.unwrap(); + store.store_room_state(&room).await.unwrap(); + let loaded = store.load_room_state(&id).await.unwrap(); assert_eq!(loaded, Room::new(&id, &user)); } @@ -183,14 +195,15 @@ mod test { } async fn test_load_rooms() { - let store = JsonStore; + let path: &Path = &PATH; + let store = JsonStore::open(path).unwrap(); let id = RoomId::try_from("!roomid:example.com").unwrap(); let user = UserId::try_from("@example:example.com").unwrap(); let room = Room::new(&id, &user); - store.store_room_state(&PATH, &room).await.unwrap(); - let loaded = store.load_all_rooms(&PATH).await.unwrap(); + store.store_room_state(&room).await.unwrap(); + let loaded = store.load_all_rooms().await.unwrap(); assert_eq!(&room, loaded.get(&id).unwrap()); } @@ -221,20 +234,19 @@ mod test { .with_body_from_file("tests/data/login_response.json") .create(); - let mut path = PATH.clone(); - path.push(session.user_id.to_string()); - // a sync response to populate our JSON store with user_id added to path - let config = AsyncClientConfig::default().state_store_path(&path); + let path: &Path = &PATH; + // a sync response to populate our JSON store + let config = + AsyncClientConfig::default().state_store(Box::new(JsonStore::open(path).unwrap())); let client = AsyncClient::new_with_config(homeserver.clone(), Some(session.clone()), config) .unwrap(); let sync_settings = SyncSettings::new().timeout(std::time::Duration::from_millis(3000)); let _ = client.sync(sync_settings).await.unwrap(); - // remove user_id as login will set this - path.pop(); // once logged in without syncing the client is updated from the state store - let config = AsyncClientConfig::default().state_store_path(&path); + let config = + AsyncClientConfig::default().state_store(Box::new(JsonStore::open(path).unwrap())); let client = AsyncClient::new_with_config(homeserver, None, config).unwrap(); client .login("example", "wordpass", None, None)