base: Finish up the error handling for the new stores

master
Damir Jelić 2021-01-19 12:03:46 +01:00
parent ef95d9b539
commit b8fcc003ea
3 changed files with 113 additions and 84 deletions

View File

@ -14,7 +14,10 @@ use matrix_sdk_common::{
locks::RwLock, locks::RwLock,
}; };
use sled::{transaction::TransactionResult, Config, Db, Transactional, Tree}; use sled::{
transaction::{ConflictableTransactionError, TransactionError},
Config, Db, Transactional, Tree,
};
use tracing::info; use tracing::info;
use crate::{ use crate::{
@ -33,6 +36,15 @@ pub enum StoreError {
Identifier(#[from] matrix_sdk_common::identifiers::Error), Identifier(#[from] matrix_sdk_common::identifiers::Error),
} }
impl From<TransactionError<serde_json::Error>> for StoreError {
fn from(e: TransactionError<serde_json::Error>) -> Self {
match e {
TransactionError::Abort(e) => Self::Json(e),
TransactionError::Storage(e) => Self::Sled(e),
}
}
}
/// A `StateStore` specific result type. /// A `StateStore` specific result type.
pub type Result<T> = std::result::Result<T, StoreError>; pub type Result<T> = std::result::Result<T, StoreError>;
@ -329,7 +341,7 @@ impl SledStore {
pub async fn save_changes(&self, changes: &StateChanges) -> Result<()> { pub async fn save_changes(&self, changes: &StateChanges) -> Result<()> {
let now = SystemTime::now(); let now = SystemTime::now();
let ret: TransactionResult<()> = ( let ret: std::result::Result<(), TransactionError<serde_json::Error>> = (
&self.session, &self.session,
&self.account_data, &self.account_data,
&self.members, &self.members,
@ -385,7 +397,8 @@ impl SledStore {
members.insert( members.insert(
format!("{}{}", room.as_str(), &event.state_key).as_str(), format!("{}{}", room.as_str(), &event.state_key).as_str(),
serde_json::to_vec(&event).unwrap(), serde_json::to_vec(&event)
.map_err(ConflictableTransactionError::Abort)?,
)?; )?;
} }
} }
@ -394,21 +407,26 @@ impl SledStore {
for (user_id, profile) in users { for (user_id, profile) in users {
profiles.insert( profiles.insert(
format!("{}{}", room.as_str(), user_id.as_str()).as_str(), format!("{}{}", room.as_str(), user_id.as_str()).as_str(),
serde_json::to_vec(&profile).unwrap(), serde_json::to_vec(&profile)
.map_err(ConflictableTransactionError::Abort)?,
)?; )?;
} }
} }
for (event_type, event) in &changes.account_data { for (event_type, event) in &changes.account_data {
account_data account_data.insert(
.insert(event_type.as_str(), serde_json::to_vec(&event).unwrap())?; event_type.as_str(),
serde_json::to_vec(&event)
.map_err(ConflictableTransactionError::Abort)?,
)?;
} }
for (room, events) in &changes.room_account_data { for (room, events) in &changes.room_account_data {
for (event_type, event) in events { for (event_type, event) in events {
room_account_data.insert( room_account_data.insert(
format!("{}{}", room.as_str(), event_type).as_str(), format!("{}{}", room.as_str(), event_type).as_str(),
serde_json::to_vec(&event).unwrap(), serde_json::to_vec(&event)
.map_err(ConflictableTransactionError::Abort)?,
)?; )?;
} }
} }
@ -424,30 +442,43 @@ impl SledStore {
event.state_key(), event.state_key(),
) )
.as_bytes(), .as_bytes(),
serde_json::to_vec(&event).unwrap(), serde_json::to_vec(&event)
.map_err(ConflictableTransactionError::Abort)?,
)?; )?;
} }
} }
} }
for (room_id, room_info) in &changes.room_infos { for (room_id, room_info) in &changes.room_infos {
rooms.insert(room_id.as_bytes(), serde_json::to_vec(room_info).unwrap())?; rooms.insert(
room_id.as_bytes(),
serde_json::to_vec(room_info)
.map_err(ConflictableTransactionError::Abort)?,
)?;
} }
for (sender, event) in &changes.presence { for (sender, event) in &changes.presence {
presence.insert(sender.as_bytes(), serde_json::to_vec(&event).unwrap())?; presence.insert(
sender.as_bytes(),
serde_json::to_vec(&event)
.map_err(ConflictableTransactionError::Abort)?,
)?;
} }
for (room_id, info) in &changes.invited_room_info { for (room_id, info) in &changes.invited_room_info {
striped_rooms striped_rooms.insert(
.insert(room_id.as_str(), serde_json::to_vec(&info).unwrap())?; room_id.as_str(),
serde_json::to_vec(&info)
.map_err(ConflictableTransactionError::Abort)?,
)?;
} }
for (room, events) in &changes.stripped_members { for (room, events) in &changes.stripped_members {
for event in events.values() { for event in events.values() {
stripped_members.insert( stripped_members.insert(
format!("{}{}", room.as_str(), &event.state_key).as_str(), format!("{}{}", room.as_str(), &event.state_key).as_str(),
serde_json::to_vec(&event).unwrap(), serde_json::to_vec(&event)
.map_err(ConflictableTransactionError::Abort)?,
)?; )?;
} }
} }
@ -463,7 +494,8 @@ impl SledStore {
event.state_key(), event.state_key(),
) )
.as_bytes(), .as_bytes(),
serde_json::to_vec(&event).unwrap(), serde_json::to_vec(&event)
.map_err(ConflictableTransactionError::Abort)?,
)?; )?;
} }
} }
@ -473,7 +505,7 @@ impl SledStore {
}, },
); );
ret.unwrap(); ret?;
self.inner.flush_async().await?; self.inner.flush_async().await?;

