base: Re-introduce a state store trait.

master
Damir Jelić 2021-01-21 12:08:16 +01:00
parent 2bcc0afb91
commit de4df4e50a
9 changed files with 181 additions and 99 deletions

View File

@ -27,8 +27,6 @@ use std::{
#[cfg(feature = "encryption")]
use dashmap::DashMap;
#[cfg(feature = "encryption")]
use futures::TryStreamExt;
use futures_timer::Delay as sleep;
use http::HeaderValue;
use mime::{self, Mime};
@ -1155,10 +1153,10 @@ impl Client {
let _guard = mutex.lock().await;
{
let room = self.get_joined_room(room_id).unwrap();
let members = room.joined_user_ids().await;
let members_iter: Vec<UserId> = members.try_collect().await?;
self.claim_one_time_keys(&mut members_iter.iter()).await?;
let joined = self.store().get_joined_user_ids(room_id).await?;
let invited = self.store().get_invited_user_ids(room_id).await?;
let members = joined.iter().chain(&invited);
self.claim_one_time_keys(members).await?;
};
let response = self.share_group_session(room_id).await;
@ -1831,7 +1829,7 @@ impl Client {
#[cfg(feature = "encryption")]
#[cfg_attr(feature = "docs", doc(cfg(encryption)))]
#[instrument(skip(users))]
async fn claim_one_time_keys(&self, users: &mut impl Iterator<Item = &UserId>) -> Result<()> {
async fn claim_one_time_keys(&self, users: impl Iterator<Item = &UserId>) -> Result<()> {
let _lock = self.key_claim_lock.lock().await;
if let Some((request_id, request)) = self.base_client.get_missing_sessions(users).await? {
@ -2277,7 +2275,6 @@ mod test {
get_public_rooms, get_public_rooms_filtered, register::RegistrationKind, Client,
Invite3pid, Session, SyncSettings, Url,
};
use futures::TryStreamExt;
use matrix_sdk_base::RoomMember;
use matrix_sdk_common::{
api::r0::{
@ -2893,7 +2890,7 @@ mod test {
let room = client
.get_joined_room(&room_id!("!SVkFJHzfwvuaIEawgC:localhost"))
.unwrap();
let members: Vec<RoomMember> = room.active_members().await.try_collect().await.unwrap();
let members: Vec<RoomMember> = room.active_members().await.unwrap();
assert_eq!(1, members.len());
// assert!(room.power_levels.is_some())

View File

@ -1,6 +1,6 @@
use std::{convert::TryFrom, fmt::Debug, io, sync::Arc};
use futures::{executor::block_on, TryStreamExt};
use futures::executor::block_on;
use serde::Serialize;
use atty::Stream;
@ -86,14 +86,7 @@ impl InspectorHelper {
}
fn complete_rooms(&self, arg: Option<&&str>) -> Vec<Pair> {
let rooms: Vec<RoomInfo> = block_on(async {
self.store
.get_room_infos()
.await
.try_collect()
.await
.unwrap()
});
let rooms: Vec<RoomInfo> = block_on(async { self.store.get_room_infos().await.unwrap() });
rooms
.into_iter()
@ -286,24 +279,12 @@ impl Inspector {
}
async fn list_rooms(&self) {
let rooms: Vec<RoomInfo> = self
.store
.get_room_infos()
.await
.try_collect()
.await
.unwrap();
let rooms: Vec<RoomInfo> = self.store.get_room_infos().await.unwrap();
self.printer.pretty_print_struct(&rooms);
}
async fn get_profiles(&self, room_id: RoomId) {
let joined: Vec<UserId> = self
.store
.get_joined_user_ids(&room_id)
.await
.try_collect()
.await
.unwrap();
let joined: Vec<UserId> = self.store.get_joined_user_ids(&room_id).await.unwrap();
for member in joined {
let event = self.store.get_profile(&room_id, &member).await.unwrap();
@ -312,13 +293,7 @@ impl Inspector {
}
async fn get_members(&self, room_id: RoomId) {
let joined: Vec<UserId> = self
.store
.get_joined_user_ids(&room_id)
.await
.try_collect()
.await
.unwrap();
let joined: Vec<UserId> = self.store.get_joined_user_ids(&room_id).await.unwrap();
for member in joined {
let event = self

View File

@ -23,8 +23,6 @@ use std::{
time::SystemTime,
};
#[cfg(feature = "encryption")]
use futures::{StreamExt, TryStreamExt};
use matrix_sdk_common::{
api::r0 as api,
deserialized_responses::{
@ -762,11 +760,11 @@ impl BaseClient {
// The room turned on encryption in this sync, we need
// to get also all the existing users and mark them for
// tracking.
let joined = self.store.get_joined_user_ids(&room_id).await;
let invited = self.store.get_invited_user_ids(&room_id).await;
let joined = self.store.get_joined_user_ids(&room_id).await?;
let invited = self.store.get_invited_user_ids(&room_id).await?;
let user_ids: Vec<UserId> = joined.chain(invited).try_collect().await?;
o.update_tracked_users(&user_ids).await
let user_ids = joined.iter().chain(&invited);
o.update_tracked_users(user_ids).await
}
o.update_tracked_users(&user_ids).await
@ -1030,7 +1028,7 @@ impl BaseClient {
#[cfg_attr(feature = "docs", doc(cfg(encryption)))]
pub async fn get_missing_sessions(
&self,
users: &mut impl Iterator<Item = &UserId>,
users: impl Iterator<Item = &UserId>,
) -> Result<Option<(Uuid, KeysClaimRequest)>> {
let olm = self.olm.lock().await;
@ -1048,11 +1046,11 @@ impl BaseClient {
match &*olm {
Some(o) => {
let joined = self.store.get_joined_user_ids(room_id).await;
let invited = self.store.get_invited_user_ids(room_id).await;
let members: Vec<UserId> = joined.chain(invited).try_collect().await?;
let joined = self.store.get_joined_user_ids(room_id).await?;
let invited = self.store.get_invited_user_ids(room_id).await?;
let members = joined.iter().chain(&invited);
Ok(
o.share_group_session(room_id, members.iter(), EncryptionSettings::default())
o.share_group_session(room_id, members, EncryptionSettings::default())
.await?,
)
}

View File

@ -19,7 +19,7 @@ use std::{
use futures::{
future,
stream::{self, Stream, StreamExt, TryStreamExt},
stream::{self, StreamExt},
};
use matrix_sdk_common::{
api::r0::sync::sync_events::RoomSummary as RumaSummary,
@ -38,7 +38,7 @@ use tracing::info;
use crate::{
deserialized_responses::UnreadNotificationsCount,
store::{Result as StoreResult, SledStore},
store::{Result as StoreResult, StateStore},
};
use super::{BaseRoomInfo, RoomMember};
@ -48,7 +48,7 @@ pub struct Room {
room_id: Arc<RoomId>,
own_user_id: Arc<UserId>,
inner: Arc<SyncRwLock<RoomInfo>>,
store: SledStore,
store: Arc<Box<dyn StateStore>>,
}
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
@ -72,7 +72,7 @@ pub enum RoomType {
impl Room {
pub(crate) fn new(
own_user_id: &UserId,
store: SledStore,
store: Arc<Box<dyn StateStore>>,
room_id: &RoomId,
room_type: RoomType,
) -> Self {
@ -91,7 +91,11 @@ impl Room {
Self::restore(own_user_id, store, room_info)
}
pub(crate) fn restore(own_user_id: &UserId, store: SledStore, room_info: RoomInfo) -> Self {
pub(crate) fn restore(
own_user_id: &UserId,
store: Arc<Box<dyn StateStore>>,
room_info: RoomInfo,
) -> Self {
Self {
own_user_id: Arc::new(own_user_id.clone()),
room_id: room_info.room_id.clone(),
@ -197,35 +201,40 @@ impl Room {
self.calculate_name().await
}
pub async fn joined_user_ids(&self) -> impl Stream<Item = StoreResult<UserId>> {
pub async fn joined_user_ids(&self) -> StoreResult<Vec<UserId>> {
self.store.get_joined_user_ids(self.room_id()).await
}
pub async fn joined_members(&self) -> impl Stream<Item = StoreResult<RoomMember>> + '_ {
let joined = self.store.get_joined_user_ids(self.room_id()).await;
pub async fn joined_members(&self) -> StoreResult<Vec<RoomMember>> {
let joined = self.store.get_joined_user_ids(self.room_id()).await?;
let mut members = Vec::new();
joined.filter_map(move |u| async move {
let ret = match u {
Ok(u) => self.get_member(&u).await,
Err(e) => Err(e),
};
for u in joined {
let m = self.get_member(&u).await?;
ret.transpose()
})
if let Some(member) = m {
members.push(member);
}
}
Ok(members)
}
pub async fn active_members(&self) -> impl Stream<Item = StoreResult<RoomMember>> + '_ {
let joined = self.store.get_joined_user_ids(self.room_id()).await;
let invited = self.store.get_invited_user_ids(self.room_id()).await;
pub async fn active_members(&self) -> StoreResult<Vec<RoomMember>> {
let joined = self.store.get_joined_user_ids(self.room_id()).await?;
let invited = self.store.get_invited_user_ids(self.room_id()).await?;
joined.chain(invited).filter_map(move |u| async move {
let ret = match u {
Ok(u) => self.get_member(&u).await,
Err(e) => Err(e),
};
let mut members = Vec::new();
ret.transpose()
})
for u in joined.iter().chain(&invited) {
let m = self.get_member(u).await?;
if let Some(member) = m {
members.push(member);
}
}
Ok(members)
}
/// Calculate the canonical display name of the room, taking into account
@ -257,13 +266,13 @@ impl Room {
let is_own_member = |m: &RoomMember| m.user_id() == &*self.own_user_id;
let is_own_user_id = |u: &str| u == self.own_user_id().as_str();
let members: StoreResult<Vec<RoomMember>> = if summary.heroes.is_empty() {
let members: Vec<RoomMember> = if summary.heroes.is_empty() {
self.active_members()
.await
.try_filter(|u| future::ready(!is_own_member(&u)))
.await?
.into_iter()
.filter(|u| !is_own_member(&u))
.take(5)
.try_collect()
.await
.collect()
} else {
let members: Vec<_> = stream::iter(summary.heroes.iter())
.filter(|u| future::ready(!is_own_user_id(u)))
@ -274,7 +283,9 @@ impl Room {
.collect()
.await;
members.into_iter().collect()
let members: StoreResult<Vec<_>> = members.into_iter().collect();
members?
};
info!(
@ -288,7 +299,7 @@ impl Room {
let inner = self.inner.read().unwrap();
Ok(inner
.base_info
.calculate_room_name(joined, invited, members?))
.calculate_room_name(joined, invited, members))
}
pub(crate) fn clone_info(&self) -> RoomInfo {

View File

@ -20,7 +20,7 @@ use matrix_sdk_common::{
};
use serde::{Deserialize, Serialize};
use crate::store::SledStore;
use crate::store::StateStore;
use super::BaseRoomInfo;
@ -29,11 +29,11 @@ pub struct StrippedRoom {
room_id: Arc<RoomId>,
own_user_id: Arc<UserId>,
inner: Arc<SyncMutex<StrippedRoomInfo>>,
store: SledStore,
store: Arc<Box<dyn StateStore>>,
}
impl StrippedRoom {
pub fn new(own_user_id: &UserId, store: SledStore, room_id: &RoomId) -> Self {
pub fn new(own_user_id: &UserId, store: Arc<Box<dyn StateStore>>, room_id: &RoomId) -> Self {
let room_id = Arc::new(room_id.clone());
Self {

View File

@ -15,14 +15,15 @@
use std::{collections::BTreeMap, ops::Deref, path::Path, sync::Arc};
use dashmap::DashMap;
use futures::stream::StreamExt;
use matrix_sdk_common::{
async_trait,
events::{
presence::PresenceEvent, room::member::MemberEventContent, AnyBasicEvent,
AnyStrippedStateEvent, AnySyncStateEvent, EventContent,
AnyStrippedStateEvent, AnySyncStateEvent, EventContent, EventType,
},
identifiers::{RoomId, UserId},
locks::RwLock,
AsyncTraitDeps,
};
use crate::{
@ -53,9 +54,48 @@ pub enum StoreError {
/// A `StateStore` specific result type.
pub type Result<T> = std::result::Result<T, StoreError>;
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
pub trait StateStore: AsyncTraitDeps {
async fn save_filter(&self, filter_name: &str, filter_id: &str) -> Result<()>;
async fn save_changes(&self, changes: &StateChanges) -> Result<()>;
async fn get_filter(&self, filter_id: &str) -> Result<Option<String>>;
async fn get_sync_token(&self) -> Result<Option<String>>;
async fn get_presence_event(&self, user_id: &UserId) -> Result<Option<PresenceEvent>>;
async fn get_state_event(
&self,
room_id: &RoomId,
event_type: EventType,
state_key: &str,
) -> Result<Option<AnySyncStateEvent>>;
async fn get_profile(
&self,
room_id: &RoomId,
user_id: &UserId,
) -> Result<Option<MemberEventContent>>;
async fn get_member_event(
&self,
room_id: &RoomId,
state_key: &UserId,
) -> Result<Option<MemberEvent>>;
async fn get_invited_user_ids(&self, room_id: &RoomId) -> Result<Vec<UserId>>;
async fn get_joined_user_ids(&self, room_id: &RoomId) -> Result<Vec<UserId>>;
async fn get_room_infos(&self) -> Result<Vec<RoomInfo>>;
}
#[derive(Debug, Clone)]
pub struct Store {
inner: SledStore,
inner: Arc<Box<dyn StateStore>>,
session: Arc<RwLock<Option<Session>>>,
sync_token: Arc<RwLock<Option<String>>>,
rooms: Arc<DashMap<RoomId, Room>>,
@ -69,7 +109,7 @@ impl Store {
inner: SledStore,
) -> Self {
Self {
inner,
inner: Arc::new(Box::new(inner)),
session,
sync_token,
rooms: DashMap::new().into(),
@ -78,11 +118,8 @@ impl Store {
}
pub(crate) async fn restore_session(&self, session: Session) -> Result<()> {
let mut infos = self.inner.get_room_infos().await;
// TODO restore stripped rooms.
while let Some(info) = infos.next().await {
let room = Room::restore(&session.user_id, self.inner.clone(), info?);
for info in self.inner.get_room_infos().await?.into_iter() {
let room = Room::restore(&session.user_id, self.inner.clone(), info);
self.rooms.insert(room.room_id().to_owned(), room);
}
@ -172,7 +209,7 @@ impl Store {
}
impl Deref for Store {
type Target = SledStore;
type Target = Box<dyn StateStore>;
fn deref(&self) -> &Self::Target {
&self.inner

View File

@ -16,8 +16,12 @@ mod store_key;
use std::{convert::TryFrom, path::Path, sync::Arc, time::SystemTime};
use futures::stream::{self, Stream};
use futures::{
stream::{self, Stream},
TryStreamExt,
};
use matrix_sdk_common::{
async_trait,
events::{
presence::PresenceEvent,
room::member::{MemberEventContent, MembershipState},
@ -37,7 +41,7 @@ use crate::deserialized_responses::MemberEvent;
use self::store_key::{EncryptedEvent, StoreKey};
use super::{Result, RoomInfo, StateChanges, StoreError};
use super::{Result, RoomInfo, StateChanges, StateStore, StoreError};
#[derive(Debug, Serialize, Deserialize)]
pub enum DatabaseType {
@ -477,6 +481,66 @@ impl SledStore {
}
}
#[async_trait]
impl StateStore for SledStore {
async fn save_filter(&self, filter_name: &str, filter_id: &str) -> Result<()> {
self.save_filter(filter_name, filter_id).await
}
async fn save_changes(&self, changes: &StateChanges) -> Result<()> {
self.save_changes(changes).await
}
async fn get_filter(&self, filter_id: &str) -> Result<Option<String>> {
self.get_filter(filter_id).await
}
async fn get_sync_token(&self) -> Result<Option<String>> {
self.get_sync_token().await
}
async fn get_presence_event(&self, user_id: &UserId) -> Result<Option<PresenceEvent>> {
self.get_presence_event(user_id).await
}
async fn get_state_event(
&self,
room_id: &RoomId,
event_type: EventType,
state_key: &str,
) -> Result<Option<AnySyncStateEvent>> {
self.get_state_event(room_id, event_type, state_key).await
}
async fn get_profile(
&self,
room_id: &RoomId,
user_id: &UserId,
) -> Result<Option<MemberEventContent>> {
self.get_profile(room_id, user_id).await
}
async fn get_member_event(
&self,
room_id: &RoomId,
state_key: &UserId,
) -> Result<Option<MemberEvent>> {
self.get_member_event(room_id, state_key).await
}
async fn get_invited_user_ids(&self, room_id: &RoomId) -> Result<Vec<UserId>> {
self.get_invited_user_ids(room_id).await.try_collect().await
}
async fn get_joined_user_ids(&self, room_id: &RoomId) -> Result<Vec<UserId>> {
self.get_joined_user_ids(room_id).await.try_collect().await
}
async fn get_room_infos(&self) -> Result<Vec<RoomInfo>> {
self.get_room_infos().await.try_collect().await
}
}
#[cfg(test)]
mod test {
use std::{convert::TryFrom, time::SystemTime};

View File

@ -502,7 +502,7 @@ impl OlmMachine {
/// [`mark_request_as_sent`]: #method.mark_request_as_sent
pub async fn get_missing_sessions(
&self,
users: &mut impl Iterator<Item = &UserId>,
users: impl Iterator<Item = &UserId>,
) -> OlmResult<Option<(Uuid, KeysClaimRequest)>> {
self.session_manager.get_missing_sessions(users).await
}

View File

@ -188,7 +188,7 @@ impl SessionManager {
/// [`receive_keys_claim_response`]: #method.receive_keys_claim_response
pub async fn get_missing_sessions(
&self,
users: &mut impl Iterator<Item = &UserId>,
users: impl Iterator<Item = &UserId>,
) -> OlmResult<Option<(Uuid, KeysClaimRequest)>> {
let mut missing = BTreeMap::new();