crypto: Hook up the crypto store to the Olm machine.
This commit is contained in:
parent
d7ab847b98
commit
4aba058695
7 changed files with 214 additions and 111 deletions
|
@ -645,7 +645,7 @@ impl AsyncClient {
|
|||
.write()
|
||||
.await
|
||||
.receive_keys_upload_response(&response)
|
||||
.await;
|
||||
.await?;
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
|
|
|
@ -331,7 +331,7 @@ impl Client {
|
|||
let olm = self.olm.lock().await;
|
||||
|
||||
match &*olm {
|
||||
Some(o) => o.should_upload_keys(),
|
||||
Some(o) => o.should_upload_keys().await,
|
||||
None => false,
|
||||
}
|
||||
}
|
||||
|
@ -346,7 +346,7 @@ impl Client {
|
|||
let olm = self.olm.lock().await;
|
||||
|
||||
match &*olm {
|
||||
Some(o) => o.keys_for_upload(),
|
||||
Some(o) => o.keys_for_upload().await,
|
||||
None => Err(()),
|
||||
}
|
||||
}
|
||||
|
@ -361,10 +361,11 @@ impl Client {
|
|||
/// # Panics
|
||||
/// Panics if the client hasn't been logged in.
|
||||
#[cfg(feature = "encryption")]
|
||||
pub async fn receive_keys_upload_response(&self, response: &KeysUploadResponse) {
|
||||
pub async fn receive_keys_upload_response(&self, response: &KeysUploadResponse) -> Result<()> {
|
||||
let mut olm = self.olm.lock().await;
|
||||
|
||||
let o = olm.as_mut().expect("Client isn't logged in.");
|
||||
o.receive_keys_upload_response(response).await;
|
||||
o.receive_keys_upload_response(response).await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,6 +18,15 @@ use thiserror::Error;
|
|||
use super::store::CryptoStoreError;
|
||||
|
||||
pub type Result<T> = std::result::Result<T, OlmError>;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum OlmError {
|
||||
#[error("signature verification failed")]
|
||||
Signature(#[from] SignatureError),
|
||||
#[error("failed to read or write to the crypto store {0}")]
|
||||
Store(#[from] CryptoStoreError),
|
||||
}
|
||||
|
||||
pub type VerificationResult<T> = std::result::Result<T, SignatureError>;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
|
@ -37,11 +46,3 @@ impl From<CjsonError> for SignatureError {
|
|||
Self::CanonicalJsonError(error)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum OlmError {
|
||||
#[error("signature verification failed")]
|
||||
Signature(#[from] SignatureError),
|
||||
#[error("failed to read or write to the crypto store {0}")]
|
||||
Store(#[from] CryptoStoreError),
|
||||
}
|
||||
|
|
|
@ -14,10 +14,16 @@
|
|||
|
||||
use std::collections::HashMap;
|
||||
use std::convert::TryInto;
|
||||
use std::path::Path;
|
||||
use std::result::Result as StdResult;
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::error::{Result, SignatureError, VerificationResult};
|
||||
use super::olm::Account;
|
||||
#[cfg(feature = "sqlite-cryptostore")]
|
||||
use super::store::sqlite::SqliteStore;
|
||||
use super::store::MemoryStore;
|
||||
use super::CryptoStore;
|
||||
use crate::api;
|
||||
|
||||
use api::r0::keys;
|
||||
|
@ -26,8 +32,8 @@ use cjson;
|
|||
use olm_rs::utility::OlmUtility;
|
||||
use serde_json::json;
|
||||
use serde_json::Value;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use super::store::CryptoStoreError;
|
||||
use ruma_client_api::r0::keys::{
|
||||
AlgorithmAndDeviceId, DeviceKeys, KeyAlgorithm, OneTimeKey, SignedKey,
|
||||
};
|
||||
|
@ -40,9 +46,6 @@ use ruma_identifiers::{DeviceId, UserId};
|
|||
|
||||
pub type OneTimeKeys = HashMap<AlgorithmAndDeviceId, OneTimeKey>;
|
||||
|
||||
#[cfg(feature = "sqlite-cryptostore")]
|
||||
use super::store::sqlite::SqliteStore;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct OlmMachine {
|
||||
/// The unique user id that owns this account.
|
||||
|
@ -50,12 +53,16 @@ pub struct OlmMachine {
|
|||
/// The unique device id of the device that holds this account.
|
||||
device_id: DeviceId,
|
||||
/// Our underlying Olm Account holding our identity keys.
|
||||
account: Account,
|
||||
account: Arc<Mutex<Account>>,
|
||||
/// The number of signed one-time keys we have uploaded to the server. If
|
||||
/// this is None, no action will be taken. After a sync request the client
|
||||
/// needs to set this for us, depending on the count we will suggest the
|
||||
/// client to upload new keys.
|
||||
uploaded_signed_key_count: Option<u64>,
|
||||
/// Store for the encryption keys.
|
||||
/// Persists all the encrytpion keys so a client can resume the session
|
||||
/// without the need to create new keys.
|
||||
store: Box<dyn CryptoStore>,
|
||||
}
|
||||
|
||||
impl OlmMachine {
|
||||
|
@ -69,14 +76,39 @@ impl OlmMachine {
|
|||
Ok(OlmMachine {
|
||||
user_id: user_id.clone(),
|
||||
device_id: device_id.to_owned(),
|
||||
account: Account::new(),
|
||||
account: Arc::new(Mutex::new(Account::new())),
|
||||
uploaded_signed_key_count: None,
|
||||
store: Box::new(MemoryStore::new()),
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(feature = "sqlite-cryptostore")]
|
||||
pub async fn new_with_sqlite_store<P: AsRef<Path>>(
|
||||
user_id: &UserId,
|
||||
device_id: &str,
|
||||
path: P,
|
||||
passphrase: String,
|
||||
) -> Result<Self> {
|
||||
Ok(OlmMachine {
|
||||
user_id: user_id.clone(),
|
||||
device_id: device_id.to_owned(),
|
||||
account: Arc::new(Mutex::new(Account::new())),
|
||||
uploaded_signed_key_count: None,
|
||||
store: Box::new(
|
||||
SqliteStore::open_with_passphrase(
|
||||
&user_id.to_string(),
|
||||
device_id,
|
||||
path,
|
||||
passphrase,
|
||||
)
|
||||
.await?,
|
||||
),
|
||||
})
|
||||
}
|
||||
|
||||
/// Should account or one-time keys be uploaded to the server.
|
||||
pub fn should_upload_keys(&self) -> bool {
|
||||
if !self.account.shared() {
|
||||
pub async fn should_upload_keys(&self) -> bool {
|
||||
if !self.account.lock().await.shared() {
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -84,7 +116,7 @@ impl OlmMachine {
|
|||
// max_one_time_Keys() / 2, otherwise tell the client to upload more.
|
||||
match self.uploaded_signed_key_count {
|
||||
Some(count) => {
|
||||
let max_keys = self.account.max_one_time_keys() as u64;
|
||||
let max_keys = self.account.lock().await.max_one_time_keys() as u64;
|
||||
let key_count = (max_keys / 2) - count;
|
||||
key_count > 0
|
||||
}
|
||||
|
@ -98,8 +130,12 @@ impl OlmMachine {
|
|||
///
|
||||
/// * `response` - The keys upload response of the request that the client
|
||||
/// performed.
|
||||
pub async fn receive_keys_upload_response(&mut self, response: &keys::upload_keys::Response) {
|
||||
self.account.shared = true;
|
||||
pub async fn receive_keys_upload_response(
|
||||
&mut self,
|
||||
response: &keys::upload_keys::Response,
|
||||
) -> Result<()> {
|
||||
let mut account = self.account.lock().await;
|
||||
account.shared = true;
|
||||
|
||||
let one_time_key_count = response
|
||||
.one_time_key_counts
|
||||
|
@ -108,18 +144,22 @@ impl OlmMachine {
|
|||
let count: u64 = one_time_key_count.map_or(0, |c| (*c).into());
|
||||
self.uploaded_signed_key_count = Some(count);
|
||||
|
||||
self.account.mark_keys_as_published();
|
||||
// TODO save the account here.
|
||||
account.mark_keys_as_published();
|
||||
drop(account);
|
||||
|
||||
self.store.save_account(self.account.clone()).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Generate new one-time keys.
|
||||
///
|
||||
/// Returns the number of newly generated one-time keys. If no keys can be
|
||||
/// generated returns an empty error.
|
||||
fn generate_one_time_keys(&self) -> StdResult<u64, ()> {
|
||||
async fn generate_one_time_keys(&self) -> StdResult<u64, ()> {
|
||||
let account = self.account.lock().await;
|
||||
match self.uploaded_signed_key_count {
|
||||
Some(count) => {
|
||||
let max_keys = self.account.max_one_time_keys() as u64;
|
||||
let max_keys = account.max_one_time_keys() as u64;
|
||||
let max_on_server = max_keys / 2;
|
||||
|
||||
if count >= (max_on_server) {
|
||||
|
@ -130,9 +170,9 @@ impl OlmMachine {
|
|||
|
||||
let key_count: usize = key_count
|
||||
.try_into()
|
||||
.unwrap_or_else(|_| self.account.max_one_time_keys());
|
||||
.unwrap_or_else(|_| account.max_one_time_keys());
|
||||
|
||||
self.account.generate_one_time_keys(key_count);
|
||||
account.generate_one_time_keys(key_count);
|
||||
Ok(key_count as u64)
|
||||
}
|
||||
None => Err(()),
|
||||
|
@ -140,8 +180,8 @@ impl OlmMachine {
|
|||
}
|
||||
|
||||
/// Sign the device keys and return a JSON Value to upload them.
|
||||
fn device_keys(&self) -> DeviceKeys {
|
||||
let identity_keys = self.account.identity_keys();
|
||||
async fn device_keys(&self) -> DeviceKeys {
|
||||
let identity_keys = self.account.lock().await.identity_keys();
|
||||
|
||||
let mut keys = HashMap::new();
|
||||
|
||||
|
@ -166,7 +206,7 @@ impl OlmMachine {
|
|||
let mut signature = HashMap::new();
|
||||
signature.insert(
|
||||
AlgorithmAndDeviceId(KeyAlgorithm::Ed25519, self.device_id.clone()),
|
||||
self.sign_json(&device_keys),
|
||||
self.sign_json(&device_keys).await,
|
||||
);
|
||||
signatures.insert(self.user_id.clone(), signature);
|
||||
|
||||
|
@ -186,10 +226,9 @@ impl OlmMachine {
|
|||
/// Generate, sign and prepare one-time keys to be uploaded.
|
||||
///
|
||||
/// If no one-time keys need to be uploaded returns an empty error.
|
||||
fn signed_one_time_keys(&self) -> StdResult<OneTimeKeys, ()> {
|
||||
let _ = self.generate_one_time_keys()?;
|
||||
|
||||
let one_time_keys = self.account.one_time_keys();
|
||||
async fn signed_one_time_keys(&self) -> StdResult<OneTimeKeys, ()> {
|
||||
let _ = self.generate_one_time_keys().await?;
|
||||
let one_time_keys = self.account.lock().await.one_time_keys();
|
||||
let mut one_time_key_map = HashMap::new();
|
||||
|
||||
for (key_id, key) in one_time_keys.curve25519().iter() {
|
||||
|
@ -197,7 +236,7 @@ impl OlmMachine {
|
|||
"key": key,
|
||||
});
|
||||
|
||||
let signature = self.sign_json(&key_json);
|
||||
let signature = self.sign_json(&key_json).await;
|
||||
|
||||
let mut signature_map = HashMap::new();
|
||||
|
||||
|
@ -230,10 +269,11 @@ impl OlmMachine {
|
|||
///
|
||||
/// * `json` - The value that should be converted into a canonical JSON
|
||||
/// string.
|
||||
fn sign_json(&self, json: &Value) -> String {
|
||||
async fn sign_json(&self, json: &Value) -> String {
|
||||
let account = self.account.lock().await;
|
||||
let canonical_json = cjson::to_string(json)
|
||||
.unwrap_or_else(|_| panic!(format!("Can't serialize {} to canonical JSON", json)));
|
||||
self.account.sign(&canonical_json)
|
||||
account.sign(&canonical_json)
|
||||
}
|
||||
|
||||
/// Verify a signed JSON object.
|
||||
|
@ -305,18 +345,22 @@ impl OlmMachine {
|
|||
/// Get a tuple of device and one-time keys that need to be uploaded.
|
||||
///
|
||||
/// Returns an empty error if no keys need to be uploaded.
|
||||
pub fn keys_for_upload(&self) -> StdResult<(Option<DeviceKeys>, Option<OneTimeKeys>), ()> {
|
||||
if !self.should_upload_keys() {
|
||||
pub async fn keys_for_upload(
|
||||
&self,
|
||||
) -> StdResult<(Option<DeviceKeys>, Option<OneTimeKeys>), ()> {
|
||||
if !self.should_upload_keys().await {
|
||||
return Err(());
|
||||
}
|
||||
|
||||
let device_keys = if !self.account.shared() {
|
||||
Some(self.device_keys())
|
||||
let shared = self.account.lock().await.shared();
|
||||
|
||||
let device_keys = if !shared {
|
||||
Some(self.device_keys().await)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let one_time_keys: Option<OneTimeKeys> = self.signed_one_time_keys().ok();
|
||||
let one_time_keys: Option<OneTimeKeys> = self.signed_one_time_keys().await.ok();
|
||||
|
||||
Ok((device_keys, one_time_keys))
|
||||
}
|
||||
|
@ -413,10 +457,10 @@ mod test {
|
|||
keys::upload_keys::Response::try_from(data).expect("Can't parse the keys upload response")
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn create_olm_machine() {
|
||||
#[tokio::test]
|
||||
async fn create_olm_machine() {
|
||||
let machine = OlmMachine::new(&user_id(), DEVICE_ID).unwrap();
|
||||
assert!(machine.should_upload_keys());
|
||||
assert!(machine.should_upload_keys().await);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
@ -429,23 +473,32 @@ mod test {
|
|||
.remove(&keys::KeyAlgorithm::SignedCurve25519)
|
||||
.unwrap();
|
||||
|
||||
assert!(machine.should_upload_keys());
|
||||
machine.receive_keys_upload_response(&response).await;
|
||||
assert!(machine.should_upload_keys());
|
||||
assert!(machine.should_upload_keys().await);
|
||||
machine
|
||||
.receive_keys_upload_response(&response)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(machine.should_upload_keys().await);
|
||||
|
||||
response.one_time_key_counts.insert(
|
||||
keys::KeyAlgorithm::SignedCurve25519,
|
||||
UInt::try_from(10).unwrap(),
|
||||
);
|
||||
machine.receive_keys_upload_response(&response).await;
|
||||
assert!(machine.should_upload_keys());
|
||||
machine
|
||||
.receive_keys_upload_response(&response)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(machine.should_upload_keys().await);
|
||||
|
||||
response.one_time_key_counts.insert(
|
||||
keys::KeyAlgorithm::SignedCurve25519,
|
||||
UInt::try_from(50).unwrap(),
|
||||
);
|
||||
machine.receive_keys_upload_response(&response).await;
|
||||
assert!(!machine.should_upload_keys());
|
||||
machine
|
||||
.receive_keys_upload_response(&response)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!machine.should_upload_keys().await);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
@ -454,27 +507,33 @@ mod test {
|
|||
|
||||
let mut response = keys_upload_response();
|
||||
|
||||
assert!(machine.should_upload_keys());
|
||||
assert!(machine.generate_one_time_keys().is_err());
|
||||
assert!(machine.should_upload_keys().await);
|
||||
assert!(machine.generate_one_time_keys().await.is_err());
|
||||
|
||||
machine.receive_keys_upload_response(&response).await;
|
||||
assert!(machine.should_upload_keys());
|
||||
assert!(machine.generate_one_time_keys().is_ok());
|
||||
machine
|
||||
.receive_keys_upload_response(&response)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(machine.should_upload_keys().await);
|
||||
assert!(machine.generate_one_time_keys().await.is_ok());
|
||||
|
||||
response.one_time_key_counts.insert(
|
||||
keys::KeyAlgorithm::SignedCurve25519,
|
||||
UInt::try_from(50).unwrap(),
|
||||
);
|
||||
machine.receive_keys_upload_response(&response).await;
|
||||
assert!(machine.generate_one_time_keys().is_err());
|
||||
machine
|
||||
.receive_keys_upload_response(&response)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(machine.generate_one_time_keys().await.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_device_key_signing() {
|
||||
#[tokio::test]
|
||||
async fn test_device_key_signing() {
|
||||
let machine = OlmMachine::new(&user_id(), DEVICE_ID).unwrap();
|
||||
|
||||
let mut device_keys = machine.device_keys();
|
||||
let identity_keys = machine.account.identity_keys();
|
||||
let mut device_keys = machine.device_keys().await;
|
||||
let identity_keys = machine.account.lock().await.identity_keys();
|
||||
let ed25519_key = identity_keys.ed25519();
|
||||
|
||||
let ret = machine.verify_json(
|
||||
|
@ -486,11 +545,11 @@ mod test {
|
|||
assert!(ret.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_signature() {
|
||||
#[tokio::test]
|
||||
async fn test_invalid_signature() {
|
||||
let machine = OlmMachine::new(&user_id(), DEVICE_ID).unwrap();
|
||||
|
||||
let mut device_keys = machine.device_keys();
|
||||
let mut device_keys = machine.device_keys().await;
|
||||
|
||||
let ret = machine.verify_json(
|
||||
&machine.user_id,
|
||||
|
@ -501,13 +560,13 @@ mod test {
|
|||
assert!(ret.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_one_time_key_signing() {
|
||||
#[tokio::test]
|
||||
async fn test_one_time_key_signing() {
|
||||
let mut machine = OlmMachine::new(&user_id(), DEVICE_ID).unwrap();
|
||||
machine.uploaded_signed_key_count = Some(49);
|
||||
|
||||
let mut one_time_keys = machine.signed_one_time_keys().unwrap();
|
||||
let identity_keys = machine.account.identity_keys();
|
||||
let mut one_time_keys = machine.signed_one_time_keys().await.unwrap();
|
||||
let identity_keys = machine.account.lock().await.identity_keys();
|
||||
let ed25519_key = identity_keys.ed25519();
|
||||
|
||||
let mut one_time_key = one_time_keys.values_mut().nth(0).unwrap();
|
||||
|
@ -526,11 +585,12 @@ mod test {
|
|||
let mut machine = OlmMachine::new(&user_id(), DEVICE_ID).unwrap();
|
||||
machine.uploaded_signed_key_count = Some(0);
|
||||
|
||||
let identity_keys = machine.account.identity_keys();
|
||||
let identity_keys = machine.account.lock().await.identity_keys();
|
||||
let ed25519_key = identity_keys.ed25519();
|
||||
|
||||
let (device_keys, mut one_time_keys) = machine
|
||||
.keys_for_upload()
|
||||
.await
|
||||
.expect("Can't prepare initial key upload");
|
||||
|
||||
let ret = machine.verify_json(
|
||||
|
@ -555,9 +615,12 @@ mod test {
|
|||
UInt::new_wrapping(one_time_keys.unwrap().len() as u64),
|
||||
);
|
||||
|
||||
machine.receive_keys_upload_response(&response).await;
|
||||
machine
|
||||
.receive_keys_upload_response(&response)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let ret = machine.keys_for_upload();
|
||||
let ret = machine.keys_for_upload().await;
|
||||
assert!(ret.is_err());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -22,3 +22,4 @@ mod store;
|
|||
|
||||
pub use error::OlmError;
|
||||
pub use machine::{OlmMachine, OneTimeKeys};
|
||||
pub use store::{CryptoStore, CryptoStoreError};
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
use core::fmt::Debug;
|
||||
use std::io::Error as IoError;
|
||||
use std::sync::Arc;
|
||||
use url::ParseError;
|
||||
|
@ -8,6 +9,7 @@ use tokio::sync::Mutex;
|
|||
|
||||
use super::olm::Account;
|
||||
use olm_rs::errors::OlmAccountError;
|
||||
use olm_rs::PicklingMode;
|
||||
|
||||
#[cfg(feature = "sqlite-cryptostore")]
|
||||
pub mod sqlite;
|
||||
|
@ -33,7 +35,41 @@ pub enum CryptoStoreError {
|
|||
pub type Result<T> = std::result::Result<T, CryptoStoreError>;
|
||||
|
||||
#[async_trait]
|
||||
pub trait CryptoStore {
|
||||
async fn load_account(&self) -> Result<Option<Account>>;
|
||||
async fn save_account(&self, account: Arc<Mutex<Account>>) -> Result<()>;
|
||||
pub trait CryptoStore: Debug {
|
||||
async fn load_account(&mut self) -> Result<Option<Account>>;
|
||||
async fn save_account(&mut self, account: Arc<Mutex<Account>>) -> Result<()>;
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct MemoryStore {
|
||||
pub(crate) account_info: Option<(String, bool)>,
|
||||
}
|
||||
|
||||
impl MemoryStore {
|
||||
/// Create a new empty memory store.
|
||||
pub fn new() -> Self {
|
||||
MemoryStore { account_info: None }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl CryptoStore for MemoryStore {
|
||||
async fn load_account(&mut self) -> Result<Option<Account>> {
|
||||
let result = match &self.account_info {
|
||||
Some((pickle, shared)) => Some(Account::from_pickle(
|
||||
pickle.to_owned(),
|
||||
PicklingMode::Unencrypted,
|
||||
*shared,
|
||||
)?),
|
||||
None => None,
|
||||
};
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
async fn save_account(&mut self, account: Arc<Mutex<Account>>) -> Result<()> {
|
||||
let acc = account.lock().await;
|
||||
let pickle = acc.pickle(PicklingMode::Unencrypted);
|
||||
self.account_info = Some((pickle, acc.shared));
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
use std::path::Path;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::result::Result as StdResult;
|
||||
use std::sync::Arc;
|
||||
use url::Url;
|
||||
|
||||
|
@ -13,54 +14,60 @@ use super::{Account, CryptoStore, Result};
|
|||
pub struct SqliteStore {
|
||||
user_id: Arc<String>,
|
||||
device_id: Arc<String>,
|
||||
path: PathBuf,
|
||||
connection: Arc<Mutex<SqliteConnection>>,
|
||||
pickle_passphrase: Option<Zeroizing<String>>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for SqliteStore {
|
||||
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> StdResult<(), std::fmt::Error> {
|
||||
write!(
|
||||
fmt,
|
||||
"SqliteStore {{ user_id: {}, device_id: {}, path: {:?} }}",
|
||||
self.user_id, self.device_id, self.path
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
static DATABASE_NAME: &str = "matrix-sdk-crypto.db";
|
||||
|
||||
impl SqliteStore {
|
||||
async fn open<P: AsRef<Path>>(user_id: &str, device_id: &str, path: P) -> Result<SqliteStore> {
|
||||
let url = SqliteStore::path_to_url(path)?;
|
||||
SqliteStore::open_helper(user_id, device_id, url.as_ref(), None).await
|
||||
pub async fn open<P: AsRef<Path>>(
|
||||
user_id: &str,
|
||||
device_id: &str,
|
||||
path: P,
|
||||
) -> Result<SqliteStore> {
|
||||
SqliteStore::open_helper(user_id, device_id, path, None).await
|
||||
}
|
||||
|
||||
async fn open_with_passphrase<P: AsRef<Path>>(
|
||||
pub async fn open_with_passphrase<P: AsRef<Path>>(
|
||||
user_id: &str,
|
||||
device_id: &str,
|
||||
path: P,
|
||||
passphrase: String,
|
||||
) -> Result<SqliteStore> {
|
||||
let url = SqliteStore::path_to_url(path)?;
|
||||
SqliteStore::open_helper(
|
||||
user_id,
|
||||
device_id,
|
||||
url.as_ref(),
|
||||
Some(Zeroizing::new(passphrase)),
|
||||
)
|
||||
.await
|
||||
SqliteStore::open_helper(user_id, device_id, path, Some(Zeroizing::new(passphrase))).await
|
||||
}
|
||||
|
||||
async fn open_in_memory(user_id: &str, device_id: &str) -> Result<SqliteStore> {
|
||||
SqliteStore::open_helper(user_id, device_id, "sqlite::memory:", None).await
|
||||
}
|
||||
|
||||
fn path_to_url<P: AsRef<Path>>(path: P) -> Result<Url> {
|
||||
fn path_to_url(path: &Path) -> Result<Url> {
|
||||
// TODO this returns an empty error if the path isn't absolute.
|
||||
let url = Url::from_directory_path(path.as_ref()).expect("Invalid path");
|
||||
let url = Url::from_directory_path(path).expect("Invalid path");
|
||||
Ok(url.join(DATABASE_NAME)?)
|
||||
}
|
||||
|
||||
async fn open_helper(
|
||||
async fn open_helper<P: AsRef<Path>>(
|
||||
user_id: &str,
|
||||
device_id: &str,
|
||||
sqlite_url: &str,
|
||||
path: P,
|
||||
passphrase: Option<Zeroizing<String>>,
|
||||
) -> Result<SqliteStore> {
|
||||
let connection = SqliteConnection::connect(sqlite_url).await.unwrap();
|
||||
let url = SqliteStore::path_to_url(path.as_ref())?;
|
||||
|
||||
let connection = SqliteConnection::connect(url.as_ref()).await.unwrap();
|
||||
let store = SqliteStore {
|
||||
user_id: Arc::new(user_id.to_owned()),
|
||||
device_id: Arc::new(device_id.to_owned()),
|
||||
path: path.as_ref().to_owned(),
|
||||
connection: Arc::new(Mutex::new(connection)),
|
||||
pickle_passphrase: passphrase,
|
||||
};
|
||||
|
@ -100,7 +107,7 @@ impl SqliteStore {
|
|||
|
||||
#[async_trait]
|
||||
impl CryptoStore for SqliteStore {
|
||||
async fn load_account(&self) -> Result<Option<Account>> {
|
||||
async fn load_account(&mut self) -> Result<Option<Account>> {
|
||||
let mut connection = self.connection.lock().await;
|
||||
|
||||
let row: Option<(String, bool)> = query_as(
|
||||
|
@ -124,7 +131,7 @@ impl CryptoStore for SqliteStore {
|
|||
Ok(result)
|
||||
}
|
||||
|
||||
async fn save_account(&self, account: Arc<Mutex<Account>>) -> Result<()> {
|
||||
async fn save_account(&mut self, account: Arc<Mutex<Account>>) -> Result<()> {
|
||||
let acc = account.lock().await;
|
||||
let pickle = acc.pickle(self.get_pickle_mode());
|
||||
let mut connection = self.connection.lock().await;
|
||||
|
@ -178,12 +185,6 @@ mod test {
|
|||
.expect("Can't create store")
|
||||
}
|
||||
|
||||
async fn get_memory_store() -> SqliteStore {
|
||||
SqliteStore::open_in_memory(USER_ID, DEVICE_ID)
|
||||
.await
|
||||
.expect("Can't create memory store")
|
||||
}
|
||||
|
||||
fn get_account() -> Arc<Mutex<Account>> {
|
||||
let account = Account::new();
|
||||
Arc::new(Mutex::new(account))
|
||||
|
@ -200,7 +201,7 @@ mod test {
|
|||
|
||||
#[tokio::test]
|
||||
async fn save_account() {
|
||||
let store = get_store().await;
|
||||
let mut store = get_store().await;
|
||||
let account = get_account();
|
||||
|
||||
store
|
||||
|
@ -211,7 +212,7 @@ mod test {
|
|||
|
||||
#[tokio::test]
|
||||
async fn load_account() {
|
||||
let store = get_memory_store().await;
|
||||
let mut store = get_store().await;
|
||||
let account = get_account();
|
||||
|
||||
store
|
||||
|
|
Loading…
Reference in a new issue