diff --git a/src/crypto/store/sqlite.rs b/src/crypto/store/sqlite.rs index cc29f88b..211075cb 100644 --- a/src/crypto/store/sqlite.rs +++ b/src/crypto/store/sqlite.rs @@ -21,7 +21,8 @@ static DATABASE_NAME: &str = "matrix-sdk-crypto.db"; impl SqliteStore { async fn open>(user_id: &str, device_id: &str, path: P) -> Result { - SqliteStore::open_helper(user_id, device_id, path, None).await + let url = SqliteStore::path_to_url(path)?; + SqliteStore::open_helper(user_id, device_id, url.as_ref(), None).await } async fn open_with_passphrase>( @@ -30,19 +31,35 @@ impl SqliteStore { path: P, passphrase: String, ) -> Result { - SqliteStore::open_helper(user_id, device_id, path, Some(Zeroizing::new(passphrase))).await + let url = SqliteStore::path_to_url(path)?; + SqliteStore::open_helper( + user_id, + device_id, + url.as_ref(), + Some(Zeroizing::new(passphrase)), + ) + .await } - async fn open_helper>( + async fn open_in_memory(user_id: &str, device_id: &str) -> Result { + SqliteStore::open_helper(user_id, device_id, "sqlite::memory:", None).await + } + + fn path_to_url>(path: P) -> Result { + let url = Url::from_directory_path(path.as_ref()).expect("Can't create URL from directory"); + let url = url + .join(DATABASE_NAME) + .expect("Can't append database name to URL"); + Ok(url) + } + + async fn open_helper( user_id: &str, device_id: &str, - path: P, + sqlite_url: &str, passphrase: Option>, ) -> Result { - let url = Url::from_directory_path(path.as_ref()).unwrap(); - let url = url.join(DATABASE_NAME).unwrap(); - - let connection = SqliteConnection::connect(url.as_ref()).await.unwrap(); + let connection = SqliteConnection::connect(sqlite_url).await.unwrap(); let store = SqliteStore { user_id: Arc::new(user_id.to_owned()), device_id: Arc::new(device_id.to_owned()), @@ -114,8 +131,23 @@ impl CryptoStore for SqliteStore { ) .bind(&*self.user_id) .bind(&*self.device_id) + .bind(&pickle) + .bind(acc.shared) + .execute(&mut *connection) + .await + .unwrap(); + + query( + "UPDATE account + SET pickle = ?, + shared = ? + WHERE user_id = ? and + device_id = ?", + ) .bind(pickle) - .bind(true) + .bind(acc.shared) + .bind(&*self.user_id) + .bind(&*self.device_id) .execute(&mut *connection) .await .unwrap(); @@ -132,14 +164,23 @@ mod test { use super::{Account, CryptoStore, SqliteStore}; + static USER_ID: &str = "@example:localhost"; + static DEVICE_ID: &str = "DEVICEID"; + async fn get_store() -> SqliteStore { let tmpdir = tempdir().unwrap(); let tmpdir_path = tmpdir.path().to_str().unwrap(); - SqliteStore::open("@example:localhost", "DEVICEID", tmpdir_path) + SqliteStore::open(USER_ID, DEVICE_ID, tmpdir_path) .await .expect("Can't create store") } + async fn get_memory_store() -> SqliteStore { + SqliteStore::open_in_memory(USER_ID, DEVICE_ID) + .await + .expect("Can't create memory store") + } + fn get_account() -> Arc> { let account = Account::new(); Arc::new(Mutex::new(account)) @@ -167,7 +208,7 @@ mod test { #[tokio::test] async fn load_account() { - let store = get_store().await; + let store = get_memory_store().await; let account = get_account(); store