From 1746690eda07768ea739eb7a9c27ce2ea756ac85 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Damir=20Jeli=C4=87?= Date: Mon, 18 Jan 2021 13:38:00 +0100 Subject: [PATCH] crypto: Add a sled cryptostore --- matrix_sdk/Cargo.toml | 6 +- matrix_sdk_base/Cargo.toml | 4 +- matrix_sdk_base/src/client.rs | 4 +- matrix_sdk_crypto/Cargo.toml | 2 + matrix_sdk_crypto/src/machine.rs | 19 +- matrix_sdk_crypto/src/store/mod.rs | 2 + matrix_sdk_crypto/src/store/sled.rs | 509 ++++++++++++++++++++++++++++ 7 files changed, 530 insertions(+), 16 deletions(-) create mode 100644 matrix_sdk_crypto/src/store/sled.rs diff --git a/matrix_sdk/Cargo.toml b/matrix_sdk/Cargo.toml index 9219005c..df579477 100644 --- a/matrix_sdk/Cargo.toml +++ b/matrix_sdk/Cargo.toml @@ -15,17 +15,17 @@ features = ["docs"] rustdoc-args = ["--cfg", "feature=\"docs\""] [features] -default = ["encryption", "sqlite_cryptostore", "messages", "native-tls"] +default = ["encryption", "sled_cryptostore", "messages", "native-tls"] messages = ["matrix-sdk-base/messages"] encryption = ["matrix-sdk-base/encryption", "dashmap"] -sqlite_cryptostore = ["matrix-sdk-base/sqlite_cryptostore"] +sled_cryptostore = ["matrix-sdk-base/sled_cryptostore"] unstable-synapse-quirks = ["matrix-sdk-base/unstable-synapse-quirks"] native-tls = ["reqwest/native-tls"] rustls-tls = ["reqwest/rustls-tls"] socks = ["reqwest/socks"] -docs = ["encryption", "sqlite_cryptostore", "messages"] +docs = ["encryption", "sled_cryptostore", "messages"] [dependencies] dashmap = { version = "4.0.1", optional = true } diff --git a/matrix_sdk_base/Cargo.toml b/matrix_sdk_base/Cargo.toml index 5569c448..542a4b54 100644 --- a/matrix_sdk_base/Cargo.toml +++ b/matrix_sdk_base/Cargo.toml @@ -18,10 +18,10 @@ rustdoc-args = ["--cfg", "feature=\"docs\""] default = [] messages = [] encryption = ["matrix-sdk-crypto"] -sqlite_cryptostore = ["matrix-sdk-crypto/sqlite_cryptostore"] +sled_cryptostore = ["matrix-sdk-crypto/sled_cryptostore"] unstable-synapse-quirks = ["matrix-sdk-common/unstable-synapse-quirks"] -docs = ["encryption", "sqlite_cryptostore", "messages"] +docs = ["encryption", "sled_cryptostore", "messages"] [dependencies] dashmap= "4.0.1" diff --git a/matrix_sdk_base/src/client.rs b/matrix_sdk_base/src/client.rs index 5d35f95c..ce7bcc90 100644 --- a/matrix_sdk_base/src/client.rs +++ b/matrix_sdk_base/src/client.rs @@ -370,7 +370,7 @@ impl BaseClient { .map_err(OlmError::from)?, ); } else if let Some(path) = self.store_path.as_ref() { - #[cfg(feature = "sqlite_cryptostore")] + #[cfg(feature = "sled_cryptostore")] { *olm = Some( OlmMachine::new_with_default_store( @@ -383,7 +383,7 @@ impl BaseClient { .map_err(OlmError::from)?, ); } - #[cfg(not(feature = "sqlite_cryptostore"))] + #[cfg(not(feature = "sled_cryptostore"))] { *olm = Some(OlmMachine::new(&session.user_id, &session.device_id)); } diff --git a/matrix_sdk_crypto/Cargo.toml b/matrix_sdk_crypto/Cargo.toml index 1d24c208..9e025a70 100644 --- a/matrix_sdk_crypto/Cargo.toml +++ b/matrix_sdk_crypto/Cargo.toml @@ -17,6 +17,7 @@ rustdoc-args = ["--cfg", "feature=\"docs\""] [features] default = [] sqlite_cryptostore = ["sqlx"] +sled_cryptostore = ["sled"] docs = ["sqlite_cryptostore"] [dependencies] @@ -29,6 +30,7 @@ serde_json = "1.0.61" zeroize = { version = "1.2.0", features = ["zeroize_derive"] } # Misc dependencies +sled = { version = "0.34.6", optional = true } thiserror = "1.0.23" tracing = "0.1.22" atomic = "0.5.0" diff --git a/matrix_sdk_crypto/src/machine.rs b/matrix_sdk_crypto/src/machine.rs index b7921aed..4abcccd9 100644 --- a/matrix_sdk_crypto/src/machine.rs +++ b/matrix_sdk_crypto/src/machine.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#[cfg(feature = "sqlite_cryptostore")] +#[cfg(feature = "sled_cryptostore")] use std::path::Path; use std::{collections::BTreeMap, mem, sync::Arc}; @@ -44,6 +44,8 @@ use matrix_sdk_common::{ Raw, UInt, }; +#[cfg(feature = "sled_cryptostore")] +use crate::store::sled::SledStore; use crate::{ error::{EventError, MegolmError, MegolmResult, OlmError, OlmResult}, identities::{Device, IdentityManager, UserDevices}, @@ -59,11 +61,9 @@ use crate::{ Changes, CryptoStore, DeviceChanges, IdentityChanges, MemoryStore, Result as StoreResult, Store, }, - verification::{Sas, VerificationMachine}, + verification::{Sas, VerificationMachine, VerificationRequest}, ToDeviceRequest, }; -#[cfg(feature = "sqlite_cryptostore")] -use crate::{store::sqlite::SqliteStore, verification::VerificationRequest}; /// State machine implementation of the Olm/Megolm encryption protocol used for /// Matrix end to end encryption. @@ -258,7 +258,7 @@ impl OlmMachine { /// * `user_id` - The unique id of the user that owns this machine. /// /// * `device_id` - The unique id of the device that owns this machine. - #[cfg(feature = "sqlite_cryptostore")] + #[cfg(feature = "sled_cryptostore")] #[cfg_attr(feature = "docs", doc(cfg(r#sqlite_cryptostore)))] pub async fn new_with_default_store( user_id: &UserId, @@ -266,8 +266,7 @@ impl OlmMachine { path: impl AsRef, passphrase: &str, ) -> StoreResult { - let store = - SqliteStore::open_with_passphrase(&user_id, device_id, path, passphrase).await?; + let store = SledStore::open_with_passphrase(path, passphrase)?; OlmMachine::new_with_store(user_id.to_owned(), device_id.into(), Box::new(store)).await } @@ -1756,9 +1755,11 @@ pub(crate) mod test { } } - #[tokio::test(flavor = "multi_thread")] - #[cfg(feature = "sqlite_cryptostore")] + #[tokio::test] + #[cfg(feature = "sled_cryptostore")] async fn test_machine_with_default_store() { + use tempfile::tempdir; + let tmpdir = tempdir().unwrap(); let machine = OlmMachine::new_with_default_store( diff --git a/matrix_sdk_crypto/src/store/mod.rs b/matrix_sdk_crypto/src/store/mod.rs index 54f27b11..df430719 100644 --- a/matrix_sdk_crypto/src/store/mod.rs +++ b/matrix_sdk_crypto/src/store/mod.rs @@ -40,6 +40,8 @@ pub mod caches; mod memorystore; mod pickle_key; +#[cfg(feature = "sled_cryptostore")] +pub(crate) mod sled; #[cfg(not(target_arch = "wasm32"))] #[cfg(feature = "sqlite_cryptostore")] pub(crate) mod sqlite; diff --git a/matrix_sdk_crypto/src/store/sled.rs b/matrix_sdk_crypto/src/store/sled.rs new file mode 100644 index 00000000..fb5f3e2f --- /dev/null +++ b/matrix_sdk_crypto/src/store/sled.rs @@ -0,0 +1,509 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{ + collections::{HashMap, HashSet}, + convert::TryFrom, + path::Path, + sync::Arc, +}; + +use dashmap::DashSet; +use olm_rs::PicklingMode; +use sled::{ + transaction::{ConflictableTransactionError, TransactionError}, + Config, Db, Transactional, Tree, +}; + +use matrix_sdk_common::{ + async_trait, + identifiers::{DeviceId, DeviceIdBox, RoomId, UserId}, + locks::Mutex, +}; + +use super::{ + caches::SessionStore, Changes, CryptoStore, CryptoStoreError, InboundGroupSession, PickleKey, + ReadOnlyAccount, Result, Session, +}; +use crate::{ + identities::{ReadOnlyDevice, UserIdentities}, + olm::{PickledInboundGroupSession, PickledSession, PrivateCrossSigningIdentity}, +}; + +/// This needs to be 32 bytes long since AES-GCM requires it, otherwise we will +/// panic once we try to pickle a Signing object. +const DEFAULT_PICKLE: &str = "DEFAULT_PICKLE_PASSPHRASE_123456"; + +/// An in-memory only store that will forget all the E2EE key once it's dropped. +#[derive(Debug, Clone)] +pub struct SledStore { + inner: Db, + pickle_key: Arc, + + session_cache: SessionStore, + tracked_users_cache: Arc>, + users_for_key_query_cache: Arc>, + + account: Tree, + private_identity: Tree, + + olm_hashes: Tree, + sessions: Tree, + inbound_group_sessions: Tree, + + devices: Tree, + identities: Tree, + + tracked_users: Tree, + users_for_key_query: Tree, + values: Tree, +} + +impl SledStore { + pub fn open_with_passphrase(path: impl AsRef, passphrase: &str) -> Result { + let path = path.as_ref().join("matrix-sdk-crypto"); + let db = Config::new().temporary(false).path(path).open().unwrap(); + + SledStore::open_helper(db, Some(passphrase)) + } + + fn open_helper(db: Db, passphrase: Option<&str>) -> Result { + let account = db.open_tree("account").unwrap(); + let private_identity = db.open_tree("private_identity").unwrap(); + + let sessions = db.open_tree("session").unwrap(); + let inbound_group_sessions = db.open_tree("inbound_group_sessions").unwrap(); + let tracked_users = db.open_tree("tracked_users").unwrap(); + let users_for_key_query = db.open_tree("users_for_key_query").unwrap(); + let olm_hashes = db.open_tree("olm_hashes").unwrap(); + + let devices = db.open_tree("devices").unwrap(); + let identities = db.open_tree("identities").unwrap(); + let values = db.open_tree("values").unwrap(); + + let session_cache = SessionStore::new(); + + let pickle_key = if let Some(passphrase) = passphrase { + Self::get_or_create_pickle_key(&passphrase, &db)? + } else { + PickleKey::try_from(DEFAULT_PICKLE.as_bytes().to_vec()) + .expect("Can't create default pickle key") + }; + + Ok(Self { + inner: db, + pickle_key: pickle_key.into(), + account, + private_identity, + sessions, + session_cache, + tracked_users_cache: DashSet::new().into(), + users_for_key_query_cache: DashSet::new().into(), + inbound_group_sessions, + devices, + tracked_users, + users_for_key_query, + olm_hashes, + identities, + values, + }) + } + + fn get_or_create_pickle_key(passphrase: &str, database: &Db) -> Result { + let key = if let Some(key) = database + .get("pickle_key") + .unwrap() + .map(|v| serde_json::from_slice(&v)) + { + PickleKey::from_encrypted(passphrase, key?) + .map_err(|_| CryptoStoreError::UnpicklingError)? + } else { + let key = PickleKey::new(); + let encrypted = key.encrypt(passphrase); + database + .insert("pickle_key", serde_json::to_vec(&encrypted)?) + .unwrap(); + key + }; + + Ok(key) + } + + fn get_pickle_mode(&self) -> PicklingMode { + self.pickle_key.pickle_mode() + } + + fn get_pickle_key(&self) -> &[u8] { + self.pickle_key.key() + } + + async fn load_tracked_users(&self) { + for value in self.tracked_users.iter() { + let (user, dirty) = value.unwrap(); + let user = UserId::try_from(String::from_utf8_lossy(&user).to_string()).unwrap(); + let dirty = dirty.get(0).map(|d| *d == 1).unwrap_or(true); + + self.tracked_users_cache.insert(user.clone()); + + if dirty { + self.users_for_key_query_cache.insert(user); + } + } + } + + pub async fn save_changes(&self, changes: Changes) -> Result<()> { + let account_pickle = if let Some(a) = changes.account { + Some(a.pickle(self.get_pickle_mode()).await) + } else { + None + }; + + let private_identity_pickle = if let Some(i) = changes.private_identity { + Some(i.pickle(DEFAULT_PICKLE.as_bytes()).await.unwrap()) + } else { + None + }; + + let device_changes = changes.devices; + let mut session_changes = HashMap::new(); + + for session in changes.sessions { + let sender_key = session.sender_key(); + let session_id = session.session_id(); + + let pickle = session.pickle(self.get_pickle_mode()).await; + let key = format!("{}{}", sender_key, session_id); + + self.session_cache.add(session).await; + session_changes.insert(key, pickle); + } + + let mut inbound_session_changes = HashMap::new(); + + for session in changes.inbound_group_sessions { + let room_id = session.room_id(); + let sender_key = session.sender_key(); + let session_id = session.session_id(); + let key = format!("{}{}{}", room_id, sender_key, session_id); + let pickle = session.pickle(self.get_pickle_mode()).await; + + inbound_session_changes.insert(key, pickle); + } + + let identity_changes = changes.identities; + let olm_hashes = changes.message_hashes; + + let ret: std::result::Result<(), TransactionError> = ( + &self.account, + &self.private_identity, + &self.devices, + &self.identities, + &self.sessions, + &self.inbound_group_sessions, + &self.olm_hashes, + ) + .transaction( + |( + account, + private_identity, + devices, + identities, + sessions, + inbound_sessions, + hashes, + )| { + if let Some(a) = &account_pickle { + account.insert( + "account", + serde_json::to_vec(a).map_err(ConflictableTransactionError::Abort)?, + )?; + } + + if let Some(i) = &private_identity_pickle { + private_identity.insert( + "identity", + serde_json::to_vec(&i).map_err(ConflictableTransactionError::Abort)?, + )?; + } + + for device in device_changes.new.iter().chain(&device_changes.changed) { + let key = format!("{}{}", device.user_id(), device.device_id()); + let device = serde_json::to_vec(&device) + .map_err(ConflictableTransactionError::Abort)?; + devices.insert(key.as_str(), device)?; + } + + for device in &device_changes.deleted { + let key = format!("{}{}", device.user_id(), device.device_id()); + devices.remove(key.as_str())?; + } + + for identity in identity_changes.changed.iter().chain(&identity_changes.new) { + identities.insert( + identity.user_id().as_str(), + serde_json::to_vec(&identity) + .map_err(ConflictableTransactionError::Abort)?, + )?; + } + + for (key, session) in &session_changes { + sessions.insert( + key.as_str(), + serde_json::to_vec(&session) + .map_err(ConflictableTransactionError::Abort)?, + )?; + } + + for (key, session) in &inbound_session_changes { + inbound_sessions.insert( + key.as_str(), + serde_json::to_vec(&session) + .map_err(ConflictableTransactionError::Abort)?, + )?; + } + + for hash in &olm_hashes { + hashes.insert( + serde_json::to_vec(&hash) + .map_err(ConflictableTransactionError::Abort)?, + &[0], + )?; + } + + Ok(()) + }, + ); + + if let Err(e) = ret { + match e { + TransactionError::Abort(e) => return Err(e.into()), + TransactionError::Storage(e) => panic!("Internal sled error {:?}", e), + } + } + + self.inner.flush_async().await.unwrap(); + + Ok(()) + } +} + +#[async_trait] +impl CryptoStore for SledStore { + async fn load_account(&self) -> Result> { + if let Some(pickle) = self.account.get("account").unwrap() { + let pickle = serde_json::from_slice(&pickle)?; + + self.load_tracked_users().await; + + Ok(Some(ReadOnlyAccount::from_pickle( + pickle, + self.get_pickle_mode(), + )?)) + } else { + Ok(None) + } + } + + async fn save_account(&self, account: ReadOnlyAccount) -> Result<()> { + let pickle = account.pickle(self.get_pickle_mode()).await; + self.account + .insert("account", serde_json::to_vec(&pickle)?) + .unwrap(); + + Ok(()) + } + + async fn save_changes(&self, changes: Changes) -> Result<()> { + self.save_changes(changes).await + } + + async fn get_sessions(&self, sender_key: &str) -> Result>>>> { + let account = self + .load_account() + .await? + .ok_or(CryptoStoreError::AccountUnset)?; + + if self.session_cache.get(sender_key).is_none() { + let sessions: std::result::Result, _> = self + .sessions + .scan_prefix(sender_key) + .map(|s| serde_json::from_slice(&s.unwrap().1)) + .collect(); + + let sessions: std::result::Result, _> = sessions? + .into_iter() + .map(|p| { + Session::from_pickle( + account.user_id.clone(), + account.device_id.clone(), + account.identity_keys.clone(), + p, + self.get_pickle_mode(), + ) + }) + .collect(); + + self.session_cache.set_for_sender(sender_key, sessions?); + } + + Ok(self.session_cache.get(sender_key)) + } + + async fn get_inbound_group_session( + &self, + room_id: &RoomId, + sender_key: &str, + session_id: &str, + ) -> Result> { + let key = format!("{}{}{}", room_id, sender_key, session_id); + let pickle = self + .inbound_group_sessions + .get(&key) + .unwrap() + .map(|p| serde_json::from_slice(&p)); + + if let Some(pickle) = pickle { + Ok(Some(InboundGroupSession::from_pickle( + pickle?, + self.get_pickle_mode(), + )?)) + } else { + Ok(None) + } + } + + async fn get_inbound_group_sessions(&self) -> Result> { + let pickles: std::result::Result, _> = self + .inbound_group_sessions + .iter() + .map(|p| serde_json::from_slice(&p.unwrap().1)) + .collect(); + + Ok(pickles? + .into_iter() + .filter_map(|p| InboundGroupSession::from_pickle(p, self.get_pickle_mode()).ok()) + .collect()) + } + + fn users_for_key_query(&self) -> HashSet { + #[allow(clippy::map_clone)] + self.users_for_key_query_cache + .iter() + .map(|u| u.clone()) + .collect() + } + + fn is_user_tracked(&self, user_id: &UserId) -> bool { + self.tracked_users_cache.contains(user_id) + } + + fn has_users_for_key_query(&self) -> bool { + !self.users_for_key_query_cache.is_empty() + } + + async fn update_tracked_user(&self, user: &UserId, dirty: bool) -> Result { + let already_added = self.tracked_users_cache.insert(user.clone()); + + if dirty { + self.users_for_key_query_cache.insert(user.clone()); + } else { + self.users_for_key_query_cache.remove(user); + } + + self.tracked_users + .insert(user.as_str(), &[dirty as u8]) + .unwrap(); + + Ok(already_added) + } + + async fn get_device( + &self, + user_id: &UserId, + device_id: &DeviceId, + ) -> Result> { + let key = format!("{}{}", user_id, device_id); + + if let Some(d) = self.devices.get(key).unwrap() { + Ok(Some(serde_json::from_slice(&d)?)) + } else { + Ok(None) + } + } + + async fn get_user_devices( + &self, + user_id: &UserId, + ) -> Result> { + let devices: std::result::Result, _> = self + .devices + .scan_prefix(user_id.as_str()) + .map(|d| serde_json::from_slice(&d.unwrap().1)) + .collect(); + + Ok(devices? + .into_iter() + .map(|d| (d.device_id().to_owned(), d)) + .collect()) + } + + async fn get_user_identity(&self, user_id: &UserId) -> Result> { + Ok(self + .identities + .get(user_id.as_str()) + .unwrap() + .map(|i| serde_json::from_slice(&i).unwrap())) + } + + async fn save_value(&self, key: String, value: String) -> Result<()> { + self.values.insert(key.as_str(), value.as_str()).unwrap(); + Ok(()) + } + + async fn remove_value(&self, key: &str) -> Result<()> { + self.values.remove(key).unwrap(); + Ok(()) + } + + async fn get_value(&self, key: &str) -> Result> { + Ok(self + .values + .get(key) + .unwrap() + .map(|v| String::from_utf8_lossy(&v).to_string())) + } + + async fn load_identity(&self) -> Result> { + if let Some(i) = self.private_identity.get("identity").unwrap() { + let pickle = serde_json::from_slice(&i)?; + Ok(Some( + PrivateCrossSigningIdentity::from_pickle(pickle, self.get_pickle_key()) + .await + .unwrap(), + )) + } else { + Ok(None) + } + } + + async fn is_message_known(&self, message_hash: &crate::olm::OlmMessageHash) -> Result { + Ok(self + .olm_hashes + .contains_key(serde_json::to_vec(message_hash)?) + .unwrap()) + } +} + +#[cfg(test)] +mod test {}