crypto: Upgrade sqlx to the beta release.

This change is much needed to enable transactions in our sqlite store,
before this release creating a transaction would take ownership of the
connection, now it just mutably borrows it.
master
Damir Jelić 2020-10-16 15:05:53 +02:00
parent e7a24d5e68
commit fc54c63a4c
4 changed files with 40 additions and 40 deletions

View File

@ -52,7 +52,7 @@ default-features = false
features = ["std", "std-future"]
[target.'cfg(not(target_arch = "wasm32"))'.dependencies.sqlx]
version = "0.3.5"
version = "0.4.0-beta.1"
optional = true
default-features = false
features = ["runtime-tokio", "sqlite", "macros"]

View File

@ -1576,7 +1576,7 @@ pub(crate) mod test {
}
}
#[tokio::test]
#[tokio::test(threaded_scheduler)]
#[cfg(feature = "sqlite_cryptostore")]
async fn test_machine_with_default_store() {
let tmpdir = tempdir().unwrap();

View File

@ -55,7 +55,6 @@ use olm_rs::errors::{OlmAccountError, OlmGroupSessionError, OlmSessionError};
use serde::{Deserialize, Serialize};
use serde_json::Error as SerdeError;
use thiserror::Error;
use url::ParseError;
#[cfg_attr(feature = "docs", doc(cfg(r#sqlite_cryptostore)))]
#[cfg(not(target_arch = "wasm32"))]
@ -245,10 +244,6 @@ pub enum CryptoStoreError {
/// The store failed to (de)serialize a data type.
#[error(transparent)]
Serialization(#[from] SerdeError),
/// An error occurred while parsing an URL.
#[error(transparent)]
UrlParse(#[from] ParseError),
}
/// Trait abstracting a store that the `OlmMachine` uses to store cryptographic

View File

@ -30,8 +30,7 @@ use matrix_sdk_common::{
instant::Duration,
locks::Mutex,
};
use sqlx::{query, query_as, sqlite::SqliteQueryAs, Connect, Executor, SqliteConnection};
use url::Url;
use sqlx::{query, query_as, sqlite::SqliteConnectOptions, Connection, Executor, SqliteConnection};
use zeroize::Zeroizing;
use super::{
@ -131,20 +130,20 @@ impl SqliteStore {
.await
}
fn path_to_url(path: &Path) -> Result<Url> {
// TODO this returns an empty error if the path isn't absolute.
let url = Url::from_directory_path(path).expect("Invalid path");
Ok(url.join(DATABASE_NAME)?)
}
async fn open_helper<P: AsRef<Path>>(
user_id: &UserId,
device_id: &DeviceId,
path: P,
passphrase: Option<Zeroizing<String>>,
) -> Result<SqliteStore> {
let url = SqliteStore::path_to_url(path.as_ref())?;
let connection = SqliteConnection::connect(url.as_ref()).await?;
let path = path.as_ref().join(DATABASE_NAME);
let options = SqliteConnectOptions::new()
.foreign_keys(true)
.create_if_missing(true)
.read_only(false)
.filename(&path);
let connection = SqliteConnection::connect_with(&options).await?;
let store = SqliteStore {
user_id: Arc::new(user_id.to_owned()),
@ -153,7 +152,7 @@ impl SqliteStore {
sessions: SessionStore::new(),
inbound_group_sessions: GroupSessionStore::new(),
devices: DeviceStore::new(),
path: Arc::new(path.as_ref().to_owned()),
path: Arc::new(path),
connection: Arc::new(Mutex::new(connection)),
pickle_passphrase: Arc::new(passphrase),
tracked_users: Arc::new(DashSet::new()),
@ -1249,9 +1248,14 @@ impl CryptoStore for SqliteStore {
async fn save_sessions(&self, sessions: &[Session]) -> Result<()> {
let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?;
// TODO turn this into a transaction
for session in sessions {
self.lazy_load_sessions(&session.sender_key).await?;
}
let mut connection = self.connection.lock().await;
let mut transaction = connection.begin().await?;
for session in sessions {
self.sessions.add(session.clone()).await;
let pickle = session.pickle(self.get_pickle_mode()).await;
@ -1260,8 +1264,6 @@ impl CryptoStore for SqliteStore {
let creation_time = serde_json::to_string(&pickle.creation_time)?;
let last_use_time = serde_json::to_string(&pickle.last_use_time)?;
let mut connection = self.connection.lock().await;
query(
"REPLACE INTO sessions (
session_id, account_id, creation_time, last_use_time, sender_key, pickle
@ -1273,10 +1275,12 @@ impl CryptoStore for SqliteStore {
.bind(&*last_use_time)
.bind(&pickle.sender_key)
.bind(&pickle.pickle.as_str())
.execute(&mut *connection)
.execute(&mut *transaction)
.await?;
}
transaction.commit().await?;
Ok(())
}
@ -1287,15 +1291,16 @@ impl CryptoStore for SqliteStore {
async fn save_inbound_group_sessions(&self, sessions: &[InboundGroupSession]) -> Result<()> {
let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?;
let mut connection = self.connection.lock().await;
// FIXME use a transaction here once sqlx gets better support for them.
let mut transaction = connection.begin().await?;
for session in sessions {
self.save_inbound_group_session_helper(account_id, &mut connection, session)
self.save_inbound_group_session_helper(account_id, &mut transaction, session)
.await?;
self.inbound_group_sessions.add(session.clone());
}
transaction.commit().await?;
Ok(())
}
@ -1549,7 +1554,7 @@ mod test {
(alice, session)
}
#[tokio::test]
#[tokio::test(threaded_scheduler)]
async fn create_store() {
let tmpdir = tempdir().unwrap();
let tmpdir_path = tmpdir.path().to_str().unwrap();
@ -1558,7 +1563,7 @@ mod test {
.expect("Can't create store");
}
#[tokio::test]
#[tokio::test(threaded_scheduler)]
async fn save_account() {
let (store, _dir) = get_store(None).await;
assert!(store.load_account().await.unwrap().is_none());
@ -1570,7 +1575,7 @@ mod test {
.expect("Can't save account");
}
#[tokio::test]
#[tokio::test(threaded_scheduler)]
async fn load_account() {
let (store, _dir) = get_store(None).await;
let account = get_account();
@ -1586,7 +1591,7 @@ mod test {
assert_eq!(account, loaded_account);
}
#[tokio::test]
#[tokio::test(threaded_scheduler)]
async fn load_account_with_passphrase() {
let (store, _dir) = get_store(Some("secret_passphrase")).await;
let account = get_account();
@ -1602,7 +1607,7 @@ mod test {
assert_eq!(account, loaded_account);
}
#[tokio::test]
#[tokio::test(threaded_scheduler)]
async fn save_and_share_account() {
let (store, _dir) = get_store(None).await;
let account = get_account();
@ -1630,7 +1635,7 @@ mod test {
);
}
#[tokio::test]
#[tokio::test(threaded_scheduler)]
async fn save_session() {
let (store, _dir) = get_store(None).await;
let (account, session) = get_account_and_session().await;
@ -1645,7 +1650,7 @@ mod test {
store.save_sessions(&[session]).await.unwrap();
}
#[tokio::test]
#[tokio::test(threaded_scheduler)]
async fn load_sessions() {
let (store, _dir) = get_store(None).await;
let (account, session) = get_account_and_session().await;
@ -1664,7 +1669,7 @@ mod test {
assert_eq!(&session, loaded_session);
}
#[tokio::test]
#[tokio::test(threaded_scheduler)]
async fn add_and_save_session() {
let (store, dir) = get_store(None).await;
let (account, session) = get_account_and_session().await;
@ -1699,7 +1704,7 @@ mod test {
assert_eq!(session_id, session.session_id());
}
#[tokio::test]
#[tokio::test(threaded_scheduler)]
async fn save_inbound_group_session() {
let (account, store, _dir) = get_loaded_store().await;
@ -1719,7 +1724,7 @@ mod test {
.expect("Can't save group session");
}
#[tokio::test]
#[tokio::test(threaded_scheduler)]
async fn load_inbound_group_session() {
let (account, store, dir) = get_loaded_store().await;
@ -1761,7 +1766,7 @@ mod test {
assert!(!export.forwarding_curve25519_key_chain.is_empty())
}
#[tokio::test]
#[tokio::test(threaded_scheduler)]
async fn test_tracked_users() {
let (_account, store, dir) = get_loaded_store().await;
let device = get_device();
@ -1808,7 +1813,7 @@ mod test {
assert!(!store.users_for_key_query().contains(device.user_id()));
}
#[tokio::test]
#[tokio::test(threaded_scheduler)]
async fn device_saving() {
let (_account, store, dir) = get_loaded_store().await;
let device = get_device();
@ -1842,7 +1847,7 @@ mod test {
assert_eq!(user_devices.devices().next().unwrap(), &device);
}
#[tokio::test]
#[tokio::test(threaded_scheduler)]
async fn device_deleting() {
let (_account, store, dir) = get_loaded_store().await;
let device = get_device();
@ -1864,7 +1869,7 @@ mod test {
assert!(loaded_device.is_none());
}
#[tokio::test]
#[tokio::test(threaded_scheduler)]
async fn user_saving() {
let dir = tempdir().unwrap();
let tmpdir_path = dir.path().to_str().unwrap();
@ -1941,7 +1946,7 @@ mod test {
assert!(loaded_user.own().unwrap().is_verified())
}
#[tokio::test]
#[tokio::test(threaded_scheduler)]
async fn key_value_saving() {
let (_, store, _dir) = get_loaded_store().await;
let key = "test_key".to_string();