diff --git a/src/async_client.rs b/src/async_client.rs index 4cc79cfc..c8afd107 100644 --- a/src/async_client.rs +++ b/src/async_client.rs @@ -645,7 +645,7 @@ impl AsyncClient { .write() .await .receive_keys_upload_response(&response) - .await; + .await?; Ok(response) } diff --git a/src/base_client.rs b/src/base_client.rs index 398ffb3d..f7437c11 100644 --- a/src/base_client.rs +++ b/src/base_client.rs @@ -331,7 +331,7 @@ impl Client { let olm = self.olm.lock().await; match &*olm { - Some(o) => o.should_upload_keys(), + Some(o) => o.should_upload_keys().await, None => false, } } @@ -346,7 +346,7 @@ impl Client { let olm = self.olm.lock().await; match &*olm { - Some(o) => o.keys_for_upload(), + Some(o) => o.keys_for_upload().await, None => Err(()), } } @@ -361,10 +361,11 @@ impl Client { /// # Panics /// Panics if the client hasn't been logged in. #[cfg(feature = "encryption")] - pub async fn receive_keys_upload_response(&self, response: &KeysUploadResponse) { + pub async fn receive_keys_upload_response(&self, response: &KeysUploadResponse) -> Result<()> { let mut olm = self.olm.lock().await; let o = olm.as_mut().expect("Client isn't logged in."); - o.receive_keys_upload_response(response).await; + o.receive_keys_upload_response(response).await?; + Ok(()) } } diff --git a/src/crypto/error.rs b/src/crypto/error.rs index 4bc9e63d..38ba72de 100644 --- a/src/crypto/error.rs +++ b/src/crypto/error.rs @@ -18,6 +18,15 @@ use thiserror::Error; use super::store::CryptoStoreError; pub type Result = std::result::Result; + +#[derive(Error, Debug)] +pub enum OlmError { + #[error("signature verification failed")] + Signature(#[from] SignatureError), + #[error("failed to read or write to the crypto store {0}")] + Store(#[from] CryptoStoreError), +} + pub type VerificationResult = std::result::Result; #[derive(Error, Debug)] @@ -37,11 +46,3 @@ impl From for SignatureError { Self::CanonicalJsonError(error) } } - -#[derive(Error, Debug)] -pub enum OlmError { - #[error("signature verification failed")] - Signature(#[from] SignatureError), - #[error("failed to read or write to the crypto store {0}")] - Store(#[from] CryptoStoreError), -} diff --git a/src/crypto/machine.rs b/src/crypto/machine.rs index f49cb873..e7e2f548 100644 --- a/src/crypto/machine.rs +++ b/src/crypto/machine.rs @@ -14,10 +14,16 @@ use std::collections::HashMap; use std::convert::TryInto; +use std::path::Path; use std::result::Result as StdResult; +use std::sync::Arc; use super::error::{Result, SignatureError, VerificationResult}; use super::olm::Account; +#[cfg(feature = "sqlite-cryptostore")] +use super::store::sqlite::SqliteStore; +use super::store::MemoryStore; +use super::CryptoStore; use crate::api; use api::r0::keys; @@ -26,8 +32,8 @@ use cjson; use olm_rs::utility::OlmUtility; use serde_json::json; use serde_json::Value; +use tokio::sync::Mutex; -use super::store::CryptoStoreError; use ruma_client_api::r0::keys::{ AlgorithmAndDeviceId, DeviceKeys, KeyAlgorithm, OneTimeKey, SignedKey, }; @@ -40,9 +46,6 @@ use ruma_identifiers::{DeviceId, UserId}; pub type OneTimeKeys = HashMap; -#[cfg(feature = "sqlite-cryptostore")] -use super::store::sqlite::SqliteStore; - #[derive(Debug)] pub struct OlmMachine { /// The unique user id that owns this account. @@ -50,12 +53,16 @@ pub struct OlmMachine { /// The unique device id of the device that holds this account. device_id: DeviceId, /// Our underlying Olm Account holding our identity keys. - account: Account, + account: Arc>, /// The number of signed one-time keys we have uploaded to the server. If /// this is None, no action will be taken. After a sync request the client /// needs to set this for us, depending on the count we will suggest the /// client to upload new keys. uploaded_signed_key_count: Option, + /// Store for the encryption keys. + /// Persists all the encrytpion keys so a client can resume the session + /// without the need to create new keys. + store: Box, } impl OlmMachine { @@ -69,14 +76,39 @@ impl OlmMachine { Ok(OlmMachine { user_id: user_id.clone(), device_id: device_id.to_owned(), - account: Account::new(), + account: Arc::new(Mutex::new(Account::new())), uploaded_signed_key_count: None, + store: Box::new(MemoryStore::new()), + }) + } + + #[cfg(feature = "sqlite-cryptostore")] + pub async fn new_with_sqlite_store>( + user_id: &UserId, + device_id: &str, + path: P, + passphrase: String, + ) -> Result { + Ok(OlmMachine { + user_id: user_id.clone(), + device_id: device_id.to_owned(), + account: Arc::new(Mutex::new(Account::new())), + uploaded_signed_key_count: None, + store: Box::new( + SqliteStore::open_with_passphrase( + &user_id.to_string(), + device_id, + path, + passphrase, + ) + .await?, + ), }) } /// Should account or one-time keys be uploaded to the server. - pub fn should_upload_keys(&self) -> bool { - if !self.account.shared() { + pub async fn should_upload_keys(&self) -> bool { + if !self.account.lock().await.shared() { return true; } @@ -84,7 +116,7 @@ impl OlmMachine { // max_one_time_Keys() / 2, otherwise tell the client to upload more. match self.uploaded_signed_key_count { Some(count) => { - let max_keys = self.account.max_one_time_keys() as u64; + let max_keys = self.account.lock().await.max_one_time_keys() as u64; let key_count = (max_keys / 2) - count; key_count > 0 } @@ -98,8 +130,12 @@ impl OlmMachine { /// /// * `response` - The keys upload response of the request that the client /// performed. - pub async fn receive_keys_upload_response(&mut self, response: &keys::upload_keys::Response) { - self.account.shared = true; + pub async fn receive_keys_upload_response( + &mut self, + response: &keys::upload_keys::Response, + ) -> Result<()> { + let mut account = self.account.lock().await; + account.shared = true; let one_time_key_count = response .one_time_key_counts @@ -108,18 +144,22 @@ impl OlmMachine { let count: u64 = one_time_key_count.map_or(0, |c| (*c).into()); self.uploaded_signed_key_count = Some(count); - self.account.mark_keys_as_published(); - // TODO save the account here. + account.mark_keys_as_published(); + drop(account); + + self.store.save_account(self.account.clone()).await?; + Ok(()) } /// Generate new one-time keys. /// /// Returns the number of newly generated one-time keys. If no keys can be /// generated returns an empty error. - fn generate_one_time_keys(&self) -> StdResult { + async fn generate_one_time_keys(&self) -> StdResult { + let account = self.account.lock().await; match self.uploaded_signed_key_count { Some(count) => { - let max_keys = self.account.max_one_time_keys() as u64; + let max_keys = account.max_one_time_keys() as u64; let max_on_server = max_keys / 2; if count >= (max_on_server) { @@ -130,9 +170,9 @@ impl OlmMachine { let key_count: usize = key_count .try_into() - .unwrap_or_else(|_| self.account.max_one_time_keys()); + .unwrap_or_else(|_| account.max_one_time_keys()); - self.account.generate_one_time_keys(key_count); + account.generate_one_time_keys(key_count); Ok(key_count as u64) } None => Err(()), @@ -140,8 +180,8 @@ impl OlmMachine { } /// Sign the device keys and return a JSON Value to upload them. - fn device_keys(&self) -> DeviceKeys { - let identity_keys = self.account.identity_keys(); + async fn device_keys(&self) -> DeviceKeys { + let identity_keys = self.account.lock().await.identity_keys(); let mut keys = HashMap::new(); @@ -166,7 +206,7 @@ impl OlmMachine { let mut signature = HashMap::new(); signature.insert( AlgorithmAndDeviceId(KeyAlgorithm::Ed25519, self.device_id.clone()), - self.sign_json(&device_keys), + self.sign_json(&device_keys).await, ); signatures.insert(self.user_id.clone(), signature); @@ -186,10 +226,9 @@ impl OlmMachine { /// Generate, sign and prepare one-time keys to be uploaded. /// /// If no one-time keys need to be uploaded returns an empty error. - fn signed_one_time_keys(&self) -> StdResult { - let _ = self.generate_one_time_keys()?; - - let one_time_keys = self.account.one_time_keys(); + async fn signed_one_time_keys(&self) -> StdResult { + let _ = self.generate_one_time_keys().await?; + let one_time_keys = self.account.lock().await.one_time_keys(); let mut one_time_key_map = HashMap::new(); for (key_id, key) in one_time_keys.curve25519().iter() { @@ -197,7 +236,7 @@ impl OlmMachine { "key": key, }); - let signature = self.sign_json(&key_json); + let signature = self.sign_json(&key_json).await; let mut signature_map = HashMap::new(); @@ -230,10 +269,11 @@ impl OlmMachine { /// /// * `json` - The value that should be converted into a canonical JSON /// string. - fn sign_json(&self, json: &Value) -> String { + async fn sign_json(&self, json: &Value) -> String { + let account = self.account.lock().await; let canonical_json = cjson::to_string(json) .unwrap_or_else(|_| panic!(format!("Can't serialize {} to canonical JSON", json))); - self.account.sign(&canonical_json) + account.sign(&canonical_json) } /// Verify a signed JSON object. @@ -305,18 +345,22 @@ impl OlmMachine { /// Get a tuple of device and one-time keys that need to be uploaded. /// /// Returns an empty error if no keys need to be uploaded. - pub fn keys_for_upload(&self) -> StdResult<(Option, Option), ()> { - if !self.should_upload_keys() { + pub async fn keys_for_upload( + &self, + ) -> StdResult<(Option, Option), ()> { + if !self.should_upload_keys().await { return Err(()); } - let device_keys = if !self.account.shared() { - Some(self.device_keys()) + let shared = self.account.lock().await.shared(); + + let device_keys = if !shared { + Some(self.device_keys().await) } else { None }; - let one_time_keys: Option = self.signed_one_time_keys().ok(); + let one_time_keys: Option = self.signed_one_time_keys().await.ok(); Ok((device_keys, one_time_keys)) } @@ -413,10 +457,10 @@ mod test { keys::upload_keys::Response::try_from(data).expect("Can't parse the keys upload response") } - #[test] - fn create_olm_machine() { + #[tokio::test] + async fn create_olm_machine() { let machine = OlmMachine::new(&user_id(), DEVICE_ID).unwrap(); - assert!(machine.should_upload_keys()); + assert!(machine.should_upload_keys().await); } #[tokio::test] @@ -429,23 +473,32 @@ mod test { .remove(&keys::KeyAlgorithm::SignedCurve25519) .unwrap(); - assert!(machine.should_upload_keys()); - machine.receive_keys_upload_response(&response).await; - assert!(machine.should_upload_keys()); + assert!(machine.should_upload_keys().await); + machine + .receive_keys_upload_response(&response) + .await + .unwrap(); + assert!(machine.should_upload_keys().await); response.one_time_key_counts.insert( keys::KeyAlgorithm::SignedCurve25519, UInt::try_from(10).unwrap(), ); - machine.receive_keys_upload_response(&response).await; - assert!(machine.should_upload_keys()); + machine + .receive_keys_upload_response(&response) + .await + .unwrap(); + assert!(machine.should_upload_keys().await); response.one_time_key_counts.insert( keys::KeyAlgorithm::SignedCurve25519, UInt::try_from(50).unwrap(), ); - machine.receive_keys_upload_response(&response).await; - assert!(!machine.should_upload_keys()); + machine + .receive_keys_upload_response(&response) + .await + .unwrap(); + assert!(!machine.should_upload_keys().await); } #[tokio::test] @@ -454,27 +507,33 @@ mod test { let mut response = keys_upload_response(); - assert!(machine.should_upload_keys()); - assert!(machine.generate_one_time_keys().is_err()); + assert!(machine.should_upload_keys().await); + assert!(machine.generate_one_time_keys().await.is_err()); - machine.receive_keys_upload_response(&response).await; - assert!(machine.should_upload_keys()); - assert!(machine.generate_one_time_keys().is_ok()); + machine + .receive_keys_upload_response(&response) + .await + .unwrap(); + assert!(machine.should_upload_keys().await); + assert!(machine.generate_one_time_keys().await.is_ok()); response.one_time_key_counts.insert( keys::KeyAlgorithm::SignedCurve25519, UInt::try_from(50).unwrap(), ); - machine.receive_keys_upload_response(&response).await; - assert!(machine.generate_one_time_keys().is_err()); + machine + .receive_keys_upload_response(&response) + .await + .unwrap(); + assert!(machine.generate_one_time_keys().await.is_err()); } - #[test] - fn test_device_key_signing() { + #[tokio::test] + async fn test_device_key_signing() { let machine = OlmMachine::new(&user_id(), DEVICE_ID).unwrap(); - let mut device_keys = machine.device_keys(); - let identity_keys = machine.account.identity_keys(); + let mut device_keys = machine.device_keys().await; + let identity_keys = machine.account.lock().await.identity_keys(); let ed25519_key = identity_keys.ed25519(); let ret = machine.verify_json( @@ -486,11 +545,11 @@ mod test { assert!(ret.is_ok()); } - #[test] - fn test_invalid_signature() { + #[tokio::test] + async fn test_invalid_signature() { let machine = OlmMachine::new(&user_id(), DEVICE_ID).unwrap(); - let mut device_keys = machine.device_keys(); + let mut device_keys = machine.device_keys().await; let ret = machine.verify_json( &machine.user_id, @@ -501,13 +560,13 @@ mod test { assert!(ret.is_err()); } - #[test] - fn test_one_time_key_signing() { + #[tokio::test] + async fn test_one_time_key_signing() { let mut machine = OlmMachine::new(&user_id(), DEVICE_ID).unwrap(); machine.uploaded_signed_key_count = Some(49); - let mut one_time_keys = machine.signed_one_time_keys().unwrap(); - let identity_keys = machine.account.identity_keys(); + let mut one_time_keys = machine.signed_one_time_keys().await.unwrap(); + let identity_keys = machine.account.lock().await.identity_keys(); let ed25519_key = identity_keys.ed25519(); let mut one_time_key = one_time_keys.values_mut().nth(0).unwrap(); @@ -526,11 +585,12 @@ mod test { let mut machine = OlmMachine::new(&user_id(), DEVICE_ID).unwrap(); machine.uploaded_signed_key_count = Some(0); - let identity_keys = machine.account.identity_keys(); + let identity_keys = machine.account.lock().await.identity_keys(); let ed25519_key = identity_keys.ed25519(); let (device_keys, mut one_time_keys) = machine .keys_for_upload() + .await .expect("Can't prepare initial key upload"); let ret = machine.verify_json( @@ -555,9 +615,12 @@ mod test { UInt::new_wrapping(one_time_keys.unwrap().len() as u64), ); - machine.receive_keys_upload_response(&response).await; + machine + .receive_keys_upload_response(&response) + .await + .unwrap(); - let ret = machine.keys_for_upload(); + let ret = machine.keys_for_upload().await; assert!(ret.is_err()); } } diff --git a/src/crypto/mod.rs b/src/crypto/mod.rs index 7571489c..e9f0b887 100644 --- a/src/crypto/mod.rs +++ b/src/crypto/mod.rs @@ -22,3 +22,4 @@ mod store; pub use error::OlmError; pub use machine::{OlmMachine, OneTimeKeys}; +pub use store::{CryptoStore, CryptoStoreError}; diff --git a/src/crypto/store/mod.rs b/src/crypto/store/mod.rs index 3ac0fe51..c8c18a3a 100644 --- a/src/crypto/store/mod.rs +++ b/src/crypto/store/mod.rs @@ -1,3 +1,4 @@ +use core::fmt::Debug; use std::io::Error as IoError; use std::sync::Arc; use url::ParseError; @@ -8,6 +9,7 @@ use tokio::sync::Mutex; use super::olm::Account; use olm_rs::errors::OlmAccountError; +use olm_rs::PicklingMode; #[cfg(feature = "sqlite-cryptostore")] pub mod sqlite; @@ -33,7 +35,41 @@ pub enum CryptoStoreError { pub type Result = std::result::Result; #[async_trait] -pub trait CryptoStore { - async fn load_account(&self) -> Result>; - async fn save_account(&self, account: Arc>) -> Result<()>; +pub trait CryptoStore: Debug { + async fn load_account(&mut self) -> Result>; + async fn save_account(&mut self, account: Arc>) -> Result<()>; +} + +#[derive(Debug)] +pub struct MemoryStore { + pub(crate) account_info: Option<(String, bool)>, +} + +impl MemoryStore { + /// Create a new empty memory store. + pub fn new() -> Self { + MemoryStore { account_info: None } + } +} + +#[async_trait] +impl CryptoStore for MemoryStore { + async fn load_account(&mut self) -> Result> { + let result = match &self.account_info { + Some((pickle, shared)) => Some(Account::from_pickle( + pickle.to_owned(), + PicklingMode::Unencrypted, + *shared, + )?), + None => None, + }; + Ok(result) + } + + async fn save_account(&mut self, account: Arc>) -> Result<()> { + let acc = account.lock().await; + let pickle = acc.pickle(PicklingMode::Unencrypted); + self.account_info = Some((pickle, acc.shared)); + Ok(()) + } } diff --git a/src/crypto/store/sqlite.rs b/src/crypto/store/sqlite.rs index 06fa3219..fca1093e 100644 --- a/src/crypto/store/sqlite.rs +++ b/src/crypto/store/sqlite.rs @@ -1,4 +1,5 @@ -use std::path::Path; +use std::path::{Path, PathBuf}; +use std::result::Result as StdResult; use std::sync::Arc; use url::Url; @@ -13,54 +14,60 @@ use super::{Account, CryptoStore, Result}; pub struct SqliteStore { user_id: Arc, device_id: Arc, + path: PathBuf, connection: Arc>, pickle_passphrase: Option>, } +impl std::fmt::Debug for SqliteStore { + fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> StdResult<(), std::fmt::Error> { + write!( + fmt, + "SqliteStore {{ user_id: {}, device_id: {}, path: {:?} }}", + self.user_id, self.device_id, self.path + ) + } +} + static DATABASE_NAME: &str = "matrix-sdk-crypto.db"; impl SqliteStore { - async fn open>(user_id: &str, device_id: &str, path: P) -> Result { - let url = SqliteStore::path_to_url(path)?; - SqliteStore::open_helper(user_id, device_id, url.as_ref(), None).await + pub async fn open>( + user_id: &str, + device_id: &str, + path: P, + ) -> Result { + SqliteStore::open_helper(user_id, device_id, path, None).await } - async fn open_with_passphrase>( + pub async fn open_with_passphrase>( user_id: &str, device_id: &str, path: P, passphrase: String, ) -> Result { - let url = SqliteStore::path_to_url(path)?; - SqliteStore::open_helper( - user_id, - device_id, - url.as_ref(), - Some(Zeroizing::new(passphrase)), - ) - .await + SqliteStore::open_helper(user_id, device_id, path, Some(Zeroizing::new(passphrase))).await } - async fn open_in_memory(user_id: &str, device_id: &str) -> Result { - SqliteStore::open_helper(user_id, device_id, "sqlite::memory:", None).await - } - - fn path_to_url>(path: P) -> Result { + fn path_to_url(path: &Path) -> Result { // TODO this returns an empty error if the path isn't absolute. - let url = Url::from_directory_path(path.as_ref()).expect("Invalid path"); + let url = Url::from_directory_path(path).expect("Invalid path"); Ok(url.join(DATABASE_NAME)?) } - async fn open_helper( + async fn open_helper>( user_id: &str, device_id: &str, - sqlite_url: &str, + path: P, passphrase: Option>, ) -> Result { - let connection = SqliteConnection::connect(sqlite_url).await.unwrap(); + let url = SqliteStore::path_to_url(path.as_ref())?; + + let connection = SqliteConnection::connect(url.as_ref()).await.unwrap(); let store = SqliteStore { user_id: Arc::new(user_id.to_owned()), device_id: Arc::new(device_id.to_owned()), + path: path.as_ref().to_owned(), connection: Arc::new(Mutex::new(connection)), pickle_passphrase: passphrase, }; @@ -100,7 +107,7 @@ impl SqliteStore { #[async_trait] impl CryptoStore for SqliteStore { - async fn load_account(&self) -> Result> { + async fn load_account(&mut self) -> Result> { let mut connection = self.connection.lock().await; let row: Option<(String, bool)> = query_as( @@ -124,7 +131,7 @@ impl CryptoStore for SqliteStore { Ok(result) } - async fn save_account(&self, account: Arc>) -> Result<()> { + async fn save_account(&mut self, account: Arc>) -> Result<()> { let acc = account.lock().await; let pickle = acc.pickle(self.get_pickle_mode()); let mut connection = self.connection.lock().await; @@ -178,12 +185,6 @@ mod test { .expect("Can't create store") } - async fn get_memory_store() -> SqliteStore { - SqliteStore::open_in_memory(USER_ID, DEVICE_ID) - .await - .expect("Can't create memory store") - } - fn get_account() -> Arc> { let account = Account::new(); Arc::new(Mutex::new(account)) @@ -200,7 +201,7 @@ mod test { #[tokio::test] async fn save_account() { - let store = get_store().await; + let mut store = get_store().await; let account = get_account(); store @@ -211,7 +212,7 @@ mod test { #[tokio::test] async fn load_account() { - let store = get_memory_store().await; + let mut store = get_store().await; let account = get_account(); store