base: Re-introduce a state store trait.

master
Damir Jelić 2021-01-21 12:08:16 +01:00
parent 2bcc0afb91
commit de4df4e50a
9 changed files with 181 additions and 99 deletions

View File

@ -27,8 +27,6 @@ use std::{
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
use dashmap::DashMap; use dashmap::DashMap;
#[cfg(feature = "encryption")]
use futures::TryStreamExt;
use futures_timer::Delay as sleep; use futures_timer::Delay as sleep;
use http::HeaderValue; use http::HeaderValue;
use mime::{self, Mime}; use mime::{self, Mime};
@ -1155,10 +1153,10 @@ impl Client {
let _guard = mutex.lock().await; let _guard = mutex.lock().await;
{ {
let room = self.get_joined_room(room_id).unwrap(); let joined = self.store().get_joined_user_ids(room_id).await?;
let members = room.joined_user_ids().await; let invited = self.store().get_invited_user_ids(room_id).await?;
let members_iter: Vec<UserId> = members.try_collect().await?; let members = joined.iter().chain(&invited);
self.claim_one_time_keys(&mut members_iter.iter()).await?; self.claim_one_time_keys(members).await?;
}; };
let response = self.share_group_session(room_id).await; let response = self.share_group_session(room_id).await;
@ -1831,7 +1829,7 @@ impl Client {
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
#[cfg_attr(feature = "docs", doc(cfg(encryption)))] #[cfg_attr(feature = "docs", doc(cfg(encryption)))]
#[instrument(skip(users))] #[instrument(skip(users))]
async fn claim_one_time_keys(&self, users: &mut impl Iterator<Item = &UserId>) -> Result<()> { async fn claim_one_time_keys(&self, users: impl Iterator<Item = &UserId>) -> Result<()> {
let _lock = self.key_claim_lock.lock().await; let _lock = self.key_claim_lock.lock().await;
if let Some((request_id, request)) = self.base_client.get_missing_sessions(users).await? { if let Some((request_id, request)) = self.base_client.get_missing_sessions(users).await? {
@ -2277,7 +2275,6 @@ mod test {
get_public_rooms, get_public_rooms_filtered, register::RegistrationKind, Client, get_public_rooms, get_public_rooms_filtered, register::RegistrationKind, Client,
Invite3pid, Session, SyncSettings, Url, Invite3pid, Session, SyncSettings, Url,
}; };
use futures::TryStreamExt;
use matrix_sdk_base::RoomMember; use matrix_sdk_base::RoomMember;
use matrix_sdk_common::{ use matrix_sdk_common::{
api::r0::{ api::r0::{
@ -2893,7 +2890,7 @@ mod test {
let room = client let room = client
.get_joined_room(&room_id!("!SVkFJHzfwvuaIEawgC:localhost")) .get_joined_room(&room_id!("!SVkFJHzfwvuaIEawgC:localhost"))
.unwrap(); .unwrap();
let members: Vec<RoomMember> = room.active_members().await.try_collect().await.unwrap(); let members: Vec<RoomMember> = room.active_members().await.unwrap();
assert_eq!(1, members.len()); assert_eq!(1, members.len());
// assert!(room.power_levels.is_some()) // assert!(room.power_levels.is_some())

View File

@ -1,6 +1,6 @@
use std::{convert::TryFrom, fmt::Debug, io, sync::Arc}; use std::{convert::TryFrom, fmt::Debug, io, sync::Arc};
use futures::{executor::block_on, TryStreamExt}; use futures::executor::block_on;
use serde::Serialize; use serde::Serialize;
use atty::Stream; use atty::Stream;
@ -86,14 +86,7 @@ impl InspectorHelper {
} }
fn complete_rooms(&self, arg: Option<&&str>) -> Vec<Pair> { fn complete_rooms(&self, arg: Option<&&str>) -> Vec<Pair> {
let rooms: Vec<RoomInfo> = block_on(async { let rooms: Vec<RoomInfo> = block_on(async { self.store.get_room_infos().await.unwrap() });
self.store
.get_room_infos()
.await
.try_collect()
.await
.unwrap()
});
rooms rooms
.into_iter() .into_iter()
@ -286,24 +279,12 @@ impl Inspector {
} }
async fn list_rooms(&self) { async fn list_rooms(&self) {
let rooms: Vec<RoomInfo> = self let rooms: Vec<RoomInfo> = self.store.get_room_infos().await.unwrap();
.store
.get_room_infos()
.await
.try_collect()
.await
.unwrap();
self.printer.pretty_print_struct(&rooms); self.printer.pretty_print_struct(&rooms);
} }
async fn get_profiles(&self, room_id: RoomId) { async fn get_profiles(&self, room_id: RoomId) {
let joined: Vec<UserId> = self let joined: Vec<UserId> = self.store.get_joined_user_ids(&room_id).await.unwrap();
.store
.get_joined_user_ids(&room_id)
.await
.try_collect()
.await
.unwrap();
for member in joined { for member in joined {
let event = self.store.get_profile(&room_id, &member).await.unwrap(); let event = self.store.get_profile(&room_id, &member).await.unwrap();
@ -312,13 +293,7 @@ impl Inspector {
} }
async fn get_members(&self, room_id: RoomId) { async fn get_members(&self, room_id: RoomId) {
let joined: Vec<UserId> = self let joined: Vec<UserId> = self.store.get_joined_user_ids(&room_id).await.unwrap();
.store
.get_joined_user_ids(&room_id)
.await
.try_collect()
.await
.unwrap();
for member in joined { for member in joined {
let event = self let event = self

View File

@ -23,8 +23,6 @@ use std::{
time::SystemTime, time::SystemTime,
}; };
#[cfg(feature = "encryption")]
use futures::{StreamExt, TryStreamExt};
use matrix_sdk_common::{ use matrix_sdk_common::{
api::r0 as api, api::r0 as api,
deserialized_responses::{ deserialized_responses::{
@ -762,11 +760,11 @@ impl BaseClient {
// The room turned on encryption in this sync, we need // The room turned on encryption in this sync, we need
// to get also all the existing users and mark them for // to get also all the existing users and mark them for
// tracking. // tracking.
let joined = self.store.get_joined_user_ids(&room_id).await; let joined = self.store.get_joined_user_ids(&room_id).await?;
let invited = self.store.get_invited_user_ids(&room_id).await; let invited = self.store.get_invited_user_ids(&room_id).await?;
let user_ids: Vec<UserId> = joined.chain(invited).try_collect().await?; let user_ids = joined.iter().chain(&invited);
o.update_tracked_users(&user_ids).await o.update_tracked_users(user_ids).await
} }
o.update_tracked_users(&user_ids).await o.update_tracked_users(&user_ids).await
@ -1030,7 +1028,7 @@ impl BaseClient {
#[cfg_attr(feature = "docs", doc(cfg(encryption)))] #[cfg_attr(feature = "docs", doc(cfg(encryption)))]
pub async fn get_missing_sessions( pub async fn get_missing_sessions(
&self, &self,
users: &mut impl Iterator<Item = &UserId>, users: impl Iterator<Item = &UserId>,
) -> Result<Option<(Uuid, KeysClaimRequest)>> { ) -> Result<Option<(Uuid, KeysClaimRequest)>> {
let olm = self.olm.lock().await; let olm = self.olm.lock().await;
@ -1048,11 +1046,11 @@ impl BaseClient {
match &*olm { match &*olm {
Some(o) => { Some(o) => {
let joined = self.store.get_joined_user_ids(room_id).await; let joined = self.store.get_joined_user_ids(room_id).await?;
let invited = self.store.get_invited_user_ids(room_id).await; let invited = self.store.get_invited_user_ids(room_id).await?;
let members: Vec<UserId> = joined.chain(invited).try_collect().await?; let members = joined.iter().chain(&invited);
Ok( Ok(
o.share_group_session(room_id, members.iter(), EncryptionSettings::default()) o.share_group_session(room_id, members, EncryptionSettings::default())
.await?, .await?,
) )
} }

View File

@ -19,7 +19,7 @@ use std::{
use futures::{ use futures::{
future, future,
stream::{self, Stream, StreamExt, TryStreamExt}, stream::{self, StreamExt},
}; };
use matrix_sdk_common::{ use matrix_sdk_common::{
api::r0::sync::sync_events::RoomSummary as RumaSummary, api::r0::sync::sync_events::RoomSummary as RumaSummary,
@ -38,7 +38,7 @@ use tracing::info;
use crate::{ use crate::{
deserialized_responses::UnreadNotificationsCount, deserialized_responses::UnreadNotificationsCount,
store::{Result as StoreResult, SledStore}, store::{Result as StoreResult, StateStore},
}; };
use super::{BaseRoomInfo, RoomMember}; use super::{BaseRoomInfo, RoomMember};
@ -48,7 +48,7 @@ pub struct Room {
room_id: Arc<RoomId>, room_id: Arc<RoomId>,
own_user_id: Arc<UserId>, own_user_id: Arc<UserId>,
inner: Arc<SyncRwLock<RoomInfo>>, inner: Arc<SyncRwLock<RoomInfo>>,
store: SledStore, store: Arc<Box<dyn StateStore>>,
} }
#[derive(Clone, Debug, Default, Serialize, Deserialize)] #[derive(Clone, Debug, Default, Serialize, Deserialize)]
@ -72,7 +72,7 @@ pub enum RoomType {
impl Room { impl Room {
pub(crate) fn new( pub(crate) fn new(
own_user_id: &UserId, own_user_id: &UserId,
store: SledStore, store: Arc<Box<dyn StateStore>>,
room_id: &RoomId, room_id: &RoomId,
room_type: RoomType, room_type: RoomType,
) -> Self { ) -> Self {
@ -91,7 +91,11 @@ impl Room {
Self::restore(own_user_id, store, room_info) Self::restore(own_user_id, store, room_info)
} }
pub(crate) fn restore(own_user_id: &UserId, store: SledStore, room_info: RoomInfo) -> Self { pub(crate) fn restore(
own_user_id: &UserId,
store: Arc<Box<dyn StateStore>>,
room_info: RoomInfo,
) -> Self {
Self { Self {
own_user_id: Arc::new(own_user_id.clone()), own_user_id: Arc::new(own_user_id.clone()),
room_id: room_info.room_id.clone(), room_id: room_info.room_id.clone(),
@ -197,35 +201,40 @@ impl Room {
self.calculate_name().await self.calculate_name().await
} }
pub async fn joined_user_ids(&self) -> impl Stream<Item = StoreResult<UserId>> { pub async fn joined_user_ids(&self) -> StoreResult<Vec<UserId>> {
self.store.get_joined_user_ids(self.room_id()).await self.store.get_joined_user_ids(self.room_id()).await
} }
pub async fn joined_members(&self) -> impl Stream<Item = StoreResult<RoomMember>> + '_ { pub async fn joined_members(&self) -> StoreResult<Vec<RoomMember>> {
let joined = self.store.get_joined_user_ids(self.room_id()).await; let joined = self.store.get_joined_user_ids(self.room_id()).await?;
let mut members = Vec::new();
joined.filter_map(move |u| async move { for u in joined {
let ret = match u { let m = self.get_member(&u).await?;
Ok(u) => self.get_member(&u).await,
Err(e) => Err(e),
};
ret.transpose() if let Some(member) = m {
}) members.push(member);
}
}
Ok(members)
} }
pub async fn active_members(&self) -> impl Stream<Item = StoreResult<RoomMember>> + '_ { pub async fn active_members(&self) -> StoreResult<Vec<RoomMember>> {
let joined = self.store.get_joined_user_ids(self.room_id()).await; let joined = self.store.get_joined_user_ids(self.room_id()).await?;
let invited = self.store.get_invited_user_ids(self.room_id()).await; let invited = self.store.get_invited_user_ids(self.room_id()).await?;
joined.chain(invited).filter_map(move |u| async move { let mut members = Vec::new();
let ret = match u {
Ok(u) => self.get_member(&u).await,
Err(e) => Err(e),
};
ret.transpose() for u in joined.iter().chain(&invited) {
}) let m = self.get_member(u).await?;
if let Some(member) = m {
members.push(member);
}
}
Ok(members)
} }
/// Calculate the canonical display name of the room, taking into account /// Calculate the canonical display name of the room, taking into account
@ -257,13 +266,13 @@ impl Room {
let is_own_member = |m: &RoomMember| m.user_id() == &*self.own_user_id; let is_own_member = |m: &RoomMember| m.user_id() == &*self.own_user_id;
let is_own_user_id = |u: &str| u == self.own_user_id().as_str(); let is_own_user_id = |u: &str| u == self.own_user_id().as_str();
let members: StoreResult<Vec<RoomMember>> = if summary.heroes.is_empty() { let members: Vec<RoomMember> = if summary.heroes.is_empty() {
self.active_members() self.active_members()
.await .await?
.try_filter(|u| future::ready(!is_own_member(&u))) .into_iter()
.filter(|u| !is_own_member(&u))
.take(5) .take(5)
.try_collect() .collect()
.await
} else { } else {
let members: Vec<_> = stream::iter(summary.heroes.iter()) let members: Vec<_> = stream::iter(summary.heroes.iter())
.filter(|u| future::ready(!is_own_user_id(u))) .filter(|u| future::ready(!is_own_user_id(u)))
@ -274,7 +283,9 @@ impl Room {
.collect() .collect()
.await; .await;
members.into_iter().collect() let members: StoreResult<Vec<_>> = members.into_iter().collect();
members?
}; };
info!( info!(
@ -288,7 +299,7 @@ impl Room {
let inner = self.inner.read().unwrap(); let inner = self.inner.read().unwrap();
Ok(inner Ok(inner
.base_info .base_info
.calculate_room_name(joined, invited, members?)) .calculate_room_name(joined, invited, members))
} }
pub(crate) fn clone_info(&self) -> RoomInfo { pub(crate) fn clone_info(&self) -> RoomInfo {

View File

@ -20,7 +20,7 @@ use matrix_sdk_common::{
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::store::SledStore; use crate::store::StateStore;
use super::BaseRoomInfo; use super::BaseRoomInfo;
@ -29,11 +29,11 @@ pub struct StrippedRoom {
room_id: Arc<RoomId>, room_id: Arc<RoomId>,
own_user_id: Arc<UserId>, own_user_id: Arc<UserId>,
inner: Arc<SyncMutex<StrippedRoomInfo>>, inner: Arc<SyncMutex<StrippedRoomInfo>>,
store: SledStore, store: Arc<Box<dyn StateStore>>,
} }
impl StrippedRoom { impl StrippedRoom {
pub fn new(own_user_id: &UserId, store: SledStore, room_id: &RoomId) -> Self { pub fn new(own_user_id: &UserId, store: Arc<Box<dyn StateStore>>, room_id: &RoomId) -> Self {
let room_id = Arc::new(room_id.clone()); let room_id = Arc::new(room_id.clone());
Self { Self {

View File

@ -15,14 +15,15 @@
use std::{collections::BTreeMap, ops::Deref, path::Path, sync::Arc}; use std::{collections::BTreeMap, ops::Deref, path::Path, sync::Arc};
use dashmap::DashMap; use dashmap::DashMap;
use futures::stream::StreamExt;
use matrix_sdk_common::{ use matrix_sdk_common::{
async_trait,
events::{ events::{
presence::PresenceEvent, room::member::MemberEventContent, AnyBasicEvent, presence::PresenceEvent, room::member::MemberEventContent, AnyBasicEvent,
AnyStrippedStateEvent, AnySyncStateEvent, EventContent, AnyStrippedStateEvent, AnySyncStateEvent, EventContent, EventType,
}, },
identifiers::{RoomId, UserId}, identifiers::{RoomId, UserId},
locks::RwLock, locks::RwLock,
AsyncTraitDeps,
}; };
use crate::{ use crate::{
@ -53,9 +54,48 @@ pub enum StoreError {
/// A `StateStore` specific result type. /// A `StateStore` specific result type.
pub type Result<T> = std::result::Result<T, StoreError>; pub type Result<T> = std::result::Result<T, StoreError>;
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
pub trait StateStore: AsyncTraitDeps {
async fn save_filter(&self, filter_name: &str, filter_id: &str) -> Result<()>;
async fn save_changes(&self, changes: &StateChanges) -> Result<()>;
async fn get_filter(&self, filter_id: &str) -> Result<Option<String>>;
async fn get_sync_token(&self) -> Result<Option<String>>;
async fn get_presence_event(&self, user_id: &UserId) -> Result<Option<PresenceEvent>>;
async fn get_state_event(
&self,
room_id: &RoomId,
event_type: EventType,
state_key: &str,
) -> Result<Option<AnySyncStateEvent>>;
async fn get_profile(
&self,
room_id: &RoomId,
user_id: &UserId,
) -> Result<Option<MemberEventContent>>;
async fn get_member_event(
&self,
room_id: &RoomId,
state_key: &UserId,
) -> Result<Option<MemberEvent>>;
async fn get_invited_user_ids(&self, room_id: &RoomId) -> Result<Vec<UserId>>;
async fn get_joined_user_ids(&self, room_id: &RoomId) -> Result<Vec<UserId>>;
async fn get_room_infos(&self) -> Result<Vec<RoomInfo>>;
}
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Store { pub struct Store {
inner: SledStore, inner: Arc<Box<dyn StateStore>>,
session: Arc<RwLock<Option<Session>>>, session: Arc<RwLock<Option<Session>>>,
sync_token: Arc<RwLock<Option<String>>>, sync_token: Arc<RwLock<Option<String>>>,
rooms: Arc<DashMap<RoomId, Room>>, rooms: Arc<DashMap<RoomId, Room>>,
@ -69,7 +109,7 @@ impl Store {
inner: SledStore, inner: SledStore,
) -> Self { ) -> Self {
Self { Self {
inner, inner: Arc::new(Box::new(inner)),
session, session,
sync_token, sync_token,
rooms: DashMap::new().into(), rooms: DashMap::new().into(),
@ -78,11 +118,8 @@ impl Store {
} }
pub(crate) async fn restore_session(&self, session: Session) -> Result<()> { pub(crate) async fn restore_session(&self, session: Session) -> Result<()> {
let mut infos = self.inner.get_room_infos().await; for info in self.inner.get_room_infos().await?.into_iter() {
let room = Room::restore(&session.user_id, self.inner.clone(), info);
// TODO restore stripped rooms.
while let Some(info) = infos.next().await {
let room = Room::restore(&session.user_id, self.inner.clone(), info?);
self.rooms.insert(room.room_id().to_owned(), room); self.rooms.insert(room.room_id().to_owned(), room);
} }
@ -172,7 +209,7 @@ impl Store {
} }
impl Deref for Store { impl Deref for Store {
type Target = SledStore; type Target = Box<dyn StateStore>;
fn deref(&self) -> &Self::Target { fn deref(&self) -> &Self::Target {
&self.inner &self.inner

View File

@ -16,8 +16,12 @@ mod store_key;
use std::{convert::TryFrom, path::Path, sync::Arc, time::SystemTime}; use std::{convert::TryFrom, path::Path, sync::Arc, time::SystemTime};
use futures::stream::{self, Stream}; use futures::{
stream::{self, Stream},
TryStreamExt,
};
use matrix_sdk_common::{ use matrix_sdk_common::{
async_trait,
events::{ events::{
presence::PresenceEvent, presence::PresenceEvent,
room::member::{MemberEventContent, MembershipState}, room::member::{MemberEventContent, MembershipState},
@ -37,7 +41,7 @@ use crate::deserialized_responses::MemberEvent;
use self::store_key::{EncryptedEvent, StoreKey}; use self::store_key::{EncryptedEvent, StoreKey};
use super::{Result, RoomInfo, StateChanges, StoreError}; use super::{Result, RoomInfo, StateChanges, StateStore, StoreError};
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub enum DatabaseType { pub enum DatabaseType {
@ -477,6 +481,66 @@ impl SledStore {
} }
} }
#[async_trait]
impl StateStore for SledStore {
async fn save_filter(&self, filter_name: &str, filter_id: &str) -> Result<()> {
self.save_filter(filter_name, filter_id).await
}
async fn save_changes(&self, changes: &StateChanges) -> Result<()> {
self.save_changes(changes).await
}
async fn get_filter(&self, filter_id: &str) -> Result<Option<String>> {
self.get_filter(filter_id).await
}
async fn get_sync_token(&self) -> Result<Option<String>> {
self.get_sync_token().await
}
async fn get_presence_event(&self, user_id: &UserId) -> Result<Option<PresenceEvent>> {
self.get_presence_event(user_id).await
}
async fn get_state_event(
&self,
room_id: &RoomId,
event_type: EventType,
state_key: &str,
) -> Result<Option<AnySyncStateEvent>> {
self.get_state_event(room_id, event_type, state_key).await
}
async fn get_profile(
&self,
room_id: &RoomId,
user_id: &UserId,
) -> Result<Option<MemberEventContent>> {
self.get_profile(room_id, user_id).await
}
async fn get_member_event(
&self,
room_id: &RoomId,
state_key: &UserId,
) -> Result<Option<MemberEvent>> {
self.get_member_event(room_id, state_key).await
}
async fn get_invited_user_ids(&self, room_id: &RoomId) -> Result<Vec<UserId>> {
self.get_invited_user_ids(room_id).await.try_collect().await
}
async fn get_joined_user_ids(&self, room_id: &RoomId) -> Result<Vec<UserId>> {
self.get_joined_user_ids(room_id).await.try_collect().await
}
async fn get_room_infos(&self) -> Result<Vec<RoomInfo>> {
self.get_room_infos().await.try_collect().await
}
}
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use std::{convert::TryFrom, time::SystemTime}; use std::{convert::TryFrom, time::SystemTime};

View File

@ -502,7 +502,7 @@ impl OlmMachine {
/// [`mark_request_as_sent`]: #method.mark_request_as_sent /// [`mark_request_as_sent`]: #method.mark_request_as_sent
pub async fn get_missing_sessions( pub async fn get_missing_sessions(
&self, &self,
users: &mut impl Iterator<Item = &UserId>, users: impl Iterator<Item = &UserId>,
) -> OlmResult<Option<(Uuid, KeysClaimRequest)>> { ) -> OlmResult<Option<(Uuid, KeysClaimRequest)>> {
self.session_manager.get_missing_sessions(users).await self.session_manager.get_missing_sessions(users).await
} }

View File

@ -188,7 +188,7 @@ impl SessionManager {
/// [`receive_keys_claim_response`]: #method.receive_keys_claim_response /// [`receive_keys_claim_response`]: #method.receive_keys_claim_response
pub async fn get_missing_sessions( pub async fn get_missing_sessions(
&self, &self,
users: &mut impl Iterator<Item = &UserId>, users: impl Iterator<Item = &UserId>,
) -> OlmResult<Option<(Uuid, KeysClaimRequest)>> { ) -> OlmResult<Option<(Uuid, KeysClaimRequest)>> {
let mut missing = BTreeMap::new(); let mut missing = BTreeMap::new();