diff --git a/Cargo.toml b/Cargo.toml index 798fad8b..81dfb96a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,7 +12,8 @@ version = "0.1.0" [features] default = [] -encryption = ["olm-rs", "serde/derive", "serde_json", "cjson"] +encryption = ["olm-rs", "serde/derive", "serde_json", "cjson", "async-trait"] +sqlite-cryptostore = ["sqlx", "zeroize"] [dependencies] js_int = "0.1.2" @@ -27,13 +28,22 @@ log = "0.4.8" ruma-identifiers = "0.14.1" url = "2.1.1" -olm-rs = { git = "https://gitlab.gnome.org/jhaye/olm-rs/", optional = true, features = ["serde"]} +olm-rs = { path = "/home/poljar/werk/matrix/olm-rs", optional = true, features = ["serde"]} serde = { version = "1.0.104", optional = true, features = ["derive"] } serde_json = { version = "1.0.48", optional = true } cjson = { version = "0.1.0", optional = true } tokio = { version = "0.2.13", default-features = false, features = ["sync", "time"] } +async-trait = { version = "0.1.24", optional = true } +zeroize = { version = "*", optional = true} + +[dependencies.sqlx] +git = "https://github.com/launchbadge/sqlx/" +optional = true +default-features = false +features = ["runtime-tokio", "sqlite"] [dev-dependencies] tokio = { version = "0.2.13", features = ["rt-threaded", "macros"] } +tempfile = "3.1.0" url = "2.1.1" mockito = "0.23.3" diff --git a/Makefile b/Makefile index 33681eaf..ccc1cdfd 100644 --- a/Makefile +++ b/Makefile @@ -1,8 +1,8 @@ test: - cargo test --features encryption + cargo test --features 'encryption sqlite-cryptostore' coverage: - cargo tarpaulin --features encryption -v + cargo tarpaulin --features 'encryption sqlite-cryptostore' -v clean: cargo clean diff --git a/src/crypto/mod.rs b/src/crypto/mod.rs index 922497d3..79b417af 100644 --- a/src/crypto/mod.rs +++ b/src/crypto/mod.rs @@ -18,5 +18,6 @@ mod error; mod machine; #[allow(dead_code)] mod olm; +mod store; pub use machine::{OlmMachine, OneTimeKeys}; diff --git a/src/crypto/store/mod.rs b/src/crypto/store/mod.rs new file mode 100644 index 00000000..5a09928e --- /dev/null +++ b/src/crypto/store/mod.rs @@ -0,0 +1,16 @@ +use std::io::Result; +use std::sync::Arc; + +use async_trait::async_trait; +use tokio::sync::Mutex; + +use super::olm::Account; + +#[cfg(feature = "sqlite-cryptostore")] +pub mod sqlite; + +#[async_trait] +pub trait CryptoStore { + async fn load_account(&self) -> Result; + async fn save_account(&self, account: Arc>) -> Result<()>; +} diff --git a/src/crypto/store/sqlite.rs b/src/crypto/store/sqlite.rs new file mode 100644 index 00000000..cc29f88b --- /dev/null +++ b/src/crypto/store/sqlite.rs @@ -0,0 +1,183 @@ +use std::path::Path; +use std::sync::Arc; +use url::Url; + +use async_trait::async_trait; +use olm_rs::PicklingMode; +use sqlx::{query, query_as, sqlite::SqliteQueryAs, Connect, Executor, SqliteConnection}; +use tokio::sync::Mutex; +use zeroize::Zeroizing; + +use super::{Account, CryptoStore, Result}; + +pub struct SqliteStore { + user_id: Arc, + device_id: Arc, + connection: Arc>, + pickle_passphrase: Option>, +} + +static DATABASE_NAME: &str = "matrix-sdk-crypto.db"; + +impl SqliteStore { + async fn open>(user_id: &str, device_id: &str, path: P) -> Result { + SqliteStore::open_helper(user_id, device_id, path, None).await + } + + async fn open_with_passphrase>( + user_id: &str, + device_id: &str, + path: P, + passphrase: String, + ) -> Result { + SqliteStore::open_helper(user_id, device_id, path, Some(Zeroizing::new(passphrase))).await + } + + async fn open_helper>( + user_id: &str, + device_id: &str, + path: P, + passphrase: Option>, + ) -> Result { + let url = Url::from_directory_path(path.as_ref()).unwrap(); + let url = url.join(DATABASE_NAME).unwrap(); + + let connection = SqliteConnection::connect(url.as_ref()).await.unwrap(); + let store = SqliteStore { + user_id: Arc::new(user_id.to_owned()), + device_id: Arc::new(device_id.to_owned()), + connection: Arc::new(Mutex::new(connection)), + pickle_passphrase: passphrase, + }; + store.create_tables().await?; + Ok(store) + } + + async fn create_tables(&self) -> Result<()> { + let mut connection = self.connection.lock().await; + connection + .execute( + r#" + CREATE TABLE IF NOT EXISTS account ( + "id" INTEGER NOT NULL PRIMARY KEY, + "user_id" TEXT NOT NULL, + "device_id" TEXT NOT NULL, + "pickle" BLOB NOT NULL, + "shared" INTEGER NOT NULL, + UNIQUE(user_id,device_id) + ); + "#, + ) + .await + .unwrap(); + + Ok(()) + } + + fn get_pickle_mode(&self) -> PicklingMode { + match &self.pickle_passphrase { + Some(p) => PicklingMode::Encrypted { + key: p.as_bytes().to_vec(), + }, + None => PicklingMode::Unencrypted, + } + } +} + +#[async_trait] +impl CryptoStore for SqliteStore { + async fn load_account(&self) -> Result { + let mut connection = self.connection.lock().await; + + let (pickle, shared): (String, bool) = query_as( + "SELECT pickle, shared FROM account + WHERE user_id = ? and device_id = ?", + ) + .bind(&*self.user_id) + .bind(&*self.device_id) + .fetch_one(&mut *connection) + .await + .unwrap(); + + Ok(Account::from_pickle(pickle, self.get_pickle_mode(), shared).unwrap()) + } + + async fn save_account(&self, account: Arc>) -> Result<()> { + let acc = account.lock().await; + let pickle = acc.pickle(self.get_pickle_mode()); + let mut connection = self.connection.lock().await; + + query( + "INSERT OR IGNORE INTO account ( + user_id, device_id, pickle, shared + ) VALUES (?, ?, ?, ?)", + ) + .bind(&*self.user_id) + .bind(&*self.device_id) + .bind(pickle) + .bind(true) + .execute(&mut *connection) + .await + .unwrap(); + + Ok(()) + } +} + +#[cfg(test)] +mod test { + use std::sync::Arc; + use tempfile::tempdir; + use tokio::sync::Mutex; + + use super::{Account, CryptoStore, SqliteStore}; + + async fn get_store() -> SqliteStore { + let tmpdir = tempdir().unwrap(); + let tmpdir_path = tmpdir.path().to_str().unwrap(); + SqliteStore::open("@example:localhost", "DEVICEID", tmpdir_path) + .await + .expect("Can't create store") + } + + fn get_account() -> Arc> { + let account = Account::new(); + Arc::new(Mutex::new(account)) + } + + #[tokio::test] + async fn create_store() { + let tmpdir = tempdir().unwrap(); + let tmpdir_path = tmpdir.path().to_str().unwrap(); + let _ = SqliteStore::open("@example:localhost", "DEVICEID", tmpdir_path) + .await + .expect("Can't create store"); + } + + #[tokio::test] + async fn save_account() { + let store = get_store().await; + let account = get_account(); + + store + .save_account(account) + .await + .expect("Can't save account"); + } + + #[tokio::test] + async fn load_account() { + let store = get_store().await; + let account = get_account(); + + store + .save_account(account.clone()) + .await + .expect("Can't save account"); + + let acc = account.lock().await; + let loaded_account = store.load_account().await.expect("Can't load account"); + + assert_eq!(*acc, loaded_account); + } +}