diff --git a/matrix_sdk_base/src/client.rs b/matrix_sdk_base/src/client.rs index 3685b997..3fc1cdc8 100644 --- a/matrix_sdk_base/src/client.rs +++ b/matrix_sdk_base/src/client.rs @@ -23,8 +23,6 @@ use std::{ time::SystemTime, }; -use dashmap::DashMap; - #[cfg(feature = "encryption")] use futures::StreamExt; use matrix_sdk_common::{ @@ -64,10 +62,10 @@ use crate::{ JoinedRoom as JoinedRoomResponse, LeftRoom as LeftRoomResponse, MemberEvent, Presence, Rooms, State, StrippedMemberEvent, SyncResponse, Timeline, }, - rooms::{Room, RoomInfo, RoomType, StrippedRoom, StrippedRoomInfo}, + rooms::{RoomInfo, RoomType, StrippedRoomInfo}, session::Session, - store::{StateChanges, Store}, - EventEmitter, InvitedRoom, JoinedRoom, LeftRoom, RoomState, + store::{SledStore, StateChanges, Store}, + EventEmitter, JoinedRoom, RoomState, }; pub type Token = String; @@ -179,8 +177,6 @@ pub struct BaseClient { pub(crate) sync_token: Arc>>, /// Database store: Store, - rooms: Arc>, - stripped_rooms: Arc>, #[cfg(feature = "encryption")] olm: Arc>>, #[cfg(feature = "encryption")] @@ -288,17 +284,18 @@ impl BaseClient { pub fn new_with_config(config: BaseClientConfig) -> Result { let store = if let Some(path) = &config.store_path { info!("Opening store in path {}", path.display()); - Store::open_with_path(path) + SledStore::open_with_path(path) } else { - Store::open() + SledStore::open() }; + let session = Arc::new(RwLock::new(None)); + let store = Store::new(session.clone(), store); + Ok(BaseClient { - session: RwLock::new(None).into(), + session, sync_token: RwLock::new(None).into(), store, - rooms: DashMap::new().into(), - stripped_rooms: DashMap::new().into(), #[cfg(feature = "encryption")] olm: Mutex::new(None).into(), #[cfg(feature = "encryption")] @@ -424,32 +421,6 @@ impl BaseClient { *self.event_emitter.write().await = Some(emitter); } - async fn get_or_create_stripped_room(&self, room_id: &RoomId) -> StrippedRoom { - let session = self.session.read().await; - let user_id = &session - .as_ref() - .expect("Creating room while not being logged in") - .user_id; - - self.stripped_rooms - .entry(room_id.clone()) - .or_insert_with(|| StrippedRoom::new(user_id, self.store.clone(), room_id)) - .clone() - } - - async fn get_or_create_room(&self, room_id: &RoomId, room_type: RoomType) -> Room { - let session = self.session.read().await; - let user_id = &session - .as_ref() - .expect("Creating room while not being logged in") - .user_id; - - self.rooms - .entry(room_id.clone()) - .or_insert_with(|| Room::new(user_id, self.store.clone(), room_id, room_type)) - .clone() - } - async fn handle_timeline( &self, room_id: &RoomId, @@ -673,7 +644,10 @@ impl BaseClient { let mut rooms = Rooms::default(); for (room_id, new_info) in response.rooms.join { - let room = self.get_or_create_room(&room_id, RoomType::Joined).await; + let room = self + .store + .get_or_create_room(&room_id, RoomType::Joined) + .await; let mut room_info = room.clone_info(); room_info.update_summary(&new_info.summary); @@ -750,7 +724,10 @@ impl BaseClient { } for (room_id, new_info) in response.rooms.leave { - let room = self.get_or_create_room(&room_id, RoomType::Left).await; + let room = self + .store + .get_or_create_room(&room_id, RoomType::Left) + .await; let mut room_info = room.clone_info(); room_info.mark_as_left(); @@ -782,13 +759,16 @@ impl BaseClient { for (room_id, new_info) in response.rooms.invite { { - let room = self.get_or_create_room(&room_id, RoomType::Invited).await; + let room = self + .store + .get_or_create_room(&room_id, RoomType::Invited) + .await; let mut room_info = room.clone_info(); room_info.mark_as_invited(); changes.add_room(room_info); } - let room = self.get_or_create_stripped_room(&room_id).await; + let room = self.store.get_or_create_stripped_room(&room_id).await; let mut room_info = room.clone_info(); let (state, members, state_events) = @@ -862,7 +842,7 @@ impl BaseClient { async fn apply_changes(&self, changes: &StateChanges) { for (room_id, room_info) in &changes.room_infos { - if let Some(room) = self.get_bare_room(&room_id) { + if let Some(room) = self.store.get_bare_room(&room_id) { room.update_summary(room_info.clone()) } } @@ -873,7 +853,7 @@ impl BaseClient { room_id: &RoomId, response: &api::membership::get_member_events::Response, ) -> Result<()> { - if let Some(room) = self.get_bare_room(room_id) { + if let Some(room) = self.store.get_bare_room(room_id) { let mut room_info = room.clone_info(); room_info.mark_members_synced(); @@ -1038,30 +1018,12 @@ impl BaseClient { } } - fn get_bare_room(&self, room_id: &RoomId) -> Option { - #[allow(clippy::map_clone)] - self.rooms.get(room_id).map(|r| r.clone()) - } - pub fn get_joined_room(&self, room_id: &RoomId) -> Option { - self.get_room(room_id).map(|r| r.joined()).flatten() + self.store.get_joined_room(room_id) } pub fn get_room(&self, room_id: &RoomId) -> Option { - self.get_bare_room(room_id) - .map(|r| match r.room_type() { - RoomType::Joined => Some(RoomState::Joined(JoinedRoom { inner: r })), - RoomType::Left => Some(RoomState::Left(LeftRoom { inner: r })), - RoomType::Invited => self - .get_stripped_room(room_id) - .map(|r| RoomState::Invited(InvitedRoom { inner: r })), - }) - .flatten() - } - - fn get_stripped_room(&self, room_id: &RoomId) -> Option { - #[allow(clippy::map_clone)] - self.stripped_rooms.get(room_id).map(|r| r.clone()) + self.store.get_room(room_id) } /// Encrypt a message event content. diff --git a/matrix_sdk_base/src/rooms/normal.rs b/matrix_sdk_base/src/rooms/normal.rs index 7ee21458..9e997dea 100644 --- a/matrix_sdk_base/src/rooms/normal.rs +++ b/matrix_sdk_base/src/rooms/normal.rs @@ -29,7 +29,7 @@ use matrix_sdk_common::{ use serde::{Deserialize, Serialize}; use tracing::info; -use crate::{responses::UnreadNotificationsCount, store::Store}; +use crate::{responses::UnreadNotificationsCount, store::SledStore}; use super::{BaseRoomInfo, RoomMember}; @@ -38,7 +38,7 @@ pub struct Room { room_id: Arc, own_user_id: Arc, inner: Arc>, - store: Store, + store: SledStore, } #[derive(Clone, Debug, Default, Serialize, Deserialize)] @@ -60,7 +60,12 @@ pub enum RoomType { } impl Room { - pub fn new(own_user_id: &UserId, store: Store, room_id: &RoomId, room_type: RoomType) -> Self { + pub fn new( + own_user_id: &UserId, + store: SledStore, + room_id: &RoomId, + room_type: RoomType, + ) -> Self { let room_id = Arc::new(room_id.clone()); Self { diff --git a/matrix_sdk_base/src/rooms/stripped.rs b/matrix_sdk_base/src/rooms/stripped.rs index 0a2adf18..805904ee 100644 --- a/matrix_sdk_base/src/rooms/stripped.rs +++ b/matrix_sdk_base/src/rooms/stripped.rs @@ -20,7 +20,7 @@ use matrix_sdk_common::{ }; use serde::{Deserialize, Serialize}; -use crate::store::Store; +use crate::store::SledStore; use super::BaseRoomInfo; @@ -29,11 +29,11 @@ pub struct StrippedRoom { room_id: Arc, own_user_id: Arc, inner: Arc>, - store: Store, + store: SledStore, } impl StrippedRoom { - pub fn new(own_user_id: &UserId, store: Store, room_id: &RoomId) -> Self { + pub fn new(own_user_id: &UserId, store: SledStore, room_id: &RoomId) -> Self { let room_id = Arc::new(room_id.clone()); Self { diff --git a/matrix_sdk_base/src/store.rs b/matrix_sdk_base/src/store.rs index 7489a6a4..1518457f 100644 --- a/matrix_sdk_base/src/store.rs +++ b/matrix_sdk_base/src/store.rs @@ -1,5 +1,8 @@ -use std::{collections::BTreeMap, convert::TryFrom, path::Path, time::SystemTime}; +use std::{ + collections::BTreeMap, convert::TryFrom, ops::Deref, path::Path, sync::Arc, time::SystemTime, +}; +use dashmap::DashMap; use futures::stream::{self, Stream}; use matrix_sdk_common::{ events::{ @@ -7,6 +10,7 @@ use matrix_sdk_common::{ AnyStrippedStateEvent, AnySyncStateEvent, EventContent, EventType, }, identifiers::{RoomId, UserId}, + locks::RwLock, }; use sled::{transaction::TransactionResult, Config, Db, Transactional, Tree}; @@ -14,12 +18,91 @@ use tracing::info; use crate::{ responses::{MemberEvent, StrippedMemberEvent}, - rooms::RoomInfo, - Session, + rooms::{RoomInfo, RoomType, StrippedRoom}, + InvitedRoom, JoinedRoom, LeftRoom, Room, RoomState, Session, }; #[derive(Debug, Clone)] pub struct Store { + inner: SledStore, + session: Arc>>, + rooms: Arc>, + stripped_rooms: Arc>, +} + +impl Store { + pub fn new(session: Arc>>, inner: SledStore) -> Self { + Self { + inner, + session, + rooms: DashMap::new().into(), + stripped_rooms: DashMap::new().into(), + } + } + + pub(crate) fn get_bare_room(&self, room_id: &RoomId) -> Option { + #[allow(clippy::map_clone)] + self.rooms.get(room_id).map(|r| r.clone()) + } + + pub(crate) fn get_joined_room(&self, room_id: &RoomId) -> Option { + self.get_room(room_id).map(|r| r.joined()).flatten() + } + + pub(crate) fn get_room(&self, room_id: &RoomId) -> Option { + self.get_bare_room(room_id) + .map(|r| match r.room_type() { + RoomType::Joined => Some(RoomState::Joined(JoinedRoom { inner: r })), + RoomType::Left => Some(RoomState::Left(LeftRoom { inner: r })), + RoomType::Invited => self + .get_stripped_room(room_id) + .map(|r| RoomState::Invited(InvitedRoom { inner: r })), + }) + .flatten() + } + + fn get_stripped_room(&self, room_id: &RoomId) -> Option { + #[allow(clippy::map_clone)] + self.stripped_rooms.get(room_id).map(|r| r.clone()) + } + + pub(crate) async fn get_or_create_stripped_room(&self, room_id: &RoomId) -> StrippedRoom { + let session = self.session.read().await; + let user_id = &session + .as_ref() + .expect("Creating room while not being logged in") + .user_id; + + self.stripped_rooms + .entry(room_id.clone()) + .or_insert_with(|| StrippedRoom::new(user_id, self.inner.clone(), room_id)) + .clone() + } + + pub(crate) async fn get_or_create_room(&self, room_id: &RoomId, room_type: RoomType) -> Room { + let session = self.session.read().await; + let user_id = &session + .as_ref() + .expect("Creating room while not being logged in") + .user_id; + + self.rooms + .entry(room_id.clone()) + .or_insert_with(|| Room::new(user_id, self.inner.clone(), room_id, room_type)) + .clone() + } +} + +impl Deref for Store { + type Target = SledStore; + + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +#[derive(Debug, Clone)] +pub struct SledStore { inner: Db, session: Tree, account_data: Tree, @@ -109,7 +192,7 @@ impl From for StateChanges { } } -impl Store { +impl SledStore { fn open_helper(db: Db) -> Self { let session = db.open_tree("session").unwrap(); let account_data = db.open_tree("account_data").unwrap(); @@ -147,14 +230,14 @@ impl Store { pub fn open() -> Self { let db = Config::new().temporary(true).open().unwrap(); - Store::open_helper(db) + SledStore::open_helper(db) } pub fn open_with_path(path: impl AsRef) -> Self { let path = path.as_ref().join("matrix-sdk-state"); let db = Config::new().temporary(false).path(path).open().unwrap(); - Store::open_helper(db) + SledStore::open_helper(db) } pub async fn save_filter(&self, filter_name: &str, filter_id: &str) { @@ -393,7 +476,7 @@ mod test { }; use matrix_sdk_test::async_test; - use super::{StateChanges, Store}; + use super::{SledStore, StateChanges}; use crate::{responses::MemberEvent, Session}; fn user_id() -> UserId { @@ -432,7 +515,7 @@ mod test { access_token: "TEST_TOKEN".to_owned(), }; - let store = Store::open(); + let store = SledStore::open(); store.save_changes(&session.clone().into()).await; let stored_session = store.get_session().unwrap(); @@ -442,7 +525,7 @@ mod test { #[async_test] async fn test_member_saving() { - let store = Store::open(); + let store = SledStore::open(); let room_id = room_id!("!test:localhost"); let user_id = user_id(); diff --git a/matrix_sdk_crypto/Cargo.toml b/matrix_sdk_crypto/Cargo.toml index 63353cb2..9c46d866 100644 --- a/matrix_sdk_crypto/Cargo.toml +++ b/matrix_sdk_crypto/Cargo.toml @@ -30,7 +30,6 @@ getrandom = "0.2.0" serde = { version = "1.0.117", features = ["derive", "rc"] } serde_json = "1.0.59" zeroize = { version = "1.1.1", features = ["zeroize_derive"] } -url = "2.1.1" # Misc dependencies thiserror = "1.0.21"