From fc54c63a4c684afd2529a8e67e4673b193e7c956 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Damir=20Jeli=C4=87?= Date: Fri, 16 Oct 2020 15:05:53 +0200 Subject: [PATCH] crypto: Upgrade sqlx to the beta release. This change is much needed to enable transactions in our sqlite store, before this release creating a transaction would take ownership of the connection, now it just mutably borrows it. --- matrix_sdk_crypto/Cargo.toml | 2 +- matrix_sdk_crypto/src/machine.rs | 2 +- matrix_sdk_crypto/src/store/mod.rs | 5 -- matrix_sdk_crypto/src/store/sqlite.rs | 71 ++++++++++++++------------- 4 files changed, 40 insertions(+), 40 deletions(-) diff --git a/matrix_sdk_crypto/Cargo.toml b/matrix_sdk_crypto/Cargo.toml index c437da65..cb882341 100644 --- a/matrix_sdk_crypto/Cargo.toml +++ b/matrix_sdk_crypto/Cargo.toml @@ -52,7 +52,7 @@ default-features = false features = ["std", "std-future"] [target.'cfg(not(target_arch = "wasm32"))'.dependencies.sqlx] -version = "0.3.5" +version = "0.4.0-beta.1" optional = true default-features = false features = ["runtime-tokio", "sqlite", "macros"] diff --git a/matrix_sdk_crypto/src/machine.rs b/matrix_sdk_crypto/src/machine.rs index 03127c25..a2b270d9 100644 --- a/matrix_sdk_crypto/src/machine.rs +++ b/matrix_sdk_crypto/src/machine.rs @@ -1576,7 +1576,7 @@ pub(crate) mod test { } } - #[tokio::test] + #[tokio::test(threaded_scheduler)] #[cfg(feature = "sqlite_cryptostore")] async fn test_machine_with_default_store() { let tmpdir = tempdir().unwrap(); diff --git a/matrix_sdk_crypto/src/store/mod.rs b/matrix_sdk_crypto/src/store/mod.rs index 767bbae4..7dad4d07 100644 --- a/matrix_sdk_crypto/src/store/mod.rs +++ b/matrix_sdk_crypto/src/store/mod.rs @@ -55,7 +55,6 @@ use olm_rs::errors::{OlmAccountError, OlmGroupSessionError, OlmSessionError}; use serde::{Deserialize, Serialize}; use serde_json::Error as SerdeError; use thiserror::Error; -use url::ParseError; #[cfg_attr(feature = "docs", doc(cfg(r#sqlite_cryptostore)))] #[cfg(not(target_arch = "wasm32"))] @@ -245,10 +244,6 @@ pub enum CryptoStoreError { /// The store failed to (de)serialize a data type. #[error(transparent)] Serialization(#[from] SerdeError), - - /// An error occurred while parsing an URL. - #[error(transparent)] - UrlParse(#[from] ParseError), } /// Trait abstracting a store that the `OlmMachine` uses to store cryptographic diff --git a/matrix_sdk_crypto/src/store/sqlite.rs b/matrix_sdk_crypto/src/store/sqlite.rs index fe67e2d6..e469948c 100644 --- a/matrix_sdk_crypto/src/store/sqlite.rs +++ b/matrix_sdk_crypto/src/store/sqlite.rs @@ -30,8 +30,7 @@ use matrix_sdk_common::{ instant::Duration, locks::Mutex, }; -use sqlx::{query, query_as, sqlite::SqliteQueryAs, Connect, Executor, SqliteConnection}; -use url::Url; +use sqlx::{query, query_as, sqlite::SqliteConnectOptions, Connection, Executor, SqliteConnection}; use zeroize::Zeroizing; use super::{ @@ -131,20 +130,20 @@ impl SqliteStore { .await } - fn path_to_url(path: &Path) -> Result { - // TODO this returns an empty error if the path isn't absolute. - let url = Url::from_directory_path(path).expect("Invalid path"); - Ok(url.join(DATABASE_NAME)?) - } - async fn open_helper>( user_id: &UserId, device_id: &DeviceId, path: P, passphrase: Option>, ) -> Result { - let url = SqliteStore::path_to_url(path.as_ref())?; - let connection = SqliteConnection::connect(url.as_ref()).await?; + let path = path.as_ref().join(DATABASE_NAME); + let options = SqliteConnectOptions::new() + .foreign_keys(true) + .create_if_missing(true) + .read_only(false) + .filename(&path); + + let connection = SqliteConnection::connect_with(&options).await?; let store = SqliteStore { user_id: Arc::new(user_id.to_owned()), @@ -153,7 +152,7 @@ impl SqliteStore { sessions: SessionStore::new(), inbound_group_sessions: GroupSessionStore::new(), devices: DeviceStore::new(), - path: Arc::new(path.as_ref().to_owned()), + path: Arc::new(path), connection: Arc::new(Mutex::new(connection)), pickle_passphrase: Arc::new(passphrase), tracked_users: Arc::new(DashSet::new()), @@ -1249,9 +1248,14 @@ impl CryptoStore for SqliteStore { async fn save_sessions(&self, sessions: &[Session]) -> Result<()> { let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; - // TODO turn this into a transaction for session in sessions { self.lazy_load_sessions(&session.sender_key).await?; + } + + let mut connection = self.connection.lock().await; + let mut transaction = connection.begin().await?; + + for session in sessions { self.sessions.add(session.clone()).await; let pickle = session.pickle(self.get_pickle_mode()).await; @@ -1260,8 +1264,6 @@ impl CryptoStore for SqliteStore { let creation_time = serde_json::to_string(&pickle.creation_time)?; let last_use_time = serde_json::to_string(&pickle.last_use_time)?; - let mut connection = self.connection.lock().await; - query( "REPLACE INTO sessions ( session_id, account_id, creation_time, last_use_time, sender_key, pickle @@ -1273,10 +1275,12 @@ impl CryptoStore for SqliteStore { .bind(&*last_use_time) .bind(&pickle.sender_key) .bind(&pickle.pickle.as_str()) - .execute(&mut *connection) + .execute(&mut *transaction) .await?; } + transaction.commit().await?; + Ok(()) } @@ -1287,15 +1291,16 @@ impl CryptoStore for SqliteStore { async fn save_inbound_group_sessions(&self, sessions: &[InboundGroupSession]) -> Result<()> { let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; let mut connection = self.connection.lock().await; - - // FIXME use a transaction here once sqlx gets better support for them. + let mut transaction = connection.begin().await?; for session in sessions { - self.save_inbound_group_session_helper(account_id, &mut connection, session) + self.save_inbound_group_session_helper(account_id, &mut transaction, session) .await?; self.inbound_group_sessions.add(session.clone()); } + transaction.commit().await?; + Ok(()) } @@ -1549,7 +1554,7 @@ mod test { (alice, session) } - #[tokio::test] + #[tokio::test(threaded_scheduler)] async fn create_store() { let tmpdir = tempdir().unwrap(); let tmpdir_path = tmpdir.path().to_str().unwrap(); @@ -1558,7 +1563,7 @@ mod test { .expect("Can't create store"); } - #[tokio::test] + #[tokio::test(threaded_scheduler)] async fn save_account() { let (store, _dir) = get_store(None).await; assert!(store.load_account().await.unwrap().is_none()); @@ -1570,7 +1575,7 @@ mod test { .expect("Can't save account"); } - #[tokio::test] + #[tokio::test(threaded_scheduler)] async fn load_account() { let (store, _dir) = get_store(None).await; let account = get_account(); @@ -1586,7 +1591,7 @@ mod test { assert_eq!(account, loaded_account); } - #[tokio::test] + #[tokio::test(threaded_scheduler)] async fn load_account_with_passphrase() { let (store, _dir) = get_store(Some("secret_passphrase")).await; let account = get_account(); @@ -1602,7 +1607,7 @@ mod test { assert_eq!(account, loaded_account); } - #[tokio::test] + #[tokio::test(threaded_scheduler)] async fn save_and_share_account() { let (store, _dir) = get_store(None).await; let account = get_account(); @@ -1630,7 +1635,7 @@ mod test { ); } - #[tokio::test] + #[tokio::test(threaded_scheduler)] async fn save_session() { let (store, _dir) = get_store(None).await; let (account, session) = get_account_and_session().await; @@ -1645,7 +1650,7 @@ mod test { store.save_sessions(&[session]).await.unwrap(); } - #[tokio::test] + #[tokio::test(threaded_scheduler)] async fn load_sessions() { let (store, _dir) = get_store(None).await; let (account, session) = get_account_and_session().await; @@ -1664,7 +1669,7 @@ mod test { assert_eq!(&session, loaded_session); } - #[tokio::test] + #[tokio::test(threaded_scheduler)] async fn add_and_save_session() { let (store, dir) = get_store(None).await; let (account, session) = get_account_and_session().await; @@ -1699,7 +1704,7 @@ mod test { assert_eq!(session_id, session.session_id()); } - #[tokio::test] + #[tokio::test(threaded_scheduler)] async fn save_inbound_group_session() { let (account, store, _dir) = get_loaded_store().await; @@ -1719,7 +1724,7 @@ mod test { .expect("Can't save group session"); } - #[tokio::test] + #[tokio::test(threaded_scheduler)] async fn load_inbound_group_session() { let (account, store, dir) = get_loaded_store().await; @@ -1761,7 +1766,7 @@ mod test { assert!(!export.forwarding_curve25519_key_chain.is_empty()) } - #[tokio::test] + #[tokio::test(threaded_scheduler)] async fn test_tracked_users() { let (_account, store, dir) = get_loaded_store().await; let device = get_device(); @@ -1808,7 +1813,7 @@ mod test { assert!(!store.users_for_key_query().contains(device.user_id())); } - #[tokio::test] + #[tokio::test(threaded_scheduler)] async fn device_saving() { let (_account, store, dir) = get_loaded_store().await; let device = get_device(); @@ -1842,7 +1847,7 @@ mod test { assert_eq!(user_devices.devices().next().unwrap(), &device); } - #[tokio::test] + #[tokio::test(threaded_scheduler)] async fn device_deleting() { let (_account, store, dir) = get_loaded_store().await; let device = get_device(); @@ -1864,7 +1869,7 @@ mod test { assert!(loaded_device.is_none()); } - #[tokio::test] + #[tokio::test(threaded_scheduler)] async fn user_saving() { let dir = tempdir().unwrap(); let tmpdir_path = dir.path().to_str().unwrap(); @@ -1941,7 +1946,7 @@ mod test { assert!(loaded_user.own().unwrap().is_verified()) } - #[tokio::test] + #[tokio::test(threaded_scheduler)] async fn key_value_saving() { let (_, store, _dir) = get_loaded_store().await; let key = "test_key".to_string();