crypto: Hook up the crypto store to the Olm machine.

This commit is contained in:
Damir Jelić 2020-03-18 15:50:32 +01:00
parent d7ab847b98
commit 4aba058695
7 changed files with 214 additions and 111 deletions

View file

@ -645,7 +645,7 @@ impl AsyncClient {
.write()
.await
.receive_keys_upload_response(&response)
.await;
.await?;
Ok(response)
}

View file

@ -331,7 +331,7 @@ impl Client {
let olm = self.olm.lock().await;
match &*olm {
Some(o) => o.should_upload_keys(),
Some(o) => o.should_upload_keys().await,
None => false,
}
}
@ -346,7 +346,7 @@ impl Client {
let olm = self.olm.lock().await;
match &*olm {
Some(o) => o.keys_for_upload(),
Some(o) => o.keys_for_upload().await,
None => Err(()),
}
}
@ -361,10 +361,11 @@ impl Client {
/// # Panics
/// Panics if the client hasn't been logged in.
#[cfg(feature = "encryption")]
pub async fn receive_keys_upload_response(&self, response: &KeysUploadResponse) {
pub async fn receive_keys_upload_response(&self, response: &KeysUploadResponse) -> Result<()> {
let mut olm = self.olm.lock().await;
let o = olm.as_mut().expect("Client isn't logged in.");
o.receive_keys_upload_response(response).await;
o.receive_keys_upload_response(response).await?;
Ok(())
}
}

View file

@ -18,6 +18,15 @@ use thiserror::Error;
use super::store::CryptoStoreError;
pub type Result<T> = std::result::Result<T, OlmError>;
#[derive(Error, Debug)]
pub enum OlmError {
#[error("signature verification failed")]
Signature(#[from] SignatureError),
#[error("failed to read or write to the crypto store {0}")]
Store(#[from] CryptoStoreError),
}
pub type VerificationResult<T> = std::result::Result<T, SignatureError>;
#[derive(Error, Debug)]
@ -37,11 +46,3 @@ impl From<CjsonError> for SignatureError {
Self::CanonicalJsonError(error)
}
}
#[derive(Error, Debug)]
pub enum OlmError {
#[error("signature verification failed")]
Signature(#[from] SignatureError),
#[error("failed to read or write to the crypto store {0}")]
Store(#[from] CryptoStoreError),
}

View file

@ -14,10 +14,16 @@
use std::collections::HashMap;
use std::convert::TryInto;
use std::path::Path;
use std::result::Result as StdResult;
use std::sync::Arc;
use super::error::{Result, SignatureError, VerificationResult};
use super::olm::Account;
#[cfg(feature = "sqlite-cryptostore")]
use super::store::sqlite::SqliteStore;
use super::store::MemoryStore;
use super::CryptoStore;
use crate::api;
use api::r0::keys;
@ -26,8 +32,8 @@ use cjson;
use olm_rs::utility::OlmUtility;
use serde_json::json;
use serde_json::Value;
use tokio::sync::Mutex;
use super::store::CryptoStoreError;
use ruma_client_api::r0::keys::{
AlgorithmAndDeviceId, DeviceKeys, KeyAlgorithm, OneTimeKey, SignedKey,
};
@ -40,9 +46,6 @@ use ruma_identifiers::{DeviceId, UserId};
pub type OneTimeKeys = HashMap<AlgorithmAndDeviceId, OneTimeKey>;
#[cfg(feature = "sqlite-cryptostore")]
use super::store::sqlite::SqliteStore;
#[derive(Debug)]
pub struct OlmMachine {
/// The unique user id that owns this account.
@ -50,12 +53,16 @@ pub struct OlmMachine {
/// The unique device id of the device that holds this account.
device_id: DeviceId,
/// Our underlying Olm Account holding our identity keys.
account: Account,
account: Arc<Mutex<Account>>,
/// The number of signed one-time keys we have uploaded to the server. If
/// this is None, no action will be taken. After a sync request the client
/// needs to set this for us, depending on the count we will suggest the
/// client to upload new keys.
uploaded_signed_key_count: Option<u64>,
/// Store for the encryption keys.
/// Persists all the encrytpion keys so a client can resume the session
/// without the need to create new keys.
store: Box<dyn CryptoStore>,
}
impl OlmMachine {
@ -69,14 +76,39 @@ impl OlmMachine {
Ok(OlmMachine {
user_id: user_id.clone(),
device_id: device_id.to_owned(),
account: Account::new(),
account: Arc::new(Mutex::new(Account::new())),
uploaded_signed_key_count: None,
store: Box::new(MemoryStore::new()),
})
}
#[cfg(feature = "sqlite-cryptostore")]
pub async fn new_with_sqlite_store<P: AsRef<Path>>(
user_id: &UserId,
device_id: &str,
path: P,
passphrase: String,
) -> Result<Self> {
Ok(OlmMachine {
user_id: user_id.clone(),
device_id: device_id.to_owned(),
account: Arc::new(Mutex::new(Account::new())),
uploaded_signed_key_count: None,
store: Box::new(
SqliteStore::open_with_passphrase(
&user_id.to_string(),
device_id,
path,
passphrase,
)
.await?,
),
})
}
/// Should account or one-time keys be uploaded to the server.
pub fn should_upload_keys(&self) -> bool {
if !self.account.shared() {
pub async fn should_upload_keys(&self) -> bool {
if !self.account.lock().await.shared() {
return true;
}
@ -84,7 +116,7 @@ impl OlmMachine {
// max_one_time_Keys() / 2, otherwise tell the client to upload more.
match self.uploaded_signed_key_count {
Some(count) => {
let max_keys = self.account.max_one_time_keys() as u64;
let max_keys = self.account.lock().await.max_one_time_keys() as u64;
let key_count = (max_keys / 2) - count;
key_count > 0
}
@ -98,8 +130,12 @@ impl OlmMachine {
///
/// * `response` - The keys upload response of the request that the client
/// performed.
pub async fn receive_keys_upload_response(&mut self, response: &keys::upload_keys::Response) {
self.account.shared = true;
pub async fn receive_keys_upload_response(
&mut self,
response: &keys::upload_keys::Response,
) -> Result<()> {
let mut account = self.account.lock().await;
account.shared = true;
let one_time_key_count = response
.one_time_key_counts
@ -108,18 +144,22 @@ impl OlmMachine {
let count: u64 = one_time_key_count.map_or(0, |c| (*c).into());
self.uploaded_signed_key_count = Some(count);
self.account.mark_keys_as_published();
// TODO save the account here.
account.mark_keys_as_published();
drop(account);
self.store.save_account(self.account.clone()).await?;
Ok(())
}
/// Generate new one-time keys.
///
/// Returns the number of newly generated one-time keys. If no keys can be
/// generated returns an empty error.
fn generate_one_time_keys(&self) -> StdResult<u64, ()> {
async fn generate_one_time_keys(&self) -> StdResult<u64, ()> {
let account = self.account.lock().await;
match self.uploaded_signed_key_count {
Some(count) => {
let max_keys = self.account.max_one_time_keys() as u64;
let max_keys = account.max_one_time_keys() as u64;
let max_on_server = max_keys / 2;
if count >= (max_on_server) {
@ -130,9 +170,9 @@ impl OlmMachine {
let key_count: usize = key_count
.try_into()
.unwrap_or_else(|_| self.account.max_one_time_keys());
.unwrap_or_else(|_| account.max_one_time_keys());
self.account.generate_one_time_keys(key_count);
account.generate_one_time_keys(key_count);
Ok(key_count as u64)
}
None => Err(()),
@ -140,8 +180,8 @@ impl OlmMachine {
}
/// Sign the device keys and return a JSON Value to upload them.
fn device_keys(&self) -> DeviceKeys {
let identity_keys = self.account.identity_keys();
async fn device_keys(&self) -> DeviceKeys {
let identity_keys = self.account.lock().await.identity_keys();
let mut keys = HashMap::new();
@ -166,7 +206,7 @@ impl OlmMachine {
let mut signature = HashMap::new();
signature.insert(
AlgorithmAndDeviceId(KeyAlgorithm::Ed25519, self.device_id.clone()),
self.sign_json(&device_keys),
self.sign_json(&device_keys).await,
);
signatures.insert(self.user_id.clone(), signature);
@ -186,10 +226,9 @@ impl OlmMachine {
/// Generate, sign and prepare one-time keys to be uploaded.
///
/// If no one-time keys need to be uploaded returns an empty error.
fn signed_one_time_keys(&self) -> StdResult<OneTimeKeys, ()> {
let _ = self.generate_one_time_keys()?;
let one_time_keys = self.account.one_time_keys();
async fn signed_one_time_keys(&self) -> StdResult<OneTimeKeys, ()> {
let _ = self.generate_one_time_keys().await?;
let one_time_keys = self.account.lock().await.one_time_keys();
let mut one_time_key_map = HashMap::new();
for (key_id, key) in one_time_keys.curve25519().iter() {
@ -197,7 +236,7 @@ impl OlmMachine {
"key": key,
});
let signature = self.sign_json(&key_json);
let signature = self.sign_json(&key_json).await;
let mut signature_map = HashMap::new();
@ -230,10 +269,11 @@ impl OlmMachine {
///
/// * `json` - The value that should be converted into a canonical JSON
/// string.
fn sign_json(&self, json: &Value) -> String {
async fn sign_json(&self, json: &Value) -> String {
let account = self.account.lock().await;
let canonical_json = cjson::to_string(json)
.unwrap_or_else(|_| panic!(format!("Can't serialize {} to canonical JSON", json)));
self.account.sign(&canonical_json)
account.sign(&canonical_json)
}
/// Verify a signed JSON object.
@ -305,18 +345,22 @@ impl OlmMachine {
/// Get a tuple of device and one-time keys that need to be uploaded.
///
/// Returns an empty error if no keys need to be uploaded.
pub fn keys_for_upload(&self) -> StdResult<(Option<DeviceKeys>, Option<OneTimeKeys>), ()> {
if !self.should_upload_keys() {
pub async fn keys_for_upload(
&self,
) -> StdResult<(Option<DeviceKeys>, Option<OneTimeKeys>), ()> {
if !self.should_upload_keys().await {
return Err(());
}
let device_keys = if !self.account.shared() {
Some(self.device_keys())
let shared = self.account.lock().await.shared();
let device_keys = if !shared {
Some(self.device_keys().await)
} else {
None
};
let one_time_keys: Option<OneTimeKeys> = self.signed_one_time_keys().ok();
let one_time_keys: Option<OneTimeKeys> = self.signed_one_time_keys().await.ok();
Ok((device_keys, one_time_keys))
}
@ -413,10 +457,10 @@ mod test {
keys::upload_keys::Response::try_from(data).expect("Can't parse the keys upload response")
}
#[test]
fn create_olm_machine() {
#[tokio::test]
async fn create_olm_machine() {
let machine = OlmMachine::new(&user_id(), DEVICE_ID).unwrap();
assert!(machine.should_upload_keys());
assert!(machine.should_upload_keys().await);
}
#[tokio::test]
@ -429,23 +473,32 @@ mod test {
.remove(&keys::KeyAlgorithm::SignedCurve25519)
.unwrap();
assert!(machine.should_upload_keys());
machine.receive_keys_upload_response(&response).await;
assert!(machine.should_upload_keys());
assert!(machine.should_upload_keys().await);
machine
.receive_keys_upload_response(&response)
.await
.unwrap();
assert!(machine.should_upload_keys().await);
response.one_time_key_counts.insert(
keys::KeyAlgorithm::SignedCurve25519,
UInt::try_from(10).unwrap(),
);
machine.receive_keys_upload_response(&response).await;
assert!(machine.should_upload_keys());
machine
.receive_keys_upload_response(&response)
.await
.unwrap();
assert!(machine.should_upload_keys().await);
response.one_time_key_counts.insert(
keys::KeyAlgorithm::SignedCurve25519,
UInt::try_from(50).unwrap(),
);
machine.receive_keys_upload_response(&response).await;
assert!(!machine.should_upload_keys());
machine
.receive_keys_upload_response(&response)
.await
.unwrap();
assert!(!machine.should_upload_keys().await);
}
#[tokio::test]
@ -454,27 +507,33 @@ mod test {
let mut response = keys_upload_response();
assert!(machine.should_upload_keys());
assert!(machine.generate_one_time_keys().is_err());
assert!(machine.should_upload_keys().await);
assert!(machine.generate_one_time_keys().await.is_err());
machine.receive_keys_upload_response(&response).await;
assert!(machine.should_upload_keys());
assert!(machine.generate_one_time_keys().is_ok());
machine
.receive_keys_upload_response(&response)
.await
.unwrap();
assert!(machine.should_upload_keys().await);
assert!(machine.generate_one_time_keys().await.is_ok());
response.one_time_key_counts.insert(
keys::KeyAlgorithm::SignedCurve25519,
UInt::try_from(50).unwrap(),
);
machine.receive_keys_upload_response(&response).await;
assert!(machine.generate_one_time_keys().is_err());
machine
.receive_keys_upload_response(&response)
.await
.unwrap();
assert!(machine.generate_one_time_keys().await.is_err());
}
#[test]
fn test_device_key_signing() {
#[tokio::test]
async fn test_device_key_signing() {
let machine = OlmMachine::new(&user_id(), DEVICE_ID).unwrap();
let mut device_keys = machine.device_keys();
let identity_keys = machine.account.identity_keys();
let mut device_keys = machine.device_keys().await;
let identity_keys = machine.account.lock().await.identity_keys();
let ed25519_key = identity_keys.ed25519();
let ret = machine.verify_json(
@ -486,11 +545,11 @@ mod test {
assert!(ret.is_ok());
}
#[test]
fn test_invalid_signature() {
#[tokio::test]
async fn test_invalid_signature() {
let machine = OlmMachine::new(&user_id(), DEVICE_ID).unwrap();
let mut device_keys = machine.device_keys();
let mut device_keys = machine.device_keys().await;
let ret = machine.verify_json(
&machine.user_id,
@ -501,13 +560,13 @@ mod test {
assert!(ret.is_err());
}
#[test]
fn test_one_time_key_signing() {
#[tokio::test]
async fn test_one_time_key_signing() {
let mut machine = OlmMachine::new(&user_id(), DEVICE_ID).unwrap();
machine.uploaded_signed_key_count = Some(49);
let mut one_time_keys = machine.signed_one_time_keys().unwrap();
let identity_keys = machine.account.identity_keys();
let mut one_time_keys = machine.signed_one_time_keys().await.unwrap();
let identity_keys = machine.account.lock().await.identity_keys();
let ed25519_key = identity_keys.ed25519();
let mut one_time_key = one_time_keys.values_mut().nth(0).unwrap();
@ -526,11 +585,12 @@ mod test {
let mut machine = OlmMachine::new(&user_id(), DEVICE_ID).unwrap();
machine.uploaded_signed_key_count = Some(0);
let identity_keys = machine.account.identity_keys();
let identity_keys = machine.account.lock().await.identity_keys();
let ed25519_key = identity_keys.ed25519();
let (device_keys, mut one_time_keys) = machine
.keys_for_upload()
.await
.expect("Can't prepare initial key upload");
let ret = machine.verify_json(
@ -555,9 +615,12 @@ mod test {
UInt::new_wrapping(one_time_keys.unwrap().len() as u64),
);
machine.receive_keys_upload_response(&response).await;
machine
.receive_keys_upload_response(&response)
.await
.unwrap();
let ret = machine.keys_for_upload();
let ret = machine.keys_for_upload().await;
assert!(ret.is_err());
}
}

View file

@ -22,3 +22,4 @@ mod store;
pub use error::OlmError;
pub use machine::{OlmMachine, OneTimeKeys};
pub use store::{CryptoStore, CryptoStoreError};

View file

@ -1,3 +1,4 @@
use core::fmt::Debug;
use std::io::Error as IoError;
use std::sync::Arc;
use url::ParseError;
@ -8,6 +9,7 @@ use tokio::sync::Mutex;
use super::olm::Account;
use olm_rs::errors::OlmAccountError;
use olm_rs::PicklingMode;
#[cfg(feature = "sqlite-cryptostore")]
pub mod sqlite;
@ -33,7 +35,41 @@ pub enum CryptoStoreError {
pub type Result<T> = std::result::Result<T, CryptoStoreError>;
#[async_trait]
pub trait CryptoStore {
async fn load_account(&self) -> Result<Option<Account>>;
async fn save_account(&self, account: Arc<Mutex<Account>>) -> Result<()>;
pub trait CryptoStore: Debug {
async fn load_account(&mut self) -> Result<Option<Account>>;
async fn save_account(&mut self, account: Arc<Mutex<Account>>) -> Result<()>;
}
#[derive(Debug)]
pub struct MemoryStore {
pub(crate) account_info: Option<(String, bool)>,
}
impl MemoryStore {
/// Create a new empty memory store.
pub fn new() -> Self {
MemoryStore { account_info: None }
}
}
#[async_trait]
impl CryptoStore for MemoryStore {
async fn load_account(&mut self) -> Result<Option<Account>> {
let result = match &self.account_info {
Some((pickle, shared)) => Some(Account::from_pickle(
pickle.to_owned(),
PicklingMode::Unencrypted,
*shared,
)?),
None => None,
};
Ok(result)
}
async fn save_account(&mut self, account: Arc<Mutex<Account>>) -> Result<()> {
let acc = account.lock().await;
let pickle = acc.pickle(PicklingMode::Unencrypted);
self.account_info = Some((pickle, acc.shared));
Ok(())
}
}

View file

@ -1,4 +1,5 @@
use std::path::Path;
use std::path::{Path, PathBuf};
use std::result::Result as StdResult;
use std::sync::Arc;
use url::Url;
@ -13,54 +14,60 @@ use super::{Account, CryptoStore, Result};
pub struct SqliteStore {
user_id: Arc<String>,
device_id: Arc<String>,
path: PathBuf,
connection: Arc<Mutex<SqliteConnection>>,
pickle_passphrase: Option<Zeroizing<String>>,
}
impl std::fmt::Debug for SqliteStore {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> StdResult<(), std::fmt::Error> {
write!(
fmt,
"SqliteStore {{ user_id: {}, device_id: {}, path: {:?} }}",
self.user_id, self.device_id, self.path
)
}
}
static DATABASE_NAME: &str = "matrix-sdk-crypto.db";
impl SqliteStore {
async fn open<P: AsRef<Path>>(user_id: &str, device_id: &str, path: P) -> Result<SqliteStore> {
let url = SqliteStore::path_to_url(path)?;
SqliteStore::open_helper(user_id, device_id, url.as_ref(), None).await
pub async fn open<P: AsRef<Path>>(
user_id: &str,
device_id: &str,
path: P,
) -> Result<SqliteStore> {
SqliteStore::open_helper(user_id, device_id, path, None).await
}
async fn open_with_passphrase<P: AsRef<Path>>(
pub async fn open_with_passphrase<P: AsRef<Path>>(
user_id: &str,
device_id: &str,
path: P,
passphrase: String,
) -> Result<SqliteStore> {
let url = SqliteStore::path_to_url(path)?;
SqliteStore::open_helper(
user_id,
device_id,
url.as_ref(),
Some(Zeroizing::new(passphrase)),
)
.await
SqliteStore::open_helper(user_id, device_id, path, Some(Zeroizing::new(passphrase))).await
}
async fn open_in_memory(user_id: &str, device_id: &str) -> Result<SqliteStore> {
SqliteStore::open_helper(user_id, device_id, "sqlite::memory:", None).await
}
fn path_to_url<P: AsRef<Path>>(path: P) -> Result<Url> {
fn path_to_url(path: &Path) -> Result<Url> {
// TODO this returns an empty error if the path isn't absolute.
let url = Url::from_directory_path(path.as_ref()).expect("Invalid path");
let url = Url::from_directory_path(path).expect("Invalid path");
Ok(url.join(DATABASE_NAME)?)
}
async fn open_helper(
async fn open_helper<P: AsRef<Path>>(
user_id: &str,
device_id: &str,
sqlite_url: &str,
path: P,
passphrase: Option<Zeroizing<String>>,
) -> Result<SqliteStore> {
let connection = SqliteConnection::connect(sqlite_url).await.unwrap();
let url = SqliteStore::path_to_url(path.as_ref())?;
let connection = SqliteConnection::connect(url.as_ref()).await.unwrap();
let store = SqliteStore {
user_id: Arc::new(user_id.to_owned()),
device_id: Arc::new(device_id.to_owned()),
path: path.as_ref().to_owned(),
connection: Arc::new(Mutex::new(connection)),
pickle_passphrase: passphrase,
};
@ -100,7 +107,7 @@ impl SqliteStore {
#[async_trait]
impl CryptoStore for SqliteStore {
async fn load_account(&self) -> Result<Option<Account>> {
async fn load_account(&mut self) -> Result<Option<Account>> {
let mut connection = self.connection.lock().await;
let row: Option<(String, bool)> = query_as(
@ -124,7 +131,7 @@ impl CryptoStore for SqliteStore {
Ok(result)
}
async fn save_account(&self, account: Arc<Mutex<Account>>) -> Result<()> {
async fn save_account(&mut self, account: Arc<Mutex<Account>>) -> Result<()> {
let acc = account.lock().await;
let pickle = acc.pickle(self.get_pickle_mode());
let mut connection = self.connection.lock().await;
@ -178,12 +185,6 @@ mod test {
.expect("Can't create store")
}
async fn get_memory_store() -> SqliteStore {
SqliteStore::open_in_memory(USER_ID, DEVICE_ID)
.await
.expect("Can't create memory store")
}
fn get_account() -> Arc<Mutex<Account>> {
let account = Account::new();
Arc::new(Mutex::new(account))
@ -200,7 +201,7 @@ mod test {
#[tokio::test]
async fn save_account() {
let store = get_store().await;
let mut store = get_store().await;
let account = get_account();
store
@ -211,7 +212,7 @@ mod test {
#[tokio::test]
async fn load_account() {
let store = get_memory_store().await;
let mut store = get_store().await;
let account = get_account();
store