crypto: Hook up the crypto store to the Olm machine.

master
Damir Jelić 2020-03-18 15:50:32 +01:00
parent d7ab847b98
commit 4aba058695
7 changed files with 214 additions and 111 deletions

View File

@ -645,7 +645,7 @@ impl AsyncClient {
.write() .write()
.await .await
.receive_keys_upload_response(&response) .receive_keys_upload_response(&response)
.await; .await?;
Ok(response) Ok(response)
} }

View File

@ -331,7 +331,7 @@ impl Client {
let olm = self.olm.lock().await; let olm = self.olm.lock().await;
match &*olm { match &*olm {
Some(o) => o.should_upload_keys(), Some(o) => o.should_upload_keys().await,
None => false, None => false,
} }
} }
@ -346,7 +346,7 @@ impl Client {
let olm = self.olm.lock().await; let olm = self.olm.lock().await;
match &*olm { match &*olm {
Some(o) => o.keys_for_upload(), Some(o) => o.keys_for_upload().await,
None => Err(()), None => Err(()),
} }
} }
@ -361,10 +361,11 @@ impl Client {
/// # Panics /// # Panics
/// Panics if the client hasn't been logged in. /// Panics if the client hasn't been logged in.
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
pub async fn receive_keys_upload_response(&self, response: &KeysUploadResponse) { pub async fn receive_keys_upload_response(&self, response: &KeysUploadResponse) -> Result<()> {
let mut olm = self.olm.lock().await; let mut olm = self.olm.lock().await;
let o = olm.as_mut().expect("Client isn't logged in."); let o = olm.as_mut().expect("Client isn't logged in.");
o.receive_keys_upload_response(response).await; o.receive_keys_upload_response(response).await?;
Ok(())
} }
} }

View File

