From bf4f32eccfef9891a4ff263652b76bc0456822b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Damir=20Jeli=C4=87?= Date: Wed, 27 Jan 2021 15:29:42 +0100 Subject: [PATCH] crypto: Remove the sqlite store for now --- matrix_sdk/src/lib.rs | 2 +- matrix_sdk_base/src/lib.rs | 2 +- matrix_sdk_crypto/Cargo.toml | 9 +- matrix_sdk_crypto/src/machine.rs | 6 +- matrix_sdk_crypto/src/store/mod.rs | 20 +- matrix_sdk_crypto/src/store/sqlite.rs | 2469 ------------------------- 6 files changed, 6 insertions(+), 2502 deletions(-) delete mode 100644 matrix_sdk_crypto/src/store/sqlite.rs diff --git a/matrix_sdk/src/lib.rs b/matrix_sdk/src/lib.rs index a8566864..8528e920 100644 --- a/matrix_sdk/src/lib.rs +++ b/matrix_sdk/src/lib.rs @@ -37,7 +37,7 @@ //! The following crate feature flags are available: //! //! * `encryption`: Enables end-to-end encryption support in the library. -//! * `sqlite_cryptostore`: Enables a SQLite based store for the encryption +//! * `sled_cryptostore`: Enables a Sled based store for the encryption //! keys. If this is disabled and `encryption` support is enabled the keys will //! by default be stored only in memory and thus lost after the client is //! destroyed. diff --git a/matrix_sdk_base/src/lib.rs b/matrix_sdk_base/src/lib.rs index e10b6fdb..2f27b8e4 100644 --- a/matrix_sdk_base/src/lib.rs +++ b/matrix_sdk_base/src/lib.rs @@ -20,7 +20,7 @@ //! The following crate feature flags are available: //! //! * `encryption`: Enables end-to-end encryption support in the library. -//! * `sqlite_cryptostore`: Enables a SQLite based store for the encryption +//! * `sled_cryptostore`: Enables a Sled based store for the encryption //! keys. If this is disabled and `encryption` support is enabled the keys will //! by default be stored only in memory and thus lost after the client is //! destroyed. diff --git a/matrix_sdk_crypto/Cargo.toml b/matrix_sdk_crypto/Cargo.toml index 9b53e062..a14cf0f6 100644 --- a/matrix_sdk_crypto/Cargo.toml +++ b/matrix_sdk_crypto/Cargo.toml @@ -16,9 +16,8 @@ rustdoc-args = ["--cfg", "feature=\"docs\""] [features] default = [] -sqlite_cryptostore = ["sqlx"] sled_cryptostore = ["sled"] -docs = ["sqlite_cryptostore"] +docs = ["sled_cryptostore"] [dependencies] matrix-sdk-common = { version = "0.2.0", path = "../matrix_sdk_common" } @@ -43,12 +42,6 @@ hmac = "0.10.1" base64 = "0.13.0" byteorder = "1.4.2" -[target.'cfg(not(target_arch = "wasm32"))'.dependencies.sqlx] -version = "0.4.2" -optional = true -default-features = false -features = ["runtime-tokio-native-tls", "sqlite", "macros"] - [dev-dependencies] tokio = { version = "1.1.0", default-features = false, features = ["rt-multi-thread", "macros"] } futures = "0.3.12" diff --git a/matrix_sdk_crypto/src/machine.rs b/matrix_sdk_crypto/src/machine.rs index f6bdf31c..b8405aa4 100644 --- a/matrix_sdk_crypto/src/machine.rs +++ b/matrix_sdk_crypto/src/machine.rs @@ -252,7 +252,7 @@ impl OlmMachine { /// Create a new machine with the default crypto store. /// - /// The default store uses a SQLite database to store the encryption keys. + /// The default store uses a Sled database to store the encryption keys. /// /// # Arguments /// @@ -260,7 +260,7 @@ impl OlmMachine { /// /// * `device_id` - The unique id of the device that owns this machine. #[cfg(feature = "sled_cryptostore")] - #[cfg_attr(feature = "docs", doc(cfg(r#sqlite_cryptostore)))] + #[cfg_attr(feature = "docs", doc(cfg(sled_cryptostore)))] pub async fn new_with_default_store( user_id: &UserId, device_id: &DeviceId, @@ -1162,8 +1162,6 @@ pub(crate) mod test { use http::Response; use serde_json::json; - #[cfg(feature = "sqlite_cryptostore")] - use tempfile::tempdir; use crate::{ machine::OlmMachine, diff --git a/matrix_sdk_crypto/src/store/mod.rs b/matrix_sdk_crypto/src/store/mod.rs index 06dd52dd..b72e5f01 100644 --- a/matrix_sdk_crypto/src/store/mod.rs +++ b/matrix_sdk_crypto/src/store/mod.rs @@ -17,7 +17,7 @@ //! The storage layer for the [`OlmMachine`] can be customized using a trait. //! Implementing your own [`CryptoStore`] //! -//! An in-memory only store is provided as well as a SQLite based one, depending +//! An in-memory only store is provided as well as a Sled based one, depending //! on your needs and targets a custom store may be implemented, e.g. for //! `wasm-unknown-unknown` an indexeddb store would be needed //! @@ -42,17 +42,11 @@ mod memorystore; mod pickle_key; #[cfg(feature = "sled_cryptostore")] pub(crate) mod sled; -#[cfg(not(target_arch = "wasm32"))] -#[cfg(feature = "sqlite_cryptostore")] -pub(crate) mod sqlite; #[cfg(feature = "sled_cryptostore")] pub use self::sled::SledStore; pub use memorystore::MemoryStore; pub use pickle_key::{EncryptedPickleKey, PickleKey}; -#[cfg(not(target_arch = "wasm32"))] -#[cfg(feature = "sqlite_cryptostore")] -pub use sqlite::SqliteStore; use std::{ collections::{HashMap, HashSet}, @@ -67,11 +61,6 @@ use serde::{Deserialize, Serialize}; use serde_json::Error as SerdeError; use thiserror::Error; -#[cfg_attr(feature = "docs", doc(cfg(r#sqlite_cryptostore)))] -#[cfg(not(target_arch = "wasm32"))] -#[cfg(feature = "sqlite_cryptostore")] -use sqlx::Error as SqlxError; - use matrix_sdk_common::{ async_trait, identifiers::{ @@ -295,13 +284,6 @@ pub enum CryptoStoreError { #[error("can't save/load sessions or group sessions in the store before an account is stored")] AccountUnset, - /// SQL error occurred. - // TODO flatten the SqlxError to make it easier for other store - // implementations. - #[cfg(feature = "sqlite_cryptostore")] - #[error(transparent)] - Database(#[from] SqlxError), - /// Error in the internal database #[cfg(feature = "sled_cryptostore")] #[error(transparent)] diff --git a/matrix_sdk_crypto/src/store/sqlite.rs b/matrix_sdk_crypto/src/store/sqlite.rs deleted file mode 100644 index bedbe4dd..00000000 --- a/matrix_sdk_crypto/src/store/sqlite.rs +++ /dev/null @@ -1,2469 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::{ - collections::{BTreeMap, HashMap, HashSet}, - convert::TryFrom, - path::Path, - result::Result as StdResult, - sync::{Arc, Mutex as SyncMutex}, -}; - -use dashmap::DashSet; -use matrix_sdk_common::{ - api::r0::keys::{CrossSigningKey, KeyUsage}, - async_trait, - identifiers::{ - DeviceId, DeviceIdBox, DeviceKeyAlgorithm, DeviceKeyId, EventEncryptionAlgorithm, RoomId, - UserId, - }, - instant::Duration, - locks::Mutex, -}; -use sqlx::{query, query_as, sqlite::SqliteConnectOptions, Connection, Executor, SqliteConnection}; - -use super::{ - caches::SessionStore, - pickle_key::{EncryptedPickleKey, PickleKey}, - Changes, CryptoStore, CryptoStoreError, Result, -}; -use crate::{ - identities::{LocalTrust, OwnUserIdentity, ReadOnlyDevice, UserIdentities, UserIdentity}, - olm::{ - AccountPickle, IdentityKeys, InboundGroupSession, InboundGroupSessionPickle, - OlmMessageHash, PickledAccount, PickledCrossSigningIdentity, PickledInboundGroupSession, - PickledSession, PicklingMode, PrivateCrossSigningIdentity, ReadOnlyAccount, Session, - SessionPickle, - }, -}; - -/// This needs to be 32 bytes long since AES-GCM requires it, otherwise we will -/// panic once we try to pickle a Signing object. -const DEFAULT_PICKLE: &str = "DEFAULT_PICKLE_PASSPHRASE_123456"; - -/// SQLite based implementation of a `CryptoStore`. -#[derive(Clone)] -#[cfg_attr(feature = "docs", doc(cfg(r#sqlite_cryptostore)))] -pub struct SqliteStore { - user_id: Arc, - device_id: Arc>, - account_info: Arc>>, - path: Arc, - - sessions: SessionStore, - tracked_users: Arc>, - users_for_key_query: Arc>, - - connection: Arc>, - pickle_key: Arc, -} - -#[derive(Clone)] -struct AccountInfo { - account_id: i64, - identity_keys: Arc, -} - -#[derive(Debug, PartialEq, Copy, Clone, sqlx::Type)] -#[repr(i32)] -enum CrosssigningKeyType { - Master = 0, - SelfSigning = 1, - UserSigning = 2, -} - -static DATABASE_NAME: &str = "matrix-sdk-crypto.db"; - -impl SqliteStore { - /// Open a new `SqliteStore`. - /// - /// # Arguments - /// - /// * `user_id` - The unique id of the user for which the store should be - /// opened. - /// - /// * `device_id` - The unique id of the device for which the store should - /// be opened. - /// - /// * `path` - The path where the database file should reside in. - pub async fn open>( - user_id: &UserId, - device_id: &DeviceId, - path: P, - ) -> Result { - SqliteStore::open_helper(user_id, device_id, path, None).await - } - - /// Open a new `SqliteStore`. - /// - /// # Arguments - /// - /// * `user_id` - The unique id of the user for which the store should be - /// opened. - /// - /// * `device_id` - The unique id of the device for which the store should - /// be opened. - /// - /// * `path` - The path where the database file should reside in. - /// - /// * `passphrase` - The passphrase that should be used to securely store - /// the encryption keys. - pub async fn open_with_passphrase>( - user_id: &UserId, - device_id: &DeviceId, - path: P, - passphrase: &str, - ) -> Result { - SqliteStore::open_helper(user_id, device_id, path, Some(passphrase)).await - } - - async fn open_helper>( - user_id: &UserId, - device_id: &DeviceId, - path: P, - passphrase: Option<&str>, - ) -> Result { - let path = path.as_ref().join(DATABASE_NAME); - let options = SqliteConnectOptions::new() - .foreign_keys(true) - .create_if_missing(true) - .read_only(false) - .filename(&path); - - let mut connection = SqliteConnection::connect_with(&options).await?; - Self::create_tables(&mut connection).await?; - - let pickle_key = if let Some(passphrase) = passphrase { - Self::get_or_create_pickle_key(user_id, device_id, &passphrase, &mut connection).await? - } else { - PickleKey::try_from(DEFAULT_PICKLE.as_bytes().to_vec()) - .expect("Can't create default pickle key") - }; - - let store = SqliteStore { - user_id: Arc::new(user_id.to_owned()), - device_id: Arc::new(device_id.into()), - account_info: Arc::new(SyncMutex::new(None)), - sessions: SessionStore::new(), - path: path.into(), - connection: Arc::new(Mutex::new(connection)), - tracked_users: Arc::new(DashSet::new()), - users_for_key_query: Arc::new(DashSet::new()), - pickle_key: Arc::new(pickle_key), - }; - - Ok(store) - } - - fn account_id(&self) -> Option { - self.account_info - .lock() - .unwrap() - .as_ref() - .map(|i| i.account_id) - } - - async fn create_tables(connection: &mut SqliteConnection) -> Result<()> { - connection - .execute( - r#" - CREATE TABLE IF NOT EXISTS accounts ( - "id" INTEGER NOT NULL PRIMARY KEY, - "user_id" TEXT NOT NULL, - "device_id" TEXT NOT NULL, - "pickle" BLOB NOT NULL, - "shared" INTEGER NOT NULL, - "uploaded_key_count" INTEGER NOT NULL, - UNIQUE(user_id,device_id) - ); - "#, - ) - .await?; - - connection - .execute( - r#" - CREATE TABLE IF NOT EXISTS private_identities ( - "id" INTEGER NOT NULL PRIMARY KEY, - "account_id" INTEGER NOT NULL, - "user_id" TEXT NOT NULL, - "pickle" TEXT NOT NULL, - "shared" INTEGER NOT NULL, - FOREIGN KEY ("account_id") REFERENCES "accounts" ("id") - ON DELETE CASCADE - UNIQUE(account_id, user_id) - ); - "#, - ) - .await?; - - connection - .execute( - r#" - CREATE TABLE IF NOT EXISTS pickle_keys ( - "id" INTEGER NOT NULL PRIMARY KEY, - "user_id" TEXT NOT NULL, - "device_id" TEXT NOT NULL, - "key" TEXT NOT NULL, - UNIQUE(user_id,device_id) - ); - "#, - ) - .await?; - - connection - .execute( - r#" - CREATE TABLE IF NOT EXISTS sessions ( - "session_id" TEXT NOT NULL PRIMARY KEY, - "account_id" INTEGER NOT NULL, - "creation_time" TEXT NOT NULL, - "last_use_time" TEXT NOT NULL, - "sender_key" TEXT NOT NULL, - "pickle" BLOB NOT NULL, - FOREIGN KEY ("account_id") REFERENCES "accounts" ("id") - ON DELETE CASCADE - ); - - CREATE INDEX IF NOT EXISTS "olmsessions_account_id" ON "sessions" ("account_id"); - "#, - ) - .await?; - - connection - .execute( - r#" - CREATE TABLE IF NOT EXISTS tracked_users ( - "id" INTEGER NOT NULL PRIMARY KEY, - "account_id" INTEGER NOT NULL, - "user_id" TEXT NOT NULL, - "dirty" INTEGER NOT NULL, - FOREIGN KEY ("account_id") REFERENCES "accounts" ("id") - ON DELETE CASCADE - UNIQUE(account_id,user_id) - ); - - CREATE INDEX IF NOT EXISTS "tracked_users_account_id" ON "tracked_users" ("account_id"); - "#, - ) - .await?; - - connection - .execute( - r#" - CREATE TABLE IF NOT EXISTS inbound_group_sessions ( - "id" INTEGER NOT NULL PRIMARY KEY, - "session_id" TEXT NOT NULL, - "account_id" INTEGER NOT NULL, - "sender_key" TEXT NOT NULL, - "room_id" TEXT NOT NULL, - "pickle" BLOB NOT NULL, - "imported" INTEGER NOT NULL, - FOREIGN KEY ("account_id") REFERENCES "accounts" ("id") - ON DELETE CASCADE - UNIQUE(account_id,session_id,sender_key) - ); - - CREATE INDEX IF NOT EXISTS "olm_groups_sessions_account_id" ON "inbound_group_sessions" ("account_id"); - "#, - ) - .await?; - - connection - .execute( - r#" - CREATE TABLE IF NOT EXISTS group_session_claimed_keys ( - "id" INTEGER NOT NULL PRIMARY KEY, - "session_id" INTEGER NOT NULL, - "algorithm" TEXT NOT NULL, - "key" TEXT NOT NULL, - FOREIGN KEY ("session_id") REFERENCES "inbound_group_sessions" ("id") - ON DELETE CASCADE - UNIQUE(session_id, algorithm) - ); - - CREATE INDEX IF NOT EXISTS "group_session_claimed_keys_session_id" ON "group_session_claimed_keys" ("session_id"); - "#, - ) - .await?; - - connection - .execute( - r#" - CREATE TABLE IF NOT EXISTS group_session_chains ( - "id" INTEGER NOT NULL PRIMARY KEY, - "key" TEXT NOT NULL, - "session_id" INTEGER NOT NULL, - FOREIGN KEY ("session_id") REFERENCES "inbound_group_sessions" ("id") - ON DELETE CASCADE - UNIQUE(session_id, key) - ); - - CREATE INDEX IF NOT EXISTS "group_session_chains_session_id" ON "group_session_chains" ("session_id"); - "#, - ) - .await?; - - connection - .execute( - r#" - CREATE TABLE IF NOT EXISTS devices ( - "id" INTEGER NOT NULL PRIMARY KEY, - "account_id" INTEGER NOT NULL, - "user_id" TEXT NOT NULL, - "device_id" TEXT NOT NULL, - "display_name" TEXT, - "trust_state" INTEGER NOT NULL, - FOREIGN KEY ("account_id") REFERENCES "accounts" ("id") - ON DELETE CASCADE - UNIQUE(account_id,user_id,device_id) - ); - - CREATE INDEX IF NOT EXISTS "devices_account_id" ON "devices" ("account_id"); - "#, - ) - .await?; - - connection - .execute( - r#" - CREATE TABLE IF NOT EXISTS algorithms ( - "id" INTEGER NOT NULL PRIMARY KEY, - "device_id" INTEGER NOT NULL, - "algorithm" TEXT NOT NULL, - FOREIGN KEY ("device_id") REFERENCES "devices" ("id") - ON DELETE CASCADE - UNIQUE(device_id, algorithm) - ); - - CREATE INDEX IF NOT EXISTS "algorithms_device_id" ON "algorithms" ("device_id"); - "#, - ) - .await?; - - connection - .execute( - r#" - CREATE TABLE IF NOT EXISTS device_keys ( - "id" INTEGER NOT NULL PRIMARY KEY, - "device_id" INTEGER NOT NULL, - "algorithm" TEXT NOT NULL, - "key" TEXT NOT NULL, - FOREIGN KEY ("device_id") REFERENCES "devices" ("id") - ON DELETE CASCADE - UNIQUE(device_id, algorithm) - ); - - CREATE INDEX IF NOT EXISTS "device_keys_device_id" ON "device_keys" ("device_id"); - "#, - ) - .await?; - - connection - .execute( - r#" - CREATE TABLE IF NOT EXISTS device_signatures ( - "id" INTEGER NOT NULL PRIMARY KEY, - "device_id" INTEGER NOT NULL, - "user_id" TEXT NOT NULL, - "key_algorithm" TEXT NOT NULL, - "signature" TEXT NOT NULL, - FOREIGN KEY ("device_id") REFERENCES "devices" ("id") - ON DELETE CASCADE - UNIQUE(device_id, user_id, key_algorithm) - ); - - CREATE INDEX IF NOT EXISTS "device_keys_device_id" ON "device_keys" ("device_id"); - "#, - ) - .await?; - - connection - .execute( - r#" - CREATE TABLE IF NOT EXISTS users ( - "id" INTEGER NOT NULL PRIMARY KEY, - "account_id" INTEGER NOT NULL, - "user_id" TEXT NOT NULL, - FOREIGN KEY ("account_id") REFERENCES "accounts" ("id") - ON DELETE CASCADE - UNIQUE(account_id,user_id) - ); - - CREATE INDEX IF NOT EXISTS "users_account_id" ON "users" ("account_id"); - "#, - ) - .await?; - - connection - .execute( - r#" - CREATE TABLE IF NOT EXISTS users_trust_state ( - "id" INTEGER NOT NULL PRIMARY KEY, - "trusted" INTEGER NOT NULL, - "user_id" INTEGER NOT NULL, - FOREIGN KEY ("user_id") REFERENCES "users" ("id") - ON DELETE CASCADE - UNIQUE(user_id) - ); - "#, - ) - .await?; - - connection - .execute( - r#" - CREATE TABLE IF NOT EXISTS cross_signing_keys ( - "id" INTEGER NOT NULL PRIMARY KEY, - "key_type" INTEGER NOT NULL, - "usage" STRING NOT NULL, - "user_id" INTEGER NOT NULL, - FOREIGN KEY ("user_id") REFERENCES "users" ("id") ON DELETE CASCADE - UNIQUE(user_id, key_type) - ); - - CREATE INDEX IF NOT EXISTS "cross_signing_keys_users" ON "users" ("user_id"); - "#, - ) - .await?; - - connection - .execute( - r#" - CREATE TABLE IF NOT EXISTS user_keys ( - "id" INTEGER NOT NULL PRIMARY KEY, - "key" TEXT NOT NULL, - "key_id" TEXT NOT NULL, - "cross_signing_key" INTEGER NOT NULL, - FOREIGN KEY ("cross_signing_key") REFERENCES "cross_signing_keys" ("id") ON DELETE CASCADE - UNIQUE(cross_signing_key, key_id) - ); - - CREATE INDEX IF NOT EXISTS "cross_signing_keys_keys" ON "cross_signing_keys" ("cross_signing_key"); - "#, - ) - .await?; - - connection - .execute( - r#" - CREATE TABLE IF NOT EXISTS user_key_signatures ( - "id" INTEGER NOT NULL PRIMARY KEY, - "user_id" TEXT NOT NULL, - "key_id" INTEGER NOT NULL, - "signature" TEXT NOT NULL, - "cross_signing_key" INTEGER NOT NULL, - FOREIGN KEY ("cross_signing_key") REFERENCES "cross_signing_keys" ("id") - ON DELETE CASCADE - UNIQUE(user_id, key_id, cross_signing_key) - ); - - CREATE INDEX IF NOT EXISTS "cross_signing_keys_signatures" ON "cross_signing_keys" ("cross_signing_key"); - "#, - ) - .await?; - - connection - .execute( - r#" - CREATE TABLE IF NOT EXISTS key_value ( - "id" INTEGER NOT NULL PRIMARY KEY, - "account_id" INTEGER NOT NULL, - "key" TEXT NOT NULL, - "value" TEXT NOT NULL, - FOREIGN KEY ("account_id") REFERENCES "accounts" ("id") - ON DELETE CASCADE - UNIQUE(account_id,key) - ); - - CREATE INDEX IF NOT EXISTS "key_values_index" ON "key_value" ("account_id"); - "#, - ) - .await?; - - connection - .execute( - r#" - CREATE TABLE IF NOT EXISTS olm_hashes ( - "id" INTEGER NOT NULL PRIMARY KEY, - "account_id" INTEGER NOT NULL, - "sender_key" TEXT NOT NULL, - "hash" TEXT NOT NULL, - FOREIGN KEY ("account_id") REFERENCES "accounts" ("id") - ON DELETE CASCADE - UNIQUE(account_id,sender_key,hash) - ); - - CREATE INDEX IF NOT EXISTS "olm_hashes_index" ON "olm_hashes" ("account_id"); - "#, - ) - .await?; - - Ok(()) - } - - async fn save_pickle_key( - user_id: &UserId, - device_id: &DeviceId, - key: EncryptedPickleKey, - connection: &mut SqliteConnection, - ) -> Result<()> { - let key = serde_json::to_string(&key)?; - - query( - "INSERT INTO pickle_keys ( - user_id, device_id, key - ) VALUES (?1, ?2, ?3) - ON CONFLICT(user_id, device_id) DO UPDATE SET - key = excluded.key - ", - ) - .bind(user_id.as_str()) - .bind(device_id.as_str()) - .bind(key) - .execute(&mut *connection) - .await?; - - Ok(()) - } - - async fn get_or_create_pickle_key( - user_id: &UserId, - device_id: &DeviceId, - passphrase: &str, - connection: &mut SqliteConnection, - ) -> Result { - let row: Option<(String,)> = - query_as("SELECT key FROM pickle_keys WHERE user_id = ? and device_id = ?") - .bind(user_id.as_str()) - .bind(device_id.as_str()) - .fetch_optional(&mut *connection) - .await?; - - Ok(if let Some(row) = row { - let encrypted: EncryptedPickleKey = serde_json::from_str(&row.0)?; - PickleKey::from_encrypted(passphrase, encrypted) - .map_err(|_| CryptoStoreError::UnpicklingError)? - } else { - let key = PickleKey::new(); - let encrypted = key.encrypt(passphrase); - Self::save_pickle_key(user_id, device_id, encrypted, connection).await?; - - key - }) - } - - async fn lazy_load_sessions( - &self, - connection: &mut SqliteConnection, - sender_key: &str, - ) -> Result<()> { - let loaded_sessions = self.sessions.get(sender_key).is_some(); - - if !loaded_sessions { - let sessions = self - .load_sessions_for_helper(connection, sender_key) - .await?; - - if !sessions.is_empty() { - self.sessions.set_for_sender(sender_key, sessions); - } - } - - Ok(()) - } - - async fn get_sessions_for( - &self, - connection: &mut SqliteConnection, - sender_key: &str, - ) -> Result>>>> { - self.lazy_load_sessions(connection, sender_key).await?; - Ok(self.sessions.get(sender_key)) - } - - #[cfg(test)] - async fn load_sessions_for(&self, sender_key: &str) -> Result> { - let mut connection = self.connection.lock().await; - self.load_sessions_for_helper(&mut connection, sender_key) - .await - } - - async fn load_sessions_for_helper( - &self, - connection: &mut SqliteConnection, - sender_key: &str, - ) -> Result> { - let account_info = self - .account_info - .lock() - .unwrap() - .clone() - .ok_or(CryptoStoreError::AccountUnset)?; - let mut rows: Vec<(String, String, String, String)> = query_as( - "SELECT pickle, sender_key, creation_time, last_use_time - FROM sessions WHERE account_id = ? and sender_key = ?", - ) - .bind(account_info.account_id) - .bind(sender_key) - .fetch_all(&mut *connection) - .await?; - - Ok(rows - .drain(..) - .map(|row| { - let pickle = row.0; - let sender_key = row.1; - let creation_time = serde_json::from_str::(&row.2)?; - let last_use_time = serde_json::from_str::(&row.3)?; - - let pickle = PickledSession { - pickle: SessionPickle::from(pickle), - last_use_time, - creation_time, - sender_key, - }; - - Ok(Session::from_pickle( - self.user_id.clone(), - self.device_id.clone(), - account_info.identity_keys.clone(), - pickle, - self.get_pickle_mode(), - )?) - }) - .collect::>>()?) - } - - async fn load_inbound_session_data( - &self, - connection: &mut SqliteConnection, - session_row_id: i64, - pickle: String, - sender_key: String, - room_id: RoomId, - imported: bool, - ) -> Result { - let key_rows: Vec<(String, String)> = - query_as("SELECT algorithm, key FROM group_session_claimed_keys WHERE session_id = ?") - .bind(session_row_id) - .fetch_all(&mut *connection) - .await?; - - let claimed_keys: BTreeMap = key_rows - .into_iter() - .filter_map(|row| { - let algorithm = DeviceKeyAlgorithm::try_from(row.0).ok()?; - let key = row.1; - - Some((algorithm, key)) - }) - .collect(); - - let mut chain_rows: Vec<(String,)> = - query_as("SELECT key, key FROM group_session_chains WHERE session_id = ?") - .bind(session_row_id) - .fetch_all(&mut *connection) - .await?; - - let chains: Vec = chain_rows.drain(..).map(|r| r.0).collect(); - - let chains = if chains.is_empty() { - None - } else { - Some(chains) - }; - - let pickle = PickledInboundGroupSession { - pickle: InboundGroupSessionPickle::from(pickle), - sender_key, - signing_key: claimed_keys, - room_id, - forwarding_chains: chains, - imported, - }; - - Ok(InboundGroupSession::from_pickle( - pickle, - self.get_pickle_mode(), - )?) - } - - async fn load_inbound_group_session_helper( - &self, - room_id: &RoomId, - sender_key: &str, - session_id: &str, - ) -> Result> { - let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; - let mut connection = self.connection.lock().await; - - let row: Option<(i64, String, bool)> = query_as( - "SELECT id, pickle, imported - FROM inbound_group_sessions - WHERE ( - account_id = ? and - room_id = ? and - sender_key = ? and - session_id = ? - )", - ) - .bind(account_id) - .bind(room_id.as_str()) - .bind(sender_key) - .bind(session_id) - .fetch_optional(&mut *connection) - .await?; - - let row = if let Some(r) = row { - r - } else { - return Ok(None); - }; - - let session_row_id = row.0; - let pickle = row.1; - let imported = row.2; - - let session = self - .load_inbound_session_data( - &mut connection, - session_row_id, - pickle, - sender_key.to_owned(), - room_id.to_owned(), - imported, - ) - .await?; - - Ok(Some(session)) - } - - async fn load_inbound_group_sessions(&self) -> Result> { - let mut sessions = Vec::new(); - - let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; - let mut connection = self.connection.lock().await; - - let mut rows: Vec<(i64, String, String, String, bool)> = query_as( - "SELECT id, pickle, sender_key, room_id, imported - FROM inbound_group_sessions WHERE account_id = ?", - ) - .bind(account_id) - .fetch_all(&mut *connection) - .await?; - - for row in rows.drain(..) { - let session_row_id = row.0; - let pickle = row.1; - let sender_key = row.2; - let room_id = RoomId::try_from(row.3)?; - let imported = row.4; - - let session = self - .load_inbound_session_data( - &mut connection, - session_row_id, - pickle, - sender_key, - room_id.to_owned(), - imported, - ) - .await?; - - sessions.push(session); - } - - Ok(sessions) - } - - async fn save_tracked_user(&self, user: &UserId, dirty: bool) -> Result<()> { - let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; - let mut connection = self.connection.lock().await; - // TODO see the todo in the memory store, we need to avoid a race - // between a sync and key query. - - query( - "INSERT INTO tracked_users ( - account_id, user_id, dirty - ) VALUES (?1, ?2, ?3) - ON CONFLICT(account_id, user_id) DO UPDATE SET - dirty = excluded.dirty - ", - ) - .bind(account_id) - .bind(user.to_string()) - .bind(dirty) - .execute(&mut *connection) - .await?; - - Ok(()) - } - - async fn load_tracked_users(&self) -> Result<()> { - let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; - let mut connection = self.connection.lock().await; - - let rows: Vec<(String, bool)> = query_as( - "SELECT user_id, dirty - FROM tracked_users WHERE account_id = ?", - ) - .bind(account_id) - .fetch_all(&mut *connection) - .await?; - - for row in rows { - let user_id: &str = &row.0; - let dirty: bool = row.1; - - if let Ok(u) = UserId::try_from(user_id) { - self.tracked_users.insert(u.clone()); - if dirty { - self.users_for_key_query.insert(u); - } - } else { - continue; - }; - } - - Ok(()) - } - - async fn load_device_data( - &self, - connection: &mut SqliteConnection, - device_row_id: i64, - user_id: &UserId, - device_id: DeviceIdBox, - trust_state: LocalTrust, - display_name: Option, - ) -> Result { - let algorithm_rows: Vec<(String,)> = - query_as("SELECT algorithm FROM algorithms WHERE device_id = ?") - .bind(device_row_id) - .fetch_all(&mut *connection) - .await?; - - let algorithms = algorithm_rows - .iter() - .map(|row| { - let algorithm: &str = &row.0; - EventEncryptionAlgorithm::from(algorithm) - }) - .collect::>(); - - let key_rows: Vec<(String, String)> = - query_as("SELECT algorithm, key FROM device_keys WHERE device_id = ?") - .bind(device_row_id) - .fetch_all(&mut *connection) - .await?; - - let keys: BTreeMap = key_rows - .into_iter() - .filter_map(|row| { - let algorithm = DeviceKeyAlgorithm::try_from(row.0).ok()?; - let key = row.1; - - Some((DeviceKeyId::from_parts(algorithm, &device_id), key)) - }) - .collect(); - - let signature_rows: Vec<(String, String, String)> = query_as( - "SELECT user_id, key_algorithm, signature - FROM device_signatures WHERE device_id = ?", - ) - .bind(device_row_id) - .fetch_all(&mut *connection) - .await?; - - let mut signatures: BTreeMap> = BTreeMap::new(); - - for row in signature_rows { - let user_id = if let Ok(u) = UserId::try_from(&*row.0) { - u - } else { - continue; - }; - - let key_algorithm = if let Ok(k) = DeviceKeyAlgorithm::try_from(row.1) { - k - } else { - continue; - }; - - let signature = row.2; - - signatures - .entry(user_id) - .or_insert_with(BTreeMap::new) - .insert( - DeviceKeyId::from_parts(key_algorithm, device_id.as_str().into()), - signature.to_owned(), - ); - } - - Ok(ReadOnlyDevice::new( - user_id.to_owned(), - device_id, - display_name.clone(), - trust_state, - algorithms, - keys, - signatures, - )) - } - - async fn get_single_device( - &self, - user_id: &UserId, - device_id: &DeviceId, - ) -> Result> { - let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; - let mut connection = self.connection.lock().await; - - let row: Option<(i64, Option, i64)> = query_as( - "SELECT id, display_name, trust_state - FROM devices WHERE account_id = ? and user_id = ? and device_id = ?", - ) - .bind(account_id) - .bind(user_id.as_str()) - .bind(device_id.as_str()) - .fetch_optional(&mut *connection) - .await?; - - let row = if let Some(r) = row { - r - } else { - return Ok(None); - }; - - let device_row_id = row.0; - let display_name = row.1; - let trust_state = LocalTrust::from(row.2); - let device = self - .load_device_data( - &mut connection, - device_row_id, - user_id, - device_id.into(), - trust_state, - display_name, - ) - .await?; - - Ok(Some(device)) - } - - async fn load_devices(&self, user_id: &UserId) -> Result> { - let mut devices = HashMap::new(); - - let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; - let mut connection = self.connection.lock().await; - - let mut rows: Vec<(i64, String, Option, i64)> = query_as( - "SELECT id, device_id, display_name, trust_state - FROM devices WHERE account_id = ? and user_id = ?", - ) - .bind(account_id) - .bind(user_id.as_str()) - .fetch_all(&mut *connection) - .await?; - - for row in rows.drain(..) { - let device_row_id = row.0; - let device_id: DeviceIdBox = row.1.into(); - let display_name = row.2; - let trust_state = LocalTrust::from(row.3); - - let device = self - .load_device_data( - &mut connection, - device_row_id, - user_id, - device_id.clone(), - trust_state, - display_name, - ) - .await?; - - devices.insert(device_id, device); - } - - Ok(devices) - } - - async fn save_device_helper( - &self, - connection: &mut SqliteConnection, - device: ReadOnlyDevice, - ) -> Result<()> { - let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; - - query( - "INSERT INTO devices ( - account_id, user_id, device_id, - display_name, trust_state - ) VALUES (?1, ?2, ?3, ?4, ?5) - ON CONFLICT(account_id, user_id, device_id) DO UPDATE SET - display_name = excluded.display_name, - trust_state = excluded.trust_state - ", - ) - .bind(account_id) - .bind(device.user_id().as_str()) - .bind(device.device_id().as_str()) - .bind(device.display_name()) - .bind(device.local_trust_state() as i64) - .execute(&mut *connection) - .await?; - - let row: (i64,) = query_as( - "SELECT id FROM devices - WHERE user_id = ? and device_id = ?", - ) - .bind(device.user_id().as_str()) - .bind(device.device_id().as_str()) - .fetch_one(&mut *connection) - .await?; - - let device_row_id = row.0; - - for algorithm in device.algorithms() { - query( - "INSERT OR IGNORE INTO algorithms ( - device_id, algorithm - ) VALUES (?1, ?2) - ", - ) - .bind(device_row_id) - .bind(algorithm.to_string()) - .execute(&mut *connection) - .await?; - } - - for (key_id, key) in device.keys() { - query( - "INSERT OR IGNORE INTO device_keys ( - device_id, algorithm, key - ) VALUES (?1, ?2, ?3) - ", - ) - .bind(device_row_id) - .bind(key_id.algorithm().to_string()) - .bind(key) - .execute(&mut *connection) - .await?; - } - - for (user_id, signature_map) in device.signatures() { - for (key_id, signature) in signature_map { - query( - "INSERT OR IGNORE INTO device_signatures ( - device_id, user_id, key_algorithm, signature - ) VALUES (?1, ?2, ?3, ?4) - ", - ) - .bind(device_row_id) - .bind(user_id.as_str()) - .bind(key_id.algorithm().to_string()) - .bind(signature) - .execute(&mut *connection) - .await?; - } - } - - Ok(()) - } - - fn get_pickle_mode(&self) -> PicklingMode { - self.pickle_key.pickle_mode() - } - - fn get_pickle_key(&self) -> &[u8] { - self.pickle_key.key() - } - - async fn save_inbound_group_session_helper( - &self, - account_id: i64, - connection: &mut SqliteConnection, - session: &InboundGroupSession, - ) -> Result<()> { - let pickle = session.pickle(self.get_pickle_mode()).await; - let session_id = session.session_id(); - - query( - "REPLACE INTO inbound_group_sessions ( - session_id, account_id, sender_key, - room_id, pickle, imported - ) VALUES (?1, ?2, ?3, ?4, ?5, ?6) - ", - ) - .bind(session_id) - .bind(account_id) - .bind(&pickle.sender_key) - .bind(pickle.room_id.as_str()) - .bind(pickle.pickle.as_str()) - .bind(pickle.imported) - .execute(&mut *connection) - .await?; - - let row: (i64,) = query_as( - "SELECT id FROM inbound_group_sessions - WHERE account_id = ? and session_id = ? and sender_key = ?", - ) - .bind(account_id) - .bind(session_id) - .bind(pickle.sender_key) - .fetch_one(&mut *connection) - .await?; - - let session_row_id = row.0; - - for (key_id, key) in pickle.signing_key { - query( - "REPLACE INTO group_session_claimed_keys ( - session_id, algorithm, key - ) VALUES (?1, ?2, ?3) - ", - ) - .bind(session_row_id) - .bind(serde_json::to_string(&key_id)?) - .bind(key) - .execute(&mut *connection) - .await?; - } - - if let Some(chains) = pickle.forwarding_chains { - for key in chains { - query( - "REPLACE INTO group_session_chains ( - session_id, key - ) VALUES (?1, ?2) - ", - ) - .bind(session_row_id) - .bind(key) - .execute(&mut *connection) - .await?; - } - } - - Ok(()) - } - - async fn load_cross_signing_key( - connection: &mut SqliteConnection, - user_id: &UserId, - user_row_id: i64, - key_type: CrosssigningKeyType, - ) -> Result { - let row: (i64, String) = - query_as("SELECT id, usage FROM cross_signing_keys WHERE user_id =? and key_type =?") - .bind(user_row_id) - .bind(key_type) - .fetch_one(&mut *connection) - .await?; - - let key_row_id = row.0; - let usage: Vec = serde_json::from_str(&row.1)?; - - let key_rows: Vec<(String, String)> = - query_as("SELECT key_id, key FROM user_keys WHERE cross_signing_key = ?") - .bind(key_row_id) - .fetch_all(&mut *connection) - .await?; - - let mut keys = BTreeMap::new(); - let mut signatures = BTreeMap::new(); - - for row in key_rows { - let key_id = row.0; - let key = row.1; - - keys.insert(key_id, key); - } - - let mut signature_rows: Vec<(String, String, String)> = query_as( - "SELECT user_id, key_id, signature FROM user_key_signatures WHERE cross_signing_key = ?", - ) - .bind(key_row_id) - .fetch_all(&mut *connection) - .await?; - - for row in signature_rows.drain(..) { - let user_id = if let Ok(u) = UserId::try_from(row.0) { - u - } else { - continue; - }; - - let key_id = row.1; - let signature = row.2; - - signatures - .entry(user_id) - .or_insert_with(BTreeMap::new) - .insert(key_id, signature); - } - - Ok(CrossSigningKey { - user_id: user_id.to_owned(), - usage, - keys, - signatures, - }) - } - - async fn load_user(&self, user_id: &UserId) -> Result> { - let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; - let mut connection = self.connection.lock().await; - - let row: Option<(i64,)> = - query_as("SELECT id FROM users WHERE account_id = ? and user_id = ?") - .bind(account_id) - .bind(user_id.as_str()) - .fetch_optional(&mut *connection) - .await?; - - let user_row_id = if let Some(row) = row { - row.0 - } else { - return Ok(None); - }; - - let master = SqliteStore::load_cross_signing_key( - &mut connection, - user_id, - user_row_id, - CrosssigningKeyType::Master, - ) - .await?; - let self_singing = SqliteStore::load_cross_signing_key( - &mut connection, - user_id, - user_row_id, - CrosssigningKeyType::SelfSigning, - ) - .await?; - - if user_id == &*self.user_id { - let user_signing = SqliteStore::load_cross_signing_key( - &mut connection, - user_id, - user_row_id, - CrosssigningKeyType::UserSigning, - ) - .await?; - - let verified: Option<(bool,)> = - query_as("SELECT trusted FROM users_trust_state WHERE user_id = ?") - .bind(user_row_id) - .fetch_optional(&mut *connection) - .await?; - - let verified = verified.map_or(false, |r| r.0); - - let identity = - OwnUserIdentity::new(master.into(), self_singing.into(), user_signing.into()) - .expect("Signature check failed on stored identity"); - - if verified { - identity.mark_as_verified(); - } - - Ok(Some(UserIdentities::Own(identity))) - } else { - Ok(Some(UserIdentities::Other( - UserIdentity::new(master.into(), self_singing.into()) - .expect("Signature check failed on stored identity"), - ))) - } - } - - async fn save_cross_signing_key( - connection: &mut SqliteConnection, - user_row_id: i64, - key_type: CrosssigningKeyType, - cross_signing_key: impl AsRef, - ) -> Result<()> { - let cross_signing_key: &CrossSigningKey = cross_signing_key.as_ref(); - - query( - "REPLACE INTO cross_signing_keys ( - user_id, key_type, usage - ) VALUES (?1, ?2, ?3) - ", - ) - .bind(user_row_id) - .bind(key_type) - .bind(serde_json::to_string(&cross_signing_key.usage)?) - .execute(&mut *connection) - .await?; - - let row: (i64,) = query_as( - "SELECT id FROM cross_signing_keys - WHERE user_id = ? and key_type = ?", - ) - .bind(user_row_id) - .bind(key_type) - .fetch_one(&mut *connection) - .await?; - - let key_row_id = row.0; - - for (key_id, key) in &cross_signing_key.keys { - query( - "REPLACE INTO user_keys ( - cross_signing_key, key_id, key - ) VALUES (?1, ?2, ?3) - ", - ) - .bind(key_row_id) - .bind(key_id.as_str()) - .bind(key) - .execute(&mut *connection) - .await?; - } - - for (user_id, signature_map) in &cross_signing_key.signatures { - for (key_id, signature) in signature_map { - query( - "REPLACE INTO user_key_signatures ( - cross_signing_key, user_id, key_id, signature - ) VALUES (?1, ?2, ?3, ?4) - ", - ) - .bind(key_row_id) - .bind(user_id.as_str()) - .bind(key_id.as_str()) - .bind(signature) - .execute(&mut *connection) - .await?; - } - } - - Ok(()) - } - - #[cfg(test)] - async fn save_sessions(&self, sessions: &[Session]) -> Result<()> { - let mut connection = self.connection.lock().await; - let mut transaction = connection.begin().await?; - - self.save_sessions_helper(&mut transaction, sessions) - .await?; - transaction.commit().await?; - - Ok(()) - } - - async fn save_sessions_helper( - &self, - connection: &mut SqliteConnection, - sessions: &[Session], - ) -> Result<()> { - let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; - - for session in sessions { - self.lazy_load_sessions(connection, &session.sender_key) - .await?; - } - - for session in sessions { - self.sessions.add(session.clone()).await; - - let pickle = session.pickle(self.get_pickle_mode()).await; - - let session_id = session.session_id(); - let creation_time = serde_json::to_string(&pickle.creation_time)?; - let last_use_time = serde_json::to_string(&pickle.last_use_time)?; - - query( - "REPLACE INTO sessions ( - session_id, account_id, creation_time, last_use_time, sender_key, pickle - ) VALUES (?, ?, ?, ?, ?, ?)", - ) - .bind(&session_id) - .bind(&account_id) - .bind(&*creation_time) - .bind(&*last_use_time) - .bind(&pickle.sender_key) - .bind(&pickle.pickle.as_str()) - .execute(&mut *connection) - .await?; - } - - Ok(()) - } - - async fn save_devices( - &self, - mut connection: &mut SqliteConnection, - devices: &[ReadOnlyDevice], - ) -> Result<()> { - for device in devices { - self.save_device_helper(&mut connection, device.clone()) - .await? - } - - Ok(()) - } - - async fn delete_devices( - &self, - connection: &mut SqliteConnection, - devices: &[ReadOnlyDevice], - ) -> Result<()> { - let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; - - for device in devices { - query( - "DELETE FROM devices - WHERE account_id = ?1 and user_id = ?2 and device_id = ?3 - ", - ) - .bind(account_id) - .bind(&device.user_id().to_string()) - .bind(device.device_id().as_str()) - .execute(&mut *connection) - .await?; - } - - Ok(()) - } - - #[cfg(test)] - async fn save_inbound_group_sessions_test( - &self, - sessions: &[InboundGroupSession], - ) -> Result<()> { - let mut connection = self.connection.lock().await; - let mut transaction = connection.begin().await?; - - self.save_inbound_group_sessions(&mut transaction, sessions) - .await?; - - transaction.commit().await?; - Ok(()) - } - - async fn save_inbound_group_sessions( - &self, - connection: &mut SqliteConnection, - sessions: &[InboundGroupSession], - ) -> Result<()> { - let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; - - for session in sessions { - self.save_inbound_group_session_helper(account_id, connection, session) - .await?; - } - - Ok(()) - } - - async fn save_user_identities( - &self, - mut connection: &mut SqliteConnection, - users: &[UserIdentities], - ) -> Result<()> { - for user in users { - self.save_user_helper(&mut connection, user).await?; - } - Ok(()) - } - - async fn save_olm_hashses( - &self, - connection: &mut SqliteConnection, - hashes: &[OlmMessageHash], - ) -> Result<()> { - let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; - - for hash in hashes { - query("REPLACE INTO olm_hashes (account_id, sender_key, hash) VALUES (?1, ?2, ?3)") - .bind(account_id) - .bind(&hash.sender_key) - .bind(&hash.hash) - .execute(&mut *connection) - .await?; - } - - Ok(()) - } - - async fn save_identity( - &self, - connection: &mut SqliteConnection, - identity: PrivateCrossSigningIdentity, - ) -> Result<()> { - let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; - let pickle = identity.pickle(self.get_pickle_key()).await?; - - query( - "INSERT INTO private_identities ( - account_id, user_id, pickle, shared - ) VALUES (?1, ?2, ?3, ?4) - ON CONFLICT(account_id, user_id) DO UPDATE SET - pickle = excluded.pickle, - shared = excluded.shared - ", - ) - .bind(account_id) - .bind(pickle.user_id.as_str()) - .bind(pickle.pickle) - .bind(pickle.shared) - .execute(&mut *connection) - .await?; - - Ok(()) - } - - async fn save_account_helper( - &self, - connection: &mut SqliteConnection, - account: ReadOnlyAccount, - ) -> Result<()> { - let pickle = account.pickle(self.get_pickle_mode()).await; - - query( - "INSERT INTO accounts ( - user_id, device_id, pickle, shared, uploaded_key_count - ) VALUES (?1, ?2, ?3, ?4, ?5) - ON CONFLICT(user_id, device_id) DO UPDATE SET - pickle = excluded.pickle, - shared = excluded.shared, - uploaded_key_count = excluded.uploaded_key_count - ", - ) - .bind(pickle.user_id.as_str()) - .bind(pickle.device_id.as_str()) - .bind(pickle.pickle.as_str()) - .bind(pickle.shared) - .bind(pickle.uploaded_signed_key_count) - .execute(&mut *connection) - .await?; - - let account_id: (i64,) = - query_as("SELECT id FROM accounts WHERE user_id = ? and device_id = ?") - .bind(self.user_id.as_str()) - .bind(self.device_id.as_str()) - .fetch_one(&mut *connection) - .await?; - - *self.account_info.lock().unwrap() = Some(AccountInfo { - account_id: account_id.0, - identity_keys: account.identity_keys.clone(), - }); - - Ok(()) - } - - async fn save_user_helper( - &self, - mut connection: &mut SqliteConnection, - user: &UserIdentities, - ) -> Result<()> { - let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; - - query("REPLACE INTO users (account_id, user_id) VALUES (?1, ?2)") - .bind(account_id) - .bind(user.user_id().as_str()) - .execute(&mut *connection) - .await?; - - let row: (i64,) = query_as( - "SELECT id FROM users - WHERE account_id = ? and user_id = ?", - ) - .bind(account_id) - .bind(user.user_id().as_str()) - .fetch_one(&mut *connection) - .await?; - - let user_row_id = row.0; - - SqliteStore::save_cross_signing_key( - &mut connection, - user_row_id, - CrosssigningKeyType::Master, - user.master_key(), - ) - .await?; - SqliteStore::save_cross_signing_key( - &mut connection, - user_row_id, - CrosssigningKeyType::SelfSigning, - user.self_signing_key(), - ) - .await?; - - if let UserIdentities::Own(own_identity) = user { - SqliteStore::save_cross_signing_key( - &mut connection, - user_row_id, - CrosssigningKeyType::UserSigning, - own_identity.user_signing_key(), - ) - .await?; - - query("REPLACE INTO users_trust_state (user_id, trusted) VALUES (?1, ?2)") - .bind(user_row_id) - .bind(own_identity.is_verified()) - .execute(&mut *connection) - .await?; - } - - Ok(()) - } -} - -#[async_trait] -impl CryptoStore for SqliteStore { - async fn load_account(&self) -> Result> { - let mut connection = self.connection.lock().await; - - let row: Option<(i64, String, bool, i64)> = query_as( - "SELECT id, pickle, shared, uploaded_key_count FROM accounts - WHERE user_id = ? and device_id = ?", - ) - .bind(self.user_id.as_str()) - .bind(self.device_id.as_str()) - .fetch_optional(&mut *connection) - .await?; - - let result = if let Some((id, pickle, shared, uploaded_key_count)) = row { - let pickle = PickledAccount { - user_id: (&*self.user_id).clone(), - device_id: (&*self.device_id).clone(), - pickle: AccountPickle::from(pickle), - shared, - uploaded_signed_key_count: uploaded_key_count, - }; - - let account = ReadOnlyAccount::from_pickle(pickle, self.get_pickle_mode())?; - - *self.account_info.lock().unwrap() = Some(AccountInfo { - account_id: id, - identity_keys: account.identity_keys.clone(), - }); - - Some(account) - } else { - return Ok(None); - }; - - drop(connection); - - self.load_tracked_users().await?; - - Ok(result) - } - - async fn save_account(&self, account: ReadOnlyAccount) -> Result<()> { - let mut connection = self.connection.lock().await; - self.save_account_helper(&mut connection, account).await - } - - async fn load_identity(&self) -> Result> { - let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; - let mut connection = self.connection.lock().await; - - let row: Option<(String, bool)> = query_as( - "SELECT pickle, shared FROM private_identities - WHERE account_id = ?", - ) - .bind(account_id) - .fetch_optional(&mut *connection) - .await?; - - if let Some(row) = row { - let pickle = PickledCrossSigningIdentity { - user_id: (&*self.user_id).clone(), - pickle: row.0, - shared: row.1, - }; - - // TODO remove this unwrap - let identity = PrivateCrossSigningIdentity::from_pickle(pickle, self.get_pickle_key()) - .await - .unwrap(); - - Ok(Some(identity)) - } else { - Ok(None) - } - } - - async fn save_changes(&self, changes: Changes) -> Result<()> { - let mut connection = self.connection.lock().await; - let mut transaction = connection.begin().await?; - - if let Some(account) = changes.account { - self.save_account_helper(&mut transaction, account).await?; - } - - if let Some(identity) = changes.private_identity { - self.save_identity(&mut transaction, identity).await?; - } - - self.save_sessions_helper(&mut transaction, &changes.sessions) - .await?; - self.save_inbound_group_sessions(&mut transaction, &changes.inbound_group_sessions) - .await?; - - self.save_devices(&mut transaction, &changes.devices.new) - .await?; - self.save_devices(&mut transaction, &changes.devices.changed) - .await?; - self.delete_devices(&mut transaction, &changes.devices.deleted) - .await?; - - self.save_user_identities(&mut transaction, &changes.identities.new) - .await?; - self.save_user_identities(&mut transaction, &changes.identities.changed) - .await?; - self.save_olm_hashses(&mut transaction, &changes.message_hashes) - .await?; - - transaction.commit().await?; - - Ok(()) - } - - async fn get_sessions(&self, sender_key: &str) -> Result>>>> { - let mut connection = self.connection.lock().await; - Ok(self.get_sessions_for(&mut connection, sender_key).await?) - } - - async fn get_inbound_group_session( - &self, - room_id: &RoomId, - sender_key: &str, - session_id: &str, - ) -> Result> { - Ok(self - .load_inbound_group_session_helper(room_id, sender_key, session_id) - .await?) - } - - async fn get_inbound_group_sessions(&self) -> Result> { - Ok(self.load_inbound_group_sessions().await?) - } - - fn is_user_tracked(&self, user_id: &UserId) -> bool { - self.tracked_users.contains(user_id) - } - - fn has_users_for_key_query(&self) -> bool { - !self.users_for_key_query.is_empty() - } - - fn users_for_key_query(&self) -> HashSet { - #[allow(clippy::map_clone)] - self.users_for_key_query.iter().map(|u| u.clone()).collect() - } - - async fn update_tracked_user(&self, user: &UserId, dirty: bool) -> Result { - let already_added = self.tracked_users.insert(user.clone()); - - if dirty { - self.users_for_key_query.insert(user.clone()); - } else { - self.users_for_key_query.remove(user); - } - - self.save_tracked_user(user, dirty).await?; - - Ok(already_added) - } - - async fn get_device( - &self, - user_id: &UserId, - device_id: &DeviceId, - ) -> Result> { - self.get_single_device(user_id, device_id).await - } - - async fn get_user_devices( - &self, - user_id: &UserId, - ) -> Result> { - Ok(self.load_devices(user_id).await?) - } - - async fn get_user_identity(&self, user_id: &UserId) -> Result> { - self.load_user(user_id).await - } - - async fn save_value(&self, key: String, value: String) -> Result<()> { - let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; - let mut connection = self.connection.lock().await; - - query("REPLACE INTO key_value (account_id, key, value) VALUES (?1, ?2, ?3)") - .bind(account_id) - .bind(&key) - .bind(&value) - .execute(&mut *connection) - .await?; - - Ok(()) - } - - async fn remove_value(&self, key: &str) -> Result<()> { - let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; - let mut connection = self.connection.lock().await; - - query( - "DELETE FROM key_value - WHERE account_id = ?1 and key = ?2 - ", - ) - .bind(account_id) - .bind(key) - .execute(&mut *connection) - .await?; - - Ok(()) - } - - async fn get_value(&self, key: &str) -> Result> { - let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; - let mut connection = self.connection.lock().await; - - let row: Option<(String,)> = - query_as("SELECT value FROM key_value WHERE account_id = ? and key = ?") - .bind(account_id) - .bind(key) - .fetch_optional(&mut *connection) - .await?; - - Ok(row.map(|r| r.0)) - } - - async fn is_message_known(&self, message_hash: &OlmMessageHash) -> Result { - let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; - let mut connection = self.connection.lock().await; - - let row: Option<(String,)> = query_as( - "SELECT hash FROM olm_hashes WHERE account_id = ? and sender_key = ? and hash = ?", - ) - .bind(account_id) - .bind(&message_hash.sender_key) - .bind(&message_hash.hash) - .fetch_optional(&mut *connection) - .await?; - - Ok(row.is_some()) - } -} - -#[cfg(not(tarpaulin_include))] -impl std::fmt::Debug for SqliteStore { - fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> StdResult<(), std::fmt::Error> { - fmt.debug_struct("SqliteStore") - .field("user_id", &self.user_id) - .field("device_id", &self.device_id) - .field("path", &self.path) - .finish() - } -} - -#[cfg(test)] -mod test { - use crate::{ - identities::{ - device::test::get_device, - user::test::{get_other_identity, get_own_identity}, - }, - olm::{ - GroupSessionKey, InboundGroupSession, OlmMessageHash, PrivateCrossSigningIdentity, - ReadOnlyAccount, Session, - }, - store::{Changes, DeviceChanges, IdentityChanges}, - }; - use matrix_sdk_common::{ - api::r0::keys::SignedKey, - identifiers::{room_id, user_id, DeviceId, UserId}, - }; - use olm_rs::outbound_group_session::OlmOutboundGroupSession; - use std::collections::BTreeMap; - use tempfile::tempdir; - - use super::{CryptoStore, SqliteStore}; - - fn alice_id() -> UserId { - user_id!("@alice:example.org") - } - - fn alice_device_id() -> Box { - "ALICEDEVICE".into() - } - - fn bob_id() -> UserId { - user_id!("@bob:example.org") - } - - fn bob_device_id() -> Box { - "BOBDEVICE".into() - } - - async fn get_store(passphrase: Option<&str>) -> (SqliteStore, tempfile::TempDir) { - let tmpdir = tempdir().unwrap(); - let tmpdir_path = tmpdir.path().to_str().unwrap(); - - let store = if let Some(passphrase) = passphrase { - SqliteStore::open_with_passphrase( - &alice_id(), - &alice_device_id(), - tmpdir_path, - passphrase, - ) - .await - .expect("Can't create a passphrase protected store") - } else { - SqliteStore::open(&alice_id(), &alice_device_id(), tmpdir_path) - .await - .expect("Can't create store") - }; - - (store, tmpdir) - } - - async fn get_loaded_store() -> (ReadOnlyAccount, SqliteStore, tempfile::TempDir) { - let (store, dir) = get_store(None).await; - let account = get_account(); - store - .save_account(account.clone()) - .await - .expect("Can't save account"); - - (account, store, dir) - } - - fn get_account() -> ReadOnlyAccount { - ReadOnlyAccount::new(&alice_id(), &alice_device_id()) - } - - async fn get_account_and_session() -> (ReadOnlyAccount, Session) { - let alice = ReadOnlyAccount::new(&alice_id(), &alice_device_id()); - let bob = ReadOnlyAccount::new(&bob_id(), &bob_device_id()); - - bob.generate_one_time_keys_helper(1).await; - let one_time_key = bob - .one_time_keys() - .await - .curve25519() - .iter() - .next() - .unwrap() - .1 - .to_owned(); - let one_time_key = SignedKey { - key: one_time_key, - signatures: BTreeMap::new(), - }; - let sender_key = bob.identity_keys().curve25519().to_owned(); - let session = alice - .create_outbound_session_helper(&sender_key, &one_time_key) - .await - .unwrap(); - - (alice, session) - } - - #[tokio::test(flavor = "multi_thread")] - async fn create_store() { - let tmpdir = tempdir().unwrap(); - let tmpdir_path = tmpdir.path().to_str().unwrap(); - let _ = SqliteStore::open(&alice_id(), &alice_device_id(), tmpdir_path) - .await - .expect("Can't create store"); - } - - #[tokio::test(flavor = "multi_thread")] - async fn save_account() { - let (store, _dir) = get_store(None).await; - assert!(store.load_account().await.unwrap().is_none()); - let account = get_account(); - - store - .save_account(account) - .await - .expect("Can't save account"); - } - - #[tokio::test(flavor = "multi_thread")] - async fn load_account() { - let (store, _dir) = get_store(None).await; - let account = get_account(); - - store - .save_account(account.clone()) - .await - .expect("Can't save account"); - - let loaded_account = store.load_account().await.expect("Can't load account"); - let loaded_account = loaded_account.unwrap(); - - assert_eq!(account, loaded_account); - } - - #[tokio::test(flavor = "multi_thread")] - async fn load_account_with_passphrase() { - let (store, _dir) = get_store(Some("secret_passphrase")).await; - let account = get_account(); - - store - .save_account(account.clone()) - .await - .expect("Can't save account"); - - let loaded_account = store.load_account().await.expect("Can't load account"); - let loaded_account = loaded_account.unwrap(); - - assert_eq!(account, loaded_account); - } - - #[tokio::test(flavor = "multi_thread")] - async fn save_and_share_account() { - let (store, _dir) = get_store(None).await; - let account = get_account(); - - store - .save_account(account.clone()) - .await - .expect("Can't save account"); - - account.mark_as_shared(); - account.update_uploaded_key_count(50); - - store - .save_account(account.clone()) - .await - .expect("Can't save account"); - - let loaded_account = store.load_account().await.expect("Can't load account"); - let loaded_account = loaded_account.unwrap(); - - assert_eq!(account, loaded_account); - assert_eq!( - account.uploaded_key_count(), - loaded_account.uploaded_key_count() - ); - } - - #[tokio::test(flavor = "multi_thread")] - async fn save_session() { - let (store, _dir) = get_store(None).await; - let (account, session) = get_account_and_session().await; - - assert!(store.save_sessions(&[session.clone()]).await.is_err()); - - store - .save_account(account.clone()) - .await - .expect("Can't save account"); - - store.save_sessions(&[session]).await.unwrap(); - } - - #[tokio::test(flavor = "multi_thread")] - async fn load_sessions() { - let (store, _dir) = get_store(None).await; - let (account, session) = get_account_and_session().await; - store - .save_account(account.clone()) - .await - .expect("Can't save account"); - store.save_sessions(&[session.clone()]).await.unwrap(); - - let sessions = store - .load_sessions_for(&session.sender_key) - .await - .expect("Can't load sessions"); - let loaded_session = &sessions[0]; - - assert_eq!(&session, loaded_session); - } - - #[tokio::test(flavor = "multi_thread")] - async fn add_and_save_session() { - let (store, dir) = get_store(None).await; - let (account, session) = get_account_and_session().await; - let sender_key = session.sender_key.to_owned(); - let session_id = session.session_id().to_owned(); - - store - .save_account(account.clone()) - .await - .expect("Can't save account"); - store.save_sessions(&[session]).await.unwrap(); - - let sessions = store.get_sessions(&sender_key).await.unwrap().unwrap(); - let sessions_lock = sessions.lock().await; - let session = &sessions_lock[0]; - - assert_eq!(session_id, session.session_id()); - - drop(store); - - let store = SqliteStore::open(&alice_id(), &alice_device_id(), dir.path()) - .await - .expect("Can't create store"); - - let loaded_account = store.load_account().await.unwrap().unwrap(); - assert_eq!(account, loaded_account); - - let sessions = store.get_sessions(&sender_key).await.unwrap().unwrap(); - let sessions_lock = sessions.lock().await; - let session = &sessions_lock[0]; - - assert_eq!(session_id, session.session_id()); - } - - #[tokio::test(flavor = "multi_thread")] - async fn save_inbound_group_session() { - let (account, store, _dir) = get_loaded_store().await; - - let identity_keys = account.identity_keys(); - let outbound_session = OlmOutboundGroupSession::new(); - let session = InboundGroupSession::new( - identity_keys.curve25519(), - identity_keys.ed25519(), - &room_id!("!test:localhost"), - GroupSessionKey(outbound_session.session_key()), - ) - .expect("Can't create session"); - - store - .save_inbound_group_sessions_test(&[session]) - .await - .expect("Can't save group session"); - } - - #[tokio::test(flavor = "multi_thread")] - async fn load_inbound_group_session() { - let (account, store, dir) = get_loaded_store().await; - - let identity_keys = account.identity_keys(); - let outbound_session = OlmOutboundGroupSession::new(); - let session = InboundGroupSession::new( - identity_keys.curve25519(), - identity_keys.ed25519(), - &room_id!("!test:localhost"), - GroupSessionKey(outbound_session.session_key()), - ) - .expect("Can't create session"); - - let mut export = session.export().await; - - export.forwarding_curve25519_key_chain = vec!["some_chain".to_owned()]; - - let session = InboundGroupSession::from_export(export).unwrap(); - - store - .save_inbound_group_sessions_test(&[session.clone()]) - .await - .expect("Can't save group session"); - - let store = SqliteStore::open(&alice_id(), &alice_device_id(), dir.path()) - .await - .expect("Can't create store"); - - store.load_account().await.unwrap(); - - let loaded_session = store - .get_inbound_group_session(&session.room_id, &session.sender_key, session.session_id()) - .await - .unwrap() - .unwrap(); - assert_eq!(session, loaded_session); - let export = loaded_session.export().await; - assert!(!export.forwarding_curve25519_key_chain.is_empty()) - } - - #[tokio::test(flavor = "multi_thread")] - async fn test_tracked_users() { - let (_account, store, dir) = get_loaded_store().await; - let device = get_device(); - - assert!(store - .update_tracked_user(device.user_id(), false) - .await - .unwrap()); - assert!(!store - .update_tracked_user(device.user_id(), false) - .await - .unwrap()); - - assert!(store.is_user_tracked(device.user_id())); - assert!(!store.users_for_key_query().contains(device.user_id())); - assert!(!store - .update_tracked_user(device.user_id(), true) - .await - .unwrap()); - assert!(store.users_for_key_query().contains(device.user_id())); - drop(store); - - let store = SqliteStore::open(&alice_id(), &alice_device_id(), dir.path()) - .await - .expect("Can't create store"); - - store.load_account().await.unwrap(); - - assert!(store.is_user_tracked(device.user_id())); - assert!(store.users_for_key_query().contains(device.user_id())); - - store - .update_tracked_user(device.user_id(), false) - .await - .unwrap(); - assert!(!store.users_for_key_query().contains(device.user_id())); - - let store = SqliteStore::open(&alice_id(), &alice_device_id(), dir.path()) - .await - .expect("Can't create store"); - - store.load_account().await.unwrap(); - - assert!(!store.users_for_key_query().contains(device.user_id())); - } - - #[tokio::test(flavor = "multi_thread")] - async fn device_saving() { - let (_account, store, dir) = get_loaded_store().await; - let device = get_device(); - - let changes = Changes { - devices: DeviceChanges { - changed: vec![device.clone()], - ..Default::default() - }, - ..Default::default() - }; - - store.save_changes(changes).await.unwrap(); - - drop(store); - - let store = SqliteStore::open(&alice_id(), &alice_device_id(), dir.path()) - .await - .expect("Can't create store"); - - store.load_account().await.unwrap(); - - let loaded_device = store - .get_device(device.user_id(), device.device_id()) - .await - .unwrap() - .unwrap(); - - assert_eq!(device, loaded_device); - - for algorithm in loaded_device.algorithms() { - assert!(device.algorithms().contains(algorithm)); - } - assert_eq!(device.algorithms().len(), loaded_device.algorithms().len()); - assert_eq!(device.keys(), loaded_device.keys()); - - let user_devices = store.get_user_devices(device.user_id()).await.unwrap(); - assert_eq!(&**user_devices.keys().next().unwrap(), device.device_id()); - assert_eq!(user_devices.values().next().unwrap(), &device); - } - - #[tokio::test(flavor = "multi_thread")] - async fn device_deleting() { - let (_account, store, dir) = get_loaded_store().await; - let device = get_device(); - - let changes = Changes { - devices: DeviceChanges { - changed: vec![device.clone()], - ..Default::default() - }, - ..Default::default() - }; - - store.save_changes(changes).await.unwrap(); - - let changes = Changes { - devices: DeviceChanges { - deleted: vec![device.clone()], - ..Default::default() - }, - ..Default::default() - }; - - store.save_changes(changes).await.unwrap(); - - let store = SqliteStore::open(&alice_id(), &alice_device_id(), dir.path()) - .await - .expect("Can't create store"); - - store.load_account().await.unwrap(); - - let loaded_device = store - .get_device(device.user_id(), device.device_id()) - .await - .unwrap(); - - assert!(loaded_device.is_none()); - } - - #[tokio::test(flavor = "multi_thread")] - async fn user_saving() { - let dir = tempdir().unwrap(); - let tmpdir_path = dir.path().to_str().unwrap(); - - let user_id = user_id!("@example:localhost"); - let device_id: &DeviceId = "WSKKLTJZCL".into(); - - let store = SqliteStore::open(&user_id, &device_id, tmpdir_path) - .await - .expect("Can't create store"); - - let account = ReadOnlyAccount::new(&user_id, &device_id); - - store - .save_account(account.clone()) - .await - .expect("Can't save account"); - - let own_identity = get_own_identity(); - - let changes = Changes { - identities: IdentityChanges { - changed: vec![own_identity.clone().into()], - ..Default::default() - }, - ..Default::default() - }; - - store - .save_changes(changes) - .await - .expect("Can't save identity"); - - drop(store); - - let store = SqliteStore::open(&user_id, &device_id, dir.path()) - .await - .expect("Can't create store"); - - store.load_account().await.unwrap(); - - let loaded_user = store - .get_user_identity(own_identity.user_id()) - .await - .unwrap() - .unwrap(); - - assert_eq!(loaded_user.master_key(), own_identity.master_key()); - assert_eq!( - loaded_user.self_signing_key(), - own_identity.self_signing_key() - ); - assert_eq!(loaded_user, own_identity.clone().into()); - - let other_identity = get_other_identity(); - - let changes = Changes { - identities: IdentityChanges { - changed: vec![other_identity.clone().into()], - ..Default::default() - }, - ..Default::default() - }; - - store.save_changes(changes).await.unwrap(); - - let loaded_user = store - .load_user(other_identity.user_id()) - .await - .unwrap() - .unwrap(); - - assert_eq!(loaded_user.master_key(), other_identity.master_key()); - assert_eq!( - loaded_user.self_signing_key(), - other_identity.self_signing_key() - ); - assert_eq!(loaded_user, other_identity.into()); - - own_identity.mark_as_verified(); - - let changes = Changes { - identities: IdentityChanges { - changed: vec![own_identity.into()], - ..Default::default() - }, - ..Default::default() - }; - - store.save_changes(changes).await.unwrap(); - let loaded_user = store.load_user(&user_id).await.unwrap().unwrap(); - assert!(loaded_user.own().unwrap().is_verified()) - } - - #[tokio::test(flavor = "multi_thread")] - async fn private_identity_saving() { - let (_, store, _dir) = get_loaded_store().await; - assert!(store.load_identity().await.unwrap().is_none()); - let identity = PrivateCrossSigningIdentity::new((&*store.user_id).clone()).await; - - let changes = Changes { - private_identity: Some(identity.clone()), - ..Default::default() - }; - - store.save_changes(changes).await.unwrap(); - let loaded_identity = store.load_identity().await.unwrap().unwrap(); - assert_eq!(identity.user_id(), loaded_identity.user_id()); - } - - #[tokio::test(flavor = "multi_thread")] - async fn key_value_saving() { - let (_, store, _dir) = get_loaded_store().await; - let key = "test_key".to_string(); - let value = "secret value".to_string(); - - store.save_value(key.clone(), value.clone()).await.unwrap(); - let stored_value = store.get_value(&key).await.unwrap().unwrap(); - - assert_eq!(value, stored_value); - - store.remove_value(&key).await.unwrap(); - assert!(store.get_value(&key).await.unwrap().is_none()); - } - - #[tokio::test(flavor = "multi_thread")] - async fn olm_hash_saving() { - let (_, store, _dir) = get_loaded_store().await; - - let hash = OlmMessageHash { - sender_key: "test_sender".to_owned(), - hash: "test_hash".to_owned(), - }; - - let mut changes = Changes::default(); - changes.message_hashes.push(hash.clone()); - - assert!(!store.is_message_known(&hash).await.unwrap()); - store.save_changes(changes).await.unwrap(); - assert!(store.is_message_known(&hash).await.unwrap()); - } -}