From 15d8de56e14cc3d7c2852a9a088a2797b057c6bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Damir=20Jeli=C4=87?= Date: Tue, 25 Feb 2020 14:24:18 +0100 Subject: [PATCH] crypto: Add an initial version of the olm state machine. --- Cargo.toml | 1 + src/crypto/machine.rs | 181 ++++++++++++++++++++++++++++++++++++++++++ src/crypto/mod.rs | 3 +- src/crypto/olm.rs | 21 ++++- 4 files changed, 201 insertions(+), 5 deletions(-) create mode 100644 src/crypto/machine.rs diff --git a/Cargo.toml b/Cargo.toml index 7edebbbf..74061d10 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,5 +34,6 @@ serde_json = { version = "*", optional = true } [dev-dependencies] tokio = { version = "0.2.11", features = ["full"] } +async-std = { version = "1.5.0", features = ["attributes"] } url = "2.1.1" mockito = "0.23.3" diff --git a/src/crypto/machine.rs b/src/crypto/machine.rs new file mode 100644 index 00000000..2037b5c6 --- /dev/null +++ b/src/crypto/machine.rs @@ -0,0 +1,181 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::convert::TryInto; + +use super::olm::Account; +use crate::api; + +use api::r0::keys; + +struct OlmMachine { + /// The unique user id that owns this account. + user_id: String, + /// The unique device id of the device that holds this account. + device_id: String, + /// Our underlying Olm Account holding our identity keys. + account: Account, + /// The number of one-time keys we have uploaded to the server. If this is + /// None, no action will be taken. After a sync request the client needs to + /// set this for us, depending on the count we will suggest the client + /// to upload new keys. + uploaded_key_count: Option, +} + +impl OlmMachine { + /// Create a new account. + pub fn new(user_id: &str, device_id: &str) -> Self { + OlmMachine { + user_id: user_id.to_owned(), + device_id: device_id.to_owned(), + account: Account::new(), + uploaded_key_count: None, + } + } + + /// Should account or one-time keys be uploaded to the server. + pub fn should_upload_keys(&self) -> bool { + if !self.account.shared() { + return true; + } + + // If we have a known key count, check that we have more than + // max_one_time_Keys() / 2, otherwise tell the client to upload more. + match self.uploaded_key_count { + Some(count) => { + let max_keys = self.account.max_one_time_keys() as u64; + let key_count = (max_keys / 2) - count; + key_count > 0 + } + None => false, + } + } + + /// Receive a successfull keys upload response. + /// + /// # Arugments + /// + /// `response` - The keys upload response of the request that the client + /// performed. + pub async fn receive_keys_upload_response(&mut self, response: &keys::upload_keys::Response) { + self.account.shared = true; + let one_time_key_count = response + .one_time_key_counts + .get(&keys::KeyAlgorithm::SignedCurve25519); + + if let Some(c) = one_time_key_count { + let count: u64 = (*c).into(); + self.uploaded_key_count = Some(count); + } + + self.account.mark_keys_as_published(); + // TODO save the account here. + } + + /// Generate new one-time keys. + /// + /// Returns the number of newly generated one-time keys. If no keys can be + /// generated returns an empty error. + fn generate_one_time_keys(&self) -> Result { + match self.uploaded_key_count { + Some(count) => { + let max_keys = self.account.max_one_time_keys() as u64; + let key_count = (max_keys / 2) - count; + + if key_count <= 0 { + return Err(()) + } + + let key_count: usize = key_count.try_into().unwrap_or_else(|_| self.account.max_one_time_keys()); + + self.account.generate_one_time_keys(key_count); + Ok(key_count as u64) + }, + None => Err(()) + } + } +} + +#[cfg(test)] +mod test { + const USER_ID: &str = "@test:example.org"; + const DEVICE_ID: &str = "DEVICEID"; + + use std::convert::TryFrom; + use std::fs::File; + use std::io::prelude::*; + use js_int::UInt; + + use crate::api::r0::keys; + use crate::crypto::machine::OlmMachine; + use http::Response; + + fn response_from_file(path: &str) -> Response> { + let mut file = File::open(path).expect(&format!("No such data file found {}", path)); + let mut contents = Vec::new(); + file.read_to_end(&mut contents) + .expect(&format!("Can't read data file {}", path)); + + Response::builder().status(200).body(contents).unwrap() + } + + fn keys_upload_response() -> keys::upload_keys::Response { + let data = response_from_file("tests/data/keys_upload.json"); + keys::upload_keys::Response::try_from(data).expect("Can't parse the keys upload response") + } + + #[test] + fn create_olm_machine() { + let machine = OlmMachine::new(USER_ID, DEVICE_ID); + assert!(machine.should_upload_keys()); + } + + #[async_std::test] + async fn receive_keys_upload_response() { + let mut machine = OlmMachine::new(USER_ID, DEVICE_ID); + let mut response = keys_upload_response(); + + response.one_time_key_counts.remove(&keys::KeyAlgorithm::SignedCurve25519).unwrap(); + + assert!(machine.should_upload_keys()); + machine.receive_keys_upload_response(&response).await; + assert!(!machine.should_upload_keys()); + + response.one_time_key_counts.insert(keys::KeyAlgorithm::SignedCurve25519, UInt::try_from(10).unwrap()); + machine.receive_keys_upload_response(&response).await; + assert!(machine.should_upload_keys()); + + response.one_time_key_counts.insert(keys::KeyAlgorithm::SignedCurve25519, UInt::try_from(50).unwrap()); + machine.receive_keys_upload_response(&response).await; + assert!(!machine.should_upload_keys()); + } + + #[async_std::test] + async fn generate_one_time_keys() { + let mut machine = OlmMachine::new(USER_ID, DEVICE_ID); + + let mut response = keys_upload_response(); + + assert!(machine.should_upload_keys()); + assert!(machine.generate_one_time_keys().is_err()); + + machine.receive_keys_upload_response(&response).await; + assert!(machine.should_upload_keys()); + assert!(machine.generate_one_time_keys().is_ok()); + + response.one_time_key_counts.insert(keys::KeyAlgorithm::SignedCurve25519, UInt::try_from(50).unwrap()); + machine.receive_keys_upload_response(&response).await; + assert!(machine.generate_one_time_keys().is_err()); + } +} diff --git a/src/crypto/mod.rs b/src/crypto/mod.rs index 08b5dba4..ab8cce3f 100644 --- a/src/crypto/mod.rs +++ b/src/crypto/mod.rs @@ -14,5 +14,6 @@ // TODO remove this. #[allow(dead_code)] - +mod machine; +#[allow(dead_code)] mod olm; diff --git a/src/crypto/olm.rs b/src/crypto/olm.rs index a049cc69..33617b68 100644 --- a/src/crypto/olm.rs +++ b/src/crypto/olm.rs @@ -18,7 +18,6 @@ use std::collections::{hash_map::Iter, hash_map::Keys, hash_map::Values, HashMap use serde; use serde::Deserialize; - /// Struct representing the parsed result of `OlmAccount::identity_keys()`. #[derive(Deserialize, Debug, PartialEq)] pub struct IdentityKeys { @@ -109,7 +108,7 @@ impl OneTimeKeys { pub struct Account { inner: OlmAccount, - shared: bool, + pub(crate) shared: bool, } impl Account { @@ -148,6 +147,10 @@ impl Account { self.inner.max_number_of_one_time_keys() } + /// Mark the current set of one-time keys as being published. + pub fn mark_keys_as_published(&self) { + self.inner.mark_keys_as_published(); + } } #[cfg(test)] @@ -165,7 +168,10 @@ mod test { assert_ne!(identyty_keys.keys().len(), 0); assert_ne!(identyty_keys.iter().len(), 0); assert!(identyty_keys.contains_key("ed25519")); - assert_eq!(identyty_keys.ed25519(), identyty_keys.get("ed25519").unwrap()); + assert_eq!( + identyty_keys.ed25519(), + identyty_keys.get("ed25519").unwrap() + ); assert!(!identyty_keys.curve25519().is_empty()); } @@ -186,6 +192,13 @@ mod test { assert_ne!(one_time_keys.iter().len(), 0); assert!(one_time_keys.contains_key("curve25519")); assert_eq!(one_time_keys.curve25519().keys().len(), 10); - assert_eq!(one_time_keys.curve25519(), one_time_keys.get("curve25519").unwrap()); + assert_eq!( + one_time_keys.curve25519(), + one_time_keys.get("curve25519").unwrap() + ); + + account.mark_keys_as_published(); + let one_time_keys = account.one_time_keys(); + assert!(one_time_keys.curve25519().is_empty()); } }