1197 lines
40 KiB
Rust
1197 lines
40 KiB
Rust
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
mod store_key;
|
|
|
|
use std::{
|
|
collections::BTreeSet,
|
|
convert::{TryFrom, TryInto},
|
|
path::{Path, PathBuf},
|
|
sync::Arc,
|
|
time::Instant,
|
|
};
|
|
|
|
use futures::{
|
|
stream::{self, Stream},
|
|
TryStreamExt,
|
|
};
|
|
use matrix_sdk_common::async_trait;
|
|
use ruma::{
|
|
events::{
|
|
presence::PresenceEvent,
|
|
receipt::Receipt,
|
|
room::member::{MemberEventContent, MembershipState},
|
|
AnyGlobalAccountDataEvent, AnyRoomAccountDataEvent, AnySyncStateEvent, EventType,
|
|
},
|
|
receipt::ReceiptType,
|
|
serde::Raw,
|
|
EventId, MxcUri, RoomId, UserId,
|
|
};
|
|
use serde::{Deserialize, Serialize};
|
|
use sled::{
|
|
transaction::{ConflictableTransactionError, TransactionError},
|
|
Config, Db, Transactional, Tree,
|
|
};
|
|
use tracing::info;
|
|
|
|
use self::store_key::{EncryptedEvent, StoreKey};
|
|
use super::{Result, RoomInfo, StateChanges, StateStore, StoreError};
|
|
use crate::{
|
|
deserialized_responses::MemberEvent,
|
|
media::{MediaRequest, UniqueKey},
|
|
};
|
|
|
|
#[derive(Debug, Serialize, Deserialize)]
|
|
pub enum DatabaseType {
|
|
Unencrypted,
|
|
Encrypted(store_key::EncryptedStoreKey),
|
|
}
|
|
|
|
#[derive(Debug, thiserror::Error)]
|
|
pub enum SerializationError {
|
|
#[error(transparent)]
|
|
Json(#[from] serde_json::Error),
|
|
#[error(transparent)]
|
|
Encryption(#[from] store_key::Error),
|
|
}
|
|
|
|
impl From<TransactionError<SerializationError>> for StoreError {
|
|
fn from(e: TransactionError<SerializationError>) -> Self {
|
|
match e {
|
|
TransactionError::Abort(e) => e.into(),
|
|
TransactionError::Storage(e) => StoreError::Sled(e),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl From<SerializationError> for StoreError {
|
|
fn from(e: SerializationError) -> Self {
|
|
match e {
|
|
SerializationError::Json(e) => StoreError::Json(e),
|
|
SerializationError::Encryption(e) => match e {
|
|
store_key::Error::Random(e) => StoreError::Encryption(e.to_string()),
|
|
store_key::Error::Serialization(e) => StoreError::Json(e),
|
|
store_key::Error::Encryption(e) => StoreError::Encryption(e),
|
|
},
|
|
}
|
|
}
|
|
}
|
|
|
|
const ENCODE_SEPARATOR: u8 = 0xff;
|
|
|
|
trait EncodeKey {
|
|
fn encode(&self) -> Vec<u8>;
|
|
}
|
|
|
|
impl EncodeKey for &UserId {
|
|
fn encode(&self) -> Vec<u8> {
|
|
self.as_str().encode()
|
|
}
|
|
}
|
|
|
|
impl EncodeKey for &RoomId {
|
|
fn encode(&self) -> Vec<u8> {
|
|
self.as_str().encode()
|
|
}
|
|
}
|
|
|
|
impl EncodeKey for &str {
|
|
fn encode(&self) -> Vec<u8> {
|
|
[self.as_bytes(), &[ENCODE_SEPARATOR]].concat()
|
|
}
|
|
}
|
|
|
|
impl EncodeKey for (&str, &str) {
|
|
fn encode(&self) -> Vec<u8> {
|
|
[self.0.as_bytes(), &[ENCODE_SEPARATOR], self.1.as_bytes(), &[ENCODE_SEPARATOR]].concat()
|
|
}
|
|
}
|
|
|
|
impl EncodeKey for (&str, &str, &str) {
|
|
fn encode(&self) -> Vec<u8> {
|
|
[
|
|
self.0.as_bytes(),
|
|
&[ENCODE_SEPARATOR],
|
|
self.1.as_bytes(),
|
|
&[ENCODE_SEPARATOR],
|
|
self.2.as_bytes(),
|
|
&[ENCODE_SEPARATOR],
|
|
]
|
|
.concat()
|
|
}
|
|
}
|
|
|
|
impl EncodeKey for (&str, &str, &str, &str) {
|
|
fn encode(&self) -> Vec<u8> {
|
|
[
|
|
self.0.as_bytes(),
|
|
&[ENCODE_SEPARATOR],
|
|
self.1.as_bytes(),
|
|
&[ENCODE_SEPARATOR],
|
|
self.2.as_bytes(),
|
|
&[ENCODE_SEPARATOR],
|
|
self.3.as_bytes(),
|
|
&[ENCODE_SEPARATOR],
|
|
]
|
|
.concat()
|
|
}
|
|
}
|
|
|
|
impl EncodeKey for EventType {
|
|
fn encode(&self) -> Vec<u8> {
|
|
self.as_str().encode()
|
|
}
|
|
}
|
|
|
|
/// Get the value at `position` in encoded `key`.
|
|
///
|
|
/// The key must have been encoded with the `EncodeKey` trait. `position`
|
|
/// corresponds to the position in the tuple before the key was encoded. If it
|
|
/// wasn't encoded in a tuple, use `0`.
|
|
///
|
|
/// Returns `None` if there is no key at `position`.
|
|
pub fn decode_key_value(key: &[u8], position: usize) -> Option<String> {
|
|
let values: Vec<&[u8]> = key.split(|v| *v == ENCODE_SEPARATOR).collect();
|
|
|
|
values.get(position).map(|s| String::from_utf8_lossy(s).to_string())
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
pub struct SledStore {
|
|
path: Option<PathBuf>,
|
|
pub(crate) inner: Db,
|
|
store_key: Arc<Option<StoreKey>>,
|
|
session: Tree,
|
|
account_data: Tree,
|
|
members: Tree,
|
|
profiles: Tree,
|
|
display_names: Tree,
|
|
joined_user_ids: Tree,
|
|
invited_user_ids: Tree,
|
|
room_info: Tree,
|
|
room_state: Tree,
|
|
room_account_data: Tree,
|
|
stripped_room_info: Tree,
|
|
stripped_room_state: Tree,
|
|
stripped_members: Tree,
|
|
presence: Tree,
|
|
room_user_receipts: Tree,
|
|
room_event_receipts: Tree,
|
|
media: Tree,
|
|
custom: Tree,
|
|
}
|
|
|
|
impl std::fmt::Debug for SledStore {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
if let Some(path) = &self.path {
|
|
f.debug_struct("SledStore").field("path", &path).finish()
|
|
} else {
|
|
f.debug_struct("SledStore").field("path", &"memory store").finish()
|
|
}
|
|
}
|
|
}
|
|
|
|
impl SledStore {
|
|
fn open_helper(db: Db, path: Option<PathBuf>, store_key: Option<StoreKey>) -> Result<Self> {
|
|
let session = db.open_tree("session")?;
|
|
let account_data = db.open_tree("account_data")?;
|
|
|
|
let members = db.open_tree("members")?;
|
|
let profiles = db.open_tree("profiles")?;
|
|
let display_names = db.open_tree("display_names")?;
|
|
let joined_user_ids = db.open_tree("joined_user_ids")?;
|
|
let invited_user_ids = db.open_tree("invited_user_ids")?;
|
|
|
|
let room_state = db.open_tree("room_state")?;
|
|
let room_info = db.open_tree("room_infos")?;
|
|
let presence = db.open_tree("presence")?;
|
|
let room_account_data = db.open_tree("room_account_data")?;
|
|
|
|
let stripped_room_info = db.open_tree("stripped_room_info")?;
|
|
let stripped_members = db.open_tree("stripped_members")?;
|
|
let stripped_room_state = db.open_tree("stripped_room_state")?;
|
|
|
|
let room_user_receipts = db.open_tree("room_user_receipts")?;
|
|
let room_event_receipts = db.open_tree("room_event_receipts")?;
|
|
|
|
let media = db.open_tree("media")?;
|
|
|
|
let custom = db.open_tree("custom")?;
|
|
|
|
Ok(Self {
|
|
path,
|
|
inner: db,
|
|
store_key: store_key.into(),
|
|
session,
|
|
account_data,
|
|
members,
|
|
profiles,
|
|
display_names,
|
|
joined_user_ids,
|
|
invited_user_ids,
|
|
room_account_data,
|
|
presence,
|
|
room_state,
|
|
room_info,
|
|
stripped_room_info,
|
|
stripped_members,
|
|
stripped_room_state,
|
|
room_user_receipts,
|
|
room_event_receipts,
|
|
media,
|
|
custom,
|
|
})
|
|
}
|
|
|
|
pub fn open() -> Result<Self> {
|
|
let db = Config::new().temporary(true).open()?;
|
|
|
|
SledStore::open_helper(db, None, None)
|
|
}
|
|
|
|
pub fn open_with_passphrase(path: impl AsRef<Path>, passphrase: &str) -> Result<Self> {
|
|
let path = path.as_ref().join("matrix-sdk-state");
|
|
let db = Config::new().temporary(false).path(&path).open()?;
|
|
|
|
let store_key: Option<DatabaseType> = db
|
|
.get("store_key".encode())?
|
|
.map(|k| serde_json::from_slice(&k).map_err(StoreError::Json))
|
|
.transpose()?;
|
|
|
|
let store_key = if let Some(key) = store_key {
|
|
if let DatabaseType::Encrypted(k) = key {
|
|
StoreKey::import(passphrase, k).map_err(|_| StoreError::StoreLocked)?
|
|
} else {
|
|
return Err(StoreError::UnencryptedStore);
|
|
}
|
|
} else {
|
|
let key = StoreKey::new().map_err::<StoreError, _>(|e| e.into())?;
|
|
let encrypted_key = DatabaseType::Encrypted(
|
|
key.export(passphrase).map_err::<StoreError, _>(|e| e.into())?,
|
|
);
|
|
db.insert("store_key".encode(), serde_json::to_vec(&encrypted_key)?)?;
|
|
key
|
|
};
|
|
|
|
SledStore::open_helper(db, Some(path), Some(store_key))
|
|
}
|
|
|
|
pub fn open_with_path(path: impl AsRef<Path>) -> Result<Self> {
|
|
let path = path.as_ref().join("matrix-sdk-state");
|
|
let db = Config::new().temporary(false).path(&path).open()?;
|
|
|
|
SledStore::open_helper(db, Some(path), None)
|
|
}
|
|
|
|
fn serialize_event(&self, event: &impl Serialize) -> Result<Vec<u8>, SerializationError> {
|
|
if let Some(key) = &*self.store_key {
|
|
let encrypted = key.encrypt(event)?;
|
|
Ok(serde_json::to_vec(&encrypted)?)
|
|
} else {
|
|
Ok(serde_json::to_vec(event)?)
|
|
}
|
|
}
|
|
|
|
fn deserialize_event<T: for<'b> Deserialize<'b>>(
|
|
&self,
|
|
event: &[u8],
|
|
) -> Result<T, SerializationError> {
|
|
if let Some(key) = &*self.store_key {
|
|
let encrypted: EncryptedEvent = serde_json::from_slice(event)?;
|
|
Ok(key.decrypt(encrypted)?)
|
|
} else {
|
|
Ok(serde_json::from_slice(event)?)
|
|
}
|
|
}
|
|
|
|
pub async fn save_filter(&self, filter_name: &str, filter_id: &str) -> Result<()> {
|
|
self.session.insert(("filter", filter_name).encode(), filter_id)?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn get_filter(&self, filter_name: &str) -> Result<Option<String>> {
|
|
Ok(self
|
|
.session
|
|
.get(("filter", filter_name).encode())?
|
|
.map(|f| String::from_utf8_lossy(&f).to_string()))
|
|
}
|
|
|
|
pub async fn get_sync_token(&self) -> Result<Option<String>> {
|
|
Ok(self
|
|
.session
|
|
.get("sync_token".encode())?
|
|
.map(|t| String::from_utf8_lossy(&t).to_string()))
|
|
}
|
|
|
|
pub async fn save_changes(&self, changes: &StateChanges) -> Result<()> {
|
|
let now = Instant::now();
|
|
|
|
let ret: Result<(), TransactionError<SerializationError>> = (
|
|
&self.session,
|
|
&self.account_data,
|
|
&self.members,
|
|
&self.profiles,
|
|
&self.display_names,
|
|
&self.joined_user_ids,
|
|
&self.invited_user_ids,
|
|
&self.room_info,
|
|
&self.room_state,
|
|
&self.room_account_data,
|
|
&self.presence,
|
|
&self.stripped_room_info,
|
|
&self.stripped_members,
|
|
&self.stripped_room_state,
|
|
)
|
|
.transaction(
|
|
|(
|
|
session,
|
|
account_data,
|
|
members,
|
|
profiles,
|
|
display_names,
|
|
joined,
|
|
invited,
|
|
rooms,
|
|
state,
|
|
room_account_data,
|
|
presence,
|
|
striped_rooms,
|
|
stripped_members,
|
|
stripped_state,
|
|
)| {
|
|
if let Some(s) = &changes.sync_token {
|
|
session.insert("sync_token".encode(), s.as_str())?;
|
|
}
|
|
|
|
for (room, events) in &changes.members {
|
|
let profile_changes = changes.profiles.get(room);
|
|
|
|
for event in events.values() {
|
|
let key = (room.as_str(), event.state_key.as_str()).encode();
|
|
|
|
match event.content.membership {
|
|
MembershipState::Join => {
|
|
joined.insert(key.as_slice(), event.state_key.as_str())?;
|
|
invited.remove(key.as_slice())?;
|
|
}
|
|
MembershipState::Invite => {
|
|
invited.insert(key.as_slice(), event.state_key.as_str())?;
|
|
joined.remove(key.as_slice())?;
|
|
}
|
|
_ => {
|
|
joined.remove(key.as_slice())?;
|
|
invited.remove(key.as_slice())?;
|
|
}
|
|
}
|
|
|
|
members.insert(
|
|
key.as_slice(),
|
|
self.serialize_event(&event)
|
|
.map_err(ConflictableTransactionError::Abort)?,
|
|
)?;
|
|
|
|
if let Some(profile) =
|
|
profile_changes.and_then(|p| p.get(&event.state_key))
|
|
{
|
|
profiles.insert(
|
|
key.as_slice(),
|
|
self.serialize_event(&profile)
|
|
.map_err(ConflictableTransactionError::Abort)?,
|
|
)?;
|
|
}
|
|
}
|
|
}
|
|
|
|
for (room_id, ambiguity_maps) in &changes.ambiguity_maps {
|
|
for (display_name, map) in ambiguity_maps {
|
|
display_names.insert(
|
|
(room_id.as_str(), display_name.as_str()).encode(),
|
|
self.serialize_event(&map)
|
|
.map_err(ConflictableTransactionError::Abort)?,
|
|
)?;
|
|
}
|
|
}
|
|
|
|
for (event_type, event) in &changes.account_data {
|
|
account_data.insert(
|
|
event_type.as_str().encode(),
|
|
self.serialize_event(&event)
|
|
.map_err(ConflictableTransactionError::Abort)?,
|
|
)?;
|
|
}
|
|
|
|
for (room, events) in &changes.room_account_data {
|
|
for (event_type, event) in events {
|
|
room_account_data.insert(
|
|
(room.as_str(), event_type.as_str()).encode(),
|
|
self.serialize_event(&event)
|
|
.map_err(ConflictableTransactionError::Abort)?,
|
|
)?;
|
|
}
|
|
}
|
|
|
|
for (room, event_types) in &changes.state {
|
|
for (event_type, events) in event_types {
|
|
for (state_key, event) in events {
|
|
state.insert(
|
|
(room.as_str(), event_type.as_str(), state_key.as_str())
|
|
.encode(),
|
|
self.serialize_event(&event)
|
|
.map_err(ConflictableTransactionError::Abort)?,
|
|
)?;
|
|
}
|
|
}
|
|
}
|
|
|
|
for (room_id, room_info) in &changes.room_infos {
|
|
rooms.insert(
|
|
room_id.encode(),
|
|
self.serialize_event(room_info)
|
|
.map_err(ConflictableTransactionError::Abort)?,
|
|
)?;
|
|
}
|
|
|
|
for (sender, event) in &changes.presence {
|
|
presence.insert(
|
|
sender.encode(),
|
|
self.serialize_event(&event)
|
|
.map_err(ConflictableTransactionError::Abort)?,
|
|
)?;
|
|
}
|
|
|
|
for (room_id, info) in &changes.invited_room_info {
|
|
striped_rooms.insert(
|
|
room_id.encode(),
|
|
self.serialize_event(&info)
|
|
.map_err(ConflictableTransactionError::Abort)?,
|
|
)?;
|
|
}
|
|
|
|
for (room, events) in &changes.stripped_members {
|
|
for event in events.values() {
|
|
stripped_members.insert(
|
|
(room.as_str(), event.state_key.as_str()).encode(),
|
|
self.serialize_event(&event)
|
|
.map_err(ConflictableTransactionError::Abort)?,
|
|
)?;
|
|
}
|
|
}
|
|
|
|
for (room, event_types) in &changes.stripped_state {
|
|
for (event_type, events) in event_types {
|
|
for (state_key, event) in events {
|
|
stripped_state.insert(
|
|
(room.as_str(), event_type.as_str(), state_key.as_str())
|
|
.encode(),
|
|
self.serialize_event(&event)
|
|
.map_err(ConflictableTransactionError::Abort)?,
|
|
)?;
|
|
}
|
|
}
|
|
}
|
|
|
|
Ok(())
|
|
},
|
|
);
|
|
|
|
ret?;
|
|
|
|
let ret: Result<(), TransactionError<SerializationError>> =
|
|
(&self.room_user_receipts, &self.room_event_receipts).transaction(
|
|
|(room_user_receipts, room_event_receipts)| {
|
|
for (room, content) in &changes.receipts {
|
|
for (event_id, receipts) in &content.0 {
|
|
for (receipt_type, receipts) in receipts {
|
|
for (user_id, receipt) in receipts {
|
|
// Add the receipt to the room user receipts
|
|
if let Some(old) = room_user_receipts.insert(
|
|
(room.as_str(), receipt_type.as_ref(), user_id.as_str())
|
|
.encode(),
|
|
self.serialize_event(&(event_id, receipt))
|
|
.map_err(ConflictableTransactionError::Abort)?,
|
|
)? {
|
|
// Remove the old receipt from the room event receipts
|
|
let (old_event, _): (EventId, Receipt) = self
|
|
.deserialize_event(&old)
|
|
.map_err(ConflictableTransactionError::Abort)?;
|
|
room_event_receipts.remove(
|
|
(
|
|
room.as_str(),
|
|
receipt_type.as_ref(),
|
|
old_event.as_str(),
|
|
user_id.as_str(),
|
|
)
|
|
.encode(),
|
|
)?;
|
|
}
|
|
|
|
// Add the receipt to the room event receipts
|
|
room_event_receipts.insert(
|
|
(
|
|
room.as_str(),
|
|
receipt_type.as_ref(),
|
|
event_id.as_str(),
|
|
user_id.as_str(),
|
|
)
|
|
.encode(),
|
|
self.serialize_event(receipt)
|
|
.map_err(ConflictableTransactionError::Abort)?,
|
|
)?;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
Ok(())
|
|
},
|
|
);
|
|
|
|
ret?;
|
|
|
|
self.inner.flush_async().await?;
|
|
|
|
info!("Saved changes in {:?}", now.elapsed());
|
|
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn get_presence_event(&self, user_id: &UserId) -> Result<Option<Raw<PresenceEvent>>> {
|
|
Ok(self.presence.get(user_id.encode())?.map(|e| self.deserialize_event(&e)).transpose()?)
|
|
}
|
|
|
|
pub async fn get_state_event(
|
|
&self,
|
|
room_id: &RoomId,
|
|
event_type: EventType,
|
|
state_key: &str,
|
|
) -> Result<Option<Raw<AnySyncStateEvent>>> {
|
|
Ok(self
|
|
.room_state
|
|
.get((room_id.as_str(), event_type.as_str(), state_key).encode())?
|
|
.map(|e| self.deserialize_event(&e))
|
|
.transpose()?)
|
|
}
|
|
|
|
pub async fn get_state_events(
|
|
&self,
|
|
room_id: &RoomId,
|
|
event_type: EventType,
|
|
) -> Result<Vec<Raw<AnySyncStateEvent>>> {
|
|
Ok(self
|
|
.room_state
|
|
.scan_prefix((room_id.as_str(), event_type.as_str()).encode())
|
|
.flat_map(|e| e.map(|(_, e)| self.deserialize_event(&e)))
|
|
.collect::<Result<_, _>>()?)
|
|
}
|
|
|
|
pub async fn get_profile(
|
|
&self,
|
|
room_id: &RoomId,
|
|
user_id: &UserId,
|
|
) -> Result<Option<MemberEventContent>> {
|
|
Ok(self
|
|
.profiles
|
|
.get((room_id.as_str(), user_id.as_str()).encode())?
|
|
.map(|p| self.deserialize_event(&p))
|
|
.transpose()?)
|
|
}
|
|
|
|
pub async fn get_member_event(
|
|
&self,
|
|
room_id: &RoomId,
|
|
state_key: &UserId,
|
|
) -> Result<Option<MemberEvent>> {
|
|
Ok(self
|
|
.members
|
|
.get((room_id.as_str(), state_key.as_str()).encode())?
|
|
.map(|v| self.deserialize_event(&v))
|
|
.transpose()?)
|
|
}
|
|
|
|
pub async fn get_user_ids_stream(
|
|
&self,
|
|
room_id: &RoomId,
|
|
) -> impl Stream<Item = Result<UserId>> {
|
|
let decode = |key: &[u8]| -> Result<UserId> {
|
|
let mut iter = key.split(|c| c == &ENCODE_SEPARATOR);
|
|
// Our key is a the room id separated from the user id by a null
|
|
// byte, discard the first value of the split.
|
|
iter.next();
|
|
|
|
let user_id = iter.next().expect("User ids weren't properly encoded");
|
|
|
|
Ok(UserId::try_from(String::from_utf8_lossy(user_id).to_string())?)
|
|
};
|
|
|
|
stream::iter(self.members.scan_prefix(room_id.encode()).map(move |u| decode(&u?.0)))
|
|
}
|
|
|
|
pub async fn get_invited_user_ids(
|
|
&self,
|
|
room_id: &RoomId,
|
|
) -> impl Stream<Item = Result<UserId>> {
|
|
stream::iter(self.invited_user_ids.scan_prefix(room_id.encode()).map(|u| {
|
|
UserId::try_from(String::from_utf8_lossy(&u?.1).to_string())
|
|
.map_err(StoreError::Identifier)
|
|
}))
|
|
}
|
|
|
|
pub async fn get_joined_user_ids(
|
|
&self,
|
|
room_id: &RoomId,
|
|
) -> impl Stream<Item = Result<UserId>> {
|
|
stream::iter(self.joined_user_ids.scan_prefix(room_id.encode()).map(|u| {
|
|
UserId::try_from(String::from_utf8_lossy(&u?.1).to_string())
|
|
.map_err(StoreError::Identifier)
|
|
}))
|
|
}
|
|
|
|
pub async fn get_room_infos(&self) -> impl Stream<Item = Result<RoomInfo>> {
|
|
let db = self.clone();
|
|
stream::iter(
|
|
self.room_info.iter().map(move |r| db.deserialize_event(&r?.1).map_err(|e| e.into())),
|
|
)
|
|
}
|
|
|
|
pub async fn get_stripped_room_infos(&self) -> impl Stream<Item = Result<RoomInfo>> {
|
|
let db = self.clone();
|
|
stream::iter(
|
|
self.stripped_room_info
|
|
.iter()
|
|
.map(move |r| db.deserialize_event(&r?.1).map_err(|e| e.into())),
|
|
)
|
|
}
|
|
|
|
pub async fn get_users_with_display_name(
|
|
&self,
|
|
room_id: &RoomId,
|
|
display_name: &str,
|
|
) -> Result<BTreeSet<UserId>> {
|
|
let key = (room_id.as_str(), display_name).encode();
|
|
|
|
Ok(self
|
|
.display_names
|
|
.get(key)?
|
|
.map(|m| self.deserialize_event(&m))
|
|
.transpose()?
|
|
.unwrap_or_default())
|
|
}
|
|
|
|
pub async fn get_account_data_event(
|
|
&self,
|
|
event_type: EventType,
|
|
) -> Result<Option<Raw<AnyGlobalAccountDataEvent>>> {
|
|
Ok(self
|
|
.account_data
|
|
.get(event_type.encode())?
|
|
.map(|m| self.deserialize_event(&m))
|
|
.transpose()?)
|
|
}
|
|
|
|
pub async fn get_room_account_data_event(
|
|
&self,
|
|
room_id: &RoomId,
|
|
event_type: EventType,
|
|
) -> Result<Option<Raw<AnyRoomAccountDataEvent>>> {
|
|
Ok(self
|
|
.room_account_data
|
|
.get((room_id.as_str(), event_type.as_str()).encode())?
|
|
.map(|m| self.deserialize_event(&m))
|
|
.transpose()?)
|
|
}
|
|
|
|
async fn get_user_room_receipt_event(
|
|
&self,
|
|
room_id: &RoomId,
|
|
receipt_type: ReceiptType,
|
|
user_id: &UserId,
|
|
) -> Result<Option<(EventId, Receipt)>> {
|
|
Ok(self
|
|
.room_user_receipts
|
|
.get((room_id.as_str(), receipt_type.as_ref(), user_id.as_str()).encode())?
|
|
.map(|m| self.deserialize_event(&m))
|
|
.transpose()?)
|
|
}
|
|
|
|
async fn get_event_room_receipt_events(
|
|
&self,
|
|
room_id: &RoomId,
|
|
receipt_type: ReceiptType,
|
|
event_id: &EventId,
|
|
) -> Result<Vec<(UserId, Receipt)>> {
|
|
self.room_event_receipts
|
|
.scan_prefix((room_id.as_str(), receipt_type.as_ref(), event_id.as_str()).encode())
|
|
.map(|u| {
|
|
u.map_err(StoreError::Sled).and_then(|(key, value)| {
|
|
self.deserialize_event(&value)
|
|
// TODO remove this unwrapping
|
|
.map(|receipt| {
|
|
(decode_key_value(&key, 3).unwrap().try_into().unwrap(), receipt)
|
|
})
|
|
.map_err(Into::into)
|
|
})
|
|
})
|
|
.collect()
|
|
}
|
|
|
|
async fn add_media_content(&self, request: &MediaRequest, data: Vec<u8>) -> Result<()> {
|
|
self.media.insert(
|
|
(request.media_type.unique_key().as_str(), request.format.unique_key().as_str())
|
|
.encode(),
|
|
data,
|
|
)?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
async fn get_media_content(&self, request: &MediaRequest) -> Result<Option<Vec<u8>>> {
|
|
Ok(self
|
|
.media
|
|
.get(
|
|
(request.media_type.unique_key().as_str(), request.format.unique_key().as_str())
|
|
.encode(),
|
|
)?
|
|
.map(|m| m.to_vec()))
|
|
}
|
|
|
|
async fn get_custom_value(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
|
|
Ok(self.custom.get(key)?.map(|v| v.to_vec()))
|
|
}
|
|
|
|
async fn set_custom_value(&self, key: &[u8], value: Vec<u8>) -> Result<Option<Vec<u8>>> {
|
|
let ret = self.custom.insert(key, value)?.map(|v| v.to_vec());
|
|
self.inner.flush_async().await?;
|
|
|
|
Ok(ret)
|
|
}
|
|
|
|
async fn remove_media_content(&self, request: &MediaRequest) -> Result<()> {
|
|
self.media.remove(
|
|
(request.media_type.unique_key().as_str(), request.format.unique_key().as_str())
|
|
.encode(),
|
|
)?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
async fn remove_media_content_for_uri(&self, uri: &MxcUri) -> Result<()> {
|
|
let keys = self.media.scan_prefix(uri.as_str().encode()).keys();
|
|
|
|
let mut batch = sled::Batch::default();
|
|
for key in keys {
|
|
batch.remove(key?);
|
|
}
|
|
|
|
Ok(self.media.apply_batch(batch)?)
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl StateStore for SledStore {
|
|
async fn save_filter(&self, filter_name: &str, filter_id: &str) -> Result<()> {
|
|
self.save_filter(filter_name, filter_id).await
|
|
}
|
|
|
|
async fn save_changes(&self, changes: &StateChanges) -> Result<()> {
|
|
self.save_changes(changes).await
|
|
}
|
|
|
|
async fn get_filter(&self, filter_id: &str) -> Result<Option<String>> {
|
|
self.get_filter(filter_id).await
|
|
}
|
|
|
|
async fn get_sync_token(&self) -> Result<Option<String>> {
|
|
self.get_sync_token().await
|
|
}
|
|
|
|
async fn get_presence_event(&self, user_id: &UserId) -> Result<Option<Raw<PresenceEvent>>> {
|
|
self.get_presence_event(user_id).await
|
|
}
|
|
|
|
async fn get_state_event(
|
|
&self,
|
|
room_id: &RoomId,
|
|
event_type: EventType,
|
|
state_key: &str,
|
|
) -> Result<Option<Raw<AnySyncStateEvent>>> {
|
|
self.get_state_event(room_id, event_type, state_key).await
|
|
}
|
|
|
|
async fn get_state_events(
|
|
&self,
|
|
room_id: &RoomId,
|
|
event_type: EventType,
|
|
) -> Result<Vec<Raw<AnySyncStateEvent>>> {
|
|
self.get_state_events(room_id, event_type).await
|
|
}
|
|
|
|
async fn get_profile(
|
|
&self,
|
|
room_id: &RoomId,
|
|
user_id: &UserId,
|
|
) -> Result<Option<MemberEventContent>> {
|
|
self.get_profile(room_id, user_id).await
|
|
}
|
|
|
|
async fn get_member_event(
|
|
&self,
|
|
room_id: &RoomId,
|
|
state_key: &UserId,
|
|
) -> Result<Option<MemberEvent>> {
|
|
self.get_member_event(room_id, state_key).await
|
|
}
|
|
|
|
async fn get_user_ids(&self, room_id: &RoomId) -> Result<Vec<UserId>> {
|
|
self.get_user_ids_stream(room_id).await.try_collect().await
|
|
}
|
|
|
|
async fn get_invited_user_ids(&self, room_id: &RoomId) -> Result<Vec<UserId>> {
|
|
self.get_invited_user_ids(room_id).await.try_collect().await
|
|
}
|
|
|
|
async fn get_joined_user_ids(&self, room_id: &RoomId) -> Result<Vec<UserId>> {
|
|
self.get_joined_user_ids(room_id).await.try_collect().await
|
|
}
|
|
|
|
async fn get_room_infos(&self) -> Result<Vec<RoomInfo>> {
|
|
self.get_room_infos().await.try_collect().await
|
|
}
|
|
|
|
async fn get_stripped_room_infos(&self) -> Result<Vec<RoomInfo>> {
|
|
self.get_stripped_room_infos().await.try_collect().await
|
|
}
|
|
|
|
async fn get_users_with_display_name(
|
|
&self,
|
|
room_id: &RoomId,
|
|
display_name: &str,
|
|
) -> Result<BTreeSet<UserId>> {
|
|
self.get_users_with_display_name(room_id, display_name).await
|
|
}
|
|
|
|
async fn get_account_data_event(
|
|
&self,
|
|
event_type: EventType,
|
|
) -> Result<Option<Raw<AnyGlobalAccountDataEvent>>> {
|
|
self.get_account_data_event(event_type).await
|
|
}
|
|
|
|
async fn get_room_account_data_event(
|
|
&self,
|
|
room_id: &RoomId,
|
|
event_type: EventType,
|
|
) -> Result<Option<Raw<AnyRoomAccountDataEvent>>> {
|
|
self.get_room_account_data_event(room_id, event_type).await
|
|
}
|
|
|
|
async fn get_user_room_receipt_event(
|
|
&self,
|
|
room_id: &RoomId,
|
|
receipt_type: ReceiptType,
|
|
user_id: &UserId,
|
|
) -> Result<Option<(EventId, Receipt)>> {
|
|
self.get_user_room_receipt_event(room_id, receipt_type, user_id).await
|
|
}
|
|
|
|
async fn get_event_room_receipt_events(
|
|
&self,
|
|
room_id: &RoomId,
|
|
receipt_type: ReceiptType,
|
|
event_id: &EventId,
|
|
) -> Result<Vec<(UserId, Receipt)>> {
|
|
self.get_event_room_receipt_events(room_id, receipt_type, event_id).await
|
|
}
|
|
|
|
async fn get_custom_value(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
|
|
self.get_custom_value(key).await
|
|
}
|
|
|
|
async fn set_custom_value(&self, key: &[u8], value: Vec<u8>) -> Result<Option<Vec<u8>>> {
|
|
self.set_custom_value(key, value).await
|
|
}
|
|
|
|
async fn add_media_content(&self, request: &MediaRequest, data: Vec<u8>) -> Result<()> {
|
|
self.add_media_content(request, data).await
|
|
}
|
|
|
|
async fn get_media_content(&self, request: &MediaRequest) -> Result<Option<Vec<u8>>> {
|
|
self.get_media_content(request).await
|
|
}
|
|
|
|
async fn remove_media_content(&self, request: &MediaRequest) -> Result<()> {
|
|
self.remove_media_content(request).await
|
|
}
|
|
|
|
async fn remove_media_content_for_uri(&self, uri: &MxcUri) -> Result<()> {
|
|
self.remove_media_content_for_uri(uri).await
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod test {
|
|
use std::convert::TryFrom;
|
|
|
|
use matrix_sdk_test::async_test;
|
|
use ruma::{
|
|
api::client::r0::media::get_content_thumbnail::Method,
|
|
event_id,
|
|
events::{
|
|
room::{
|
|
member::{MemberEventContent, MembershipState},
|
|
power_levels::PowerLevelsEventContent,
|
|
},
|
|
AnySyncStateEvent, EventType, Unsigned,
|
|
},
|
|
mxc_uri,
|
|
receipt::ReceiptType,
|
|
room_id,
|
|
serde::Raw,
|
|
uint, user_id, EventId, MilliSecondsSinceUnixEpoch, UserId,
|
|
};
|
|
use serde_json::json;
|
|
|
|
use super::{Result, SledStore, StateChanges};
|
|
use crate::{
|
|
deserialized_responses::MemberEvent,
|
|
media::{MediaFormat, MediaRequest, MediaThumbnailSize, MediaType},
|
|
StateStore,
|
|
};
|
|
|
|
fn user_id() -> UserId {
|
|
user_id!("@example:localhost")
|
|
}
|
|
|
|
fn power_level_event() -> Raw<AnySyncStateEvent> {
|
|
let content = PowerLevelsEventContent::default();
|
|
|
|
let event = json!({
|
|
"event_id": EventId::try_from("$h29iv0s8:example.com").unwrap(),
|
|
"content": content,
|
|
"sender": user_id(),
|
|
"type": "m.room.power_levels",
|
|
"origin_server_ts": 0u64,
|
|
"state_key": "",
|
|
"unsigned": Unsigned::default(),
|
|
});
|
|
|
|
serde_json::from_value(event).unwrap()
|
|
}
|
|
|
|
fn membership_event() -> MemberEvent {
|
|
MemberEvent {
|
|
event_id: EventId::try_from("$h29iv0s8:example.com").unwrap(),
|
|
content: MemberEventContent::new(MembershipState::Join),
|
|
sender: user_id(),
|
|
origin_server_ts: MilliSecondsSinceUnixEpoch::now(),
|
|
state_key: user_id(),
|
|
prev_content: None,
|
|
unsigned: Unsigned::default(),
|
|
}
|
|
}
|
|
|
|
#[async_test]
|
|
async fn test_member_saving() {
|
|
let store = SledStore::open().unwrap();
|
|
let room_id = room_id!("!test:localhost");
|
|
let user_id = user_id();
|
|
|
|
assert!(store.get_member_event(&room_id, &user_id).await.unwrap().is_none());
|
|
let mut changes = StateChanges::default();
|
|
changes
|
|
.members
|
|
.entry(room_id.clone())
|
|
.or_default()
|
|
.insert(user_id.clone(), membership_event());
|
|
|
|
store.save_changes(&changes).await.unwrap();
|
|
assert!(store.get_member_event(&room_id, &user_id).await.unwrap().is_some());
|
|
|
|
let members = store.get_user_ids(&room_id).await.unwrap();
|
|
assert!(!members.is_empty())
|
|
}
|
|
|
|
#[async_test]
|
|
async fn test_power_level_saving() {
|
|
let store = SledStore::open().unwrap();
|
|
let room_id = room_id!("!test:localhost");
|
|
|
|
let raw_event = power_level_event();
|
|
let event = raw_event.deserialize().unwrap();
|
|
|
|
assert!(store
|
|
.get_state_event(&room_id, EventType::RoomPowerLevels, "")
|
|
.await
|
|
.unwrap()
|
|
.is_none());
|
|
let mut changes = StateChanges::default();
|
|
changes.add_state_event(&room_id, event, raw_event);
|
|
|
|
store.save_changes(&changes).await.unwrap();
|
|
assert!(store
|
|
.get_state_event(&room_id, EventType::RoomPowerLevels, "")
|
|
.await
|
|
.unwrap()
|
|
.is_some());
|
|
}
|
|
|
|
#[async_test]
|
|
async fn test_receipts_saving() {
|
|
let store = SledStore::open().unwrap();
|
|
|
|
let room_id = room_id!("!test:localhost");
|
|
|
|
let first_event_id = event_id!("$1435641916114394fHBLK:matrix.org");
|
|
let second_event_id = event_id!("$fHBLK1435641916114394:matrix.org");
|
|
|
|
let first_receipt_event = serde_json::from_value(json!({
|
|
first_event_id.clone(): {
|
|
"m.read": {
|
|
user_id(): {
|
|
"ts": 1436451550453u64
|
|
}
|
|
}
|
|
}
|
|
}))
|
|
.unwrap();
|
|
|
|
let second_receipt_event = serde_json::from_value(json!({
|
|
second_event_id.clone(): {
|
|
"m.read": {
|
|
user_id(): {
|
|
"ts": 1436451551453u64
|
|
}
|
|
}
|
|
}
|
|
}))
|
|
.unwrap();
|
|
|
|
assert!(store
|
|
.get_user_room_receipt_event(&room_id, ReceiptType::Read, &user_id())
|
|
.await
|
|
.unwrap()
|
|
.is_none());
|
|
assert!(store
|
|
.get_event_room_receipt_events(&room_id, ReceiptType::Read, &first_event_id)
|
|
.await
|
|
.unwrap()
|
|
.is_empty());
|
|
assert!(store
|
|
.get_event_room_receipt_events(&room_id, ReceiptType::Read, &second_event_id)
|
|
.await
|
|
.unwrap()
|
|
.is_empty());
|
|
|
|
let mut changes = StateChanges::default();
|
|
changes.add_receipts(&room_id, first_receipt_event);
|
|
|
|
store.save_changes(&changes).await.unwrap();
|
|
assert!(store
|
|
.get_user_room_receipt_event(&room_id, ReceiptType::Read, &user_id())
|
|
.await
|
|
.unwrap()
|
|
.is_some(),);
|
|
assert_eq!(
|
|
store
|
|
.get_event_room_receipt_events(&room_id, ReceiptType::Read, &first_event_id)
|
|
.await
|
|
.unwrap()
|
|
.len(),
|
|
1
|
|
);
|
|
assert!(store
|
|
.get_event_room_receipt_events(&room_id, ReceiptType::Read, &second_event_id)
|
|
.await
|
|
.unwrap()
|
|
.is_empty());
|
|
|
|
let mut changes = StateChanges::default();
|
|
changes.add_receipts(&room_id, second_receipt_event);
|
|
|
|
store.save_changes(&changes).await.unwrap();
|
|
assert!(store
|
|
.get_user_room_receipt_event(&room_id, ReceiptType::Read, &user_id())
|
|
.await
|
|
.unwrap()
|
|
.is_some());
|
|
assert!(store
|
|
.get_event_room_receipt_events(&room_id, ReceiptType::Read, &first_event_id)
|
|
.await
|
|
.unwrap()
|
|
.is_empty());
|
|
assert_eq!(
|
|
store
|
|
.get_event_room_receipt_events(&room_id, ReceiptType::Read, &second_event_id)
|
|
.await
|
|
.unwrap()
|
|
.len(),
|
|
1
|
|
);
|
|
}
|
|
|
|
#[async_test]
|
|
async fn test_media_content() {
|
|
let store = SledStore::open().unwrap();
|
|
|
|
let uri = mxc_uri!("mxc://localhost/media");
|
|
let content: Vec<u8> = "somebinarydata".into();
|
|
|
|
let request_file =
|
|
MediaRequest { media_type: MediaType::Uri(uri.clone()), format: MediaFormat::File };
|
|
|
|
let request_thumbnail = MediaRequest {
|
|
media_type: MediaType::Uri(uri.clone()),
|
|
format: MediaFormat::Thumbnail(MediaThumbnailSize {
|
|
method: Method::Crop,
|
|
width: uint!(100),
|
|
height: uint!(100),
|
|
}),
|
|
};
|
|
|
|
assert!(store.get_media_content(&request_file).await.unwrap().is_none());
|
|
assert!(store.get_media_content(&request_thumbnail).await.unwrap().is_none());
|
|
|
|
store.add_media_content(&request_file, content.clone()).await.unwrap();
|
|
assert!(store.get_media_content(&request_file).await.unwrap().is_some());
|
|
|
|
store.remove_media_content(&request_file).await.unwrap();
|
|
assert!(store.get_media_content(&request_file).await.unwrap().is_none());
|
|
|
|
store.add_media_content(&request_file, content.clone()).await.unwrap();
|
|
assert!(store.get_media_content(&request_file).await.unwrap().is_some());
|
|
|
|
store.add_media_content(&request_thumbnail, content.clone()).await.unwrap();
|
|
assert!(store.get_media_content(&request_thumbnail).await.unwrap().is_some());
|
|
|
|
store.remove_media_content_for_uri(&uri).await.unwrap();
|
|
assert!(store.get_media_content(&request_file).await.unwrap().is_none());
|
|
assert!(store.get_media_content(&request_thumbnail).await.unwrap().is_none());
|
|
}
|
|
|
|
#[async_test]
|
|
async fn test_custom_storage() -> Result<()> {
|
|
let key = "my_key";
|
|
let value = &[0, 1, 2, 3];
|
|
let store = SledStore::open()?;
|
|
|
|
store.set_custom_value(key.as_bytes(), value.to_vec()).await?;
|
|
|
|
let read = store.get_custom_value(key.as_bytes()).await?;
|
|
|
|
assert_eq!(Some(value.as_ref()), read.as_deref());
|
|
|
|
Ok(())
|
|
}
|
|
}
|