diff --git a/matrix_sdk_base/src/store.rs b/matrix_sdk_base/src/store.rs index fedbfcc2..0ac7f54f 100644 --- a/matrix_sdk_base/src/store.rs +++ b/matrix_sdk_base/src/store.rs @@ -14,7 +14,10 @@ use matrix_sdk_common::{ locks::RwLock, }; -use sled::{transaction::TransactionResult, Config, Db, Transactional, Tree}; +use sled::{ + transaction::{ConflictableTransactionError, TransactionError}, + Config, Db, Transactional, Tree, +}; use tracing::info; use crate::{ @@ -33,6 +36,15 @@ pub enum StoreError { Identifier(#[from] matrix_sdk_common::identifiers::Error), } +impl From> for StoreError { + fn from(e: TransactionError) -> Self { + match e { + TransactionError::Abort(e) => Self::Json(e), + TransactionError::Storage(e) => Self::Sled(e), + } + } +} + /// A `StateStore` specific result type. pub type Result = std::result::Result; @@ -329,7 +341,7 @@ impl SledStore { pub async fn save_changes(&self, changes: &StateChanges) -> Result<()> { let now = SystemTime::now(); - let ret: TransactionResult<()> = ( + let ret: std::result::Result<(), TransactionError> = ( &self.session, &self.account_data, &self.members, @@ -385,7 +397,8 @@ impl SledStore { members.insert( format!("{}{}", room.as_str(), &event.state_key).as_str(), - serde_json::to_vec(&event).unwrap(), + serde_json::to_vec(&event) + .map_err(ConflictableTransactionError::Abort)?, )?; } } @@ -394,21 +407,26 @@ impl SledStore { for (user_id, profile) in users { profiles.insert( format!("{}{}", room.as_str(), user_id.as_str()).as_str(), - serde_json::to_vec(&profile).unwrap(), + serde_json::to_vec(&profile) + .map_err(ConflictableTransactionError::Abort)?, )?; } } for (event_type, event) in &changes.account_data { - account_data - .insert(event_type.as_str(), serde_json::to_vec(&event).unwrap())?; + account_data.insert( + event_type.as_str(), + serde_json::to_vec(&event) + .map_err(ConflictableTransactionError::Abort)?, + )?; } for (room, events) in &changes.room_account_data { for (event_type, event) in events { room_account_data.insert( format!("{}{}", room.as_str(), event_type).as_str(), - serde_json::to_vec(&event).unwrap(), + serde_json::to_vec(&event) + .map_err(ConflictableTransactionError::Abort)?, )?; } } @@ -424,30 +442,43 @@ impl SledStore { event.state_key(), ) .as_bytes(), - serde_json::to_vec(&event).unwrap(), + serde_json::to_vec(&event) + .map_err(ConflictableTransactionError::Abort)?, )?; } } } for (room_id, room_info) in &changes.room_infos { - rooms.insert(room_id.as_bytes(), serde_json::to_vec(room_info).unwrap())?; + rooms.insert( + room_id.as_bytes(), + serde_json::to_vec(room_info) + .map_err(ConflictableTransactionError::Abort)?, + )?; } for (sender, event) in &changes.presence { - presence.insert(sender.as_bytes(), serde_json::to_vec(&event).unwrap())?; + presence.insert( + sender.as_bytes(), + serde_json::to_vec(&event) + .map_err(ConflictableTransactionError::Abort)?, + )?; } for (room_id, info) in &changes.invited_room_info { - striped_rooms - .insert(room_id.as_str(), serde_json::to_vec(&info).unwrap())?; + striped_rooms.insert( + room_id.as_str(), + serde_json::to_vec(&info) + .map_err(ConflictableTransactionError::Abort)?, + )?; } for (room, events) in &changes.stripped_members { for event in events.values() { stripped_members.insert( format!("{}{}", room.as_str(), &event.state_key).as_str(), - serde_json::to_vec(&event).unwrap(), + serde_json::to_vec(&event) + .map_err(ConflictableTransactionError::Abort)?, )?; } } @@ -463,7 +494,8 @@ impl SledStore { event.state_key(), ) .as_bytes(), - serde_json::to_vec(&event).unwrap(), + serde_json::to_vec(&event) + .map_err(ConflictableTransactionError::Abort)?, )?; } } @@ -473,7 +505,7 @@ impl SledStore { }, ); - ret.unwrap(); + ret?; self.inner.flush_async().await?; diff --git a/matrix_sdk_crypto/src/store/mod.rs b/matrix_sdk_crypto/src/store/mod.rs index df430719..caaed436 100644 --- a/matrix_sdk_crypto/src/store/mod.rs +++ b/matrix_sdk_crypto/src/store/mod.rs @@ -296,7 +296,12 @@ pub enum CryptoStoreError { // implementations. #[cfg(feature = "sqlite_cryptostore")] #[error(transparent)] - DatabaseError(#[from] SqlxError), + Database(#[from] SqlxError), + + /// Error in the internal database + #[cfg(feature = "sled_cryptostore")] + #[error(transparent)] + Database(#[from] sled::Error), /// An IO error occurred. #[error(transparent)] diff --git a/matrix_sdk_crypto/src/store/sled.rs b/matrix_sdk_crypto/src/store/sled.rs index fb5f3e2f..9db91d1f 100644 --- a/matrix_sdk_crypto/src/store/sled.rs +++ b/matrix_sdk_crypto/src/store/sled.rs @@ -21,6 +21,7 @@ use std::{ use dashmap::DashSet; use olm_rs::PicklingMode; +pub use sled::Error; use sled::{ transaction::{ConflictableTransactionError, TransactionError}, Config, Db, Transactional, Tree, @@ -38,7 +39,7 @@ use super::{ }; use crate::{ identities::{ReadOnlyDevice, UserIdentities}, - olm::{PickledInboundGroupSession, PickledSession, PrivateCrossSigningIdentity}, + olm::{PickledInboundGroupSession, PrivateCrossSigningIdentity}, }; /// This needs to be 32 bytes long since AES-GCM requires it, otherwise we will @@ -70,27 +71,36 @@ pub struct SledStore { values: Tree, } +impl From> for CryptoStoreError { + fn from(e: TransactionError) -> Self { + match e { + TransactionError::Abort(e) => CryptoStoreError::Serialization(e), + TransactionError::Storage(e) => CryptoStoreError::Database(e), + } + } +} + impl SledStore { pub fn open_with_passphrase(path: impl AsRef, passphrase: &str) -> Result { let path = path.as_ref().join("matrix-sdk-crypto"); - let db = Config::new().temporary(false).path(path).open().unwrap(); + let db = Config::new().temporary(false).path(path).open()?; SledStore::open_helper(db, Some(passphrase)) } fn open_helper(db: Db, passphrase: Option<&str>) -> Result { - let account = db.open_tree("account").unwrap(); - let private_identity = db.open_tree("private_identity").unwrap(); + let account = db.open_tree("account")?; + let private_identity = db.open_tree("private_identity")?; - let sessions = db.open_tree("session").unwrap(); - let inbound_group_sessions = db.open_tree("inbound_group_sessions").unwrap(); - let tracked_users = db.open_tree("tracked_users").unwrap(); - let users_for_key_query = db.open_tree("users_for_key_query").unwrap(); - let olm_hashes = db.open_tree("olm_hashes").unwrap(); + let sessions = db.open_tree("session")?; + let inbound_group_sessions = db.open_tree("inbound_group_sessions")?; + let tracked_users = db.open_tree("tracked_users")?; + let users_for_key_query = db.open_tree("users_for_key_query")?; + let olm_hashes = db.open_tree("olm_hashes")?; - let devices = db.open_tree("devices").unwrap(); - let identities = db.open_tree("identities").unwrap(); - let values = db.open_tree("values").unwrap(); + let devices = db.open_tree("devices")?; + let identities = db.open_tree("identities")?; + let values = db.open_tree("values")?; let session_cache = SessionStore::new(); @@ -122,8 +132,7 @@ impl SledStore { fn get_or_create_pickle_key(passphrase: &str, database: &Db) -> Result { let key = if let Some(key) = database - .get("pickle_key") - .unwrap() + .get("pickle_key")? .map(|v| serde_json::from_slice(&v)) { PickleKey::from_encrypted(passphrase, key?) @@ -131,9 +140,7 @@ impl SledStore { } else { let key = PickleKey::new(); let encrypted = key.encrypt(passphrase); - database - .insert("pickle_key", serde_json::to_vec(&encrypted)?) - .unwrap(); + database.insert("pickle_key", serde_json::to_vec(&encrypted)?)?; key }; @@ -148,10 +155,10 @@ impl SledStore { self.pickle_key.key() } - async fn load_tracked_users(&self) { + async fn load_tracked_users(&self) -> Result<()> { for value in self.tracked_users.iter() { - let (user, dirty) = value.unwrap(); - let user = UserId::try_from(String::from_utf8_lossy(&user).to_string()).unwrap(); + let (user, dirty) = value?; + let user = UserId::try_from(String::from_utf8_lossy(&user).to_string())?; let dirty = dirty.get(0).map(|d| *d == 1).unwrap_or(true); self.tracked_users_cache.insert(user.clone()); @@ -160,6 +167,8 @@ impl SledStore { self.users_for_key_query_cache.insert(user); } } + + Ok(()) } pub async fn save_changes(&self, changes: Changes) -> Result<()> { @@ -170,7 +179,7 @@ impl SledStore { }; let private_identity_pickle = if let Some(i) = changes.private_identity { - Some(i.pickle(DEFAULT_PICKLE.as_bytes()).await.unwrap()) + Some(i.pickle(DEFAULT_PICKLE.as_bytes()).await?) } else { None }; @@ -285,14 +294,8 @@ impl SledStore { }, ); - if let Err(e) = ret { - match e { - TransactionError::Abort(e) => return Err(e.into()), - TransactionError::Storage(e) => panic!("Internal sled error {:?}", e), - } - } - - self.inner.flush_async().await.unwrap(); + ret?; + self.inner.flush_async().await?; Ok(()) } @@ -301,10 +304,10 @@ impl SledStore { #[async_trait] impl CryptoStore for SledStore { async fn load_account(&self) -> Result> { - if let Some(pickle) = self.account.get("account").unwrap() { + if let Some(pickle) = self.account.get("account")? { let pickle = serde_json::from_slice(&pickle)?; - self.load_tracked_users().await; + self.load_tracked_users().await?; Ok(Some(ReadOnlyAccount::from_pickle( pickle, @@ -318,8 +321,7 @@ impl CryptoStore for SledStore { async fn save_account(&self, account: ReadOnlyAccount) -> Result<()> { let pickle = account.pickle(self.get_pickle_mode()).await; self.account - .insert("account", serde_json::to_vec(&pickle)?) - .unwrap(); + .insert("account", serde_json::to_vec(&pickle)?)?; Ok(()) } @@ -335,22 +337,19 @@ impl CryptoStore for SledStore { .ok_or(CryptoStoreError::AccountUnset)?; if self.session_cache.get(sender_key).is_none() { - let sessions: std::result::Result, _> = self + let sessions: Result> = self .sessions .scan_prefix(sender_key) - .map(|s| serde_json::from_slice(&s.unwrap().1)) - .collect(); - - let sessions: std::result::Result, _> = sessions? - .into_iter() + .map(|s| serde_json::from_slice(&s?.1).map_err(CryptoStoreError::Serialization)) .map(|p| { Session::from_pickle( account.user_id.clone(), account.device_id.clone(), account.identity_keys.clone(), - p, + p?, self.get_pickle_mode(), ) + .map_err(CryptoStoreError::SessionUnpickling) }) .collect(); @@ -369,8 +368,7 @@ impl CryptoStore for SledStore { let key = format!("{}{}{}", room_id, sender_key, session_id); let pickle = self .inbound_group_sessions - .get(&key) - .unwrap() + .get(&key)? .map(|p| serde_json::from_slice(&p)); if let Some(pickle) = pickle { @@ -384,10 +382,10 @@ impl CryptoStore for SledStore { } async fn get_inbound_group_sessions(&self) -> Result> { - let pickles: std::result::Result, _> = self + let pickles: Result> = self .inbound_group_sessions .iter() - .map(|p| serde_json::from_slice(&p.unwrap().1)) + .map(|p| serde_json::from_slice(&p?.1).map_err(CryptoStoreError::Serialization)) .collect(); Ok(pickles? @@ -421,9 +419,7 @@ impl CryptoStore for SledStore { self.users_for_key_query_cache.remove(user); } - self.tracked_users - .insert(user.as_str(), &[dirty as u8]) - .unwrap(); + self.tracked_users.insert(user.as_str(), &[dirty as u8])?; Ok(already_added) } @@ -435,7 +431,7 @@ impl CryptoStore for SledStore { ) -> Result> { let key = format!("{}{}", user_id, device_id); - if let Some(d) = self.devices.get(key).unwrap() { + if let Some(d) = self.devices.get(key)? { Ok(Some(serde_json::from_slice(&d)?)) } else { Ok(None) @@ -446,51 +442,48 @@ impl CryptoStore for SledStore { &self, user_id: &UserId, ) -> Result> { - let devices: std::result::Result, _> = self - .devices + self.devices .scan_prefix(user_id.as_str()) - .map(|d| serde_json::from_slice(&d.unwrap().1)) - .collect(); - - Ok(devices? - .into_iter() - .map(|d| (d.device_id().to_owned(), d)) - .collect()) + .map(|d| serde_json::from_slice(&d?.1).map_err(CryptoStoreError::Serialization)) + .map(|d| { + let d: ReadOnlyDevice = d?; + Ok((d.device_id().to_owned(), d)) + }) + .collect() } async fn get_user_identity(&self, user_id: &UserId) -> Result> { Ok(self .identities - .get(user_id.as_str()) - .unwrap() - .map(|i| serde_json::from_slice(&i).unwrap())) + .get(user_id.as_str())? + .map(|i| serde_json::from_slice(&i)) + .transpose()?) } async fn save_value(&self, key: String, value: String) -> Result<()> { - self.values.insert(key.as_str(), value.as_str()).unwrap(); + self.values.insert(key.as_str(), value.as_str())?; Ok(()) } async fn remove_value(&self, key: &str) -> Result<()> { - self.values.remove(key).unwrap(); + self.values.remove(key)?; Ok(()) } async fn get_value(&self, key: &str) -> Result> { Ok(self .values - .get(key) - .unwrap() + .get(key)? .map(|v| String::from_utf8_lossy(&v).to_string())) } async fn load_identity(&self) -> Result> { - if let Some(i) = self.private_identity.get("identity").unwrap() { + if let Some(i) = self.private_identity.get("identity")? { let pickle = serde_json::from_slice(&i)?; Ok(Some( PrivateCrossSigningIdentity::from_pickle(pickle, self.get_pickle_key()) .await - .unwrap(), + .map_err(|_| CryptoStoreError::UnpicklingError)?, )) } else { Ok(None) @@ -500,8 +493,7 @@ impl CryptoStore for SledStore { async fn is_message_known(&self, message_hash: &crate::olm::OlmMessageHash) -> Result { Ok(self .olm_hashes - .contains_key(serde_json::to_vec(message_hash)?) - .unwrap()) + .contains_key(serde_json::to_vec(message_hash)?)?) } }