@ -18,6 +18,15 @@ use thiserror::Error;
use super::store::CryptoStoreError; use super::store::CryptoStoreError;
pub type Result<T> = std::result::Result<T, OlmError>; pub type Result<T> = std::result::Result<T, OlmError>;
#[derive(Error, Debug)]
pub enum OlmError {
#[error("signature verification failed")]
Signature(#[from] SignatureError),
#[error("failed to read or write to the crypto store {0}")]
Store(#[from] CryptoStoreError),
}
pub type VerificationResult<T> = std::result::Result<T, SignatureError>; pub type VerificationResult<T> = std::result::Result<T, SignatureError>;
#[derive(Error, Debug)] #[derive(Error, Debug)]
@ -37,11 +46,3 @@ impl From<CjsonError> for SignatureError {
Self::CanonicalJsonError(error) Self::CanonicalJsonError(error)
} }
} }
#[derive(Error, Debug)]
pub enum OlmError {
#[error("signature verification failed")]
Signature(#[from] SignatureError),
#[error("failed to read or write to the crypto store {0}")]
Store(#[from] CryptoStoreError),
}

View File

@ -14,10 +14,16 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::convert::TryInto; use std::convert::TryInto;
use std::path::Path;
use std::result::Result as StdResult; use std::result::Result as StdResult;
use std::sync::Arc;
use super::error::{Result, SignatureError, VerificationResult}; use super::error::{Result, SignatureError, VerificationResult};
use super::olm::Account; use super::olm::Account;
#[cfg(feature = "sqlite-cryptostore")]
use super::store::sqlite::SqliteStore;
use super::store::MemoryStore;
use super::CryptoStore;
use crate::api; use crate::api;
use api::r0::keys; use api::r0::keys;
@ -26,8 +32,8 @@ use cjson;
use olm_rs::utility::OlmUtility; use olm_rs::utility::OlmUtility;
use serde_json::json; use serde_json::json;
use serde_json::Value; use serde_json::Value;
use tokio::sync::Mutex;
use super::store::CryptoStoreError;
use ruma_client_api::r0::keys::{ use ruma_client_api::r0::keys::{
AlgorithmAndDeviceId, DeviceKeys, KeyAlgorithm, OneTimeKey, SignedKey, AlgorithmAndDeviceId, DeviceKeys, KeyAlgorithm, OneTimeKey, SignedKey,
}; };
@ -40,9 +46,6 @@ use ruma_identifiers::{DeviceId, UserId};
pub type OneTimeKeys = HashMap<AlgorithmAndDeviceId, OneTimeKey>; pub type OneTimeKeys = HashMap<AlgorithmAndDeviceId, OneTimeKey>;
#[cfg(feature = "sqlite-cryptostore")]
use super::store::sqlite::SqliteStore;
#[derive(Debug)] #[derive(Debug)]
pub struct OlmMachine { pub struct OlmMachine {
/// The unique user id that owns this account. /// The unique user id that owns this account.
@ -50,12 +53,16 @@ pub struct OlmMachine {
/// The unique device id of the device that holds this account. /// The unique device id of the device that holds this account.
device_id: DeviceId, device_id: DeviceId,
/// Our underlying Olm Account holding our identity keys. /// Our underlying Olm Account holding our identity keys.
account: Account, account: Arc<Mutex<Account>>,
/// The number of signed one-time keys we have uploaded to the server. If /// The number of signed one-time keys we have uploaded to the server. If
/// this is None, no action will be taken. After a sync request the client /// this is None, no action will be taken. After a sync request the client
/// needs to set this for us, depending on the count we will suggest the /// needs to set this for us, depending on the count we will suggest the
/// client to upload new keys. /// client to upload new keys.
uploaded_signed_key_count: Option<u64>, uploaded_signed_key_count: Option<u64>,
/// Store for the encryption keys.
/// Persists all the encrytpion keys so a client can resume the session
/// without the need to create new keys.
store: Box<dyn CryptoStore>,
} }
impl OlmMachine { impl OlmMachine {
@ -69,14 +76,39 @@ impl OlmMachine {
Ok(OlmMachine { Ok(OlmMachine {
user_id: user_id.clone(), user_id: user_id.clone(),
device_id: device_id.to_owned(), device_id: device_id.to_owned(),
account: Account::new(), account: Arc::new(Mutex::new(Account::new())),
uploaded_signed_key_count: None, uploaded_signed_key_count: None,
store: Box::new(MemoryStore::new()),
})
}
#[cfg(feature = "sqlite-cryptostore")]
pub async fn new_with_sqlite_store<P: AsRef<Path>>(
user_id: &UserId,
device_id: &str,
path: P,
passphrase: String,
) -> Result<Self> {
Ok(OlmMachine {
user_id: user_id.clone(),
device_id: device_id.to_owned(),
account: Arc::new(Mutex::new(Account::new())),
uploaded_signed_key_count: None,
store: Box::new(
SqliteStore::open_with_passphrase(
&user_id.to_string(),
device_id,
path,
passphrase,
)
.await?,
),
}) })
} }
/// Should account or one-time keys be uploaded to the server. /// Should account or one-time keys be uploaded to the server.
pub fn should_upload_keys(&self) -> bool { pub async fn should_upload_keys(&self) -> bool {
if !self.account.shared() { if !self.account.lock().await.shared() {
return true; return true;
} }
@ -84,7 +116,7 @@ impl OlmMachine {
// max_one_time_Keys() / 2, otherwise tell the client to upload more. // max_one_time_Keys() / 2, otherwise tell the client to upload more.
match self.uploaded_signed_key_count { match self.uploaded_signed_key_count {
Some(count) => { Some(count) => {
let max_keys = self.account.max_one_time_keys() as u64; let max_keys = self.account.lock().await.max_one_time_keys() as u64;
let key_count = (max_keys / 2) - count; let key_count = (max_keys / 2) - count;
key_count > 0 key_count > 0
} }
@ -98,8 +130,12 @@ impl OlmMachine {
/// ///
/// * `response` - The keys upload response of the request that the client /// * `response` - The keys upload response of the request that the client
/// performed. /// performed.
pub async fn receive_keys_upload_response(&mut self, response: &keys::upload_keys::Response) { pub async fn receive_keys_upload_response(
self.account.shared = true; &mut self,
response: &keys::upload_keys::Response,
) -> Result<()> {
let mut account = self.account.lock().await;
account.shared = true;
let one_time_key_count = response let one_time_key_count = response
.one_time_key_counts .one_time_key_counts
@ -108,18 +144,22 @@ impl OlmMachine {
let count: u64 = one_time_key_count.map_or(0, |c| (*c).into()); let count: u64 = one_time_key_count.map_or(0, |c| (*c).into());
self.uploaded_signed_key_count = Some(count); self.uploaded_signed_key_count = Some(count);
self.account.mark_keys_as_published(); account.mark_keys_as_published();
// TODO save the account here. drop(account);
self.store.save_account(self.account.clone()).await?;
Ok(())
} }
/// Generate new one-time keys. /// Generate new one-time keys.
/// ///
/// Returns the number of newly generated one-time keys. If no keys can be /// Returns the number of newly generated one-time keys. If no keys can be
/// generated returns an empty error. /// generated returns an empty error.
fn generate_one_time_keys(&self) -> StdResult<u64, ()> { async fn generate_one_time_keys(&self) -> StdResult<u64, ()> {
let account = self.account.lock().await;
match self.uploaded_signed_key_count { match self.uploaded_signed_key_count {
Some(count) => { Some(count) => {
let max_keys = self.account.max_one_time_keys() as u64; let max_keys = account.max_one_time_keys() as u64;
let max_on_server = max_keys / 2; let max_on_server = max_keys / 2;
if count >= (max_on_server) { if count >= (max_on_server) {
@ -130,9 +170,9 @@ impl OlmMachine {
let key_count: usize = key_count let key_count: usize = key_count
.try_into() .try_into()
.unwrap_or_else(|_| self.account.max_one_time_keys()); .unwrap_or_else(|_| account.max_one_time_keys());
self.account.generate_one_time_keys(key_count); account.generate_one_time_keys(key_count);
Ok(key_count as u64) Ok(key_count as u64)
} }
None => Err(()), None => Err(()),
@ -140,8 +180,8 @@ impl OlmMachine {
} }
/// Sign the device keys and return a JSON Value to upload them. /// Sign the device keys and return a JSON Value to upload them.
fn device_keys(&self) -> DeviceKeys { async fn device_keys(&self) -> DeviceKeys {
let identity_keys = self.account.identity_keys(); let identity_keys = self.account.lock().await.identity_keys();
let mut keys = HashMap::new(); let mut keys = HashMap::new();
@ -166,7 +206,7 @@ impl OlmMachine {
let mut signature = HashMap::new(); let mut signature = HashMap::new();
signature.insert( signature.insert(
AlgorithmAndDeviceId(KeyAlgorithm::Ed25519, self.device_id.clone()), AlgorithmAndDeviceId(KeyAlgorithm::Ed25519, self.device_id.clone()),
self.sign_json(&device_keys), self.sign_json(&device_keys).await,
); );
signatures.insert(self.user_id.clone(), signature); signatures.insert(self.user_id.clone(), signature);
@ -186,10 +226,9 @@ impl OlmMachine {
/// Generate, sign and prepare one-time keys to be uploaded. /// Generate, sign and prepare one-time keys to be uploaded.
/// ///
/// If no one-time keys need to be uploaded returns an empty error. /// If no one-time keys need to be uploaded returns an empty error.
fn signed_one_time_keys(&self) -> StdResult<OneTimeKeys, ()> { async fn signed_one_time_keys(&self) -> StdResult<OneTimeKeys, ()> {
let _ = self.generate_one_time_keys()?; let _ = self.generate_one_time_keys().await?;
let one_time_keys = self.account.lock().await.one_time_keys();
let one_time_keys = self.account.one_time_keys();
let mut one_time_key_map = HashMap::new(); let mut one_time_key_map = HashMap::new();
for (key_id, key) in one_time_keys.curve25519().iter() { for (key_id, key) in one_time_keys.curve25519().iter() {
@ -197,7 +236,7 @@ impl OlmMachine {
"key": key, "key": key,
}); });
let signature = self.sign_json(&key_json); let signature = self.sign_json(&key_json).await;
let mut signature_map = HashMap::new(); let mut signature_map = HashMap::new();
@ -230,10 +269,11 @@ impl OlmMachine {
/// ///
/// * `json` - The value that should be converted into a canonical JSON /// * `json` - The value that should be converted into a canonical JSON
/// string. /// string.
fn sign_json(&self, json: &Value) -> String { async fn sign_json(&self, json: &Value) -> String {
let account = self.account.lock().await;
let canonical_json = cjson::to_string(json) let canonical_json = cjson::to_string(json)
.unwrap_or_else(|_| panic!(format!("Can't serialize {} to canonical JSON", json))); .unwrap_or_else(|_| panic!(format!("Can't serialize {} to canonical JSON", json)));
self.account.sign(&canonical_json) account.sign(&canonical_json)
} }
/// Verify a signed JSON object. /// Verify a signed JSON object.
@ -305,18 +345,22 @@ impl OlmMachine {
/// Get a tuple of device and one-time keys that need to be uploaded. /// Get a tuple of device and one-time keys that need to be uploaded.
/// ///
/// Returns an empty error if no keys need to be uploaded. /// Returns an empty error if no keys need to be uploaded.
pub fn keys_for_upload(&self) -> StdResult<(Option<DeviceKeys>, Option<OneTimeKeys>), ()> { pub async fn keys_for_upload(
if !self.should_upload_keys() { &self,
) -> StdResult<(Option<DeviceKeys>, Option<OneTimeKeys>), ()> {
if !self.should_upload_keys().await {
return Err(()); return Err(());
} }
let device_keys = if !self.account.shared() { let shared = self.account.lock().await.shared();
Some(self.device_keys())
let device_keys = if !shared {
Some(self.device_keys().await)
} else { } else {
None None
}; };
let one_time_keys: Option<OneTimeKeys> = self.signed_one_time_keys().ok(); let one_time_keys: Option<OneTimeKeys> = self.signed_one_time_keys().await.ok();
Ok((device_keys, one_time_keys)) Ok((device_keys, one_time_keys))
} }
@ -413,10 +457,10 @@ mod test {
keys::upload_keys::Response::try_from(data).expect("Can't parse the keys upload response") keys::upload_keys::Response::try_from(data).expect("Can't parse the keys upload response")
} }
#[test] #[tokio::test]
fn create_olm_machine() { async fn create_olm_machine() {
let machine = OlmMachine::new(&user_id(), DEVICE_ID).unwrap(); let machine = OlmMachine::new(&user_id(), DEVICE_ID).unwrap();
assert!(machine.should_upload_keys()); assert!(machine.should_upload_keys().await);
} }
#[tokio::test] #[tokio::test]
@ -429,23 +473,32 @@ mod test {
.remove(&keys::KeyAlgorithm::SignedCurve25519) .remove(&keys::KeyAlgorithm::SignedCurve25519)
.unwrap(); .unwrap();
assert!(machine.should_upload_keys()); assert!(machine.should_upload_keys().await);
machine.receive_keys_upload_response(&response).await; machine
assert!(machine.should_upload_keys()); .receive_keys_upload_response(&response)
.await
.unwrap();
assert!(machine.should_upload_keys().await);
response.one_time_key_counts.insert( response.one_time_key_counts.insert(
keys::KeyAlgorithm::SignedCurve25519, keys::KeyAlgorithm::SignedCurve25519,
UInt::try_from(10).unwrap(), UInt::try_from(10).unwrap(),
); );
machine.receive_keys_upload_response(&response).await; machine
assert!(machine.should_upload_keys()); .receive_keys_upload_response(&response)
.await
.unwrap();
assert!(machine.should_upload_keys().await);
response.one_time_key_counts.insert( response.one_time_key_counts.insert(
keys::KeyAlgorithm::SignedCurve25519, keys::KeyAlgorithm::SignedCurve25519,
UInt::try_from(50).unwrap(), UInt::try_from(50).unwrap(),
); );
machine.receive_keys_upload_response(&response).await; machine
assert!(!machine.should_upload_keys()); .receive_keys_upload_response(&response)
.await
.unwrap();
assert!(!machine.should_upload_keys().await);
} }
#[tokio::test] #[tokio::test]
@ -454,27 +507,33 @@ mod test {
let mut response = keys_upload_response(); let mut response = keys_upload_response();
assert!(machine.should_upload_keys()); assert!(machine.should_upload_keys().await);
assert!(machine.generate_one_time_keys().is_err()); assert!(machine.generate_one_time_keys().await.is_err());
machine.receive_keys_upload_response(&response).await; machine
assert!(machine.should_upload_keys()); .receive_keys_upload_response(&response)
assert!(machine.generate_one_time_keys().is_ok()); .await
.unwrap();
assert!(machine.should_upload_keys().await);
assert!(machine.generate_one_time_keys().await.is_ok());
response.one_time_key_counts.insert( response.one_time_key_counts.insert(
keys::KeyAlgorithm::SignedCurve25519, keys::KeyAlgorithm::SignedCurve25519,
UInt::try_from(50).unwrap(), UInt::try_from(50).unwrap(),
); );
machine.receive_keys_upload_response(&response).await; machine
assert!(machine.generate_one_time_keys().is_err()); .receive_keys_upload_response(&response)
.await
.unwrap();
assert!(machine.generate_one_time_keys().await.is_err());
} }
#[test] #[tokio::test]
fn test_device_key_signing() { async fn test_device_key_signing() {
let machine = OlmMachine::new(&user_id(), DEVICE_ID).unwrap(); let machine = OlmMachine::new(&user_id(), DEVICE_ID).unwrap();
let mut device_keys = machine.device_keys(); let mut device_keys = machine.device_keys().await;
let identity_keys = machine.account.identity_keys(); let identity_keys = machine.account.lock().await.identity_keys();
let ed25519_key = identity_keys.ed25519(); let ed25519_key = identity_keys.ed25519();
let ret = machine.verify_json( let ret = machine.verify_json(
@ -486,11 +545,11 @@ mod test {
assert!(ret.is_ok()); assert!(ret.is_ok());
} }
#[test] #[tokio::test]
fn test_invalid_signature() { async fn test_invalid_signature() {
let machine = OlmMachine::new(&user_id(), DEVICE_ID).unwrap(); let machine = OlmMachine::new(&user_id(), DEVICE_ID).unwrap();
let mut device_keys = machine.device_keys(); let mut device_keys = machine.device_keys().await;
let ret = machine.verify_json( let ret = machine.verify_json(
&machine.user_id, &machine.user_id,
@ -501,13 +560,13 @@ mod test {
assert!(ret.is_err()); assert!(ret.is_err());
} }
#[test] #[tokio::test]
fn test_one_time_key_signing() { async fn test_one_time_key_signing() {
let mut machine = OlmMachine::new(&user_id(), DEVICE_ID).unwrap(); let mut machine = OlmMachine::new(&user_id(), DEVICE_ID).unwrap();
machine.uploaded_signed_key_count = Some(49); machine.uploaded_signed_key_count = Some(49);
let mut one_time_keys = machine.signed_one_time_keys().unwrap(); let mut one_time_keys = machine.signed_one_time_keys().await.unwrap();
let identity_keys = machine.account.identity_keys(); let identity_keys = machine.account.lock().await.identity_keys();
let ed25519_key = identity_keys.ed25519(); let ed25519_key = identity_keys.ed25519();
let mut one_time_key = one_time_keys.values_mut().nth(0).unwrap(); let mut one_time_key = one_time_keys.values_mut().nth(0).unwrap();
@ -526,11 +585,12 @@ mod test {
let mut machine = OlmMachine::new(&user_id(), DEVICE_ID).unwrap(); let mut machine = OlmMachine::new(&user_id(), DEVICE_ID).unwrap();
machine.uploaded_signed_key_count = Some(0); machine.uploaded_signed_key_count = Some(0);
let identity_keys = machine.account.identity_keys(); let identity_keys = machine.account.lock().await.identity_keys();
let ed25519_key = identity_keys.ed25519(); let ed25519_key = identity_keys.ed25519();
let (device_keys, mut one_time_keys) = machine let (device_keys, mut one_time_keys) = machine
.keys_for_upload() .keys_for_upload()
.await
.expect("Can't prepare initial key upload"); .expect("Can't prepare initial key upload");
let ret = machine.verify_json( let ret = machine.verify_json(
@ -555,9 +615,12 @@ mod test {
UInt::new_wrapping(one_time_keys.unwrap().len() as u64), UInt::new_wrapping(one_time_keys.unwrap().len() as u64),
); );
machine.receive_keys_upload_response(&response).await; machine
.receive_keys_upload_response(&response)
.await
.unwrap();
let ret = machine.keys_for_upload(); let ret = machine.keys_for_upload().await;
assert!(ret.is_err()); assert!(ret.is_err());
} }
} }

View File

@ -22,3 +22,4 @@ mod store;
pub use error::OlmError; pub use error::OlmError;
pub use machine::{OlmMachine, OneTimeKeys}; pub use machine::{OlmMachine, OneTimeKeys};
pub use store::{CryptoStore, CryptoStoreError};

View File

@ -1,3 +1,4 @@
use core::fmt::Debug;
use std::io::Error as IoError; use std::io::Error as IoError;
use std::sync::Arc; use std::sync::Arc;
use url::ParseError; use url::ParseError;
@ -8,6 +9,7 @@ use tokio::sync::Mutex;
use super::olm::Account; use super::olm::Account;
use olm_rs::errors::OlmAccountError; use olm_rs::errors::OlmAccountError;
use olm_rs::PicklingMode;
#[cfg(feature = "sqlite-cryptostore")] #[cfg(feature = "sqlite-cryptostore")]
pub mod sqlite; pub mod sqlite;
@ -33,7 +35,41 @@ pub enum CryptoStoreError {
pub type Result<T> = std::result::Result<T, CryptoStoreError>; pub type Result<T> = std::result::Result<T, CryptoStoreError>;
#[async_trait] #[async_trait]
pub trait CryptoStore { pub trait CryptoStore: Debug {
async fn load_account(&self) -> Result<Option<Account>>; async fn load_account(&mut self) -> Result<Option<Account>>;
async fn save_account(&self, account: Arc<Mutex<Account>>) -> Result<()>; async fn save_account(&mut self, account: Arc<Mutex<Account>>) -> Result<()>;
}
#[derive(Debug)]
pub struct MemoryStore {
pub(crate) account_info: Option<(String, bool)>,
}
impl MemoryStore {
/// Create a new empty memory store.
pub fn new() -> Self {
MemoryStore { account_info: None }
}
}
#[async_trait]
impl CryptoStore for MemoryStore {
async fn load_account(&mut self) -> Result<Option<Account>> {
let result = match &self.account_info {
Some((pickle, shared)) => Some(Account::from_pickle(
pickle.to_owned(),
PicklingMode::Unencrypted,
*shared,
)?),
None => None,
};
Ok(result)
}
async fn save_account(&mut self, account: Arc<Mutex<Account>>) -> Result<()> {
let acc = account.lock().await;
let pickle = acc.pickle(PicklingMode::Unencrypted);
self.account_info = Some((pickle, acc.shared));
Ok(())
}
} }

View File

@ -1,4 +1,5 @@
use std::path::Path; use std::path::{Path, PathBuf};
use std::result::Result as StdResult;
use std::sync::Arc; use std::sync::Arc;
use url::Url; use url::Url;
@ -13,54 +14,60 @@ use super::{Account, CryptoStore, Result};
pub struct SqliteStore { pub struct SqliteStore {
user_id: Arc<String>, user_id: Arc<String>,
device_id: Arc<String>, device_id: Arc<String>,
path: PathBuf,
connection: Arc<Mutex<SqliteConnection>>, connection: Arc<Mutex<SqliteConnection>>,
pickle_passphrase: Option<Zeroizing<String>>, pickle_passphrase: Option<Zeroizing<String>>,
} }
impl std::fmt::Debug for SqliteStore {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> StdResult<(), std::fmt::Error> {
write!(
fmt,
"SqliteStore {{ user_id: {}, device_id: {}, path: {:?} }}",
self.user_id, self.device_id, self.path
)
}
}
static DATABASE_NAME: &str = "matrix-sdk-crypto.db"; static DATABASE_NAME: &str = "matrix-sdk-crypto.db";
impl SqliteStore { impl SqliteStore {
async fn open<P: AsRef<Path>>(user_id: &str, device_id: &str, path: P) -> Result<SqliteStore> { pub async fn open<P: AsRef<Path>>(
let url = SqliteStore::path_to_url(path)?; user_id: &str,
SqliteStore::open_helper(user_id, device_id, url.as_ref(), None).await device_id: &str,
path: P,
) -> Result<SqliteStore> {
SqliteStore::open_helper(user_id, device_id, path, None).await
} }
async fn open_with_passphrase<P: AsRef<Path>>( pub async fn open_with_passphrase<P: AsRef<Path>>(
user_id: &str, user_id: &str,
device_id: &str, device_id: &str,
path: P, path: P,
passphrase: String, passphrase: String,
) -> Result<SqliteStore> { ) -> Result<SqliteStore> {
let url = SqliteStore::path_to_url(path)?; SqliteStore::open_helper(user_id, device_id, path, Some(Zeroizing::new(passphrase))).await
SqliteStore::open_helper(
user_id,
device_id,
url.as_ref(),
Some(Zeroizing::new(passphrase)),
)
.await
} }
async fn open_in_memory(user_id: &str, device_id: &str) -> Result<SqliteStore> { fn path_to_url(path: &Path) -> Result<Url> {
SqliteStore::open_helper(user_id, device_id, "sqlite::memory:", None).await
}
fn path_to_url<P: AsRef<Path>>(path: P) -> Result<Url> {
// TODO this returns an empty error if the path isn't absolute. // TODO this returns an empty error if the path isn't absolute.
let url = Url::from_directory_path(path.as_ref()).expect("Invalid path"); let url = Url::from_directory_path(path).expect("Invalid path");
Ok(url.join(DATABASE_NAME)?) Ok(url.join(DATABASE_NAME)?)
} }
async fn open_helper( async fn open_helper<P: AsRef<Path>>(
user_id: &str, user_id: &str,
device_id: &str, device_id: &str,
sqlite_url: &str, path: P,
passphrase: Option<Zeroizing<String>>, passphrase: Option<Zeroizing<String>>,
) -> Result<SqliteStore> { ) -> Result<SqliteStore> {
let connection = SqliteConnection::connect(sqlite_url).await.unwrap(); let url = SqliteStore::path_to_url(path.as_ref())?;
let connection = SqliteConnection::connect(url.as_ref()).await.unwrap();
let store = SqliteStore { let store = SqliteStore {
user_id: Arc::new(user_id.to_owned()), user_id: Arc::new(user_id.to_owned()),
device_id: Arc::new(device_id.to_owned()), device_id: Arc::new(device_id.to_owned()),
path: path.as_ref().to_owned(),
connection: Arc::new(Mutex::new(connection)), connection: Arc::new(Mutex::new(connection)),
pickle_passphrase: passphrase, pickle_passphrase: passphrase,
}; };
@ -100,7 +107,7 @@ impl SqliteStore {
#[async_trait] #[async_trait]
impl CryptoStore for SqliteStore { impl CryptoStore for SqliteStore {
async fn load_account(&self) -> Result<Option<Account>> { async fn load_account(&mut self) -> Result<Option<Account>> {
let mut connection = self.connection.lock().await; let mut connection = self.connection.lock().await;
let row: Option<(String, bool)> = query_as( let row: Option<(String, bool)> = query_as(
@ -124,7 +131,7 @@ impl CryptoStore for SqliteStore {
Ok(result) Ok(result)
} }
async fn save_account(&self, account: Arc<Mutex<Account>>) -> Result<()> { async fn save_account(&mut self, account: Arc<Mutex<Account>>) -> Result<()> {
let acc = account.lock().await; let acc = account.lock().await;
let pickle = acc.pickle(self.get_pickle_mode()); let pickle = acc.pickle(self.get_pickle_mode());
let mut connection = self.connection.lock().await; let mut connection = self.connection.lock().await;
@ -178,12 +185,6 @@ mod test {
.expect("Can't create store") .expect("Can't create store")
} }
async fn get_memory_store() -> SqliteStore {
SqliteStore::open_in_memory(USER_ID, DEVICE_ID)
.await
.expect("Can't create memory store")
}
fn get_account() -> Arc<Mutex<Account>> { fn get_account() -> Arc<Mutex<Account>> {
let account = Account::new(); let account = Account::new();
Arc::new(Mutex::new(account)) Arc::new(Mutex::new(account))
@ -200,7 +201,7 @@ mod test {
#[tokio::test] #[tokio::test]
async fn save_account() { async fn save_account() {
let store = get_store().await; let mut store = get_store().await;
let account = get_account(); let account = get_account();
store store
@ -211,7 +212,7 @@ mod test {
#[tokio::test] #[tokio::test]
async fn load_account() { async fn load_account() {
let store = get_memory_store().await; let mut store = get_store().await;
let account = get_account(); let account = get_account();
store store