View File

@ -296,7 +296,12 @@ pub enum CryptoStoreError {
// implementations. // implementations.
#[cfg(feature = "sqlite_cryptostore")] #[cfg(feature = "sqlite_cryptostore")]
#[error(transparent)] #[error(transparent)]
DatabaseError(#[from] SqlxError), Database(#[from] SqlxError),
/// Error in the internal database
#[cfg(feature = "sled_cryptostore")]
#[error(transparent)]
Database(#[from] sled::Error),
/// An IO error occurred. /// An IO error occurred.
#[error(transparent)] #[error(transparent)]

View File

@ -21,6 +21,7 @@ use std::{
use dashmap::DashSet; use dashmap::DashSet;
use olm_rs::PicklingMode; use olm_rs::PicklingMode;
pub use sled::Error;
use sled::{ use sled::{
transaction::{ConflictableTransactionError, TransactionError}, transaction::{ConflictableTransactionError, TransactionError},
Config, Db, Transactional, Tree, Config, Db, Transactional, Tree,
@ -38,7 +39,7 @@ use super::{
}; };
use crate::{ use crate::{
identities::{ReadOnlyDevice, UserIdentities}, identities::{ReadOnlyDevice, UserIdentities},
olm::{PickledInboundGroupSession, PickledSession, PrivateCrossSigningIdentity}, olm::{PickledInboundGroupSession, PrivateCrossSigningIdentity},
}; };
/// This needs to be 32 bytes long since AES-GCM requires it, otherwise we will /// This needs to be 32 bytes long since AES-GCM requires it, otherwise we will
@ -70,27 +71,36 @@ pub struct SledStore {
values: Tree, values: Tree,
} }
impl From<TransactionError<serde_json::Error>> for CryptoStoreError {
fn from(e: TransactionError<serde_json::Error>) -> Self {
match e {
TransactionError::Abort(e) => CryptoStoreError::Serialization(e),
TransactionError::Storage(e) => CryptoStoreError::Database(e),
}
}
}
impl SledStore { impl SledStore {
pub fn open_with_passphrase(path: impl AsRef<Path>, passphrase: &str) -> Result<Self> { pub fn open_with_passphrase(path: impl AsRef<Path>, passphrase: &str) -> Result<Self> {
let path = path.as_ref().join("matrix-sdk-crypto"); let path = path.as_ref().join("matrix-sdk-crypto");
let db = Config::new().temporary(false).path(path).open().unwrap(); let db = Config::new().temporary(false).path(path).open()?;
SledStore::open_helper(db, Some(passphrase)) SledStore::open_helper(db, Some(passphrase))
} }
fn open_helper(db: Db, passphrase: Option<&str>) -> Result<Self> { fn open_helper(db: Db, passphrase: Option<&str>) -> Result<Self> {
let account = db.open_tree("account").unwrap(); let account = db.open_tree("account")?;
let private_identity = db.open_tree("private_identity").unwrap(); let private_identity = db.open_tree("private_identity")?;
let sessions = db.open_tree("session").unwrap(); let sessions = db.open_tree("session")?;
let inbound_group_sessions = db.open_tree("inbound_group_sessions").unwrap(); let inbound_group_sessions = db.open_tree("inbound_group_sessions")?;
let tracked_users = db.open_tree("tracked_users").unwrap(); let tracked_users = db.open_tree("tracked_users")?;
let users_for_key_query = db.open_tree("users_for_key_query").unwrap(); let users_for_key_query = db.open_tree("users_for_key_query")?;
let olm_hashes = db.open_tree("olm_hashes").unwrap(); let olm_hashes = db.open_tree("olm_hashes")?;
let devices = db.open_tree("devices").unwrap(); let devices = db.open_tree("devices")?;
let identities = db.open_tree("identities").unwrap(); let identities = db.open_tree("identities")?;
let values = db.open_tree("values").unwrap(); let values = db.open_tree("values")?;
let session_cache = SessionStore::new(); let session_cache = SessionStore::new();
@ -122,8 +132,7 @@ impl SledStore {
fn get_or_create_pickle_key(passphrase: &str, database: &Db) -> Result<PickleKey> { fn get_or_create_pickle_key(passphrase: &str, database: &Db) -> Result<PickleKey> {
let key = if let Some(key) = database let key = if let Some(key) = database
.get("pickle_key") .get("pickle_key")?
.unwrap()
.map(|v| serde_json::from_slice(&v)) .map(|v| serde_json::from_slice(&v))
{ {
PickleKey::from_encrypted(passphrase, key?) PickleKey::from_encrypted(passphrase, key?)
@ -131,9 +140,7 @@ impl SledStore {
} else { } else {
let key = PickleKey::new(); let key = PickleKey::new();
let encrypted = key.encrypt(passphrase); let encrypted = key.encrypt(passphrase);
database database.insert("pickle_key", serde_json::to_vec(&encrypted)?)?;
.insert("pickle_key", serde_json::to_vec(&encrypted)?)
.unwrap();
key key
}; };
@ -148,10 +155,10 @@ impl SledStore {
self.pickle_key.key() self.pickle_key.key()
} }
async fn load_tracked_users(&self) { async fn load_tracked_users(&self) -> Result<()> {
for value in self.tracked_users.iter() { for value in self.tracked_users.iter() {
let (user, dirty) = value.unwrap(); let (user, dirty) = value?;
let user = UserId::try_from(String::from_utf8_lossy(&user).to_string()).unwrap(); let user = UserId::try_from(String::from_utf8_lossy(&user).to_string())?;
let dirty = dirty.get(0).map(|d| *d == 1).unwrap_or(true); let dirty = dirty.get(0).map(|d| *d == 1).unwrap_or(true);
self.tracked_users_cache.insert(user.clone()); self.tracked_users_cache.insert(user.clone());
@ -160,6 +167,8 @@ impl SledStore {
self.users_for_key_query_cache.insert(user); self.users_for_key_query_cache.insert(user);
} }
} }
Ok(())
} }
pub async fn save_changes(&self, changes: Changes) -> Result<()> { pub async fn save_changes(&self, changes: Changes) -> Result<()> {
@ -170,7 +179,7 @@ impl SledStore {
}; };
let private_identity_pickle = if let Some(i) = changes.private_identity { let private_identity_pickle = if let Some(i) = changes.private_identity {
Some(i.pickle(DEFAULT_PICKLE.as_bytes()).await.unwrap()) Some(i.pickle(DEFAULT_PICKLE.as_bytes()).await?)
} else { } else {
None None
}; };
@ -285,14 +294,8 @@ impl SledStore {
}, },
); );
if let Err(e) = ret { ret?;
match e { self.inner.flush_async().await?;
TransactionError::Abort(e) => return Err(e.into()),
TransactionError::Storage(e) => panic!("Internal sled error {:?}", e),
}
}
self.inner.flush_async().await.unwrap();
Ok(()) Ok(())
} }
@ -301,10 +304,10 @@ impl SledStore {
#[async_trait] #[async_trait]
impl CryptoStore for SledStore { impl CryptoStore for SledStore {
async fn load_account(&self) -> Result<Option<ReadOnlyAccount>> { async fn load_account(&self) -> Result<Option<ReadOnlyAccount>> {
if let Some(pickle) = self.account.get("account").unwrap() { if let Some(pickle) = self.account.get("account")? {
let pickle = serde_json::from_slice(&pickle)?; let pickle = serde_json::from_slice(&pickle)?;
self.load_tracked_users().await; self.load_tracked_users().await?;
Ok(Some(ReadOnlyAccount::from_pickle( Ok(Some(ReadOnlyAccount::from_pickle(
pickle, pickle,
@ -318,8 +321,7 @@ impl CryptoStore for SledStore {
async fn save_account(&self, account: ReadOnlyAccount) -> Result<()> { async fn save_account(&self, account: ReadOnlyAccount) -> Result<()> {
let pickle = account.pickle(self.get_pickle_mode()).await; let pickle = account.pickle(self.get_pickle_mode()).await;
self.account self.account
.insert("account", serde_json::to_vec(&pickle)?) .insert("account", serde_json::to_vec(&pickle)?)?;
.unwrap();
Ok(()) Ok(())
} }
@ -335,22 +337,19 @@ impl CryptoStore for SledStore {
.ok_or(CryptoStoreError::AccountUnset)?; .ok_or(CryptoStoreError::AccountUnset)?;
if self.session_cache.get(sender_key).is_none() { if self.session_cache.get(sender_key).is_none() {
let sessions: std::result::Result<Vec<PickledSession>, _> = self let sessions: Result<Vec<Session>> = self
.sessions .sessions
.scan_prefix(sender_key) .scan_prefix(sender_key)
.map(|s| serde_json::from_slice(&s.unwrap().1)) .map(|s| serde_json::from_slice(&s?.1).map_err(CryptoStoreError::Serialization))
.collect();
let sessions: std::result::Result<Vec<Session>, _> = sessions?
.into_iter()
.map(|p| { .map(|p| {
Session::from_pickle( Session::from_pickle(
account.user_id.clone(), account.user_id.clone(),
account.device_id.clone(), account.device_id.clone(),
account.identity_keys.clone(), account.identity_keys.clone(),
p, p?,
self.get_pickle_mode(), self.get_pickle_mode(),
) )
.map_err(CryptoStoreError::SessionUnpickling)
}) })
.collect(); .collect();
@ -369,8 +368,7 @@ impl CryptoStore for SledStore {
let key = format!("{}{}{}", room_id, sender_key, session_id); let key = format!("{}{}{}", room_id, sender_key, session_id);
let pickle = self let pickle = self
.inbound_group_sessions .inbound_group_sessions
.get(&key) .get(&key)?
.unwrap()
.map(|p| serde_json::from_slice(&p)); .map(|p| serde_json::from_slice(&p));
if let Some(pickle) = pickle { if let Some(pickle) = pickle {
@ -384,10 +382,10 @@ impl CryptoStore for SledStore {
} }
async fn get_inbound_group_sessions(&self) -> Result<Vec<InboundGroupSession>> { async fn get_inbound_group_sessions(&self) -> Result<Vec<InboundGroupSession>> {
let pickles: std::result::Result<Vec<PickledInboundGroupSession>, _> = self let pickles: Result<Vec<PickledInboundGroupSession>> = self
.inbound_group_sessions .inbound_group_sessions
.iter() .iter()
.map(|p| serde_json::from_slice(&p.unwrap().1)) .map(|p| serde_json::from_slice(&p?.1).map_err(CryptoStoreError::Serialization))
.collect(); .collect();
Ok(pickles? Ok(pickles?
@ -421,9 +419,7 @@ impl CryptoStore for SledStore {
self.users_for_key_query_cache.remove(user); self.users_for_key_query_cache.remove(user);
} }
self.tracked_users self.tracked_users.insert(user.as_str(), &[dirty as u8])?;
.insert(user.as_str(), &[dirty as u8])
.unwrap();
Ok(already_added) Ok(already_added)
} }
@ -435,7 +431,7 @@ impl CryptoStore for SledStore {
) -> Result<Option<ReadOnlyDevice>> { ) -> Result<Option<ReadOnlyDevice>> {
let key = format!("{}{}", user_id, device_id); let key = format!("{}{}", user_id, device_id);
if let Some(d) = self.devices.get(key).unwrap() { if let Some(d) = self.devices.get(key)? {
Ok(Some(serde_json::from_slice(&d)?)) Ok(Some(serde_json::from_slice(&d)?))
} else { } else {
Ok(None) Ok(None)
@ -446,51 +442,48 @@ impl CryptoStore for SledStore {
&self, &self,
user_id: &UserId, user_id: &UserId,
) -> Result<HashMap<DeviceIdBox, ReadOnlyDevice>> { ) -> Result<HashMap<DeviceIdBox, ReadOnlyDevice>> {
let devices: std::result::Result<Vec<ReadOnlyDevice>, _> = self self.devices
.devices
.scan_prefix(user_id.as_str()) .scan_prefix(user_id.as_str())
.map(|d| serde_json::from_slice(&d.unwrap().1)) .map(|d| serde_json::from_slice(&d?.1).map_err(CryptoStoreError::Serialization))
.collect(); .map(|d| {
let d: ReadOnlyDevice = d?;
Ok(devices? Ok((d.device_id().to_owned(), d))
.into_iter() })
.map(|d| (d.device_id().to_owned(), d)) .collect()
.collect())
} }
async fn get_user_identity(&self, user_id: &UserId) -> Result<Option<UserIdentities>> { async fn get_user_identity(&self, user_id: &UserId) -> Result<Option<UserIdentities>> {
Ok(self Ok(self
.identities .identities
.get(user_id.as_str()) .get(user_id.as_str())?
.unwrap() .map(|i| serde_json::from_slice(&i))
.map(|i| serde_json::from_slice(&i).unwrap())) .transpose()?)
} }
async fn save_value(&self, key: String, value: String) -> Result<()> { async fn save_value(&self, key: String, value: String) -> Result<()> {
self.values.insert(key.as_str(), value.as_str()).unwrap(); self.values.insert(key.as_str(), value.as_str())?;
Ok(()) Ok(())
} }
async fn remove_value(&self, key: &str) -> Result<()> { async fn remove_value(&self, key: &str) -> Result<()> {
self.values.remove(key).unwrap(); self.values.remove(key)?;
Ok(()) Ok(())
} }
async fn get_value(&self, key: &str) -> Result<Option<String>> { async fn get_value(&self, key: &str) -> Result<Option<String>> {
Ok(self Ok(self
.values .values
.get(key) .get(key)?
.unwrap()
.map(|v| String::from_utf8_lossy(&v).to_string())) .map(|v| String::from_utf8_lossy(&v).to_string()))
} }
async fn load_identity(&self) -> Result<Option<PrivateCrossSigningIdentity>> { async fn load_identity(&self) -> Result<Option<PrivateCrossSigningIdentity>> {
if let Some(i) = self.private_identity.get("identity").unwrap() { if let Some(i) = self.private_identity.get("identity")? {
let pickle = serde_json::from_slice(&i)?; let pickle = serde_json::from_slice(&i)?;
Ok(Some( Ok(Some(
PrivateCrossSigningIdentity::from_pickle(pickle, self.get_pickle_key()) PrivateCrossSigningIdentity::from_pickle(pickle, self.get_pickle_key())
.await .await
.unwrap(), .map_err(|_| CryptoStoreError::UnpicklingError)?,
)) ))
} else { } else {
Ok(None) Ok(None)
@ -500,8 +493,7 @@ impl CryptoStore for SledStore {
async fn is_message_known(&self, message_hash: &crate::olm::OlmMessageHash) -> Result<bool> { async fn is_message_known(&self, message_hash: &crate::olm::OlmMessageHash) -> Result<bool> {
Ok(self Ok(self
.olm_hashes .olm_hashes
.contains_key(serde_json::to_vec(message_hash)?) .contains_key(serde_json::to_vec(message_hash)?)?)
.unwrap())
} }
} }