base: Add some error handling to the state store

master
Damir Jelić 2021-01-18 17:31:33 +01:00
parent e5ba0298d0
commit d07063af2b
11 changed files with 229 additions and 146 deletions

View File

@ -26,7 +26,7 @@ impl EventEmitter for EventCallback {
.. ..
} = event } = event
{ {
let member = room.get_member(&sender).await.unwrap(); let member = room.get_member(&sender).await.unwrap().unwrap();
let name = member let name = member
.display_name() .display_name()
.unwrap_or_else(|| member.user_id().as_str()); .unwrap_or_else(|| member.user_id().as_str());

View File

@ -27,7 +27,7 @@ use std::{
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
use dashmap::DashMap; use dashmap::DashMap;
use futures::StreamExt; use futures::TryStreamExt;
use futures_timer::Delay as sleep; use futures_timer::Delay as sleep;
use http::HeaderValue; use http::HeaderValue;
use mime::{self, Mime}; use mime::{self, Mime};
@ -738,7 +738,7 @@ impl Client {
filter_name: &str, filter_name: &str,
definition: FilterDefinition<'_>, definition: FilterDefinition<'_>,
) -> Result<String> { ) -> Result<String> {
if let Some(filter) = self.base_client.get_filter(filter_name).await { if let Some(filter) = self.base_client.get_filter(filter_name).await? {
Ok(filter) Ok(filter)
} else { } else {
let user_id = self.user_id().await.ok_or(Error::AuthenticationRequired)?; let user_id = self.user_id().await.ok_or(Error::AuthenticationRequired)?;
@ -747,7 +747,7 @@ impl Client {
self.base_client self.base_client
.receive_filter_upload(filter_name, &response) .receive_filter_upload(filter_name, &response)
.await; .await?;
Ok(response.filter_id) Ok(response.filter_id)
} }
@ -1156,8 +1156,7 @@ impl Client {
{ {
let room = self.get_joined_room(room_id).unwrap(); let room = self.get_joined_room(room_id).unwrap();
let members = room.joined_user_ids().await; let members = room.joined_user_ids().await;
// TODO don't collect here. let members_iter: Vec<UserId> = members.try_collect().await?;
let members_iter: Vec<UserId> = members.collect().await;
self.claim_one_time_keys(&mut members_iter.iter()).await?; self.claim_one_time_keys(&mut members_iter.iter()).await?;
}; };
@ -2275,7 +2274,7 @@ mod test {
get_public_rooms, get_public_rooms_filtered, register::RegistrationKind, Client, get_public_rooms, get_public_rooms_filtered, register::RegistrationKind, Client,
Invite3pid, Session, SyncSettings, Url, Invite3pid, Session, SyncSettings, Url,
}; };
use futures::StreamExt; use futures::TryStreamExt;
use matrix_sdk_base::RoomMember; use matrix_sdk_base::RoomMember;
use matrix_sdk_common::{ use matrix_sdk_common::{
api::r0::{ api::r0::{
@ -2891,7 +2890,7 @@ mod test {
let room = client let room = client
.get_joined_room(&room_id!("!SVkFJHzfwvuaIEawgC:localhost")) .get_joined_room(&room_id!("!SVkFJHzfwvuaIEawgC:localhost"))
.unwrap(); .unwrap();
let members: Vec<RoomMember> = room.active_members().await.collect().await; let members: Vec<RoomMember> = room.active_members().await.try_collect().await.unwrap();
assert_eq!(1, members.len()); assert_eq!(1, members.len());
// assert!(room.power_levels.is_some()) // assert!(room.power_levels.is_some())
@ -2916,7 +2915,7 @@ mod test {
.get_joined_room(&room_id!("!SVkFJHzfwvuaIEawgC:localhost")) .get_joined_room(&room_id!("!SVkFJHzfwvuaIEawgC:localhost"))
.unwrap(); .unwrap();
assert_eq!("example2", room.display_name().await); assert_eq!("example2", room.display_name().await.unwrap());
} }
#[tokio::test] #[tokio::test]
@ -3010,7 +3009,7 @@ mod test {
.get_joined_room(&room_id!("!SVkFJHzfwvuaIEawgC:localhost")) .get_joined_room(&room_id!("!SVkFJHzfwvuaIEawgC:localhost"))
.unwrap(); .unwrap();
assert_eq!("tutorial".to_string(), room.display_name().await); assert_eq!("tutorial".to_string(), room.display_name().await.unwrap());
} }
#[tokio::test] #[tokio::test]

View File

@ -14,7 +14,7 @@
//! Error conditions. //! Error conditions.
use matrix_sdk_base::Error as MatrixError; use matrix_sdk_base::{Error as MatrixError, StoreError};
use matrix_sdk_common::{ use matrix_sdk_common::{
api::{ api::{
r0::uiaa::{UiaaInfo, UiaaResponse as UiaaError}, r0::uiaa::{UiaaInfo, UiaaResponse as UiaaError},
@ -73,6 +73,10 @@ pub enum Error {
#[error(transparent)] #[error(transparent)]
CryptoStoreError(#[from] CryptoStoreError), CryptoStoreError(#[from] CryptoStoreError),
/// An error occured in the state store.
#[error(transparent)]
StateStore(#[from] StoreError),
/// An error occurred while authenticating. /// An error occurred while authenticating.
/// ///
/// When registering or authenticating the Matrix server can send a `UiaaResponse` /// When registering or authenticating the Matrix server can send a `UiaaResponse`

View File

@ -68,7 +68,7 @@ compile_error!("only one of 'native-tls' or 'rustls-tls' features can be enabled
pub use matrix_sdk_base::crypto::LocalTrust; pub use matrix_sdk_base::crypto::LocalTrust;
pub use matrix_sdk_base::{ pub use matrix_sdk_base::{
Error as BaseError, EventEmitter, InvitedRoom, JoinedRoom, LeftRoom, RoomInfo, RoomMember, Error as BaseError, EventEmitter, InvitedRoom, JoinedRoom, LeftRoom, RoomInfo, RoomMember,
RoomState, Session, RoomState, Session, StoreError,
}; };
pub use matrix_sdk_common::*; pub use matrix_sdk_common::*;

View File

@ -1,6 +1,6 @@
use std::{convert::TryFrom, fmt::Debug, io, sync::Arc}; use std::{convert::TryFrom, fmt::Debug, io, sync::Arc};
use futures::{executor::block_on, StreamExt}; use futures::{executor::block_on, TryStreamExt};
use serde::Serialize; use serde::Serialize;
use atty::Stream; use atty::Stream;
@ -86,8 +86,14 @@ impl InspectorHelper {
} }
fn complete_rooms(&self, arg: Option<&&str>) -> Vec<Pair> { fn complete_rooms(&self, arg: Option<&&str>) -> Vec<Pair> {
let rooms: Vec<RoomInfo> = let rooms: Vec<RoomInfo> = block_on(async {
block_on(async { self.store.get_room_infos().await.collect().await }); self.store
.get_room_infos()
.await
.try_collect()
.await
.unwrap()
});
rooms rooms
.into_iter() .into_iter()
@ -248,7 +254,7 @@ impl Printer {
impl Inspector { impl Inspector {
fn new(database_path: &str, json: bool, color: bool) -> Self { fn new(database_path: &str, json: bool, color: bool) -> Self {
let printer = Printer::new(json, color); let printer = Printer::new(json, color);
let store = Store::open_default(database_path); let store = Store::open_default(database_path).unwrap();
Self { store, printer } Self { store, printer }
} }
@ -280,7 +286,13 @@ impl Inspector {
} }
async fn list_rooms(&self) { async fn list_rooms(&self) {
let rooms: Vec<RoomInfo> = self.store.get_room_infos().await.collect().await; let rooms: Vec<RoomInfo> = self
.store
.get_room_infos()
.await
.try_collect()
.await
.unwrap();
self.printer.pretty_print_struct(&rooms); self.printer.pretty_print_struct(&rooms);
} }
@ -289,11 +301,12 @@ impl Inspector {
.store .store
.get_joined_user_ids(&room_id) .get_joined_user_ids(&room_id)
.await .await
.collect() .try_collect()
.await; .await
.unwrap();
for member in joined { for member in joined {
let event = self.store.get_profile(&room_id, &member).await; let event = self.store.get_profile(&room_id, &member).await.unwrap();
self.printer.pretty_print_struct(&event); self.printer.pretty_print_struct(&event);
} }
} }
@ -303,18 +316,28 @@ impl Inspector {
.store .store
.get_joined_user_ids(&room_id) .get_joined_user_ids(&room_id)
.await .await
.collect() .try_collect()
.await; .await
.unwrap();
for member in joined { for member in joined {
let event = self.store.get_member_event(&room_id, &member).await; let event = self
.store
.get_member_event(&room_id, &member)
.await
.unwrap();
self.printer.pretty_print_struct(&event); self.printer.pretty_print_struct(&event);
} }
} }
async fn get_state(&self, room_id: RoomId, event_type: EventType) { async fn get_state(&self, room_id: RoomId, event_type: EventType) {
self.printer self.printer.pretty_print_struct(
.pretty_print_struct(&self.store.get_state_event(&room_id, event_type, "").await); &self
.store
.get_state_event(&room_id, event_type, "")
.await
.unwrap(),
);
} }
fn subcommands() -> Vec<Argparse<'static, 'static>> { fn subcommands() -> Vec<Argparse<'static, 'static>> {

View File

@ -25,6 +25,7 @@ use std::{
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
use futures::StreamExt; use futures::StreamExt;
use futures::TryStreamExt;
use matrix_sdk_common::{ use matrix_sdk_common::{
api::r0 as api, api::r0 as api,
events::{ events::{
@ -63,7 +64,7 @@ use crate::{
}, },
rooms::{RoomInfo, RoomType, StrippedRoomInfo}, rooms::{RoomInfo, RoomType, StrippedRoomInfo},
session::Session, session::Session,
store::{SledStore, StateChanges, Store}, store::{Result as StoreResult, SledStore, StateChanges, Store},
EventEmitter, RoomState, EventEmitter, RoomState,
}; };
@ -283,9 +284,9 @@ impl BaseClient {
pub fn new_with_config(config: BaseClientConfig) -> Result<Self> { pub fn new_with_config(config: BaseClientConfig) -> Result<Self> {
let store = if let Some(path) = &config.store_path { let store = if let Some(path) = &config.store_path {
info!("Opening store in path {}", path.display()); info!("Opening store in path {}", path.display());
SledStore::open_with_path(path) SledStore::open_with_path(path)?
} else { } else {
SledStore::open() SledStore::open()?
}; };
let session = Arc::new(RwLock::new(None)); let session = Arc::new(RwLock::new(None));
@ -352,7 +353,7 @@ impl BaseClient {
/// * `session` - An session that the user already has from a /// * `session` - An session that the user already has from a
/// previous login call. /// previous login call.
pub async fn restore_login(&self, session: Session) -> Result<()> { pub async fn restore_login(&self, session: Session) -> Result<()> {
self.store.restore_session(session.clone()).await; self.store.restore_session(session.clone()).await?;
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
{ {
@ -747,8 +748,7 @@ impl BaseClient {
let joined = self.store.get_joined_user_ids(&room_id).await; let joined = self.store.get_joined_user_ids(&room_id).await;
let invited = self.store.get_invited_user_ids(&room_id).await; let invited = self.store.get_invited_user_ids(&room_id).await;
// TODO don't use collect here. let user_ids: Vec<UserId> = joined.chain(invited).try_collect().await?;
let user_ids: Vec<UserId> = joined.chain(invited).collect().await;
o.update_tracked_users(&user_ids).await o.update_tracked_users(&user_ids).await
} }
@ -853,7 +853,7 @@ impl BaseClient {
self.handle_account_data(response.account_data.events, &mut changes) self.handle_account_data(response.account_data.events, &mut changes)
.await; .await;
self.store.save_changes(&changes).await; self.store.save_changes(&changes).await?;
*self.sync_token.write().await = Some(response.next_batch.clone()); *self.sync_token.write().await = Some(response.next_batch.clone());
self.apply_changes(&changes).await; self.apply_changes(&changes).await;
@ -921,7 +921,7 @@ impl BaseClient {
if self if self
.store .store
.get_member_event(&room_id, &member.state_key) .get_member_event(&room_id, &member.state_key)
.await .await?
.is_none() .is_none()
{ {
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
@ -948,7 +948,7 @@ impl BaseClient {
changes.add_room(room_info); changes.add_room(room_info);
self.store.save_changes(&changes).await; self.store.save_changes(&changes).await?;
self.apply_changes(&changes).await; self.apply_changes(&changes).await;
} }
@ -959,13 +959,14 @@ impl BaseClient {
&self, &self,
filter_name: &str, filter_name: &str,
response: &api::filter::create_filter::Response, response: &api::filter::create_filter::Response,
) { ) -> Result<()> {
self.store Ok(self
.store
.save_filter(filter_name, &response.filter_id) .save_filter(filter_name, &response.filter_id)
.await; .await?)
} }
pub async fn get_filter(&self, filter_name: &str) -> Option<String> { pub async fn get_filter(&self, filter_name: &str) -> StoreResult<Option<String>> {
self.store.get_filter(filter_name).await self.store.get_filter(filter_name).await
} }
@ -1038,8 +1039,7 @@ impl BaseClient {
Some(o) => { Some(o) => {
let joined = self.store.get_joined_user_ids(room_id).await; let joined = self.store.get_joined_user_ids(room_id).await;
let invited = self.store.get_invited_user_ids(room_id).await; let invited = self.store.get_invited_user_ids(room_id).await;
// TODO don't use collect here. let members: Vec<UserId> = joined.chain(invited).try_collect().await?;
let members: Vec<UserId> = joined.chain(invited).collect().await;
Ok( Ok(
o.share_group_session(room_id, members.iter(), EncryptionSettings::default()) o.share_group_session(room_id, members.iter(), EncryptionSettings::default())
.await?, .await?,

View File

@ -34,8 +34,8 @@ pub enum Error {
/// A generic error returned when the state store fails not due to /// A generic error returned when the state store fails not due to
/// IO or (de)serialization. /// IO or (de)serialization.
#[error("state store: {0}")] #[error(transparent)]
StateStore(String), StateStore(#[from] crate::store::StoreError),
/// An error when (de)serializing JSON. /// An error when (de)serializing JSON.
#[error(transparent)] #[error(transparent)]

View File

@ -292,7 +292,7 @@ pub enum CustomEvent<'c> {
/// .. /// ..
/// } = event /// } = event
/// { /// {
/// let member = room.get_member(&sender).await.unwrap(); /// let member = room.get_member(&sender).await.unwrap().unwrap();
/// let name = member /// let name = member
/// .display_name() /// .display_name()
/// .unwrap_or_else(|| member.user_id().as_str()); /// .unwrap_or_else(|| member.user_id().as_str());

View File

@ -52,7 +52,7 @@ mod store;
pub use event_emitter::EventEmitter; pub use event_emitter::EventEmitter;
pub use rooms::{InvitedRoom, JoinedRoom, LeftRoom, Room, RoomInfo, RoomMember, RoomState}; pub use rooms::{InvitedRoom, JoinedRoom, LeftRoom, Room, RoomInfo, RoomMember, RoomState};
pub use store::Store; pub use store::{Store, StoreError};
pub use client::{BaseClient, BaseClientConfig, RoomStateType}; pub use client::{BaseClient, BaseClientConfig, RoomStateType};

View File

@ -19,7 +19,7 @@ use std::{
use futures::{ use futures::{
future, future,
stream::{self, Stream, StreamExt}, stream::{self, Stream, StreamExt, TryStreamExt},
}; };
use matrix_sdk_common::{ use matrix_sdk_common::{
api::r0::sync::sync_events::RoomSummary as RumaSummary, api::r0::sync::sync_events::RoomSummary as RumaSummary,
@ -36,7 +36,10 @@ use matrix_sdk_common::{
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tracing::info; use tracing::info;
use crate::{responses::UnreadNotificationsCount, store::SledStore}; use crate::{
responses::UnreadNotificationsCount,
store::{Result as StoreResult, SledStore},
};
use super::{BaseRoomInfo, RoomMember}; use super::{BaseRoomInfo, RoomMember};
@ -190,27 +193,39 @@ impl Room {
self.inner.read().unwrap().base_info.topic.clone() self.inner.read().unwrap().base_info.topic.clone()
} }
pub async fn display_name(&self) -> String { pub async fn display_name(&self) -> StoreResult<String> {
self.calculate_name().await self.calculate_name().await
} }
pub async fn joined_user_ids(&self) -> impl Stream<Item = UserId> { pub async fn joined_user_ids(&self) -> impl Stream<Item = StoreResult<UserId>> {
self.store.get_joined_user_ids(self.room_id()).await self.store.get_joined_user_ids(self.room_id()).await
} }
pub async fn joined_members(&self) -> impl Stream<Item = RoomMember> + '_ { pub async fn joined_members(&self) -> impl Stream<Item = StoreResult<RoomMember>> + '_ {
let joined = self.store.get_joined_user_ids(self.room_id()).await; let joined = self.store.get_joined_user_ids(self.room_id()).await;
joined.filter_map(move |u| async move { self.get_member(&u).await }) joined.filter_map(move |u| async move {
let ret = match u {
Ok(u) => self.get_member(&u).await,
Err(e) => Err(e),
};
ret.transpose()
})
} }
pub async fn active_members(&self) -> impl Stream<Item = RoomMember> + '_ { pub async fn active_members(&self) -> impl Stream<Item = StoreResult<RoomMember>> + '_ {
let joined = self.store.get_joined_user_ids(self.room_id()).await; let joined = self.store.get_joined_user_ids(self.room_id()).await;
let invited = self.store.get_invited_user_ids(self.room_id()).await; let invited = self.store.get_invited_user_ids(self.room_id()).await;
joined joined.chain(invited).filter_map(move |u| async move {
.chain(invited) let ret = match u {
.filter_map(move |u| async move { self.get_member(&u).await }) Ok(u) => self.get_member(&u).await,
Err(e) => Err(e),
};
ret.transpose()
})
} }
/// Calculate the canonical display name of the room, taking into account /// Calculate the canonical display name of the room, taking into account
@ -220,16 +235,16 @@ impl Room {
/// ///
/// [spec]: /// [spec]:
/// <https://matrix.org/docs/spec/client_server/latest#calculating-the-display-name-for-a-room> /// <https://matrix.org/docs/spec/client_server/latest#calculating-the-display-name-for-a-room>
async fn calculate_name(&self) -> String { async fn calculate_name(&self) -> StoreResult<String> {
let summary = { let summary = {
let inner = self.inner.read().unwrap(); let inner = self.inner.read().unwrap();
if let Some(name) = &inner.base_info.name { if let Some(name) = &inner.base_info.name {
let name = name.trim(); let name = name.trim();
return name.to_string(); return Ok(name.to_string());
} else if let Some(alias) = &inner.base_info.canonical_alias { } else if let Some(alias) = &inner.base_info.canonical_alias {
let alias = alias.alias().trim(); let alias = alias.alias().trim();
return alias.to_string(); return Ok(alias.to_string());
} }
inner.summary.clone() inner.summary.clone()
}; };
@ -242,22 +257,24 @@ impl Room {
let is_own_member = |m: &RoomMember| m.user_id() == &*self.own_user_id; let is_own_member = |m: &RoomMember| m.user_id() == &*self.own_user_id;
let is_own_user_id = |u: &str| u == self.own_user_id().as_str(); let is_own_user_id = |u: &str| u == self.own_user_id().as_str();
let members: Vec<RoomMember> = if summary.heroes.is_empty() { let members: StoreResult<Vec<RoomMember>> = if summary.heroes.is_empty() {
self.active_members() self.active_members()
.await .await
.filter(|m| future::ready(!is_own_member(m))) .try_filter(|u| future::ready(!is_own_member(&u)))
.take(5) .take(5)
.collect() .try_collect()
.await .await
} else { } else {
stream::iter(summary.heroes.iter()) let members: Vec<_> = stream::iter(summary.heroes.iter())
.filter(|u| future::ready(!is_own_user_id(u))) .filter(|u| future::ready(!is_own_user_id(u)))
.filter_map(|u| async move { .filter_map(|u| async move {
let user_id = UserId::try_from(u.as_str()).ok()?; let user_id = UserId::try_from(u.as_str()).ok()?;
self.get_member(&user_id).await self.get_member(&user_id).await.transpose()
}) })
.collect() .collect()
.await .await;
members.into_iter().collect()
}; };
info!( info!(
@ -269,9 +286,9 @@ impl Room {
); );
let inner = self.inner.read().unwrap(); let inner = self.inner.read().unwrap();
inner Ok(inner
.base_info .base_info
.calculate_room_name(joined, invited, members) .calculate_room_name(joined, invited, members?))
} }
pub(crate) fn clone_info(&self) -> RoomInfo { pub(crate) fn clone_info(&self) -> RoomInfo {
@ -283,9 +300,9 @@ impl Room {
*inner = summary; *inner = summary;
} }
pub async fn get_member(&self, user_id: &UserId) -> Option<RoomMember> { pub async fn get_member(&self, user_id: &UserId) -> StoreResult<Option<RoomMember>> {
let presence = self.store.get_presence_event(user_id).await; let presence = self.store.get_presence_event(user_id).await?;
let profile = self.store.get_profile(self.room_id(), user_id).await; let profile = self.store.get_profile(self.room_id(), user_id).await?;
let max_power_level = self.max_power_level(); let max_power_level = self.max_power_level();
let is_room_creator = self let is_room_creator = self
.inner .inner
@ -300,7 +317,7 @@ impl Room {
let power = self let power = self
.store .store
.get_state_event(self.room_id(), EventType::RoomPowerLevels, "") .get_state_event(self.room_id(), EventType::RoomPowerLevels, "")
.await .await?
.map(|e| { .map(|e| {
if let AnySyncStateEvent::RoomPowerLevels(e) = e { if let AnySyncStateEvent::RoomPowerLevels(e) = e {
Some(e) Some(e)
@ -310,9 +327,10 @@ impl Room {
}) })
.flatten(); .flatten();
self.store Ok(self
.store
.get_member_event(&self.room_id, user_id) .get_member_event(&self.room_id, user_id)
.await .await?
.map(|e| RoomMember { .map(|e| RoomMember {
event: e.into(), event: e.into(),
profile: profile.into(), profile: profile.into(),
@ -320,7 +338,7 @@ impl Room {
power_levles: power.into(), power_levles: power.into(),
max_power_level, max_power_level,
is_room_creator, is_room_creator,
}) }))
} }
} }

View File

@ -23,6 +23,19 @@ use crate::{
InvitedRoom, JoinedRoom, LeftRoom, Room, RoomState, Session, InvitedRoom, JoinedRoom, LeftRoom, Room, RoomState, Session,
}; };
#[derive(Debug, thiserror::Error)]
pub enum StoreError {
#[error(transparent)]
Sled(#[from] sled::Error),
#[error(transparent)]
Json(#[from] serde_json::Error),
#[error(transparent)]
Identifier(#[from] matrix_sdk_common::identifiers::Error),
}
/// A `StateStore` specific result type.
pub type Result<T> = std::result::Result<T, StoreError>;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Store { pub struct Store {
inner: SledStore, inner: SledStore,
@ -47,28 +60,30 @@ impl Store {
} }
} }
pub(crate) async fn restore_session(&self, session: Session) { pub(crate) async fn restore_session(&self, session: Session) -> Result<()> {
let mut infos = self.inner.get_room_infos().await; let mut infos = self.inner.get_room_infos().await;
// TODO restore stripped rooms. // TODO restore stripped rooms.
while let Some(info) = infos.next().await { while let Some(info) = infos.next().await {
let room = Room::restore(&session.user_id, self.inner.clone(), info); let room = Room::restore(&session.user_id, self.inner.clone(), info?);
self.rooms.insert(room.room_id().to_owned(), room); self.rooms.insert(room.room_id().to_owned(), room);
} }
let token = self.get_sync_token().await; let token = self.get_sync_token().await?;
*self.sync_token.write().await = token; *self.sync_token.write().await = token;
*self.session.write().await = Some(session); *self.session.write().await = Some(session);
Ok(())
} }
pub fn open_default(path: impl AsRef<Path>) -> Self { pub fn open_default(path: impl AsRef<Path>) -> Result<Self> {
let inner = SledStore::open_with_path(path); let inner = SledStore::open_with_path(path)?;
Self::new( Ok(Self::new(
Arc::new(RwLock::new(None)), Arc::new(RwLock::new(None)),
Arc::new(RwLock::new(None)), Arc::new(RwLock::new(None)),
inner, inner,
) ))
} }
pub(crate) fn get_bare_room(&self, room_id: &RoomId) -> Option<Room> { pub(crate) fn get_bare_room(&self, room_id: &RoomId) -> Option<Room> {
@ -222,7 +237,8 @@ impl StateChanges {
} }
pub fn add_stripped_member(&mut self, room_id: &RoomId, event: StrippedMemberEvent) { pub fn add_stripped_member(&mut self, room_id: &RoomId, event: StrippedMemberEvent) {
let user_id = UserId::try_from(event.state_key.as_str()).unwrap(); let user_id = event.state_key.clone();
self.stripped_members self.stripped_members
.entry(room_id.to_owned()) .entry(room_id.to_owned())
.or_insert_with(BTreeMap::new) .or_insert_with(BTreeMap::new)
@ -240,25 +256,25 @@ impl StateChanges {
} }
impl SledStore { impl SledStore {
fn open_helper(db: Db) -> Self { fn open_helper(db: Db) -> Result<Self> {
let session = db.open_tree("session").unwrap(); let session = db.open_tree("session")?;
let account_data = db.open_tree("account_data").unwrap(); let account_data = db.open_tree("account_data")?;
let members = db.open_tree("members").unwrap(); let members = db.open_tree("members")?;
let profiles = db.open_tree("profiles").unwrap(); let profiles = db.open_tree("profiles")?;
let joined_user_ids = db.open_tree("joined_user_ids").unwrap(); let joined_user_ids = db.open_tree("joined_user_ids")?;
let invited_user_ids = db.open_tree("invited_user_ids").unwrap(); let invited_user_ids = db.open_tree("invited_user_ids")?;
let room_state = db.open_tree("room_state").unwrap(); let room_state = db.open_tree("room_state")?;
let room_info = db.open_tree("room_infos").unwrap(); let room_info = db.open_tree("room_infos")?;
let presence = db.open_tree("presence").unwrap(); let presence = db.open_tree("presence")?;
let room_account_data = db.open_tree("room_account_data").unwrap(); let room_account_data = db.open_tree("room_account_data")?;
let stripped_room_info = db.open_tree("stripped_room_info").unwrap(); let stripped_room_info = db.open_tree("stripped_room_info")?;
let stripped_members = db.open_tree("stripped_members").unwrap(); let stripped_members = db.open_tree("stripped_members")?;
let stripped_room_state = db.open_tree("stripped_room_state").unwrap(); let stripped_room_state = db.open_tree("stripped_room_state")?;
Self { Ok(Self {
inner: db, inner: db,
session, session,
account_data, account_data,
@ -273,43 +289,44 @@ impl SledStore {
stripped_room_info, stripped_room_info,
stripped_members, stripped_members,
stripped_room_state, stripped_room_state,
} })
} }
pub fn open() -> Self { pub fn open() -> Result<Self> {
let db = Config::new().temporary(true).open().unwrap(); let db = Config::new().temporary(true).open()?;
SledStore::open_helper(db) SledStore::open_helper(db)
} }
pub fn open_with_path(path: impl AsRef<Path>) -> Self { pub fn open_with_path(path: impl AsRef<Path>) -> Result<Self> {
let path = path.as_ref().join("matrix-sdk-state"); let path = path.as_ref().join("matrix-sdk-state");
let db = Config::new().temporary(false).path(path).open().unwrap(); let db = Config::new().temporary(false).path(path).open()?;
SledStore::open_helper(db) SledStore::open_helper(db)
} }
pub async fn save_filter(&self, filter_name: &str, filter_id: &str) { pub async fn save_filter(&self, filter_name: &str, filter_id: &str) -> Result<()> {
self.session self.session
.insert(&format!("filter{}", filter_name), filter_id) .insert(&format!("filter{}", filter_name), filter_id)?;
.unwrap();
Ok(())
} }
pub async fn get_filter(&self, filter_name: &str) -> Option<String> { pub async fn get_filter(&self, filter_name: &str) -> Result<Option<String>> {
self.session Ok(self
.get(&format!("filter{}", filter_name)) .session
.unwrap() .get(&format!("filter{}", filter_name))?
.map(|f| String::from_utf8_lossy(&f).to_string()) .map(|f| String::from_utf8_lossy(&f).to_string()))
} }
pub async fn get_sync_token(&self) -> Option<String> { pub async fn get_sync_token(&self) -> Result<Option<String>> {
self.session Ok(self
.get("sync_token") .session
.unwrap() .get("sync_token")?
.map(|t| String::from_utf8_lossy(&t).to_string()) .map(|t| String::from_utf8_lossy(&t).to_string()))
} }
pub async fn save_changes(&self, changes: &StateChanges) { pub async fn save_changes(&self, changes: &StateChanges) -> Result<()> {
let now = SystemTime::now(); let now = SystemTime::now();
let ret: TransactionResult<()> = ( let ret: TransactionResult<()> = (
@ -458,16 +475,19 @@ impl SledStore {
ret.unwrap(); ret.unwrap();
self.inner.flush_async().await.unwrap(); self.inner.flush_async().await?;
info!("Saved changes in {:?}", now.elapsed().unwrap()); info!("Saved changes in {:?}", now.elapsed());
Ok(())
} }
pub async fn get_presence_event(&self, user_id: &UserId) -> Option<PresenceEvent> { pub async fn get_presence_event(&self, user_id: &UserId) -> Result<Option<PresenceEvent>> {
self.presence Ok(self
.get(user_id.as_bytes()) .presence
.unwrap() .get(user_id.as_bytes())?
.map(|e| serde_json::from_slice(&e).unwrap()) .map(|e| serde_json::from_slice(&e))
.transpose()?)
} }
pub async fn get_state_event( pub async fn get_state_event(
@ -475,60 +495,71 @@ impl SledStore {
room_id: &RoomId, room_id: &RoomId,
event_type: EventType, event_type: EventType,
state_key: &str, state_key: &str,
) -> Option<AnySyncStateEvent> { ) -> Result<Option<AnySyncStateEvent>> {
self.room_state Ok(self
.get(format!("{}{}{}", room_id.as_str(), event_type, state_key).as_bytes()) .room_state
.unwrap() .get(format!("{}{}{}", room_id.as_str(), event_type, state_key).as_bytes())?
.map(|e| serde_json::from_slice(&e).unwrap()) .map(|e| serde_json::from_slice(&e))
.transpose()?)
} }
pub async fn get_profile( pub async fn get_profile(
&self, &self,
room_id: &RoomId, room_id: &RoomId,
user_id: &UserId, user_id: &UserId,
) -> Option<MemberEventContent> { ) -> Result<Option<MemberEventContent>> {
self.profiles Ok(self
.get(format!("{}{}", room_id.as_str(), user_id.as_str())) .profiles
.unwrap() .get(format!("{}{}", room_id.as_str(), user_id.as_str()))?
.map(|p| serde_json::from_slice(&p).unwrap()) .map(|p| serde_json::from_slice(&p))
.transpose()?)
} }
pub async fn get_member_event( pub async fn get_member_event(
&self, &self,
room_id: &RoomId, room_id: &RoomId,
state_key: &UserId, state_key: &UserId,
) -> Option<MemberEvent> { ) -> Result<Option<MemberEvent>> {
self.members Ok(self
.get(format!("{}{}", room_id.as_str(), state_key.as_str())) .members
.unwrap() .get(format!("{}{}", room_id.as_str(), state_key.as_str()))?
.map(|v| serde_json::from_slice(&v).unwrap()) .map(|v| serde_json::from_slice(&v))
.transpose()?)
} }
pub async fn get_invited_user_ids(&self, room_id: &RoomId) -> impl Stream<Item = UserId> { pub async fn get_invited_user_ids(
&self,
room_id: &RoomId,
) -> impl Stream<Item = Result<UserId>> {
stream::iter( stream::iter(
self.invited_user_ids self.invited_user_ids
.scan_prefix(room_id.as_bytes()) .scan_prefix(room_id.as_bytes())
.map(|u| { .map(|u| {
UserId::try_from(String::from_utf8_lossy(&u.unwrap().1).to_string()).unwrap() UserId::try_from(String::from_utf8_lossy(&u?.1).to_string())
.map_err(StoreError::Identifier)
}), }),
) )
} }
pub async fn get_joined_user_ids(&self, room_id: &RoomId) -> impl Stream<Item = UserId> { pub async fn get_joined_user_ids(
&self,
room_id: &RoomId,
) -> impl Stream<Item = Result<UserId>> {
stream::iter( stream::iter(
self.joined_user_ids self.joined_user_ids
.scan_prefix(room_id.as_bytes()) .scan_prefix(room_id.as_bytes())
.map(|u| { .map(|u| {
UserId::try_from(String::from_utf8_lossy(&u.unwrap().1).to_string()).unwrap() UserId::try_from(String::from_utf8_lossy(&u?.1).to_string())
.map_err(StoreError::Identifier)
}), }),
) )
} }
pub async fn get_room_infos(&self) -> impl Stream<Item = RoomInfo> { pub async fn get_room_infos(&self) -> impl Stream<Item = Result<RoomInfo>> {
stream::iter( stream::iter(
self.room_info self.room_info
.iter() .iter()
.map(|r| serde_json::from_slice(&r.unwrap().1).unwrap()), .map(|r| serde_json::from_slice(&r?.1).map_err(StoreError::Json)),
) )
} }
} }
@ -575,11 +606,15 @@ mod test {
#[async_test] #[async_test]
async fn test_member_saving() { async fn test_member_saving() {
let store = SledStore::open(); let store = SledStore::open().unwrap();
let room_id = room_id!("!test:localhost"); let room_id = room_id!("!test:localhost");
let user_id = user_id(); let user_id = user_id();
assert!(store.get_member_event(&room_id, &user_id).await.is_none()); assert!(store
.get_member_event(&room_id, &user_id)
.await
.unwrap()
.is_none());
let mut changes = StateChanges::default(); let mut changes = StateChanges::default();
changes changes
.members .members
@ -587,7 +622,11 @@ mod test {
.or_default() .or_default()
.insert(user_id.clone(), membership_event()); .insert(user_id.clone(), membership_event());
store.save_changes(&changes).await; store.save_changes(&changes).await.unwrap();
assert!(store.get_member_event(&room_id, &user_id).await.is_some()); assert!(store
.get_member_event(&room_id, &user_id)
.await
.unwrap()
.is_some());
} }
} }