crypto: Upgrade sqlx to the beta release.

This change is much needed to enable transactions in our sqlite store,
before this release creating a transaction would take ownership of the
connection, now it just mutably borrows it.
master
Damir Jelić 2020-10-16 15:05:53 +02:00
parent e7a24d5e68
commit fc54c63a4c
4 changed files with 40 additions and 40 deletions

View File

@ -52,7 +52,7 @@ default-features = false
features = ["std", "std-future"] features = ["std", "std-future"]
[target.'cfg(not(target_arch = "wasm32"))'.dependencies.sqlx] [target.'cfg(not(target_arch = "wasm32"))'.dependencies.sqlx]
version = "0.3.5" version = "0.4.0-beta.1"
optional = true optional = true
default-features = false default-features = false
features = ["runtime-tokio", "sqlite", "macros"] features = ["runtime-tokio", "sqlite", "macros"]

View File

@ -1576,7 +1576,7 @@ pub(crate) mod test {
} }
} }
#[tokio::test] #[tokio::test(threaded_scheduler)]
#[cfg(feature = "sqlite_cryptostore")] #[cfg(feature = "sqlite_cryptostore")]
async fn test_machine_with_default_store() { async fn test_machine_with_default_store() {
let tmpdir = tempdir().unwrap(); let tmpdir = tempdir().unwrap();

View File

@ -55,7 +55,6 @@ use olm_rs::errors::{OlmAccountError, OlmGroupSessionError, OlmSessionError};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::Error as SerdeError; use serde_json::Error as SerdeError;
use thiserror::Error; use thiserror::Error;
use url::ParseError;
#[cfg_attr(feature = "docs", doc(cfg(r#sqlite_cryptostore)))] #[cfg_attr(feature = "docs", doc(cfg(r#sqlite_cryptostore)))]
#[cfg(not(target_arch = "wasm32"))] #[cfg(not(target_arch = "wasm32"))]
@ -245,10 +244,6 @@ pub enum CryptoStoreError {
/// The store failed to (de)serialize a data type. /// The store failed to (de)serialize a data type.
#[error(transparent)] #[error(transparent)]
Serialization(#[from] SerdeError), Serialization(#[from] SerdeError),
/// An error occurred while parsing an URL.
#[error(transparent)]
UrlParse(#[from] ParseError),
} }
/// Trait abstracting a store that the `OlmMachine` uses to store cryptographic /// Trait abstracting a store that the `OlmMachine` uses to store cryptographic

View File

@ -30,8 +30,7 @@ use matrix_sdk_common::{
instant::Duration, instant::Duration,
locks::Mutex, locks::Mutex,
}; };
use sqlx::{query, query_as, sqlite::SqliteQueryAs, Connect, Executor, SqliteConnection}; use sqlx::{query, query_as, sqlite::SqliteConnectOptions, Connection, Executor, SqliteConnection};
use url::Url;
use zeroize::Zeroizing; use zeroize::Zeroizing;
use super::{ use super::{
@ -131,20 +130,20 @@ impl SqliteStore {
.await .await
} }
fn path_to_url(path: &Path) -> Result<Url> {
// TODO this returns an empty error if the path isn't absolute.
let url = Url::from_directory_path(path).expect("Invalid path");
Ok(url.join(DATABASE_NAME)?)
}
async fn open_helper<P: AsRef<Path>>( async fn open_helper<P: AsRef<Path>>(
user_id: &UserId, user_id: &UserId,
device_id: &DeviceId, device_id: &DeviceId,
path: P, path: P,
passphrase: Option<Zeroizing<String>>, passphrase: Option<Zeroizing<String>>,
) -> Result<SqliteStore> { ) -> Result<SqliteStore> {
let url = SqliteStore::path_to_url(path.as_ref())?; let path = path.as_ref().join(DATABASE_NAME);
let connection = SqliteConnection::connect(url.as_ref()).await?; let options = SqliteConnectOptions::new()
.foreign_keys(true)
.create_if_missing(true)
.read_only(false)
.filename(&path);
let connection = SqliteConnection::connect_with(&options).await?;
let store = SqliteStore { let store = SqliteStore {
user_id: Arc::new(user_id.to_owned()), user_id: Arc::new(user_id.to_owned()),
@ -153,7 +152,7 @@ impl SqliteStore {
sessions: SessionStore::new(), sessions: SessionStore::new(),
inbound_group_sessions: GroupSessionStore::new(), inbound_group_sessions: GroupSessionStore::new(),
devices: DeviceStore::new(), devices: DeviceStore::new(),
path: Arc::new(path.as_ref().to_owned()), path: Arc::new(path),
connection: Arc::new(Mutex::new(connection)), connection: Arc::new(Mutex::new(connection)),
pickle_passphrase: Arc::new(passphrase), pickle_passphrase: Arc::new(passphrase),
tracked_users: Arc::new(DashSet::new()), tracked_users: Arc::new(DashSet::new()),
@ -1249,9 +1248,14 @@ impl CryptoStore for SqliteStore {
async fn save_sessions(&self, sessions: &[Session]) -> Result<()> { async fn save_sessions(&self, sessions: &[Session]) -> Result<()> {
let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?;
// TODO turn this into a transaction
for session in sessions { for session in sessions {
self.lazy_load_sessions(&session.sender_key).await?; self.lazy_load_sessions(&session.sender_key).await?;
}
let mut connection = self.connection.lock().await;
let mut transaction = connection.begin().await?;
for session in sessions {
self.sessions.add(session.clone()).await; self.sessions.add(session.clone()).await;
let pickle = session.pickle(self.get_pickle_mode()).await; let pickle = session.pickle(self.get_pickle_mode()).await;
@ -1260,8 +1264,6 @@ impl CryptoStore for SqliteStore {
let creation_time = serde_json::to_string(&pickle.creation_time)?; let creation_time = serde_json::to_string(&pickle.creation_time)?;
let last_use_time = serde_json::to_string(&pickle.last_use_time)?; let last_use_time = serde_json::to_string(&pickle.last_use_time)?;
let mut connection = self.connection.lock().await;
query( query(
"REPLACE INTO sessions ( "REPLACE INTO sessions (
session_id, account_id, creation_time, last_use_time, sender_key, pickle session_id, account_id, creation_time, last_use_time, sender_key, pickle
@ -1273,10 +1275,12 @@ impl CryptoStore for SqliteStore {
.bind(&*last_use_time) .bind(&*last_use_time)
.bind(&pickle.sender_key) .bind(&pickle.sender_key)
.bind(&pickle.pickle.as_str()) .bind(&pickle.pickle.as_str())
.execute(&mut *connection) .execute(&mut *transaction)
.await?; .await?;
} }
transaction.commit().await?;
Ok(()) Ok(())
} }
@ -1287,15 +1291,16 @@ impl CryptoStore for SqliteStore {
async fn save_inbound_group_sessions(&self, sessions: &[InboundGroupSession]) -> Result<()> { async fn save_inbound_group_sessions(&self, sessions: &[InboundGroupSession]) -> Result<()> {
let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?;
let mut connection = self.connection.lock().await; let mut connection = self.connection.lock().await;
let mut transaction = connection.begin().await?;
// FIXME use a transaction here once sqlx gets better support for them.
for session in sessions { for session in sessions {
self.save_inbound_group_session_helper(account_id, &mut connection, session) self.save_inbound_group_session_helper(account_id, &mut transaction, session)
.await?; .await?;
self.inbound_group_sessions.add(session.clone()); self.inbound_group_sessions.add(session.clone());
} }
transaction.commit().await?;
Ok(()) Ok(())
} }
@ -1549,7 +1554,7 @@ mod test {
(alice, session) (alice, session)
} }
#[tokio::test] #[tokio::test(threaded_scheduler)]
async fn create_store() { async fn create_store() {
let tmpdir = tempdir().unwrap(); let tmpdir = tempdir().unwrap();
let tmpdir_path = tmpdir.path().to_str().unwrap(); let tmpdir_path = tmpdir.path().to_str().unwrap();
@ -1558,7 +1563,7 @@ mod test {
.expect("Can't create store"); .expect("Can't create store");
} }
#[tokio::test] #[tokio::test(threaded_scheduler)]
async fn save_account() { async fn save_account() {
let (store, _dir) = get_store(None).await; let (store, _dir) = get_store(None).await;
assert!(store.load_account().await.unwrap().is_none()); assert!(store.load_account().await.unwrap().is_none());
@ -1570,7 +1575,7 @@ mod test {
.expect("Can't save account"); .expect("Can't save account");
} }
#[tokio::test] #[tokio::test(threaded_scheduler)]
async fn load_account() { async fn load_account() {
let (store, _dir) = get_store(None).await; let (store, _dir) = get_store(None).await;
let account = get_account(); let account = get_account();
@ -1586,7 +1591,7 @@ mod test {
assert_eq!(account, loaded_account); assert_eq!(account, loaded_account);
} }
#[tokio::test] #[tokio::test(threaded_scheduler)]
async fn load_account_with_passphrase() { async fn load_account_with_passphrase() {
let (store, _dir) = get_store(Some("secret_passphrase")).await; let (store, _dir) = get_store(Some("secret_passphrase")).await;
let account = get_account(); let account = get_account();
@ -1602,7 +1607,7 @@ mod test {
assert_eq!(account, loaded_account); assert_eq!(account, loaded_account);
} }
#[tokio::test] #[tokio::test(threaded_scheduler)]
async fn save_and_share_account() { async fn save_and_share_account() {
let (store, _dir) = get_store(None).await; let (store, _dir) = get_store(None).await;
let account = get_account(); let account = get_account();
@ -1630,7 +1635,7 @@ mod test {
); );
} }
#[tokio::test] #[tokio::test(threaded_scheduler)]
async fn save_session() { async fn save_session() {
let (store, _dir) = get_store(None).await; let (store, _dir) = get_store(None).await;
let (account, session) = get_account_and_session().await; let (account, session) = get_account_and_session().await;
@ -1645,7 +1650,7 @@ mod test {
store.save_sessions(&[session]).await.unwrap(); store.save_sessions(&[session]).await.unwrap();
} }
#[tokio::test] #[tokio::test(threaded_scheduler)]
async fn load_sessions() { async fn load_sessions() {
let (store, _dir) = get_store(None).await; let (store, _dir) = get_store(None).await;
let (account, session) = get_account_and_session().await; let (account, session) = get_account_and_session().await;
@ -1664,7 +1669,7 @@ mod test {
assert_eq!(&session, loaded_session); assert_eq!(&session, loaded_session);
} }
#[tokio::test] #[tokio::test(threaded_scheduler)]
async fn add_and_save_session() { async fn add_and_save_session() {
let (store, dir) = get_store(None).await; let (store, dir) = get_store(None).await;
let (account, session) = get_account_and_session().await; let (account, session) = get_account_and_session().await;
@ -1699,7 +1704,7 @@ mod test {
assert_eq!(session_id, session.session_id()); assert_eq!(session_id, session.session_id());
} }
#[tokio::test] #[tokio::test(threaded_scheduler)]
async fn save_inbound_group_session() { async fn save_inbound_group_session() {
let (account, store, _dir) = get_loaded_store().await; let (account, store, _dir) = get_loaded_store().await;
@ -1719,7 +1724,7 @@ mod test {
.expect("Can't save group session"); .expect("Can't save group session");
} }
#[tokio::test] #[tokio::test(threaded_scheduler)]
async fn load_inbound_group_session() { async fn load_inbound_group_session() {
let (account, store, dir) = get_loaded_store().await; let (account, store, dir) = get_loaded_store().await;
@ -1761,7 +1766,7 @@ mod test {
assert!(!export.forwarding_curve25519_key_chain.is_empty()) assert!(!export.forwarding_curve25519_key_chain.is_empty())
} }
#[tokio::test] #[tokio::test(threaded_scheduler)]
async fn test_tracked_users() { async fn test_tracked_users() {
let (_account, store, dir) = get_loaded_store().await; let (_account, store, dir) = get_loaded_store().await;
let device = get_device(); let device = get_device();
@ -1808,7 +1813,7 @@ mod test {
assert!(!store.users_for_key_query().contains(device.user_id())); assert!(!store.users_for_key_query().contains(device.user_id()));
} }
#[tokio::test] #[tokio::test(threaded_scheduler)]
async fn device_saving() { async fn device_saving() {
let (_account, store, dir) = get_loaded_store().await; let (_account, store, dir) = get_loaded_store().await;
let device = get_device(); let device = get_device();
@ -1842,7 +1847,7 @@ mod test {
assert_eq!(user_devices.devices().next().unwrap(), &device); assert_eq!(user_devices.devices().next().unwrap(), &device);
} }
#[tokio::test] #[tokio::test(threaded_scheduler)]
async fn device_deleting() { async fn device_deleting() {
let (_account, store, dir) = get_loaded_store().await; let (_account, store, dir) = get_loaded_store().await;
let device = get_device(); let device = get_device();
@ -1864,7 +1869,7 @@ mod test {
assert!(loaded_device.is_none()); assert!(loaded_device.is_none());
} }
#[tokio::test] #[tokio::test(threaded_scheduler)]
async fn user_saving() { async fn user_saving() {
let dir = tempdir().unwrap(); let dir = tempdir().unwrap();
let tmpdir_path = dir.path().to_str().unwrap(); let tmpdir_path = dir.path().to_str().unwrap();
@ -1941,7 +1946,7 @@ mod test {
assert!(loaded_user.own().unwrap().is_verified()) assert!(loaded_user.own().unwrap().is_verified())
} }
#[tokio::test] #[tokio::test(threaded_scheduler)]
async fn key_value_saving() { async fn key_value_saving() {
let (_, store, _dir) = get_loaded_store().await; let (_, store, _dir) = get_loaded_store().await;
let key = "test_key".to_string(); let key = "test_key".to_string();