base: Properly handle crypto related errors in the sled store

master
Damir Jelić 2021-01-20 16:27:59 +01:00
parent 4a06c9e82d
commit 0a6b0e5804
3 changed files with 51 additions and 27 deletions

View File

@ -25,8 +25,6 @@ use matrix_sdk_common::{
locks::RwLock, locks::RwLock,
}; };
use sled::transaction::TransactionError;
use crate::{ use crate::{
deserialized_responses::{MemberEvent, StrippedMemberEvent}, deserialized_responses::{MemberEvent, StrippedMemberEvent},
rooms::{RoomInfo, RoomType, StrippedRoom}, rooms::{RoomInfo, RoomType, StrippedRoom},
@ -44,15 +42,12 @@ pub enum StoreError {
Json(#[from] serde_json::Error), Json(#[from] serde_json::Error),
#[error(transparent)] #[error(transparent)]
Identifier(#[from] matrix_sdk_common::identifiers::Error), Identifier(#[from] matrix_sdk_common::identifiers::Error),
} #[error("The store failed to be unlocked")]
StoreLocked,
impl From<TransactionError<serde_json::Error>> for StoreError { #[error("The store is not encrypted but was tried to be opened with a passphrase")]
fn from(e: TransactionError<serde_json::Error>) -> Self { UnencryptedStore,
match e { #[error("Error encrypting or decrypting data from the store: {0}")]
TransactionError::Abort(e) => Self::Json(e), Encryption(String),
TransactionError::Storage(e) => Self::Sled(e),
}
}
} }
/// A `StateStore` specific result type. /// A `StateStore` specific result type.

View File

@ -45,6 +45,35 @@ pub enum DatabaseType {
Encrypted(store_key::EncryptedStoreKey), Encrypted(store_key::EncryptedStoreKey),
} }
#[derive(Debug, thiserror::Error)]
pub enum SerializationError {
#[error(transparent)]
Json(#[from] serde_json::Error),
#[error(transparent)]
Encryption(#[from] store_key::Error),
}
impl From<TransactionError<SerializationError>> for StoreError {
fn from(e: TransactionError<SerializationError>) -> Self {
match e {
TransactionError::Abort(e) => e.into(),
TransactionError::Storage(e) => StoreError::Sled(e),
}
}
}
impl From<SerializationError> for StoreError {
fn from(e: SerializationError) -> Self {
match e {
SerializationError::Json(e) => StoreError::Json(e),
SerializationError::Encryption(e) => match e {
store_key::Error::Serialization(e) => StoreError::Json(e),
store_key::Error::Encryption(e) => StoreError::Encryption(e),
},
}
}
}
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct SledStore { pub struct SledStore {
inner: Db, inner: Db,
@ -119,10 +148,9 @@ impl SledStore {
let store_key = if let Some(key) = store_key { let store_key = if let Some(key) = store_key {
if let DatabaseType::Encrypted(k) = key { if let DatabaseType::Encrypted(k) = key {
let key = StoreKey::import(passphrase, k).unwrap(); StoreKey::import(passphrase, k).map_err(|_| StoreError::StoreLocked)?
key
} else { } else {
panic!("Trying to open an unencrypted store with a passphrase"); return Err(StoreError::UnencryptedStore);
} }
} else { } else {
let key = StoreKey::new(); let key = StoreKey::new();
@ -144,24 +172,24 @@ impl SledStore {
fn serialize_event( fn serialize_event(
&self, &self,
event: &impl Serialize, event: &impl Serialize,
) -> std::result::Result<Vec<u8>, serde_json::Error> { ) -> std::result::Result<Vec<u8>, SerializationError> {
if let Some(key) = &*self.store_key { if let Some(key) = &*self.store_key {
let encrypted = key.encrypt(event).unwrap(); let encrypted = key.encrypt(event)?;
serde_json::to_vec(&encrypted) Ok(serde_json::to_vec(&encrypted)?)
} else { } else {
serde_json::to_vec(event) Ok(serde_json::to_vec(event)?)
} }
} }
fn deserialize_event<T: for<'b> Deserialize<'b>>( fn deserialize_event<T: for<'b> Deserialize<'b>>(
&self, &self,
event: &[u8], event: &[u8],
) -> std::result::Result<T, serde_json::Error> { ) -> std::result::Result<T, SerializationError> {
if let Some(key) = &*self.store_key { if let Some(key) = &*self.store_key {
let encrypted: EncryptedEvent = serde_json::from_slice(&event)?; let encrypted: EncryptedEvent = serde_json::from_slice(&event)?;
Ok(key.decrypt(encrypted).unwrap()) Ok(key.decrypt(encrypted)?)
} else { } else {
serde_json::from_slice(event) Ok(serde_json::from_slice(event)?)
} }
} }
@ -189,7 +217,7 @@ impl SledStore {
pub async fn save_changes(&self, changes: &StateChanges) -> Result<()> { pub async fn save_changes(&self, changes: &StateChanges) -> Result<()> {
let now = SystemTime::now(); let now = SystemTime::now();
let ret: std::result::Result<(), TransactionError<serde_json::Error>> = ( let ret: std::result::Result<(), TransactionError<SerializationError>> = (
&self.session, &self.session,
&self.account_data, &self.account_data,
&self.members, &self.members,
@ -440,7 +468,7 @@ impl SledStore {
stream::iter( stream::iter(
self.room_info self.room_info
.iter() .iter()
.map(move |r| db.deserialize_event(&r?.1).map_err(StoreError::Json)), .map(move |r| db.deserialize_event(&r?.1).map_err(|e| e.into())),
) )
} }
} }

View File

@ -44,8 +44,6 @@ pub enum Error {
Serialization(#[from] serde_json::Error), Serialization(#[from] serde_json::Error),
#[error("Error encrypting or decrypting an event {0}")] #[error("Error encrypting or decrypting an event {0}")]
Encryption(String), Encryption(String),
#[error("Unknown ciphertext version")]
InvalidVersion,
} }
impl From<EncryptionError> for Error { impl From<EncryptionError> for Error {
@ -129,6 +127,7 @@ impl StoreKey {
Default::default() Default::default()
} }
/// Expand the given passphrase into a KEY_SIZE long key.
fn expand_key(passphrase: &str, salt: &[u8], rounds: u32) -> Zeroizing<Vec<u8>> { fn expand_key(passphrase: &str, salt: &[u8], rounds: u32) -> Zeroizing<Vec<u8>> {
let mut key = Zeroizing::from(vec![0u8; KEY_SIZE]); let mut key = Zeroizing::from(vec![0u8; KEY_SIZE]);
pbkdf2::<Hmac<Sha256>>(passphrase.as_bytes(), &salt, rounds, &mut *key); pbkdf2::<Hmac<Sha256>>(passphrase.as_bytes(), &salt, rounds, &mut *key);
@ -172,7 +171,7 @@ impl StoreKey {
fn get_nonce() -> Vec<u8> { fn get_nonce() -> Vec<u8> {
let mut nonce = vec![0u8; XNONCE_SIZE]; let mut nonce = vec![0u8; XNONCE_SIZE];
getrandom(&mut nonce).expect("Can't generate nonce"); getrandom(&mut nonce).expect("Can't get random nonce");
nonce nonce
} }
@ -194,7 +193,9 @@ impl StoreKey {
pub fn decrypt<T: for<'b> Deserialize<'b>>(&self, event: EncryptedEvent) -> Result<T, Error> { pub fn decrypt<T: for<'b> Deserialize<'b>>(&self, event: EncryptedEvent) -> Result<T, Error> {
if event.version != VERSION { if event.version != VERSION {
return Err(Error::InvalidVersion); return Err(Error::Encryption(
"Error decrypting: Unknown ciphertext version".to_string(),
));
} }
let cipher = XChaCha20Poly1305::new(self.key()); let cipher = XChaCha20Poly1305::new(self.key());