diff --git a/matrix_sdk_crypto/Cargo.toml b/matrix_sdk_crypto/Cargo.toml index 70f65023..1ed7498c 100644 --- a/matrix_sdk_crypto/Cargo.toml +++ b/matrix_sdk_crypto/Cargo.toml @@ -25,7 +25,8 @@ async-trait = "0.1.40" matrix-sdk-common-macros = { version = "0.1.0", path = "../matrix_sdk_common_macros" } matrix-sdk-common = { version = "0.1.0", path = "../matrix_sdk_common" } -olm-rs = { git = 'https://gitlab.gnome.org/jhaye/olm-rs/', features = ["serde"]} +olm-rs = { version = "0.6.0", features = ["serde"] } +getrandom = "0.1.14" serde = { version = "1.0.115", features = ["derive", "rc"] } serde_json = "1.0.57" cjson = "0.1.1" @@ -37,6 +38,12 @@ thiserror = "1.0.20" tracing = "0.1.19" atomic = "0.5.0" dashmap = "3.11.10" +sha2 = "0.9.1" +aes-ctr = "0.4.0" +pbkdf2 = { version = "0.5.0", default-features = false } +hmac = "0.9.0" +base64 = "0.12.3" +byteorder = "1.3.4" [dependencies.tracing-futures] version = "0.2.4" @@ -47,7 +54,7 @@ features = ["std", "std-future"] version = "0.3.5" optional = true default-features = false -features = ["runtime-tokio", "sqlite"] +features = ["runtime-tokio", "sqlite", "macros"] [dev-dependencies] tokio = { version = "0.2.22", features = ["rt-threaded", "macros"] } @@ -57,3 +64,4 @@ serde_json = "1.0.57" tempfile = "3.1.0" http = "0.2.1" matrix-sdk-test = { version = "0.1.0", path = "../matrix_sdk_test" } +indoc = "1.0.2" diff --git a/matrix_sdk_crypto/src/identities/mod.rs b/matrix_sdk_crypto/src/identities/mod.rs index 853447dc..c1941b50 100644 --- a/matrix_sdk_crypto/src/identities/mod.rs +++ b/matrix_sdk_crypto/src/identities/mod.rs @@ -41,7 +41,7 @@ //! Both identity sets need to reqularly fetched from the server using the //! `/keys/query` API call. pub(crate) mod device; -mod user; +pub(crate) mod user; pub use device::{Device, LocalTrust, ReadOnlyDevice, UserDevices}; pub use user::{ diff --git a/matrix_sdk_crypto/src/identities/user.rs b/matrix_sdk_crypto/src/identities/user.rs index 1444e7ac..f50e02cd 100644 --- a/matrix_sdk_crypto/src/identities/user.rs +++ b/matrix_sdk_crypto/src/identities/user.rs @@ -13,6 +13,7 @@ // limitations under the License. use std::{ + collections::{btree_map::Iter, BTreeMap}, convert::TryFrom, sync::{ atomic::{AtomicBool, Ordering}, @@ -24,7 +25,7 @@ use serde::{Deserialize, Serialize}; use serde_json::to_value; use matrix_sdk_common::{ - api::r0::keys::CrossSigningKey, + api::r0::keys::{CrossSigningKey, KeyUsage}, identifiers::{DeviceKeyId, UserId}, }; @@ -55,6 +56,54 @@ impl PartialEq for MasterPubkey { } } +impl PartialEq for SelfSigningPubkey { + fn eq(&self, other: &SelfSigningPubkey) -> bool { + self.0.user_id == other.0.user_id && self.0.keys == other.0.keys + } +} + +impl PartialEq for UserSigningPubkey { + fn eq(&self, other: &UserSigningPubkey) -> bool { + self.0.user_id == other.0.user_id && self.0.keys == other.0.keys + } +} + +impl From for MasterPubkey { + fn from(key: CrossSigningKey) -> Self { + Self(Arc::new(key)) + } +} + +impl From for SelfSigningPubkey { + fn from(key: CrossSigningKey) -> Self { + Self(Arc::new(key)) + } +} + +impl From for UserSigningPubkey { + fn from(key: CrossSigningKey) -> Self { + Self(Arc::new(key)) + } +} + +impl AsRef for MasterPubkey { + fn as_ref(&self) -> &CrossSigningKey { + &self.0 + } +} + +impl AsRef for SelfSigningPubkey { + fn as_ref(&self) -> &CrossSigningKey { + &self.0 + } +} + +impl AsRef for UserSigningPubkey { + fn as_ref(&self) -> &CrossSigningKey { + &self.0 + } +} + impl From<&CrossSigningKey> for MasterPubkey { fn from(key: &CrossSigningKey) -> Self { Self(Arc::new(key.clone())) @@ -117,6 +166,21 @@ impl MasterPubkey { &self.0.user_id } + /// Get the keys map of containing the master keys. + pub fn keys(&self) -> &BTreeMap { + &self.0.keys + } + + /// Get the list of `KeyUsage` that is set for this key. + pub fn usage(&self) -> &[KeyUsage] { + &self.0.usage + } + + /// Get the signatures map of this cross signing key. + pub fn signatures(&self) -> &BTreeMap> { + &self.0.signatures + } + /// Get the master key with the given key id. /// /// # Arguments @@ -167,12 +231,26 @@ impl MasterPubkey { } } +impl<'a> IntoIterator for &'a MasterPubkey { + type Item = (&'a String, &'a String); + type IntoIter = Iter<'a, String, String>; + + fn into_iter(self) -> Self::IntoIter { + self.keys().iter() + } +} + impl UserSigningPubkey { /// Get the user id of the user signing key's owner. pub fn user_id(&self) -> &UserId { &self.0.user_id } + /// Get the keys map of containing the user signing keys. + pub fn keys(&self) -> &BTreeMap { + &self.0.keys + } + /// Check if the given master key is signed by this user signing key. /// /// # Arguments @@ -202,12 +280,26 @@ impl UserSigningPubkey { } } +impl<'a> IntoIterator for &'a UserSigningPubkey { + type Item = (&'a String, &'a String); + type IntoIter = Iter<'a, String, String>; + + fn into_iter(self) -> Self::IntoIter { + self.keys().iter() + } +} + impl SelfSigningPubkey { /// Get the user id of the self signing key's owner. pub fn user_id(&self) -> &UserId { &self.0.user_id } + /// Get the keys map of containing the self signing keys. + pub fn keys(&self) -> &BTreeMap { + &self.0.keys + } + /// Check if the given device is signed by this self signing key. /// /// # Arguments @@ -236,6 +328,15 @@ impl SelfSigningPubkey { } } +impl<'a> IntoIterator for &'a SelfSigningPubkey { + type Item = (&'a String, &'a String); + type IntoIter = Iter<'a, String, String>; + + fn into_iter(self) -> Self::IntoIter { + self.keys().iter() + } +} + /// Enum over the different user identity types we can have. #[derive(Debug, Clone)] pub enum UserIdentities { @@ -245,6 +346,18 @@ pub enum UserIdentities { Other(UserIdentity), } +impl From for UserIdentities { + fn from(identity: OwnUserIdentity) -> Self { + UserIdentities::Own(identity) + } +} + +impl From for UserIdentities { + fn from(identity: UserIdentity) -> Self { + UserIdentities::Other(identity) + } +} + impl UserIdentities { /// The unique user id of this identity. pub fn user_id(&self) -> &UserId { @@ -262,6 +375,23 @@ impl UserIdentities { } } + /// Get the self-signing key of the identity. + pub fn self_signing_key(&self) -> &SelfSigningPubkey { + match self { + UserIdentities::Own(i) => &i.self_signing_key, + UserIdentities::Other(i) => &i.self_signing_key, + } + } + + /// Get the user-signing key of the identity, this is only present for our + /// own user identity.. + pub fn user_signing_key(&self) -> Option<&UserSigningPubkey> { + match self { + UserIdentities::Own(i) => Some(&i.user_signing_key), + UserIdentities::Other(_) => None, + } + } + /// Destructure the enum into an `OwnUserIdentity` if it's of the correct /// type. pub fn own(&self) -> Option<&OwnUserIdentity> { @@ -324,6 +454,11 @@ impl UserIdentity { &self.master_key } + /// Get the public self-signing key of the identity. + pub fn self_signing_key(&self) -> &SelfSigningPubkey { + &self.self_signing_key + } + /// Update the identity with a new master key and self signing key. /// /// # Arguments @@ -424,6 +559,16 @@ impl OwnUserIdentity { &self.master_key } + /// Get the public self-signing key of the identity. + pub fn self_signing_key(&self) -> &SelfSigningPubkey { + &self.self_signing_key + } + + /// Get the public user-signing key of the identity. + pub fn user_signing_key(&self) -> &UserSigningPubkey { + &self.user_signing_key + } + /// Check if the given identity has been signed by this identity. /// /// # Arguments @@ -504,7 +649,7 @@ impl OwnUserIdentity { } #[cfg(test)] -mod test { +pub(crate) mod test { use serde_json::json; use std::{convert::TryFrom, sync::Arc}; @@ -697,6 +842,20 @@ mod test { OwnUserIdentity::new(master_key.into(), self_signing.into(), user_signing.into()).unwrap() } + pub(crate) fn get_own_identity() -> OwnUserIdentity { + own_identity(&own_key_query()) + } + + pub(crate) fn get_other_identity() -> UserIdentity { + let user_id = user_id!("@example2:localhost"); + let response = other_key_query(); + + let master_key = response.master_keys.get(&user_id).unwrap(); + let self_signing = response.self_signing_keys.get(&user_id).unwrap(); + + UserIdentity::new(master_key.into(), self_signing.into()).unwrap() + } + #[test] fn own_identity_create() { let user_id = user_id!("@example:localhost"); @@ -711,19 +870,13 @@ mod test { #[test] fn other_identity_create() { - let user_id = user_id!("@example2:localhost"); - let response = other_key_query(); - - let master_key = response.master_keys.get(&user_id).unwrap(); - let self_signing = response.self_signing_keys.get(&user_id).unwrap(); - - UserIdentity::new(master_key.into(), self_signing.into()).unwrap(); + get_other_identity(); } #[test] fn own_identity_check_signatures() { let response = own_key_query(); - let identity = own_identity(&response); + let identity = get_own_identity(); let (first, second) = device(&response); assert!(identity.is_device_signed(&first).is_err()); diff --git a/matrix_sdk_crypto/src/key_export.rs b/matrix_sdk_crypto/src/key_export.rs new file mode 100644 index 00000000..cc54a3cd --- /dev/null +++ b/matrix_sdk_crypto/src/key_export.rs @@ -0,0 +1,295 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::io::{Cursor, Read, Seek, SeekFrom}; + +use base64::{decode_config, encode_config, DecodeError, STANDARD_NO_PAD}; +use byteorder::{BigEndian, ReadBytesExt}; +use getrandom::getrandom; + +use aes_ctr::{ + stream_cipher::{NewStreamCipher, SyncStreamCipher}, + Aes256Ctr, +}; +use hmac::{Hmac, Mac, NewMac}; +use pbkdf2::pbkdf2; +use sha2::{Sha256, Sha512}; + +use crate::olm::ExportedRoomKey; + +const SALT_SIZE: usize = 16; +const IV_SIZE: usize = 16; +const MAC_SIZE: usize = 32; +const KEY_SIZE: usize = 32; +const VERSION: u8 = 1; + +const HEADER: &str = "-----BEGIN MEGOLM SESSION DATA-----"; +const FOOTER: &str = "-----END MEGOLM SESSION DATA-----"; + +fn decode(input: impl AsRef<[u8]>) -> Result, DecodeError> { + decode_config(input, STANDARD_NO_PAD) +} + +fn encode(input: impl AsRef<[u8]>) -> String { + encode_config(input, STANDARD_NO_PAD) +} + +/// Try to decrypt a reader into a list of exported room keys. +/// +/// # Arguments +/// +/// * `passphrase` - The passphrase that was used to encrypt the exported keys. +/// +/// # Examples +/// ```no_run +/// # use std::io::Cursor; +/// # use matrix_sdk_crypto::{OlmMachine, decrypt_key_export}; +/// # use matrix_sdk_common::identifiers::user_id; +/// # use futures::executor::block_on; +/// # let alice = user_id!("@alice:example.org"); +/// # let machine = OlmMachine::new(&alice, "DEVICEID".into()); +/// # block_on(async { +/// # let export = Cursor::new("".to_owned()); +/// let exported_keys = decrypt_key_export(export, "1234").unwrap(); +/// machine.import_keys(exported_keys).await.unwrap(); +/// # }); +/// ``` +pub fn decrypt_key_export( + mut input: impl Read, + passphrase: &str, +) -> Result, DecodeError> { + let mut x: String = String::new(); + + input.read_to_string(&mut x).expect("Can't read string"); + + if !(x.trim_start().starts_with(HEADER) && x.trim_end().ends_with(FOOTER)) { + panic!("Invalid header/footer"); + } + + let payload: String = x + .lines() + .filter(|l| !(l.starts_with(HEADER) || l.starts_with(FOOTER))) + .collect(); + + Ok(serde_json::from_str(&decrypt_helper(&payload, passphrase)?).unwrap()) +} + +/// Encrypt the list of exported room keys using the given passphrase. +/// +/// # Arguments +/// +/// * `keys` - A list of sessions that should be encrypted. +/// +/// * `passphrase` - The passphrase that will be used to encrypt the exported +/// room keys. +/// +/// * `rounds` - The number of rounds that should be used for the key +/// derivation when the passphrase gets turned into an AES key. More rounds are +/// increasingly computationally intensive and as such help against bruteforce +/// attacks. Should be at least `10000`, while values in the `100000` ranges +/// should be preferred. +/// +/// # Examples +/// ```no_run +/// # use matrix_sdk_crypto::{OlmMachine, encrypt_key_export}; +/// # use matrix_sdk_common::identifiers::{user_id, room_id}; +/// # use futures::executor::block_on; +/// # let alice = user_id!("@alice:example.org"); +/// # let machine = OlmMachine::new(&alice, "DEVICEID".into()); +/// # block_on(async { +/// let room_id = room_id!("!test:localhost"); +/// let exported_keys = machine.export_keys(|s| s.room_id() == &room_id).await.unwrap(); +/// let encrypted_export = encrypt_key_export(&exported_keys, "1234", 1); +/// # }); +/// ``` +pub fn encrypt_key_export(keys: &[ExportedRoomKey], passphrase: &str, rounds: u32) -> String { + let mut plaintext = serde_json::to_string(keys).unwrap().into_bytes(); + let ciphertext = encrypt_helper(&mut plaintext, passphrase, rounds); + [HEADER.to_owned(), ciphertext, FOOTER.to_owned()].join("\n") +} + +fn encrypt_helper(mut plaintext: &mut [u8], passphrase: &str, rounds: u32) -> String { + let mut salt = [0u8; SALT_SIZE]; + let mut iv = [0u8; IV_SIZE]; + let mut derived_keys = [0u8; KEY_SIZE * 2]; + + getrandom(&mut salt).expect("Can't generate randomness"); + getrandom(&mut iv).expect("Can't generate randomness"); + + let mut iv = u128::from_be_bytes(iv); + iv &= !(1 << 63); + + pbkdf2::>(passphrase.as_bytes(), &salt, rounds, &mut derived_keys); + let (key, hmac_key) = derived_keys.split_at(KEY_SIZE); + + let mut aes = Aes256Ctr::new_var(&key, &iv.to_be_bytes()).expect("Can't create AES"); + + aes.apply_keystream(&mut plaintext); + + let mut payload: Vec = vec![]; + + payload.extend(&VERSION.to_be_bytes()); + payload.extend(&salt); + payload.extend(&iv.to_be_bytes()); + payload.extend(&rounds.to_be_bytes()); + payload.extend_from_slice(&plaintext); + + let mut hmac = Hmac::::new_varkey(hmac_key).unwrap(); + hmac.update(&payload); + let mac = hmac.finalize(); + + payload.extend(mac.into_bytes()); + + encode(payload) +} + +fn decrypt_helper(ciphertext: &str, passphrase: &str) -> Result { + let decoded = decode(ciphertext)?; + + let mut decoded = Cursor::new(decoded); + + let mut salt = [0u8; SALT_SIZE]; + let mut iv = [0u8; IV_SIZE]; + let mut mac = [0u8; MAC_SIZE]; + let mut derived_keys = [0u8; KEY_SIZE * 2]; + + let version = decoded.read_u8().unwrap(); + decoded.read_exact(&mut salt).unwrap(); + decoded.read_exact(&mut iv).unwrap(); + + let rounds = decoded.read_u32::().unwrap(); + let ciphertext_start = decoded.position() as usize; + + decoded.seek(SeekFrom::End(-32)).unwrap(); + let ciphertext_end = decoded.position() as usize; + + decoded.read_exact(&mut mac).unwrap(); + + let mut decoded = decoded.into_inner(); + + if version != VERSION { + panic!("Unsupported version") + } + + pbkdf2::>(passphrase.as_bytes(), &salt, rounds, &mut derived_keys); + let (key, hmac_key) = derived_keys.split_at(KEY_SIZE); + + let mut hmac = Hmac::::new_varkey(hmac_key).unwrap(); + hmac.update(&decoded[0..ciphertext_end]); + hmac.verify(&mac).expect("MAC DOESN'T MATCH"); + + let mut ciphertext = &mut decoded[ciphertext_start..ciphertext_end]; + let mut aes = Aes256Ctr::new_var(&key, &iv).expect("Can't create AES"); + aes.apply_keystream(&mut ciphertext); + + Ok(String::from_utf8(ciphertext.to_owned()).expect("Invalid utf-8")) +} + +#[cfg(test)] +mod test { + use indoc::indoc; + use proptest::prelude::*; + use std::io::Cursor; + + use matrix_sdk_common::identifiers::room_id; + use matrix_sdk_test::async_test; + + use super::{decode, decrypt_helper, decrypt_key_export, encrypt_helper, encrypt_key_export}; + use crate::machine::test::get_prepared_machine; + + const PASSPHRASE: &str = "1234"; + + const TEST_EXPORT: &str = indoc! {" + -----BEGIN MEGOLM SESSION DATA----- + Af7mGhlzQ+eGvHu93u0YXd3D/+vYMs3E7gQqOhuCtkvGAAAAASH7pEdWvFyAP1JUisAcpEo + Xke2Q7Kr9hVl/SCc6jXBNeJCZcrUbUV4D/tRQIl3E9L4fOk928YI1J+3z96qiH0uE7hpsCI + CkHKwjPU+0XTzFdIk1X8H7sZ+MD/2Sg/q3y8rtUjz7uEj4GUTnb+9SCOTVmJsRfqgUpM1CU + bDLytHf1JkohY4tWEgpsCc67xdzgodjr12qYrfg/zNm3LGpxlrffJknw4rk5QFTj4kMbqbD + ZZgDTni+HxRTDGge2J620lMOiznvXX+H09Rwruqx5aJvvaaKd86jWRpiO2oSFqHn4u5ONl9 + 41uzm62Sj0eIm6ZbA9NQs87jQw4LxsejhZVL+NdjIg80zVSBTWhTdo0DTnbFSNP4ReOiz0U + XosOF8A5T8Vdx2nvA0GXltfcHKVKQYh/LJAkNQ7P9UYL4ae/5TtQZkhB1KxCLTRWqADCl53 + uBMGpG53EMgY6G6K2DEIOkcv7sdXQF5WpemiSWZqJRWj+cjfs9BpCTbkp/rszWFl2TniWpR + RqIbT2jORlN4rTvdtF0F4z1pqP4qWyR3sLNTkXm9CFRzWADNG0RDZKxbCoo6RPvtaCTfaHo + SwfvzBS6CjfAG+FOugpV48o7+XetaUUPZ6/tZSPhCdeV8eP9q5r0QwWeXFogzoNzWt4HYx9 + MdXxzD+f0mtg5gzehrrEEARwI2bCvPpHxlt/Na9oW/GBpkjwR1LSKgg4CtpRyWngPjdEKpZ + GYW19pdjg0qdXNk/eqZsQTsNWVo6A + -----END MEGOLM SESSION DATA----- + "}; + + fn export_wihtout_headers() -> String { + TEST_EXPORT + .lines() + .filter(|l| !l.starts_with("-----")) + .collect() + } + + #[test] + fn test_decode() { + let export = export_wihtout_headers(); + assert!(decode(export).is_ok()); + } + + proptest! { + #[test] + fn proptest_encrypt_cycle(plaintext in prop::string::string_regex(".*").unwrap()) { + let mut plaintext_bytes = plaintext.clone().into_bytes(); + + let ciphertext = encrypt_helper(&mut plaintext_bytes, "test", 1); + let decrypted = decrypt_helper(&ciphertext, "test").unwrap(); + + prop_assert!(plaintext == decrypted); + } + } + + #[test] + fn test_encrypt_decrypt() { + let data = "It's a secret to everybody"; + let mut bytes = data.to_owned().into_bytes(); + + let encrypted = encrypt_helper(&mut bytes, PASSPHRASE, 10); + let decrypted = decrypt_helper(&encrypted, PASSPHRASE).unwrap(); + + assert_eq!(data, decrypted); + } + + #[async_test] + async fn test_session_encrypt() { + let (machine, _) = get_prepared_machine().await; + let room_id = room_id!("!test:localhost"); + + machine + .create_outnbound_group_session_with_defaults(&room_id) + .await + .unwrap(); + let export = machine + .export_keys(|s| s.room_id() == &room_id) + .await + .unwrap(); + + assert!(!export.is_empty()); + + let encrypted = encrypt_key_export(&export, "1234", 1); + let decrypted = decrypt_key_export(Cursor::new(encrypted), "1234").unwrap(); + + assert_eq!(export, decrypted); + assert_eq!(machine.import_keys(decrypted).await.unwrap(), 0); + } + + #[test] + fn test_real_decrypt() { + let reader = Cursor::new(TEST_EXPORT); + let imported = decrypt_key_export(reader, PASSPHRASE).expect("Can't decrypt key export"); + assert!(!imported.is_empty()) + } +} diff --git a/matrix_sdk_crypto/src/lib.rs b/matrix_sdk_crypto/src/lib.rs index 959ffe97..8274a145 100644 --- a/matrix_sdk_crypto/src/lib.rs +++ b/matrix_sdk_crypto/src/lib.rs @@ -29,6 +29,8 @@ mod error; mod identities; +#[allow(dead_code)] +mod key_export; mod machine; pub mod olm; mod requests; @@ -39,6 +41,7 @@ pub use error::{MegolmError, OlmError}; pub use identities::{ Device, LocalTrust, OwnUserIdentity, ReadOnlyDevice, UserDevices, UserIdentities, UserIdentity, }; +pub use key_export::{decrypt_key_export, encrypt_key_export}; pub use machine::OlmMachine; pub(crate) use olm::Account; pub use olm::EncryptionSettings; diff --git a/matrix_sdk_crypto/src/machine.rs b/matrix_sdk_crypto/src/machine.rs index a0935c27..3fe74fd8 100644 --- a/matrix_sdk_crypto/src/machine.rs +++ b/matrix_sdk_crypto/src/machine.rs @@ -57,8 +57,8 @@ use super::{ UserIdentities, UserIdentity, UserSigningPubkey, }, olm::{ - Account, EncryptionSettings, GroupSessionKey, IdentityKeys, InboundGroupSession, - OlmMessage, OutboundGroupSession, + Account, EncryptionSettings, ExportedRoomKey, GroupSessionKey, IdentityKeys, + InboundGroupSession, OlmMessage, OutboundGroupSession, }, requests::{IncomingResponse, OutgoingRequest, ToDeviceRequest}, store::{CryptoStore, MemoryStore, Result as StoreResult}, @@ -984,7 +984,7 @@ impl OlmMachine { &event.content.room_id, session_key, )?; - let _ = self.store.save_inbound_group_session(session).await?; + let _ = self.store.save_inbound_group_sessions(&[session]).await?; let event = Raw::from(AnyToDeviceEvent::RoomKey(event.clone())); Ok(Some(event)) @@ -1014,7 +1014,7 @@ impl OlmMachine { .await .map_err(|_| EventError::UnsupportedAlgorithm)?; - let _ = self.store.save_inbound_group_session(inbound).await?; + let _ = self.store.save_inbound_group_sessions(&[inbound]).await?; let _ = self .outbound_group_sessions @@ -1023,7 +1023,7 @@ impl OlmMachine { } #[cfg(test)] - async fn create_outnbound_group_session_with_defaults( + pub(crate) async fn create_outnbound_group_session_with_defaults( &self, room_id: &RoomId, ) -> OlmResult<()> { @@ -1529,6 +1529,105 @@ impl OlmMachine { device_owner_identity, }) } + + /// Import the given room keys into our store. + /// + /// # Arguments + /// + /// * `exported_keys` - A list of previously exported keys that should be + /// imported into our store. If we already have a better version of a key + /// the key will *not* be imported. + /// + /// Returns the number of sessions that were imported to the store. + /// + /// # Examples + /// ```no_run + /// # use std::io::Cursor; + /// # use matrix_sdk_crypto::{OlmMachine, decrypt_key_export}; + /// # use matrix_sdk_common::identifiers::user_id; + /// # use futures::executor::block_on; + /// # let alice = user_id!("@alice:example.org"); + /// # let machine = OlmMachine::new(&alice, "DEVICEID".into()); + /// # block_on(async { + /// # let export = Cursor::new("".to_owned()); + /// let exported_keys = decrypt_key_export(export, "1234").unwrap(); + /// machine.import_keys(exported_keys).await.unwrap(); + /// # }); + /// ``` + pub async fn import_keys(&self, mut exported_keys: Vec) -> StoreResult { + let mut sessions = Vec::new(); + + for key in exported_keys.drain(..) { + let session = InboundGroupSession::from_export(key)?; + + // Only import the session if we didn't have this session or if it's + // a better version of the same session, that is the first known + // index is lower. + if let Some(existing_session) = self + .store + .get_inbound_group_session( + &session.room_id, + &session.sender_key, + session.session_id(), + ) + .await? + { + let first_index = session.first_known_index().await; + let existing_index = existing_session.first_known_index().await; + + if first_index < existing_index { + sessions.push(session) + } + } else { + sessions.push(session) + } + } + + let num_sessions = sessions.len(); + + self.store.save_inbound_group_sessions(&sessions).await?; + + Ok(num_sessions) + } + + /// Export the keys that match the given predicate. + /// + /// + /// # Examples + /// + /// ```no_run + /// # use matrix_sdk_crypto::{OlmMachine, encrypt_key_export}; + /// # use matrix_sdk_common::identifiers::{user_id, room_id}; + /// # use futures::executor::block_on; + /// # let alice = user_id!("@alice:example.org"); + /// # let machine = OlmMachine::new(&alice, "DEVICEID".into()); + /// # block_on(async { + /// let room_id = room_id!("!test:localhost"); + /// let exported_keys = machine.export_keys(|s| s.room_id() == &room_id).await.unwrap(); + /// let encrypted_export = encrypt_key_export(&exported_keys, "1234", 1); + /// # }); + /// ``` + pub async fn export_keys( + &self, + mut predicate: impl FnMut(&InboundGroupSession) -> bool, + ) -> StoreResult> { + let mut exported = Vec::new(); + + let mut sessions: Vec = self + .store + .get_inbound_group_sessions() + .await? + .drain(..) + .filter(|s| predicate(&s)) + .collect(); + + for session in sessions.drain(..) { + let export = session.export().await; + exported.push(export); + } + + Ok(exported) + } } #[cfg(test)] @@ -1623,7 +1722,7 @@ pub(crate) mod test { content.deserialize().unwrap() } - async fn get_prepared_machine() -> (OlmMachine, OneTimeKeys) { + pub(crate) async fn get_prepared_machine() -> (OlmMachine, OneTimeKeys) { let machine = OlmMachine::new(&user_id(), &alice_device_id()); machine.account.update_uploaded_key_count(0); let request = machine diff --git a/matrix_sdk_crypto/src/olm/group_sessions.rs b/matrix_sdk_crypto/src/olm/group_sessions.rs deleted file mode 100644 index 729bf5c6..00000000 --- a/matrix_sdk_crypto/src/olm/group_sessions.rs +++ /dev/null @@ -1,561 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::{ - cmp::min, - convert::TryInto, - fmt, - sync::{ - atomic::{AtomicBool, AtomicU64, Ordering}, - Arc, - }, - time::Duration, -}; - -use matrix_sdk_common::{ - events::{ - room::{encrypted::EncryptedEventContent, encryption::EncryptionEventContent}, - AnyMessageEventContent, AnySyncRoomEvent, EventContent, SyncMessageEvent, - }, - identifiers::{DeviceId, EventEncryptionAlgorithm, RoomId}, - instant::Instant, - locks::Mutex, - Raw, -}; -use olm_rs::{ - errors::OlmGroupSessionError, inbound_group_session::OlmInboundGroupSession, - outbound_group_session::OlmOutboundGroupSession, PicklingMode, -}; -use serde::{Deserialize, Serialize}; -use serde_json::{json, Value}; -use zeroize::Zeroize; - -pub use olm_rs::{ - account::IdentityKeys, - session::{OlmMessage, PreKeyMessage}, - utility::OlmUtility, -}; - -use crate::error::{EventError, MegolmResult}; - -const ROTATION_PERIOD: Duration = Duration::from_millis(604800000); -const ROTATION_MESSAGES: u64 = 100; - -/// Settings for an encrypted room. -/// -/// This determines the algorithm and rotation periods of a group session. -#[derive(Debug)] -pub struct EncryptionSettings { - /// The encryption algorithm that should be used in the room. - pub algorithm: EventEncryptionAlgorithm, - /// How long the session should be used before changing it. - pub rotation_period: Duration, - /// How many messages should be sent before changing the session. - pub rotation_period_msgs: u64, -} - -impl Default for EncryptionSettings { - fn default() -> Self { - Self { - algorithm: EventEncryptionAlgorithm::MegolmV1AesSha2, - rotation_period: ROTATION_PERIOD, - rotation_period_msgs: ROTATION_MESSAGES, - } - } -} - -impl From<&EncryptionEventContent> for EncryptionSettings { - fn from(content: &EncryptionEventContent) -> Self { - let rotation_period: Duration = content - .rotation_period_ms - .map_or(ROTATION_PERIOD, |r| Duration::from_millis(r.into())); - let rotation_period_msgs: u64 = content - .rotation_period_msgs - .map_or(ROTATION_MESSAGES, Into::into); - - Self { - algorithm: content.algorithm.clone(), - rotation_period, - rotation_period_msgs, - } - } -} - -/// The private session key of a group session. -/// Can be used to create a new inbound group session. -#[derive(Clone, Debug, Serialize, Zeroize)] -#[zeroize(drop)] -pub struct GroupSessionKey(pub String); - -/// Inbound group session. -/// -/// Inbound group sessions are used to exchange room messages between a group of -/// participants. Inbound group sessions are used to decrypt the room messages. -#[derive(Clone)] -pub struct InboundGroupSession { - inner: Arc>, - session_id: Arc, - pub(crate) sender_key: Arc, - pub(crate) signing_key: Arc, - pub(crate) room_id: Arc, - forwarding_chains: Arc>>>, -} - -impl InboundGroupSession { - /// Create a new inbound group session for the given room. - /// - /// These sessions are used to decrypt room messages. - /// - /// # Arguments - /// - /// * `sender_key` - The public curve25519 key of the account that - /// sent us the session - /// - /// * `signing_key` - The public ed25519 key of the account that - /// sent us the session. - /// - /// * `room_id` - The id of the room that the session is used in. - /// - /// * `session_key` - The private session key that is used to decrypt - /// messages. - pub fn new( - sender_key: &str, - signing_key: &str, - room_id: &RoomId, - session_key: GroupSessionKey, - ) -> Result { - let session = OlmInboundGroupSession::new(&session_key.0)?; - let session_id = session.session_id(); - - Ok(InboundGroupSession { - inner: Arc::new(Mutex::new(session)), - session_id: Arc::new(session_id), - sender_key: Arc::new(sender_key.to_owned()), - signing_key: Arc::new(signing_key.to_owned()), - room_id: Arc::new(room_id.clone()), - forwarding_chains: Arc::new(Mutex::new(None)), - }) - } - - /// Store the group session as a base64 encoded string. - /// - /// # Arguments - /// - /// * `pickle_mode` - The mode that was used to pickle the group session, - /// either an unencrypted mode or an encrypted using passphrase. - pub async fn pickle(&self, pickle_mode: PicklingMode) -> PickledInboundGroupSession { - let pickle = self.inner.lock().await.pickle(pickle_mode); - - PickledInboundGroupSession { - pickle: InboundGroupSessionPickle::from(pickle), - sender_key: self.sender_key.to_string(), - signing_key: self.signing_key.to_string(), - room_id: (&*self.room_id).clone(), - forwarding_chains: self.forwarding_chains.lock().await.clone(), - } - } - - /// Restore a Session from a previously pickled string. - /// - /// Returns the restored group session or a `OlmGroupSessionError` if there - /// was an error. - /// - /// # Arguments - /// - /// * `pickle` - The pickled version of the `InboundGroupSession`. - /// - /// * `pickle_mode` - The mode that was used to pickle the session, either - /// an unencrypted mode or an encrypted using passphrase. - pub fn from_pickle( - pickle: PickledInboundGroupSession, - pickle_mode: PicklingMode, - ) -> Result { - let session = OlmInboundGroupSession::unpickle(pickle.pickle.0, pickle_mode)?; - let session_id = session.session_id(); - - Ok(InboundGroupSession { - inner: Arc::new(Mutex::new(session)), - session_id: Arc::new(session_id), - sender_key: Arc::new(pickle.sender_key), - signing_key: Arc::new(pickle.signing_key), - room_id: Arc::new(pickle.room_id), - forwarding_chains: Arc::new(Mutex::new(pickle.forwarding_chains)), - }) - } - - /// Returns the unique identifier for this session. - pub fn session_id(&self) -> &str { - &self.session_id - } - - /// Get the first message index we know how to decrypt. - pub async fn first_known_index(&self) -> u32 { - self.inner.lock().await.first_known_index() - } - - /// Decrypt the given ciphertext. - /// - /// Returns the decrypted plaintext or an `OlmGroupSessionError` if - /// decryption failed. - /// - /// # Arguments - /// - /// * `message` - The message that should be decrypted. - pub async fn decrypt_helper( - &self, - message: String, - ) -> Result<(String, u32), OlmGroupSessionError> { - self.inner.lock().await.decrypt(message) - } - - /// Decrypt an event from a room timeline. - /// - /// # Arguments - /// - /// * `event` - The event that should be decrypted. - pub async fn decrypt( - &self, - event: &SyncMessageEvent, - ) -> MegolmResult<(Raw, u32)> { - let content = match &event.content { - EncryptedEventContent::MegolmV1AesSha2(c) => c, - _ => return Err(EventError::UnsupportedAlgorithm.into()), - }; - - let (plaintext, message_index) = self.decrypt_helper(content.ciphertext.clone()).await?; - - let mut decrypted_value = serde_json::from_str::(&plaintext)?; - let decrypted_object = decrypted_value - .as_object_mut() - .ok_or(EventError::NotAnObject)?; - - // TODO better number conversion here. - let server_ts = event - .origin_server_ts - .duration_since(std::time::SystemTime::UNIX_EPOCH) - .unwrap_or_default() - .as_millis(); - let server_ts: i64 = server_ts.try_into().unwrap_or_default(); - - decrypted_object.insert("sender".to_owned(), event.sender.to_string().into()); - decrypted_object.insert("event_id".to_owned(), event.event_id.to_string().into()); - decrypted_object.insert("origin_server_ts".to_owned(), server_ts.into()); - - decrypted_object.insert( - "unsigned".to_owned(), - serde_json::to_value(&event.unsigned).unwrap_or_default(), - ); - - Ok(( - serde_json::from_value::>(decrypted_value)?, - message_index, - )) - } -} - -#[cfg(not(tarpaulin_include))] -impl fmt::Debug for InboundGroupSession { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("InboundGroupSession") - .field("session_id", &self.session_id()) - .finish() - } -} - -impl PartialEq for InboundGroupSession { - fn eq(&self, other: &Self) -> bool { - self.session_id() == other.session_id() - } -} - -/// A pickled version of an `InboundGroupSession`. -/// -/// Holds all the information that needs to be stored in a database to restore -/// an InboundGroupSession. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct PickledInboundGroupSession { - /// The pickle string holding the InboundGroupSession. - pub pickle: InboundGroupSessionPickle, - /// The public curve25519 key of the account that sent us the session - pub sender_key: String, - /// The public ed25519 key of the account that sent us the session. - pub signing_key: String, - /// The id of the room that the session is used in. - pub room_id: RoomId, - /// The list of claimed ed25519 that forwarded us this key. Will be None if - /// we dirrectly received this session. - pub forwarding_chains: Option>, -} - -/// The typed representation of a base64 encoded string of the GroupSession pickle. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct InboundGroupSessionPickle(String); - -impl From for InboundGroupSessionPickle { - fn from(pickle_string: String) -> Self { - InboundGroupSessionPickle(pickle_string) - } -} - -impl InboundGroupSessionPickle { - /// Get the string representation of the pickle. - pub fn as_str(&self) -> &str { - &self.0 - } -} - -/// Outbound group session. -/// -/// Outbound group sessions are used to exchange room messages between a group -/// of participants. Outbound group sessions are used to encrypt the room -/// messages. -#[derive(Clone)] -pub struct OutboundGroupSession { - inner: Arc>, - device_id: Arc>, - account_identity_keys: Arc, - session_id: Arc, - room_id: Arc, - creation_time: Arc, - message_count: Arc, - shared: Arc, - settings: Arc, -} - -impl OutboundGroupSession { - /// Create a new outbound group session for the given room. - /// - /// Outbound group sessions are used to encrypt room messages. - /// - /// # Arguments - /// - /// * `device_id` - The id of the device that created this session. - /// - /// * `identity_keys` - The identity keys of the account that created this - /// session. - /// - /// * `room_id` - The id of the room that the session is used in. - /// - /// * `settings` - Settings determining the algorithm and rotation period of - /// the outbound group session. - pub fn new( - device_id: Arc>, - identity_keys: Arc, - room_id: &RoomId, - settings: EncryptionSettings, - ) -> Self { - let session = OlmOutboundGroupSession::new(); - let session_id = session.session_id(); - - OutboundGroupSession { - inner: Arc::new(Mutex::new(session)), - room_id: Arc::new(room_id.to_owned()), - device_id, - account_identity_keys: identity_keys, - session_id: Arc::new(session_id), - creation_time: Arc::new(Instant::now()), - message_count: Arc::new(AtomicU64::new(0)), - shared: Arc::new(AtomicBool::new(false)), - settings: Arc::new(settings), - } - } - - /// Encrypt the given plaintext using this session. - /// - /// Returns the encrypted ciphertext. - /// - /// # Arguments - /// - /// * `plaintext` - The plaintext that should be encrypted. - pub(crate) async fn encrypt_helper(&self, plaintext: String) -> String { - let session = self.inner.lock().await; - self.message_count.fetch_add(1, Ordering::SeqCst); - session.encrypt(plaintext) - } - - /// Encrypt a room message for the given room. - /// - /// Beware that a group session needs to be shared before this method can be - /// called using the `share_group_session()` method. - /// - /// Since group sessions can expire or become invalid if the room membership - /// changes client authors should check with the - /// `should_share_group_session()` method if a new group session needs to - /// be shared. - /// - /// # Arguments - /// - /// * `content` - The plaintext content of the message that should be - /// encrypted. - /// - /// # Panics - /// - /// Panics if the content can't be serialized. - pub async fn encrypt(&self, content: AnyMessageEventContent) -> EncryptedEventContent { - let json_content = json!({ - "content": content, - "room_id": &*self.room_id, - "type": content.event_type(), - }); - - let plaintext = cjson::to_string(&json_content).unwrap_or_else(|_| { - panic!(format!( - "Can't serialize {} to canonical JSON", - json_content - )) - }); - - let ciphertext = self.encrypt_helper(plaintext).await; - - EncryptedEventContent::MegolmV1AesSha2( - matrix_sdk_common::events::room::encrypted::MegolmV1AesSha2ContentInit { - ciphertext, - sender_key: self.account_identity_keys.curve25519().to_owned(), - session_id: self.session_id().to_owned(), - device_id: (&*self.device_id).to_owned(), - } - .into(), - ) - } - - /// Check if the session has expired and if it should be rotated. - /// - /// A session will expire after some time or if enough messages have been - /// encrypted using it. - pub fn expired(&self) -> bool { - let count = self.message_count.load(Ordering::SeqCst); - - count >= self.settings.rotation_period_msgs - || self.creation_time.elapsed() - // Since the encryption settings are provided by users and not - // checked someone could set a really low rotation perdiod so - // clamp it at a minute. - >= min(self.settings.rotation_period, Duration::from_secs(3600)) - } - - /// Mark the session as shared. - /// - /// Messages shouldn't be encrypted with the session before it has been - /// shared. - pub fn mark_as_shared(&self) { - self.shared.store(true, Ordering::Relaxed); - } - - /// Check if the session has been marked as shared. - pub fn shared(&self) -> bool { - self.shared.load(Ordering::Relaxed) - } - - /// Get the session key of this session. - /// - /// A session key can be used to to create an `InboundGroupSession`. - pub async fn session_key(&self) -> GroupSessionKey { - let session = self.inner.lock().await; - GroupSessionKey(session.session_key()) - } - - /// Returns the unique identifier for this session. - pub fn session_id(&self) -> &str { - &self.session_id - } - - /// Get the current message index for this session. - /// - /// Each message is sent with an increasing index. This returns the - /// message index that will be used for the next encrypted message. - pub async fn message_index(&self) -> u32 { - let session = self.inner.lock().await; - session.session_message_index() - } - - /// Get the outbound group session key as a json value that can be sent as a - /// m.room_key. - pub async fn as_json(&self) -> Value { - json!({ - "algorithm": EventEncryptionAlgorithm::MegolmV1AesSha2, - "room_id": &*self.room_id, - "session_id": &*self.session_id, - "session_key": self.session_key().await, - "chain_index": self.message_index().await, - }) - } -} - -#[cfg(not(tarpaulin_include))] -impl std::fmt::Debug for OutboundGroupSession { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("OutboundGroupSession") - .field("session_id", &self.session_id) - .field("room_id", &self.room_id) - .field("creation_time", &self.creation_time) - .field("message_count", &self.message_count) - .finish() - } -} - -#[cfg(test)] -mod test { - use std::{ - sync::Arc, - time::{Duration, Instant}, - }; - - use matrix_sdk_common::{ - events::{ - room::message::{MessageEventContent, TextMessageEventContent}, - AnyMessageEventContent, - }, - identifiers::{room_id, user_id}, - }; - - use super::EncryptionSettings; - use crate::Account; - - #[tokio::test] - #[cfg(not(target_os = "macos"))] - async fn expiration() { - let settings = EncryptionSettings { - rotation_period_msgs: 1, - ..Default::default() - }; - - let account = Account::new(&user_id!("@alice:example.org"), "DEVICEID".into()); - let (session, _) = account - .create_group_session_pair(&room_id!("!test_room:example.org"), settings) - .await - .unwrap(); - - assert!(!session.expired()); - let _ = session - .encrypt(AnyMessageEventContent::RoomMessage( - MessageEventContent::Text(TextMessageEventContent::plain("Test message")), - )) - .await; - assert!(session.expired()); - - let settings = EncryptionSettings { - rotation_period: Duration::from_millis(100), - ..Default::default() - }; - - let (mut session, _) = account - .create_group_session_pair(&room_id!("!test_room:example.org"), settings) - .await - .unwrap(); - - assert!(!session.expired()); - session.creation_time = Arc::new(Instant::now() - Duration::from_secs(60 * 60)); - assert!(session.expired()); - } -} diff --git a/matrix_sdk_crypto/src/olm/group_sessions/inbound.rs b/matrix_sdk_crypto/src/olm/group_sessions/inbound.rs new file mode 100644 index 00000000..a35c6190 --- /dev/null +++ b/matrix_sdk_crypto/src/olm/group_sessions/inbound.rs @@ -0,0 +1,343 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{ + collections::BTreeMap, + convert::{TryFrom, TryInto}, + fmt, + sync::Arc, +}; + +use matrix_sdk_common::{ + events::{room::encrypted::EncryptedEventContent, AnySyncRoomEvent, SyncMessageEvent}, + identifiers::{DeviceKeyAlgorithm, EventEncryptionAlgorithm, RoomId}, + locks::Mutex, + Raw, +}; +use olm_rs::{ + errors::OlmGroupSessionError, inbound_group_session::OlmInboundGroupSession, PicklingMode, +}; +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +pub use olm_rs::{ + account::IdentityKeys, + session::{OlmMessage, PreKeyMessage}, + utility::OlmUtility, +}; + +use super::{ExportedGroupSessionKey, ExportedRoomKey, GroupSessionKey}; +use crate::error::{EventError, MegolmResult}; + +/// Inbound group session. +/// +/// Inbound group sessions are used to exchange room messages between a group of +/// participants. Inbound group sessions are used to decrypt the room messages. +#[derive(Clone)] +pub struct InboundGroupSession { + inner: Arc>, + session_id: Arc, + pub(crate) sender_key: Arc, + pub(crate) signing_key: Arc>, + pub(crate) room_id: Arc, + forwarding_chains: Arc>>>, + imported: Arc, +} + +impl InboundGroupSession { + /// Create a new inbound group session for the given room. + /// + /// These sessions are used to decrypt room messages. + /// + /// # Arguments + /// + /// * `sender_key` - The public curve25519 key of the account that + /// sent us the session + /// + /// * `signing_key` - The public ed25519 key of the account that + /// sent us the session. + /// + /// * `room_id` - The id of the room that the session is used in. + /// + /// * `session_key` - The private session key that is used to decrypt + /// messages. + pub(crate) fn new( + sender_key: &str, + signing_key: &str, + room_id: &RoomId, + session_key: GroupSessionKey, + ) -> Result { + let session = OlmInboundGroupSession::new(&session_key.0)?; + let session_id = session.session_id(); + + let mut keys: BTreeMap = BTreeMap::new(); + keys.insert(DeviceKeyAlgorithm::Ed25519, signing_key.to_owned()); + + Ok(InboundGroupSession { + inner: Arc::new(Mutex::new(session)), + session_id: Arc::new(session_id), + sender_key: Arc::new(sender_key.to_owned()), + signing_key: Arc::new(keys), + room_id: Arc::new(room_id.clone()), + forwarding_chains: Arc::new(Mutex::new(None)), + imported: Arc::new(false), + }) + } + + /// Create a InboundGroupSession from an exported version of the group + /// session. + /// + /// Most notably this can be called with an `ExportedRoomKey` from a + /// previous [`export()`] call. + /// + /// + /// [`export()`]: #method.export + pub fn from_export( + exported_session: impl Into, + ) -> Result { + Self::try_from(exported_session.into()) + } + + /// Store the group session as a base64 encoded string. + /// + /// # Arguments + /// + /// * `pickle_mode` - The mode that was used to pickle the group session, + /// either an unencrypted mode or an encrypted using passphrase. + pub async fn pickle(&self, pickle_mode: PicklingMode) -> PickledInboundGroupSession { + let pickle = self.inner.lock().await.pickle(pickle_mode); + + PickledInboundGroupSession { + pickle: InboundGroupSessionPickle::from(pickle), + sender_key: self.sender_key.to_string(), + signing_key: (&*self.signing_key).clone(), + room_id: (&*self.room_id).clone(), + forwarding_chains: self.forwarding_chains.lock().await.clone(), + imported: *self.imported, + } + } + + /// Export this session at the first known message index. + /// + /// If only a limited part of this session should be exported use + /// [`export_at_index()`](#method.export_at_index). + pub async fn export(&self) -> ExportedRoomKey { + self.export_at_index(self.first_known_index().await) + .await + .expect("Can't export at the first known index") + } + + /// Export this session at the given message index. + pub async fn export_at_index(&self, message_index: u32) -> Option { + let session_key = + ExportedGroupSessionKey(self.inner.lock().await.export(message_index).ok()?); + + Some(ExportedRoomKey { + algorithm: EventEncryptionAlgorithm::MegolmV1AesSha2, + room_id: (&*self.room_id).clone(), + sender_key: (&*self.sender_key).to_owned(), + session_id: self.session_id().to_owned(), + forwarding_curve25519_key_chain: self + .forwarding_chains + .lock() + .await + .as_ref() + .cloned() + .unwrap_or_default(), + sender_claimed_keys: (&*self.signing_key).clone(), + session_key, + }) + } + + /// Restore a Session from a previously pickled string. + /// + /// Returns the restored group session or a `OlmGroupSessionError` if there + /// was an error. + /// + /// # Arguments + /// + /// * `pickle` - The pickled version of the `InboundGroupSession`. + /// + /// * `pickle_mode` - The mode that was used to pickle the session, either + /// an unencrypted mode or an encrypted using passphrase. + pub fn from_pickle( + pickle: PickledInboundGroupSession, + pickle_mode: PicklingMode, + ) -> Result { + let session = OlmInboundGroupSession::unpickle(pickle.pickle.0, pickle_mode)?; + let session_id = session.session_id(); + + Ok(InboundGroupSession { + inner: Arc::new(Mutex::new(session)), + session_id: Arc::new(session_id), + sender_key: Arc::new(pickle.sender_key), + signing_key: Arc::new(pickle.signing_key), + room_id: Arc::new(pickle.room_id), + forwarding_chains: Arc::new(Mutex::new(pickle.forwarding_chains)), + imported: Arc::new(pickle.imported), + }) + } + + /// The room where this session is used in. + pub fn room_id(&self) -> &RoomId { + &self.room_id + } + + /// Returns the unique identifier for this session. + pub fn session_id(&self) -> &str { + &self.session_id + } + + /// Get the first message index we know how to decrypt. + pub async fn first_known_index(&self) -> u32 { + self.inner.lock().await.first_known_index() + } + + /// Decrypt the given ciphertext. + /// + /// Returns the decrypted plaintext or an `OlmGroupSessionError` if + /// decryption failed. + /// + /// # Arguments + /// + /// * `message` - The message that should be decrypted. + pub(crate) async fn decrypt_helper( + &self, + message: String, + ) -> Result<(String, u32), OlmGroupSessionError> { + self.inner.lock().await.decrypt(message) + } + + /// Decrypt an event from a room timeline. + /// + /// # Arguments + /// + /// * `event` - The event that should be decrypted. + pub(crate) async fn decrypt( + &self, + event: &SyncMessageEvent, + ) -> MegolmResult<(Raw, u32)> { + let content = match &event.content { + EncryptedEventContent::MegolmV1AesSha2(c) => c, + _ => return Err(EventError::UnsupportedAlgorithm.into()), + }; + + let (plaintext, message_index) = self.decrypt_helper(content.ciphertext.clone()).await?; + + let mut decrypted_value = serde_json::from_str::(&plaintext)?; + let decrypted_object = decrypted_value + .as_object_mut() + .ok_or(EventError::NotAnObject)?; + + // TODO better number conversion here. + let server_ts = event + .origin_server_ts + .duration_since(std::time::SystemTime::UNIX_EPOCH) + .unwrap_or_default() + .as_millis(); + let server_ts: i64 = server_ts.try_into().unwrap_or_default(); + + decrypted_object.insert("sender".to_owned(), event.sender.to_string().into()); + decrypted_object.insert("event_id".to_owned(), event.event_id.to_string().into()); + decrypted_object.insert("origin_server_ts".to_owned(), server_ts.into()); + + decrypted_object.insert( + "unsigned".to_owned(), + serde_json::to_value(&event.unsigned).unwrap_or_default(), + ); + + Ok(( + serde_json::from_value::>(decrypted_value)?, + message_index, + )) + } +} + +#[cfg(not(tarpaulin_include))] +impl fmt::Debug for InboundGroupSession { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("InboundGroupSession") + .field("session_id", &self.session_id()) + .finish() + } +} + +impl PartialEq for InboundGroupSession { + fn eq(&self, other: &Self) -> bool { + self.session_id() == other.session_id() + } +} + +/// A pickled version of an `InboundGroupSession`. +/// +/// Holds all the information that needs to be stored in a database to restore +/// an InboundGroupSession. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PickledInboundGroupSession { + /// The pickle string holding the InboundGroupSession. + pub pickle: InboundGroupSessionPickle, + /// The public curve25519 key of the account that sent us the session + pub sender_key: String, + /// The public ed25519 key of the account that sent us the session. + pub signing_key: BTreeMap, + /// The id of the room that the session is used in. + pub room_id: RoomId, + /// The list of claimed ed25519 that forwarded us this key. Will be None if + /// we dirrectly received this session. + pub forwarding_chains: Option>, + /// Flag remembering if the session was dirrectly sent to us by the sender + /// or if it was imported. + pub imported: bool, +} + +/// The typed representation of a base64 encoded string of the GroupSession pickle. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct InboundGroupSessionPickle(String); + +impl From for InboundGroupSessionPickle { + fn from(pickle_string: String) -> Self { + InboundGroupSessionPickle(pickle_string) + } +} + +impl InboundGroupSessionPickle { + /// Get the string representation of the pickle. + pub fn as_str(&self) -> &str { + &self.0 + } +} + +impl TryFrom for InboundGroupSession { + type Error = OlmGroupSessionError; + + fn try_from(key: ExportedRoomKey) -> Result { + let session = OlmInboundGroupSession::import(&key.session_key.0)?; + + let forwarding_chains = if key.forwarding_curve25519_key_chain.is_empty() { + None + } else { + Some(key.forwarding_curve25519_key_chain) + }; + + Ok(InboundGroupSession { + inner: Arc::new(Mutex::new(session)), + session_id: Arc::new(key.session_id), + sender_key: Arc::new(key.sender_key), + signing_key: Arc::new(key.sender_claimed_keys), + room_id: Arc::new(key.room_id), + forwarding_chains: Arc::new(Mutex::new(forwarding_chains)), + imported: Arc::new(true), + }) + } +} diff --git a/matrix_sdk_crypto/src/olm/group_sessions/mod.rs b/matrix_sdk_crypto/src/olm/group_sessions/mod.rs new file mode 100644 index 00000000..b0e69599 --- /dev/null +++ b/matrix_sdk_crypto/src/olm/group_sessions/mod.rs @@ -0,0 +1,176 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use matrix_sdk_common::{ + events::forwarded_room_key::ForwardedRoomKeyEventContent, + identifiers::{DeviceKeyAlgorithm, EventEncryptionAlgorithm, RoomId}, +}; +use serde::{Deserialize, Serialize}; +use std::{collections::BTreeMap, convert::TryInto}; +use zeroize::Zeroize; + +mod inbound; +mod outbound; + +pub use inbound::{InboundGroupSession, InboundGroupSessionPickle, PickledInboundGroupSession}; +pub use outbound::{EncryptionSettings, OutboundGroupSession}; + +/// The private session key of a group session. +/// Can be used to create a new inbound group session. +#[derive(Clone, Debug, Serialize, Deserialize, Zeroize)] +#[zeroize(drop)] +pub struct GroupSessionKey(pub String); + +/// The exported version of an private session key of a group session. +/// Can be used to create a new inbound group session. +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Zeroize)] +#[zeroize(drop)] +pub struct ExportedGroupSessionKey(pub String); + +/// An exported version of a `InboundGroupSession` +/// +/// This can be used to share the `InboundGroupSession` in an exported file. +#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)] +pub struct ExportedRoomKey { + /// The encryption algorithm that the session uses. + pub algorithm: EventEncryptionAlgorithm, + + /// The room where the session is used. + pub room_id: RoomId, + + /// The Curve25519 key of the device which initiated the session originally. + pub sender_key: String, + + /// The ID of the session that the key is for. + pub session_id: String, + + /// The key for the session. + pub session_key: ExportedGroupSessionKey, + + /// The Ed25519 key of the device which initiated the session originally. + pub sender_claimed_keys: BTreeMap, + + /// Chain of Curve25519 keys through which this session was forwarded, via + /// m.forwarded_room_key events. + pub forwarding_curve25519_key_chain: Vec, +} + +impl TryInto for ExportedRoomKey { + type Error = (); + + /// Convert an exported room key into a content for a forwarded room key + /// event. + /// + /// This will fail if the exported room key has multiple sender claimed keys + /// or if the algorithm of the claimed sender key isn't + /// `DeviceKeyAlgorithm::Ed25519`. + fn try_into(self) -> Result { + if self.sender_claimed_keys.len() != 1 { + Err(()) + } else { + let (algorithm, claimed_key) = self.sender_claimed_keys.iter().next().ok_or(())?; + + if algorithm != &DeviceKeyAlgorithm::Ed25519 { + return Err(()); + } + + Ok(ForwardedRoomKeyEventContent { + algorithm: self.algorithm, + room_id: self.room_id, + sender_key: self.sender_key, + session_id: self.session_id, + session_key: self.session_key.0.clone(), + sender_claimed_ed25519_key: claimed_key.to_owned(), + forwarding_curve25519_key_chain: self.forwarding_curve25519_key_chain, + }) + } + } +} + +impl From for ExportedRoomKey { + /// Convert the content of a forwarded room key into a exported room key. + fn from(forwarded_key: ForwardedRoomKeyEventContent) -> Self { + let mut sender_claimed_keys: BTreeMap = BTreeMap::new(); + sender_claimed_keys.insert( + DeviceKeyAlgorithm::Ed25519, + forwarded_key.sender_claimed_ed25519_key, + ); + + Self { + algorithm: forwarded_key.algorithm, + room_id: forwarded_key.room_id, + session_id: forwarded_key.session_id, + forwarding_curve25519_key_chain: forwarded_key.forwarding_curve25519_key_chain, + sender_claimed_keys, + sender_key: forwarded_key.sender_key, + session_key: ExportedGroupSessionKey(forwarded_key.session_key), + } + } +} + +#[cfg(test)] +mod test { + use std::{ + sync::Arc, + time::{Duration, Instant}, + }; + + use matrix_sdk_common::{ + events::{ + room::message::{MessageEventContent, TextMessageEventContent}, + AnyMessageEventContent, + }, + identifiers::{room_id, user_id}, + }; + + use super::EncryptionSettings; + use crate::Account; + + #[tokio::test] + #[cfg(not(target_os = "macos"))] + async fn expiration() { + let settings = EncryptionSettings { + rotation_period_msgs: 1, + ..Default::default() + }; + + let account = Account::new(&user_id!("@alice:example.org"), "DEVICEID".into()); + let (session, _) = account + .create_group_session_pair(&room_id!("!test_room:example.org"), settings) + .await + .unwrap(); + + assert!(!session.expired()); + let _ = session + .encrypt(AnyMessageEventContent::RoomMessage( + MessageEventContent::Text(TextMessageEventContent::plain("Test message")), + )) + .await; + assert!(session.expired()); + + let settings = EncryptionSettings { + rotation_period: Duration::from_millis(100), + ..Default::default() + }; + + let (mut session, _) = account + .create_group_session_pair(&room_id!("!test_room:example.org"), settings) + .await + .unwrap(); + + assert!(!session.expired()); + session.creation_time = Arc::new(Instant::now() - Duration::from_secs(60 * 60)); + assert!(session.expired()); + } +} diff --git a/matrix_sdk_crypto/src/olm/group_sessions/outbound.rs b/matrix_sdk_crypto/src/olm/group_sessions/outbound.rs new file mode 100644 index 00000000..ac521272 --- /dev/null +++ b/matrix_sdk_crypto/src/olm/group_sessions/outbound.rs @@ -0,0 +1,303 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{ + cmp::min, + fmt, + sync::{ + atomic::{AtomicBool, AtomicU64, Ordering}, + Arc, + }, + time::Duration, +}; + +use matrix_sdk_common::{ + events::{ + room::{encrypted::EncryptedEventContent, encryption::EncryptionEventContent}, + AnyMessageEventContent, EventContent, + }, + identifiers::{DeviceId, EventEncryptionAlgorithm, RoomId}, + instant::Instant, + locks::Mutex, +}; +use olm_rs::outbound_group_session::OlmOutboundGroupSession; +use serde_json::{json, Value}; + +pub use olm_rs::{ + account::IdentityKeys, + session::{OlmMessage, PreKeyMessage}, + utility::OlmUtility, +}; + +use super::GroupSessionKey; + +const ROTATION_PERIOD: Duration = Duration::from_millis(604800000); +const ROTATION_MESSAGES: u64 = 100; + +/// Settings for an encrypted room. +/// +/// This determines the algorithm and rotation periods of a group session. +#[derive(Debug)] +pub struct EncryptionSettings { + /// The encryption algorithm that should be used in the room. + pub algorithm: EventEncryptionAlgorithm, + /// How long the session should be used before changing it. + pub rotation_period: Duration, + /// How many messages should be sent before changing the session. + pub rotation_period_msgs: u64, +} + +impl Default for EncryptionSettings { + fn default() -> Self { + Self { + algorithm: EventEncryptionAlgorithm::MegolmV1AesSha2, + rotation_period: ROTATION_PERIOD, + rotation_period_msgs: ROTATION_MESSAGES, + } + } +} + +impl From<&EncryptionEventContent> for EncryptionSettings { + fn from(content: &EncryptionEventContent) -> Self { + let rotation_period: Duration = content + .rotation_period_ms + .map_or(ROTATION_PERIOD, |r| Duration::from_millis(r.into())); + let rotation_period_msgs: u64 = content + .rotation_period_msgs + .map_or(ROTATION_MESSAGES, Into::into); + + Self { + algorithm: content.algorithm.clone(), + rotation_period, + rotation_period_msgs, + } + } +} +/// Outbound group session. +/// +/// Outbound group sessions are used to exchange room messages between a group +/// of participants. Outbound group sessions are used to encrypt the room +/// messages. +#[derive(Clone)] +pub struct OutboundGroupSession { + inner: Arc>, + device_id: Arc>, + account_identity_keys: Arc, + session_id: Arc, + room_id: Arc, + pub(crate) creation_time: Arc, + message_count: Arc, + shared: Arc, + settings: Arc, +} + +impl OutboundGroupSession { + /// Create a new outbound group session for the given room. + /// + /// Outbound group sessions are used to encrypt room messages. + /// + /// # Arguments + /// + /// * `device_id` - The id of the device that created this session. + /// + /// * `identity_keys` - The identity keys of the account that created this + /// session. + /// + /// * `room_id` - The id of the room that the session is used in. + /// + /// * `settings` - Settings determining the algorithm and rotation period of + /// the outbound group session. + pub fn new( + device_id: Arc>, + identity_keys: Arc, + room_id: &RoomId, + settings: EncryptionSettings, + ) -> Self { + let session = OlmOutboundGroupSession::new(); + let session_id = session.session_id(); + + OutboundGroupSession { + inner: Arc::new(Mutex::new(session)), + room_id: Arc::new(room_id.to_owned()), + device_id, + account_identity_keys: identity_keys, + session_id: Arc::new(session_id), + creation_time: Arc::new(Instant::now()), + message_count: Arc::new(AtomicU64::new(0)), + shared: Arc::new(AtomicBool::new(false)), + settings: Arc::new(settings), + } + } + + /// Encrypt the given plaintext using this session. + /// + /// Returns the encrypted ciphertext. + /// + /// # Arguments + /// + /// * `plaintext` - The plaintext that should be encrypted. + pub(crate) async fn encrypt_helper(&self, plaintext: String) -> String { + let session = self.inner.lock().await; + self.message_count.fetch_add(1, Ordering::SeqCst); + session.encrypt(plaintext) + } + + /// Encrypt a room message for the given room. + /// + /// Beware that a group session needs to be shared before this method can be + /// called using the `share_group_session()` method. + /// + /// Since group sessions can expire or become invalid if the room membership + /// changes client authors should check with the + /// `should_share_group_session()` method if a new group session needs to + /// be shared. + /// + /// # Arguments + /// + /// * `content` - The plaintext content of the message that should be + /// encrypted. + /// + /// # Panics + /// + /// Panics if the content can't be serialized. + pub async fn encrypt(&self, content: AnyMessageEventContent) -> EncryptedEventContent { + let json_content = json!({ + "content": content, + "room_id": &*self.room_id, + "type": content.event_type(), + }); + + let plaintext = cjson::to_string(&json_content).unwrap_or_else(|_| { + panic!(format!( + "Can't serialize {} to canonical JSON", + json_content + )) + }); + + let ciphertext = self.encrypt_helper(plaintext).await; + + EncryptedEventContent::MegolmV1AesSha2( + matrix_sdk_common::events::room::encrypted::MegolmV1AesSha2ContentInit { + ciphertext, + sender_key: self.account_identity_keys.curve25519().to_owned(), + session_id: self.session_id().to_owned(), + device_id: (&*self.device_id).to_owned(), + } + .into(), + ) + } + + /// Check if the session has expired and if it should be rotated. + /// + /// A session will expire after some time or if enough messages have been + /// encrypted using it. + pub fn expired(&self) -> bool { + let count = self.message_count.load(Ordering::SeqCst); + + count >= self.settings.rotation_period_msgs + || self.creation_time.elapsed() + // Since the encryption settings are provided by users and not + // checked someone could set a really low rotation perdiod so + // clamp it at a minute. + >= min(self.settings.rotation_period, Duration::from_secs(3600)) + } + + /// Mark the session as shared. + /// + /// Messages shouldn't be encrypted with the session before it has been + /// shared. + pub fn mark_as_shared(&self) { + self.shared.store(true, Ordering::Relaxed); + } + + /// Check if the session has been marked as shared. + pub fn shared(&self) -> bool { + self.shared.load(Ordering::Relaxed) + } + + /// Get the session key of this session. + /// + /// A session key can be used to to create an `InboundGroupSession`. + pub async fn session_key(&self) -> GroupSessionKey { + let session = self.inner.lock().await; + GroupSessionKey(session.session_key()) + } + + /// Returns the unique identifier for this session. + pub fn session_id(&self) -> &str { + &self.session_id + } + + /// Get the current message index for this session. + /// + /// Each message is sent with an increasing index. This returns the + /// message index that will be used for the next encrypted message. + pub async fn message_index(&self) -> u32 { + let session = self.inner.lock().await; + session.session_message_index() + } + + /// Get the outbound group session key as a json value that can be sent as a + /// m.room_key. + pub async fn as_json(&self) -> Value { + json!({ + "algorithm": EventEncryptionAlgorithm::MegolmV1AesSha2, + "room_id": &*self.room_id, + "session_id": &*self.session_id, + "session_key": self.session_key().await, + "chain_index": self.message_index().await, + }) + } +} + +#[cfg(not(tarpaulin_include))] +impl std::fmt::Debug for OutboundGroupSession { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("OutboundGroupSession") + .field("session_id", &self.session_id) + .field("room_id", &self.room_id) + .field("creation_time", &self.creation_time) + .field("message_count", &self.message_count) + .finish() + } +} + +#[cfg(test)] +mod test { + use std::time::Duration; + + use matrix_sdk_common::{ + events::room::encryption::EncryptionEventContent, identifiers::EventEncryptionAlgorithm, + js_int::uint, + }; + + use super::{EncryptionSettings, ROTATION_MESSAGES, ROTATION_PERIOD}; + + #[test] + fn encryption_settings_conversion() { + let mut content = EncryptionEventContent::new(EventEncryptionAlgorithm::MegolmV1AesSha2); + let settings = EncryptionSettings::from(&content); + + assert_eq!(settings.rotation_period, ROTATION_PERIOD); + assert_eq!(settings.rotation_period_msgs, ROTATION_MESSAGES); + + content.rotation_period_ms = Some(uint!(3600)); + content.rotation_period_msgs = Some(uint!(500)); + + let settings = EncryptionSettings::from(&content); + + assert_eq!(settings.rotation_period, Duration::from_millis(3600)); + assert_eq!(settings.rotation_period_msgs, 500); + } +} diff --git a/matrix_sdk_crypto/src/olm/mod.rs b/matrix_sdk_crypto/src/olm/mod.rs index c69293c5..e88b4959 100644 --- a/matrix_sdk_crypto/src/olm/mod.rs +++ b/matrix_sdk_crypto/src/olm/mod.rs @@ -24,7 +24,8 @@ mod utility; pub use account::{Account, AccountPickle, IdentityKeys, PickledAccount}; pub use group_sessions::{ - EncryptionSettings, InboundGroupSession, InboundGroupSessionPickle, PickledInboundGroupSession, + EncryptionSettings, ExportedRoomKey, InboundGroupSession, InboundGroupSessionPickle, + PickledInboundGroupSession, }; pub(crate) use group_sessions::{GroupSessionKey, OutboundGroupSession}; pub use olm_rs::PicklingMode; @@ -37,10 +38,11 @@ pub(crate) mod test { use crate::olm::{Account, InboundGroupSession, Session}; use matrix_sdk_common::{ api::r0::keys::SignedKey, + events::forwarded_room_key::ForwardedRoomKeyEventContent, identifiers::{room_id, user_id, DeviceId, UserId}, }; use olm_rs::session::OlmMessage; - use std::collections::BTreeMap; + use std::{collections::BTreeMap, convert::TryInto}; fn alice_id() -> UserId { user_id!("@alice:example.org") @@ -221,4 +223,22 @@ pub(crate) mod test { inbound.decrypt_helper(ciphertext).await.unwrap().0 ); } + + #[tokio::test] + async fn group_session_export() { + let alice = Account::new(&alice_id(), &alice_device_id()); + let room_id = room_id!("!test:localhost"); + + let (_, inbound) = alice + .create_group_session_pair(&room_id, Default::default()) + .await + .unwrap(); + + let export = inbound.export().await; + let export: ForwardedRoomKeyEventContent = export.try_into().unwrap(); + + let imported = InboundGroupSession::from_export(export).unwrap(); + + assert_eq!(inbound.session_id(), imported.session_id()); + } } diff --git a/matrix_sdk_crypto/src/store/caches.rs b/matrix_sdk_crypto/src/store/caches.rs index 91344257..cc56c100 100644 --- a/matrix_sdk_crypto/src/store/caches.rs +++ b/matrix_sdk_crypto/src/store/caches.rs @@ -106,6 +106,19 @@ impl GroupSessionStore { .is_none() } + /// Get all the group sessions the store knows about. + pub fn get_all(&self) -> Vec { + self.entries + .iter() + .flat_map(|d| { + d.value() + .values() + .flat_map(|t| t.values().cloned().collect::>()) + .collect::>() + }) + .collect() + } + /// Get a inbound group session from our store. /// /// # Arguments diff --git a/matrix_sdk_crypto/src/store/memorystore.rs b/matrix_sdk_crypto/src/store/memorystore.rs index f7d3d753..4de015bf 100644 --- a/matrix_sdk_crypto/src/store/memorystore.rs +++ b/matrix_sdk_crypto/src/store/memorystore.rs @@ -80,8 +80,12 @@ impl CryptoStore for MemoryStore { Ok(self.sessions.get(sender_key)) } - async fn save_inbound_group_session(&self, session: InboundGroupSession) -> Result { - Ok(self.inbound_group_sessions.add(session)) + async fn save_inbound_group_sessions(&self, sessions: &[InboundGroupSession]) -> Result<()> { + for session in sessions { + self.inbound_group_sessions.add(session.clone()); + } + + Ok(()) } async fn get_inbound_group_session( @@ -95,6 +99,10 @@ impl CryptoStore for MemoryStore { .get(room_id, sender_key, session_id)) } + async fn get_inbound_group_sessions(&self) -> Result> { + Ok(self.inbound_group_sessions.get_all()) + } + fn users_for_key_query(&self) -> HashSet { #[allow(clippy::map_clone)] self.users_for_key_query.iter().map(|u| u.clone()).collect() @@ -208,7 +216,7 @@ mod test { let store = MemoryStore::new(); let _ = store - .save_inbound_group_session(inbound.clone()) + .save_inbound_group_sessions(&[inbound.clone()]) .await .unwrap(); diff --git a/matrix_sdk_crypto/src/store/mod.rs b/matrix_sdk_crypto/src/store/mod.rs index a7cce545..88d2ec90 100644 --- a/matrix_sdk_crypto/src/store/mod.rs +++ b/matrix_sdk_crypto/src/store/mod.rs @@ -157,15 +157,12 @@ pub trait CryptoStore: Debug { /// * `sender_key` - The sender key that was used to establish the sessions. async fn get_sessions(&self, sender_key: &str) -> Result>>>>; - /// Save the given inbound group session in the store. - /// - /// If the session wasn't already in the store true is returned, false - /// otherwise. + /// Save the given inbound group sessions in the store. /// /// # Arguments /// - /// * `session` - The session that should be stored. - async fn save_inbound_group_session(&self, session: InboundGroupSession) -> Result; + /// * `sessions` - The sessions that should be stored. + async fn save_inbound_group_sessions(&self, session: &[InboundGroupSession]) -> Result<()>; /// Get the inbound group session from our store. /// @@ -182,6 +179,9 @@ pub trait CryptoStore: Debug { session_id: &str, ) -> Result>; + /// Get all the inbound group sessions we have stored. + async fn get_inbound_group_sessions(&self) -> Result>; + /// Is the given user already tracked. fn is_user_tracked(&self, user_id: &UserId) -> bool; diff --git a/matrix_sdk_crypto/src/store/sqlite.rs b/matrix_sdk_crypto/src/store/sqlite.rs index 50b0b5f4..91b2a339 100644 --- a/matrix_sdk_crypto/src/store/sqlite.rs +++ b/matrix_sdk_crypto/src/store/sqlite.rs @@ -23,6 +23,7 @@ use std::{ use async_trait::async_trait; use dashmap::DashSet; use matrix_sdk_common::{ + api::r0::keys::{CrossSigningKey, KeyUsage}, identifiers::{ DeviceId, DeviceKeyAlgorithm, DeviceKeyId, EventEncryptionAlgorithm, RoomId, UserId, }, @@ -38,7 +39,7 @@ use super::{ CryptoStore, CryptoStoreError, Result, }; use crate::{ - identities::{LocalTrust, ReadOnlyDevice, UserIdentities}, + identities::{LocalTrust, OwnUserIdentity, ReadOnlyDevice, UserIdentities, UserIdentity}, olm::{ Account, AccountPickle, IdentityKeys, InboundGroupSession, InboundGroupSessionPickle, PickledAccount, PickledInboundGroupSession, PickledSession, PicklingMode, Session, @@ -71,6 +72,14 @@ struct AccountInfo { identity_keys: Arc, } +#[derive(Debug, PartialEq, Copy, Clone, sqlx::Type)] +#[repr(i32)] +enum CrosssigningKeyType { + Master = 0, + SelfSigning = 1, + UserSigning = 2, +} + static DATABASE_NAME: &str = "matrix-sdk-crypto.db"; impl SqliteStore { @@ -135,8 +144,8 @@ impl SqliteStore { passphrase: Option>, ) -> Result { let url = SqliteStore::path_to_url(path.as_ref())?; - let connection = SqliteConnection::connect(url.as_ref()).await?; + let store = SqliteStore { user_id: Arc::new(user_id.to_owned()), device_id: Arc::new(device_id.into()), @@ -151,6 +160,7 @@ impl SqliteStore { users_for_key_query: Arc::new(DashSet::new()), }; store.create_tables().await?; + Ok(store) } @@ -221,14 +231,16 @@ impl SqliteStore { .execute( r#" CREATE TABLE IF NOT EXISTS inbound_group_sessions ( - "session_id" TEXT NOT NULL PRIMARY KEY, + "id" INTEGER NOT NULL PRIMARY KEY, + "session_id" TEXT NOT NULL, "account_id" INTEGER NOT NULL, "sender_key" TEXT NOT NULL, - "signing_key" TEXT NOT NULL, "room_id" TEXT NOT NULL, "pickle" BLOB NOT NULL, + "imported" INTEGER NOT NULL, FOREIGN KEY ("account_id") REFERENCES "accounts" ("id") ON DELETE CASCADE + UNIQUE(account_id,session_id,sender_key) ); CREATE INDEX IF NOT EXISTS "olm_groups_sessions_account_id" ON "inbound_group_sessions" ("account_id"); @@ -236,6 +248,41 @@ impl SqliteStore { ) .await?; + connection + .execute( + r#" + CREATE TABLE IF NOT EXISTS group_session_claimed_keys ( + "id" INTEGER NOT NULL PRIMARY KEY, + "session_id" INTEGER NOT NULL, + "algorithm" TEXT NOT NULL, + "key" TEXT NOT NULL, + FOREIGN KEY ("session_id") REFERENCES "inbound_group_sessions" ("id") + ON DELETE CASCADE + UNIQUE(session_id, algorithm) + ); + + CREATE INDEX IF NOT EXISTS "group_session_claimed_keys_session_id" ON "group_session_claimed_keys" ("session_id"); + "#, + ) + .await?; + + connection + .execute( + r#" + CREATE TABLE IF NOT EXISTS group_session_chains ( + "id" INTEGER NOT NULL PRIMARY KEY, + "key" TEXT NOT NULL, + "session_id" INTEGER NOT NULL, + FOREIGN KEY ("session_id") REFERENCES "inbound_group_sessions" ("id") + ON DELETE CASCADE + UNIQUE(session_id, key) + ); + + CREATE INDEX IF NOT EXISTS "group_session_chains_session_id" ON "group_session_chains" ("session_id"); + "#, + ) + .await?; + connection .execute( r#" @@ -310,6 +357,91 @@ impl SqliteStore { ) .await?; + connection + .execute( + r#" + CREATE TABLE IF NOT EXISTS users ( + "id" INTEGER NOT NULL PRIMARY KEY, + "account_id" INTEGER NOT NULL, + "user_id" TEXT NOT NULL, + FOREIGN KEY ("account_id") REFERENCES "accounts" ("id") + ON DELETE CASCADE + UNIQUE(account_id,user_id) + ); + + CREATE INDEX IF NOT EXISTS "users_account_id" ON "users" ("account_id"); + "#, + ) + .await?; + + connection + .execute( + r#" + CREATE TABLE IF NOT EXISTS users_trust_state ( + "id" INTEGER NOT NULL PRIMARY KEY, + "trusted" INTEGER NOT NULL, + "user_id" INTEGER NOT NULL, + FOREIGN KEY ("user_id") REFERENCES "users" ("id") + ON DELETE CASCADE + UNIQUE(user_id) + ); + "#, + ) + .await?; + + connection + .execute( + r#" + CREATE TABLE IF NOT EXISTS cross_signing_keys ( + "id" INTEGER NOT NULL PRIMARY KEY, + "key_type" INTEGER NOT NULL, + "usage" STRING NOT NULL, + "user_id" INTEGER NOT NULL, + FOREIGN KEY ("user_id") REFERENCES "users" ("id") ON DELETE CASCADE + UNIQUE(user_id, key_type) + ); + + CREATE INDEX IF NOT EXISTS "cross_signing_keys_users" ON "users" ("user_id"); + "#, + ) + .await?; + + connection + .execute( + r#" + CREATE TABLE IF NOT EXISTS user_keys ( + "id" INTEGER NOT NULL PRIMARY KEY, + "key" TEXT NOT NULL, + "key_id" TEXT NOT NULL, + "cross_signing_key" INTEGER NOT NULL, + FOREIGN KEY ("cross_signing_key") REFERENCES "cross_signing_keys" ("id") ON DELETE CASCADE + UNIQUE(cross_signing_key, key_id) + ); + + CREATE INDEX IF NOT EXISTS "cross_signing_keys_keys" ON "cross_signing_keys" ("cross_signing_key"); + "#, + ) + .await?; + + connection + .execute( + r#" + CREATE TABLE IF NOT EXISTS user_key_signatures ( + "id" INTEGER NOT NULL PRIMARY KEY, + "user_id" TEXT NOT NULL, + "key_id" INTEGER NOT NULL, + "signature" TEXT NOT NULL, + "cross_signing_key" INTEGER NOT NULL, + FOREIGN KEY ("cross_signing_key") REFERENCES "cross_signing_keys" ("id") + ON DELETE CASCADE + UNIQUE(user_id, key_id, cross_signing_key) + ); + + CREATE INDEX IF NOT EXISTS "cross_signing_keys_signatures" ON "cross_signing_keys" ("cross_signing_key"); + "#, + ) + .await?; + Ok(()) } @@ -380,45 +512,67 @@ impl SqliteStore { let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; let mut connection = self.connection.lock().await; - let mut rows: Vec<(String, String, String, String)> = query_as( - "SELECT pickle, sender_key, signing_key, room_id + let mut rows: Vec<(i64, String, String, String, bool)> = query_as( + "SELECT id, pickle, sender_key, room_id, imported FROM inbound_group_sessions WHERE account_id = ?", ) .bind(account_id) .fetch_all(&mut *connection) .await?; - let mut group_sessions = rows - .drain(..) - .map(|row| { - let pickle = row.0; - let sender_key = row.1; - let signing_key = row.2; - let room_id = row.3; + for row in rows.drain(..) { + let session_row_id = row.0; + let pickle = row.1; + let sender_key = row.2; + let room_id = row.3; + let imported = row.4; - let pickle = PickledInboundGroupSession { - pickle: InboundGroupSessionPickle::from(pickle), - sender_key, - signing_key, - room_id: RoomId::try_from(room_id)?, - // Fixme we need to store/restore these once we get support - // for key requesting/forwarding. - forwarding_chains: None, - }; + let key_rows: Vec<(String, String)> = query_as( + "SELECT algorithm, key FROM group_session_claimed_keys WHERE session_id = ?", + ) + .bind(session_row_id) + .fetch_all(&mut *connection) + .await?; - Ok(InboundGroupSession::from_pickle( + let claimed_keys: BTreeMap = key_rows + .into_iter() + .filter_map(|row| { + let algorithm = row.0.parse::().ok()?; + let key = row.1; + + Some((algorithm, key)) + }) + .collect(); + + let mut chain_rows: Vec<(String,)> = + query_as("SELECT key, key FROM group_session_chains WHERE session_id = ?") + .bind(session_row_id) + .fetch_all(&mut *connection) + .await?; + + let chains: Vec = chain_rows.drain(..).map(|r| r.0).collect(); + + let chains = if chains.is_empty() { + None + } else { + Some(chains) + }; + + let pickle = PickledInboundGroupSession { + pickle: InboundGroupSessionPickle::from(pickle), + sender_key, + signing_key: claimed_keys, + room_id: RoomId::try_from(room_id)?, + forwarding_chains: chains, + imported, + }; + + self.inbound_group_sessions + .add(InboundGroupSession::from_pickle( pickle, self.get_pickle_mode(), - )?) - }) - .collect::>>()?; - - group_sessions - .drain(..) - .map(|s| { - self.inbound_group_sessions.add(s); - }) - .for_each(drop); + )?); + } Ok(()) } @@ -670,6 +824,325 @@ impl SqliteStore { None => PicklingMode::Unencrypted, } } + + async fn save_inbound_group_session_helper( + &self, + account_id: i64, + connection: &mut SqliteConnection, + session: &InboundGroupSession, + ) -> Result<()> { + let pickle = session.pickle(self.get_pickle_mode()).await; + let session_id = session.session_id(); + + query( + "REPLACE INTO inbound_group_sessions ( + session_id, account_id, sender_key, + room_id, pickle, imported + ) VALUES (?1, ?2, ?3, ?4, ?5, ?6) + ", + ) + .bind(session_id) + .bind(account_id) + .bind(&pickle.sender_key) + .bind(pickle.room_id.as_str()) + .bind(pickle.pickle.as_str()) + .bind(pickle.imported) + .execute(&mut *connection) + .await?; + + let row: (i64,) = query_as( + "SELECT id FROM inbound_group_sessions + WHERE account_id = ? and session_id = ? and sender_key = ?", + ) + .bind(account_id) + .bind(session_id) + .bind(pickle.sender_key) + .fetch_one(&mut *connection) + .await?; + + let session_row_id = row.0; + + for (key_id, key) in pickle.signing_key { + query( + "REPLACE INTO group_session_claimed_keys ( + session_id, algorithm, key + ) VALUES (?1, ?2, ?3) + ", + ) + .bind(session_row_id) + .bind(serde_json::to_string(&key_id)?) + .bind(key) + .execute(&mut *connection) + .await?; + } + + if let Some(chains) = pickle.forwarding_chains { + for key in chains { + query( + "REPLACE INTO group_session_chains ( + session_id, key + ) VALUES (?1, ?2) + ", + ) + .bind(session_row_id) + .bind(key) + .execute(&mut *connection) + .await?; + } + } + + Ok(()) + } + + async fn load_cross_signing_key( + connection: &mut SqliteConnection, + user_id: &UserId, + user_row_id: i64, + key_type: CrosssigningKeyType, + ) -> Result { + let row: (i64, String) = + query_as("SELECT id, usage FROM cross_signing_keys WHERE user_id =? and key_type =?") + .bind(user_row_id) + .bind(key_type) + .fetch_one(&mut *connection) + .await?; + + let key_row_id = row.0; + let usage: Vec = serde_json::from_str(&row.1)?; + + let key_rows: Vec<(String, String)> = + query_as("SELECT key_id, key FROM user_keys WHERE cross_signing_key = ?") + .bind(key_row_id) + .fetch_all(&mut *connection) + .await?; + + let mut keys = BTreeMap::new(); + let mut signatures = BTreeMap::new(); + + for row in key_rows { + let key_id = row.0; + let key = row.1; + + keys.insert(key_id, key); + } + + let mut signature_rows: Vec<(String, String, String)> = query_as( + "SELECT user_id, key_id, signature FROM user_key_signatures WHERE cross_signing_key = ?", + ) + .bind(key_row_id) + .fetch_all(&mut *connection) + .await?; + + for row in signature_rows.drain(..) { + let user_id = if let Ok(u) = UserId::try_from(row.0) { + u + } else { + continue; + }; + + let key_id = row.1; + let signature = row.2; + + signatures + .entry(user_id) + .or_insert_with(BTreeMap::new) + .insert(key_id, signature); + } + + Ok(CrossSigningKey { + user_id: user_id.to_owned(), + usage, + keys, + signatures, + }) + } + + async fn load_user(&self, user_id: &UserId) -> Result> { + let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; + let mut connection = self.connection.lock().await; + + let row: Option<(i64,)> = + query_as("SELECT id FROM users WHERE account_id = ? and user_id = ?") + .bind(account_id) + .bind(user_id.as_str()) + .fetch_optional(&mut *connection) + .await?; + + let user_row_id = if let Some(row) = row { + row.0 + } else { + return Ok(None); + }; + + let master = SqliteStore::load_cross_signing_key( + &mut connection, + user_id, + user_row_id, + CrosssigningKeyType::Master, + ) + .await?; + let self_singing = SqliteStore::load_cross_signing_key( + &mut connection, + user_id, + user_row_id, + CrosssigningKeyType::SelfSigning, + ) + .await?; + + if user_id == &*self.user_id { + let user_signing = SqliteStore::load_cross_signing_key( + &mut connection, + user_id, + user_row_id, + CrosssigningKeyType::UserSigning, + ) + .await?; + + let verified: Option<(bool,)> = + query_as("SELECT trusted FROM users_trust_state WHERE user_id = ?") + .bind(user_row_id) + .fetch_optional(&mut *connection) + .await?; + + let verified = verified.map_or(false, |r| r.0); + + let identity = + OwnUserIdentity::new(master.into(), self_singing.into(), user_signing.into()) + .expect("Signature check failed on stored identity"); + + if verified { + identity.mark_as_verified(); + } + + Ok(Some(UserIdentities::Own(identity))) + } else { + Ok(Some(UserIdentities::Other( + UserIdentity::new(master.into(), self_singing.into()) + .expect("Signature check failed on stored identity"), + ))) + } + } + + async fn save_cross_signing_key( + connection: &mut SqliteConnection, + user_row_id: i64, + key_type: CrosssigningKeyType, + cross_signing_key: impl AsRef, + ) -> Result<()> { + let cross_signing_key: &CrossSigningKey = cross_signing_key.as_ref(); + + query( + "REPLACE INTO cross_signing_keys ( + user_id, key_type, usage + ) VALUES (?1, ?2, ?3) + ", + ) + .bind(user_row_id) + .bind(key_type) + .bind(serde_json::to_string(&cross_signing_key.usage)?) + .execute(&mut *connection) + .await?; + + let row: (i64,) = query_as( + "SELECT id FROM cross_signing_keys + WHERE user_id = ? and key_type = ?", + ) + .bind(user_row_id) + .bind(key_type) + .fetch_one(&mut *connection) + .await?; + + let key_row_id = row.0; + + for (key_id, key) in &cross_signing_key.keys { + query( + "REPLACE INTO user_keys ( + cross_signing_key, key_id, key + ) VALUES (?1, ?2, ?3) + ", + ) + .bind(key_row_id) + .bind(key_id.as_str()) + .bind(key) + .execute(&mut *connection) + .await?; + } + + for (user_id, signature_map) in &cross_signing_key.signatures { + for (key_id, signature) in signature_map { + query( + "REPLACE INTO user_key_signatures ( + cross_signing_key, user_id, key_id, signature + ) VALUES (?1, ?2, ?3, ?4) + ", + ) + .bind(key_row_id) + .bind(user_id.as_str()) + .bind(key_id.as_str()) + .bind(signature) + .execute(&mut *connection) + .await?; + } + } + + Ok(()) + } + + async fn save_user_helper(&self, user: &UserIdentities) -> Result<()> { + let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; + + let mut connection = self.connection.lock().await; + + query("REPLACE INTO users (account_id, user_id) VALUES (?1, ?2)") + .bind(account_id) + .bind(user.user_id().as_str()) + .execute(&mut *connection) + .await?; + + let row: (i64,) = query_as( + "SELECT id FROM users + WHERE account_id = ? and user_id = ?", + ) + .bind(account_id) + .bind(user.user_id().as_str()) + .fetch_one(&mut *connection) + .await?; + + let user_row_id = row.0; + + SqliteStore::save_cross_signing_key( + &mut connection, + user_row_id, + CrosssigningKeyType::Master, + user.master_key(), + ) + .await?; + SqliteStore::save_cross_signing_key( + &mut connection, + user_row_id, + CrosssigningKeyType::SelfSigning, + user.self_signing_key(), + ) + .await?; + + if let UserIdentities::Own(own_identity) = user { + SqliteStore::save_cross_signing_key( + &mut connection, + user_row_id, + CrosssigningKeyType::UserSigning, + own_identity.user_signing_key(), + ) + .await?; + + query("REPLACE INTO users_trust_state (user_id, trusted) VALUES (?1, ?2)") + .bind(user_row_id) + .bind(own_identity.is_verified()) + .execute(&mut *connection) + .await?; + } + + Ok(()) + } } #[async_trait] @@ -790,35 +1263,19 @@ impl CryptoStore for SqliteStore { Ok(self.get_sessions_for(sender_key).await?) } - async fn save_inbound_group_session(&self, session: InboundGroupSession) -> Result { + async fn save_inbound_group_sessions(&self, sessions: &[InboundGroupSession]) -> Result<()> { let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; - let pickle = session.pickle(self.get_pickle_mode()).await; let mut connection = self.connection.lock().await; - let session_id = session.session_id(); - // FIXME we need to store/restore the forwarding chains. - // FIXME this should be converted so it accepts an array of sessions for - // the key import feature. + // FIXME use a transaction here once sqlx gets better support for them. - query( - "INSERT INTO inbound_group_sessions ( - session_id, account_id, sender_key, signing_key, - room_id, pickle - ) VALUES (?1, ?2, ?3, ?4, ?5, ?6) - ON CONFLICT(session_id) DO UPDATE SET - pickle = excluded.pickle - ", - ) - .bind(session_id) - .bind(account_id) - .bind(pickle.sender_key) - .bind(pickle.signing_key) - .bind(pickle.room_id.as_str()) - .bind(pickle.pickle.as_str()) - .execute(&mut *connection) - .await?; + for session in sessions { + self.save_inbound_group_session_helper(account_id, &mut connection, session) + .await?; + self.inbound_group_sessions.add(session.clone()); + } - Ok(self.inbound_group_sessions.add(session)) + Ok(()) } async fn get_inbound_group_session( @@ -832,6 +1289,10 @@ impl CryptoStore for SqliteStore { .get(room_id, sender_key, session_id)) } + async fn get_inbound_group_sessions(&self) -> Result> { + Ok(self.inbound_group_sessions.get_all()) + } + fn is_user_tracked(&self, user_id: &UserId) -> bool { self.tracked_users.contains(user_id) } @@ -899,11 +1360,15 @@ impl CryptoStore for SqliteStore { Ok(self.devices.user_devices(user_id)) } - async fn get_user_identity(&self, _user_id: &UserId) -> Result> { - Ok(None) + async fn get_user_identity(&self, user_id: &UserId) -> Result> { + self.load_user(user_id).await } - async fn save_user_identities(&self, _users: &[UserIdentities]) -> Result<()> { + async fn save_user_identities(&self, users: &[UserIdentities]) -> Result<()> { + for user in users { + self.save_user_helper(user).await?; + } + Ok(()) } } @@ -922,7 +1387,10 @@ impl std::fmt::Debug for SqliteStore { #[cfg(test)] mod test { use crate::{ - identities::device::test::get_device, + identities::{ + device::test::get_device, + user::test::{get_other_identity, get_own_identity}, + }, olm::{Account, GroupSessionKey, InboundGroupSession, Session}, }; use matrix_sdk_common::{ @@ -1175,14 +1643,14 @@ mod test { .expect("Can't create session"); store - .save_inbound_group_session(session) + .save_inbound_group_sessions(&[session]) .await .expect("Can't save group session"); } #[tokio::test] async fn load_inbound_group_session() { - let (account, store, _dir) = get_loaded_store().await; + let (account, store, dir) = get_loaded_store().await; let identity_keys = account.identity_keys(); let outbound_session = OlmOutboundGroupSession::new(); @@ -1194,11 +1662,22 @@ mod test { ) .expect("Can't create session"); + let mut export = session.export().await; + + export.forwarding_curve25519_key_chain = vec!["some_chain".to_owned()]; + + let session = InboundGroupSession::from_export(export).unwrap(); + store - .save_inbound_group_session(session.clone()) + .save_inbound_group_sessions(&[session.clone()]) .await .expect("Can't save group session"); + let store = SqliteStore::open(&alice_id(), &alice_device_id(), dir.path()) + .await + .expect("Can't create store"); + + store.load_account().await.unwrap(); store.load_inbound_group_sessions().await.unwrap(); let loaded_session = store @@ -1207,6 +1686,8 @@ mod test { .unwrap() .unwrap(); assert_eq!(session, loaded_session); + let export = loaded_session.export().await; + assert!(!export.forwarding_curve25519_key_chain.is_empty()) } #[tokio::test] @@ -1311,4 +1792,81 @@ mod test { assert!(loaded_device.is_none()); } + + #[tokio::test] + async fn user_saving() { + let dir = tempdir().unwrap(); + let tmpdir_path = dir.path().to_str().unwrap(); + + let user_id = user_id!("@example:localhost"); + let device_id: &DeviceId = "WSKKLTJZCL".into(); + + let store = SqliteStore::open(&user_id, &device_id, tmpdir_path) + .await + .expect("Can't create store"); + + let account = Account::new(&user_id, &device_id); + + store + .save_account(account.clone()) + .await + .expect("Can't save account"); + + let own_identity = get_own_identity(); + + store + .save_user_identities(&[own_identity.clone().into()]) + .await + .expect("Can't save identity"); + + drop(store); + + let store = SqliteStore::open(&user_id, &device_id, dir.path()) + .await + .expect("Can't create store"); + + store.load_account().await.unwrap(); + + let loaded_user = store + .get_user_identity(own_identity.user_id()) + .await + .unwrap() + .unwrap(); + + assert_eq!(loaded_user.master_key(), own_identity.master_key()); + assert_eq!( + loaded_user.self_signing_key(), + own_identity.self_signing_key() + ); + assert_eq!(loaded_user, own_identity.clone().into()); + + let other_identity = get_other_identity(); + + store + .save_user_identities(&[other_identity.clone().into()]) + .await + .unwrap(); + + let loaded_user = store + .load_user(other_identity.user_id()) + .await + .unwrap() + .unwrap(); + + assert_eq!(loaded_user.master_key(), other_identity.master_key()); + assert_eq!( + loaded_user.self_signing_key(), + other_identity.self_signing_key() + ); + assert_eq!(loaded_user, other_identity.into()); + + own_identity.mark_as_verified(); + + store + .save_user_identities(&[own_identity.into()]) + .await + .unwrap(); + let loaded_user = store.load_user(&user_id).await.unwrap().unwrap(); + assert!(loaded_user.own().unwrap().is_verified()) + } }