diff --git a/matrix_sdk/src/client.rs b/matrix_sdk/src/client.rs index 52f6188e..d1968b02 100644 --- a/matrix_sdk/src/client.rs +++ b/matrix_sdk/src/client.rs @@ -27,8 +27,6 @@ use std::{ #[cfg(feature = "encryption")] use dashmap::DashMap; -#[cfg(feature = "encryption")] -use futures::TryStreamExt; use futures_timer::Delay as sleep; use http::HeaderValue; use mime::{self, Mime}; @@ -1155,10 +1153,10 @@ impl Client { let _guard = mutex.lock().await; { - let room = self.get_joined_room(room_id).unwrap(); - let members = room.joined_user_ids().await; - let members_iter: Vec = members.try_collect().await?; - self.claim_one_time_keys(&mut members_iter.iter()).await?; + let joined = self.store().get_joined_user_ids(room_id).await?; + let invited = self.store().get_invited_user_ids(room_id).await?; + let members = joined.iter().chain(&invited); + self.claim_one_time_keys(members).await?; }; let response = self.share_group_session(room_id).await; @@ -1831,7 +1829,7 @@ impl Client { #[cfg(feature = "encryption")] #[cfg_attr(feature = "docs", doc(cfg(encryption)))] #[instrument(skip(users))] - async fn claim_one_time_keys(&self, users: &mut impl Iterator) -> Result<()> { + async fn claim_one_time_keys(&self, users: impl Iterator) -> Result<()> { let _lock = self.key_claim_lock.lock().await; if let Some((request_id, request)) = self.base_client.get_missing_sessions(users).await? { @@ -2277,7 +2275,6 @@ mod test { get_public_rooms, get_public_rooms_filtered, register::RegistrationKind, Client, Invite3pid, Session, SyncSettings, Url, }; - use futures::TryStreamExt; use matrix_sdk_base::RoomMember; use matrix_sdk_common::{ api::r0::{ @@ -2893,7 +2890,7 @@ mod test { let room = client .get_joined_room(&room_id!("!SVkFJHzfwvuaIEawgC:localhost")) .unwrap(); - let members: Vec = room.active_members().await.try_collect().await.unwrap(); + let members: Vec = room.active_members().await.unwrap(); assert_eq!(1, members.len()); // assert!(room.power_levels.is_some()) diff --git a/matrix_sdk_base/examples/state_inspector.rs b/matrix_sdk_base/examples/state_inspector.rs index a91a64bb..f4b5b35d 100644 --- a/matrix_sdk_base/examples/state_inspector.rs +++ b/matrix_sdk_base/examples/state_inspector.rs @@ -1,6 +1,6 @@ use std::{convert::TryFrom, fmt::Debug, io, sync::Arc}; -use futures::{executor::block_on, TryStreamExt}; +use futures::executor::block_on; use serde::Serialize; use atty::Stream; @@ -86,14 +86,7 @@ impl InspectorHelper { } fn complete_rooms(&self, arg: Option<&&str>) -> Vec { - let rooms: Vec = block_on(async { - self.store - .get_room_infos() - .await - .try_collect() - .await - .unwrap() - }); + let rooms: Vec = block_on(async { self.store.get_room_infos().await.unwrap() }); rooms .into_iter() @@ -286,24 +279,12 @@ impl Inspector { } async fn list_rooms(&self) { - let rooms: Vec = self - .store - .get_room_infos() - .await - .try_collect() - .await - .unwrap(); + let rooms: Vec = self.store.get_room_infos().await.unwrap(); self.printer.pretty_print_struct(&rooms); } async fn get_profiles(&self, room_id: RoomId) { - let joined: Vec = self - .store - .get_joined_user_ids(&room_id) - .await - .try_collect() - .await - .unwrap(); + let joined: Vec = self.store.get_joined_user_ids(&room_id).await.unwrap(); for member in joined { let event = self.store.get_profile(&room_id, &member).await.unwrap(); @@ -312,13 +293,7 @@ impl Inspector { } async fn get_members(&self, room_id: RoomId) { - let joined: Vec = self - .store - .get_joined_user_ids(&room_id) - .await - .try_collect() - .await - .unwrap(); + let joined: Vec = self.store.get_joined_user_ids(&room_id).await.unwrap(); for member in joined { let event = self diff --git a/matrix_sdk_base/src/client.rs b/matrix_sdk_base/src/client.rs index ef0e3957..b406e5ba 100644 --- a/matrix_sdk_base/src/client.rs +++ b/matrix_sdk_base/src/client.rs @@ -23,8 +23,6 @@ use std::{ time::SystemTime, }; -#[cfg(feature = "encryption")] -use futures::{StreamExt, TryStreamExt}; use matrix_sdk_common::{ api::r0 as api, deserialized_responses::{ @@ -762,11 +760,11 @@ impl BaseClient { // The room turned on encryption in this sync, we need // to get also all the existing users and mark them for // tracking. - let joined = self.store.get_joined_user_ids(&room_id).await; - let invited = self.store.get_invited_user_ids(&room_id).await; + let joined = self.store.get_joined_user_ids(&room_id).await?; + let invited = self.store.get_invited_user_ids(&room_id).await?; - let user_ids: Vec = joined.chain(invited).try_collect().await?; - o.update_tracked_users(&user_ids).await + let user_ids = joined.iter().chain(&invited); + o.update_tracked_users(user_ids).await } o.update_tracked_users(&user_ids).await @@ -1030,7 +1028,7 @@ impl BaseClient { #[cfg_attr(feature = "docs", doc(cfg(encryption)))] pub async fn get_missing_sessions( &self, - users: &mut impl Iterator, + users: impl Iterator, ) -> Result> { let olm = self.olm.lock().await; @@ -1048,11 +1046,11 @@ impl BaseClient { match &*olm { Some(o) => { - let joined = self.store.get_joined_user_ids(room_id).await; - let invited = self.store.get_invited_user_ids(room_id).await; - let members: Vec = joined.chain(invited).try_collect().await?; + let joined = self.store.get_joined_user_ids(room_id).await?; + let invited = self.store.get_invited_user_ids(room_id).await?; + let members = joined.iter().chain(&invited); Ok( - o.share_group_session(room_id, members.iter(), EncryptionSettings::default()) + o.share_group_session(room_id, members, EncryptionSettings::default()) .await?, ) } diff --git a/matrix_sdk_base/src/rooms/normal.rs b/matrix_sdk_base/src/rooms/normal.rs index 73e73bf3..481386c8 100644 --- a/matrix_sdk_base/src/rooms/normal.rs +++ b/matrix_sdk_base/src/rooms/normal.rs @@ -19,7 +19,7 @@ use std::{ use futures::{ future, - stream::{self, Stream, StreamExt, TryStreamExt}, + stream::{self, StreamExt}, }; use matrix_sdk_common::{ api::r0::sync::sync_events::RoomSummary as RumaSummary, @@ -38,7 +38,7 @@ use tracing::info; use crate::{ deserialized_responses::UnreadNotificationsCount, - store::{Result as StoreResult, SledStore}, + store::{Result as StoreResult, StateStore}, }; use super::{BaseRoomInfo, RoomMember}; @@ -48,7 +48,7 @@ pub struct Room { room_id: Arc, own_user_id: Arc, inner: Arc>, - store: SledStore, + store: Arc>, } #[derive(Clone, Debug, Default, Serialize, Deserialize)] @@ -72,7 +72,7 @@ pub enum RoomType { impl Room { pub(crate) fn new( own_user_id: &UserId, - store: SledStore, + store: Arc>, room_id: &RoomId, room_type: RoomType, ) -> Self { @@ -91,7 +91,11 @@ impl Room { Self::restore(own_user_id, store, room_info) } - pub(crate) fn restore(own_user_id: &UserId, store: SledStore, room_info: RoomInfo) -> Self { + pub(crate) fn restore( + own_user_id: &UserId, + store: Arc>, + room_info: RoomInfo, + ) -> Self { Self { own_user_id: Arc::new(own_user_id.clone()), room_id: room_info.room_id.clone(), @@ -197,35 +201,40 @@ impl Room { self.calculate_name().await } - pub async fn joined_user_ids(&self) -> impl Stream> { + pub async fn joined_user_ids(&self) -> StoreResult> { self.store.get_joined_user_ids(self.room_id()).await } - pub async fn joined_members(&self) -> impl Stream> + '_ { - let joined = self.store.get_joined_user_ids(self.room_id()).await; + pub async fn joined_members(&self) -> StoreResult> { + let joined = self.store.get_joined_user_ids(self.room_id()).await?; + let mut members = Vec::new(); - joined.filter_map(move |u| async move { - let ret = match u { - Ok(u) => self.get_member(&u).await, - Err(e) => Err(e), - }; + for u in joined { + let m = self.get_member(&u).await?; - ret.transpose() - }) + if let Some(member) = m { + members.push(member); + } + } + + Ok(members) } - pub async fn active_members(&self) -> impl Stream> + '_ { - let joined = self.store.get_joined_user_ids(self.room_id()).await; - let invited = self.store.get_invited_user_ids(self.room_id()).await; + pub async fn active_members(&self) -> StoreResult> { + let joined = self.store.get_joined_user_ids(self.room_id()).await?; + let invited = self.store.get_invited_user_ids(self.room_id()).await?; - joined.chain(invited).filter_map(move |u| async move { - let ret = match u { - Ok(u) => self.get_member(&u).await, - Err(e) => Err(e), - }; + let mut members = Vec::new(); - ret.transpose() - }) + for u in joined.iter().chain(&invited) { + let m = self.get_member(u).await?; + + if let Some(member) = m { + members.push(member); + } + } + + Ok(members) } /// Calculate the canonical display name of the room, taking into account @@ -257,13 +266,13 @@ impl Room { let is_own_member = |m: &RoomMember| m.user_id() == &*self.own_user_id; let is_own_user_id = |u: &str| u == self.own_user_id().as_str(); - let members: StoreResult> = if summary.heroes.is_empty() { + let members: Vec = if summary.heroes.is_empty() { self.active_members() - .await - .try_filter(|u| future::ready(!is_own_member(&u))) + .await? + .into_iter() + .filter(|u| !is_own_member(&u)) .take(5) - .try_collect() - .await + .collect() } else { let members: Vec<_> = stream::iter(summary.heroes.iter()) .filter(|u| future::ready(!is_own_user_id(u))) @@ -274,7 +283,9 @@ impl Room { .collect() .await; - members.into_iter().collect() + let members: StoreResult> = members.into_iter().collect(); + + members? }; info!( @@ -288,7 +299,7 @@ impl Room { let inner = self.inner.read().unwrap(); Ok(inner .base_info - .calculate_room_name(joined, invited, members?)) + .calculate_room_name(joined, invited, members)) } pub(crate) fn clone_info(&self) -> RoomInfo { diff --git a/matrix_sdk_base/src/rooms/stripped.rs b/matrix_sdk_base/src/rooms/stripped.rs index 61a9e677..67610806 100644 --- a/matrix_sdk_base/src/rooms/stripped.rs +++ b/matrix_sdk_base/src/rooms/stripped.rs @@ -20,7 +20,7 @@ use matrix_sdk_common::{ }; use serde::{Deserialize, Serialize}; -use crate::store::SledStore; +use crate::store::StateStore; use super::BaseRoomInfo; @@ -29,11 +29,11 @@ pub struct StrippedRoom { room_id: Arc, own_user_id: Arc, inner: Arc>, - store: SledStore, + store: Arc>, } impl StrippedRoom { - pub fn new(own_user_id: &UserId, store: SledStore, room_id: &RoomId) -> Self { + pub fn new(own_user_id: &UserId, store: Arc>, room_id: &RoomId) -> Self { let room_id = Arc::new(room_id.clone()); Self { diff --git a/matrix_sdk_base/src/store/mod.rs b/matrix_sdk_base/src/store/mod.rs index 196aaf48..9cb24673 100644 --- a/matrix_sdk_base/src/store/mod.rs +++ b/matrix_sdk_base/src/store/mod.rs @@ -15,14 +15,15 @@ use std::{collections::BTreeMap, ops::Deref, path::Path, sync::Arc}; use dashmap::DashMap; -use futures::stream::StreamExt; use matrix_sdk_common::{ + async_trait, events::{ presence::PresenceEvent, room::member::MemberEventContent, AnyBasicEvent, - AnyStrippedStateEvent, AnySyncStateEvent, EventContent, + AnyStrippedStateEvent, AnySyncStateEvent, EventContent, EventType, }, identifiers::{RoomId, UserId}, locks::RwLock, + AsyncTraitDeps, }; use crate::{ @@ -53,9 +54,48 @@ pub enum StoreError { /// A `StateStore` specific result type. pub type Result = std::result::Result; +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +pub trait StateStore: AsyncTraitDeps { + async fn save_filter(&self, filter_name: &str, filter_id: &str) -> Result<()>; + + async fn save_changes(&self, changes: &StateChanges) -> Result<()>; + + async fn get_filter(&self, filter_id: &str) -> Result>; + + async fn get_sync_token(&self) -> Result>; + + async fn get_presence_event(&self, user_id: &UserId) -> Result>; + + async fn get_state_event( + &self, + room_id: &RoomId, + event_type: EventType, + state_key: &str, + ) -> Result>; + + async fn get_profile( + &self, + room_id: &RoomId, + user_id: &UserId, + ) -> Result>; + + async fn get_member_event( + &self, + room_id: &RoomId, + state_key: &UserId, + ) -> Result>; + + async fn get_invited_user_ids(&self, room_id: &RoomId) -> Result>; + + async fn get_joined_user_ids(&self, room_id: &RoomId) -> Result>; + + async fn get_room_infos(&self) -> Result>; +} + #[derive(Debug, Clone)] pub struct Store { - inner: SledStore, + inner: Arc>, session: Arc>>, sync_token: Arc>>, rooms: Arc>, @@ -69,7 +109,7 @@ impl Store { inner: SledStore, ) -> Self { Self { - inner, + inner: Arc::new(Box::new(inner)), session, sync_token, rooms: DashMap::new().into(), @@ -78,11 +118,8 @@ impl Store { } pub(crate) async fn restore_session(&self, session: Session) -> Result<()> { - let mut infos = self.inner.get_room_infos().await; - - // TODO restore stripped rooms. - while let Some(info) = infos.next().await { - let room = Room::restore(&session.user_id, self.inner.clone(), info?); + for info in self.inner.get_room_infos().await?.into_iter() { + let room = Room::restore(&session.user_id, self.inner.clone(), info); self.rooms.insert(room.room_id().to_owned(), room); } @@ -172,7 +209,7 @@ impl Store { } impl Deref for Store { - type Target = SledStore; + type Target = Box; fn deref(&self) -> &Self::Target { &self.inner diff --git a/matrix_sdk_base/src/store/sled_store/mod.rs b/matrix_sdk_base/src/store/sled_store/mod.rs index 8affb114..b65c24c9 100644 --- a/matrix_sdk_base/src/store/sled_store/mod.rs +++ b/matrix_sdk_base/src/store/sled_store/mod.rs @@ -16,8 +16,12 @@ mod store_key; use std::{convert::TryFrom, path::Path, sync::Arc, time::SystemTime}; -use futures::stream::{self, Stream}; +use futures::{ + stream::{self, Stream}, + TryStreamExt, +}; use matrix_sdk_common::{ + async_trait, events::{ presence::PresenceEvent, room::member::{MemberEventContent, MembershipState}, @@ -37,7 +41,7 @@ use crate::deserialized_responses::MemberEvent; use self::store_key::{EncryptedEvent, StoreKey}; -use super::{Result, RoomInfo, StateChanges, StoreError}; +use super::{Result, RoomInfo, StateChanges, StateStore, StoreError}; #[derive(Debug, Serialize, Deserialize)] pub enum DatabaseType { @@ -477,6 +481,66 @@ impl SledStore { } } +#[async_trait] +impl StateStore for SledStore { + async fn save_filter(&self, filter_name: &str, filter_id: &str) -> Result<()> { + self.save_filter(filter_name, filter_id).await + } + + async fn save_changes(&self, changes: &StateChanges) -> Result<()> { + self.save_changes(changes).await + } + + async fn get_filter(&self, filter_id: &str) -> Result> { + self.get_filter(filter_id).await + } + + async fn get_sync_token(&self) -> Result> { + self.get_sync_token().await + } + + async fn get_presence_event(&self, user_id: &UserId) -> Result> { + self.get_presence_event(user_id).await + } + + async fn get_state_event( + &self, + room_id: &RoomId, + event_type: EventType, + state_key: &str, + ) -> Result> { + self.get_state_event(room_id, event_type, state_key).await + } + + async fn get_profile( + &self, + room_id: &RoomId, + user_id: &UserId, + ) -> Result> { + self.get_profile(room_id, user_id).await + } + + async fn get_member_event( + &self, + room_id: &RoomId, + state_key: &UserId, + ) -> Result> { + self.get_member_event(room_id, state_key).await + } + + async fn get_invited_user_ids(&self, room_id: &RoomId) -> Result> { + self.get_invited_user_ids(room_id).await.try_collect().await + } + + async fn get_joined_user_ids(&self, room_id: &RoomId) -> Result> { + self.get_joined_user_ids(room_id).await.try_collect().await + } + + async fn get_room_infos(&self) -> Result> { + self.get_room_infos().await.try_collect().await + } +} + #[cfg(test)] mod test { use std::{convert::TryFrom, time::SystemTime}; diff --git a/matrix_sdk_crypto/src/machine.rs b/matrix_sdk_crypto/src/machine.rs index 59fe3bb1..f6bdf31c 100644 --- a/matrix_sdk_crypto/src/machine.rs +++ b/matrix_sdk_crypto/src/machine.rs @@ -502,7 +502,7 @@ impl OlmMachine { /// [`mark_request_as_sent`]: #method.mark_request_as_sent pub async fn get_missing_sessions( &self, - users: &mut impl Iterator, + users: impl Iterator, ) -> OlmResult> { self.session_manager.get_missing_sessions(users).await } diff --git a/matrix_sdk_crypto/src/session_manager/sessions.rs b/matrix_sdk_crypto/src/session_manager/sessions.rs index 95afe118..311cb572 100644 --- a/matrix_sdk_crypto/src/session_manager/sessions.rs +++ b/matrix_sdk_crypto/src/session_manager/sessions.rs @@ -188,7 +188,7 @@ impl SessionManager { /// [`receive_keys_claim_response`]: #method.receive_keys_claim_response pub async fn get_missing_sessions( &self, - users: &mut impl Iterator, + users: impl Iterator, ) -> OlmResult> { let mut missing = BTreeMap::new();