Merge branch 'key_export' into master

master
Damir Jelić 2020-09-10 16:32:41 +02:00
commit 126ac3059b
15 changed files with 2073 additions and 655 deletions

View File

@ -25,7 +25,8 @@ async-trait = "0.1.40"
matrix-sdk-common-macros = { version = "0.1.0", path = "../matrix_sdk_common_macros" } matrix-sdk-common-macros = { version = "0.1.0", path = "../matrix_sdk_common_macros" }
matrix-sdk-common = { version = "0.1.0", path = "../matrix_sdk_common" } matrix-sdk-common = { version = "0.1.0", path = "../matrix_sdk_common" }
olm-rs = { git = 'https://gitlab.gnome.org/jhaye/olm-rs/', features = ["serde"]} olm-rs = { version = "0.6.0", features = ["serde"] }
getrandom = "0.1.14"
serde = { version = "1.0.115", features = ["derive", "rc"] } serde = { version = "1.0.115", features = ["derive", "rc"] }
serde_json = "1.0.57" serde_json = "1.0.57"
cjson = "0.1.1" cjson = "0.1.1"
@ -37,6 +38,12 @@ thiserror = "1.0.20"
tracing = "0.1.19" tracing = "0.1.19"
atomic = "0.5.0" atomic = "0.5.0"
dashmap = "3.11.10" dashmap = "3.11.10"
sha2 = "0.9.1"
aes-ctr = "0.4.0"
pbkdf2 = { version = "0.5.0", default-features = false }
hmac = "0.9.0"
base64 = "0.12.3"
byteorder = "1.3.4"
[dependencies.tracing-futures] [dependencies.tracing-futures]
version = "0.2.4" version = "0.2.4"
@ -47,7 +54,7 @@ features = ["std", "std-future"]
version = "0.3.5" version = "0.3.5"
optional = true optional = true
default-features = false default-features = false
features = ["runtime-tokio", "sqlite"] features = ["runtime-tokio", "sqlite", "macros"]
[dev-dependencies] [dev-dependencies]
tokio = { version = "0.2.22", features = ["rt-threaded", "macros"] } tokio = { version = "0.2.22", features = ["rt-threaded", "macros"] }
@ -57,3 +64,4 @@ serde_json = "1.0.57"
tempfile = "3.1.0" tempfile = "3.1.0"
http = "0.2.1" http = "0.2.1"
matrix-sdk-test = { version = "0.1.0", path = "../matrix_sdk_test" } matrix-sdk-test = { version = "0.1.0", path = "../matrix_sdk_test" }
indoc = "1.0.2"

View File

@ -41,7 +41,7 @@
//! Both identity sets need to reqularly fetched from the server using the //! Both identity sets need to reqularly fetched from the server using the
//! `/keys/query` API call. //! `/keys/query` API call.
pub(crate) mod device; pub(crate) mod device;
mod user; pub(crate) mod user;
pub use device::{Device, LocalTrust, ReadOnlyDevice, UserDevices}; pub use device::{Device, LocalTrust, ReadOnlyDevice, UserDevices};
pub use user::{ pub use user::{

View File

@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
use std::{ use std::{
collections::{btree_map::Iter, BTreeMap},
convert::TryFrom, convert::TryFrom,
sync::{ sync::{
atomic::{AtomicBool, Ordering}, atomic::{AtomicBool, Ordering},
@ -24,7 +25,7 @@ use serde::{Deserialize, Serialize};
use serde_json::to_value; use serde_json::to_value;
use matrix_sdk_common::{ use matrix_sdk_common::{
api::r0::keys::CrossSigningKey, api::r0::keys::{CrossSigningKey, KeyUsage},
identifiers::{DeviceKeyId, UserId}, identifiers::{DeviceKeyId, UserId},
}; };
@ -55,6 +56,54 @@ impl PartialEq for MasterPubkey {
} }
} }
impl PartialEq for SelfSigningPubkey {
fn eq(&self, other: &SelfSigningPubkey) -> bool {
self.0.user_id == other.0.user_id && self.0.keys == other.0.keys
}
}
impl PartialEq for UserSigningPubkey {
fn eq(&self, other: &UserSigningPubkey) -> bool {
self.0.user_id == other.0.user_id && self.0.keys == other.0.keys
}
}
impl From<CrossSigningKey> for MasterPubkey {
fn from(key: CrossSigningKey) -> Self {
Self(Arc::new(key))
}
}
impl From<CrossSigningKey> for SelfSigningPubkey {
fn from(key: CrossSigningKey) -> Self {
Self(Arc::new(key))
}
}
impl From<CrossSigningKey> for UserSigningPubkey {
fn from(key: CrossSigningKey) -> Self {
Self(Arc::new(key))
}
}
impl AsRef<CrossSigningKey> for MasterPubkey {
fn as_ref(&self) -> &CrossSigningKey {
&self.0
}
}
impl AsRef<CrossSigningKey> for SelfSigningPubkey {
fn as_ref(&self) -> &CrossSigningKey {
&self.0
}
}
impl AsRef<CrossSigningKey> for UserSigningPubkey {
fn as_ref(&self) -> &CrossSigningKey {
&self.0
}
}
impl From<&CrossSigningKey> for MasterPubkey { impl From<&CrossSigningKey> for MasterPubkey {
fn from(key: &CrossSigningKey) -> Self { fn from(key: &CrossSigningKey) -> Self {
Self(Arc::new(key.clone())) Self(Arc::new(key.clone()))
@ -117,6 +166,21 @@ impl MasterPubkey {
&self.0.user_id &self.0.user_id
} }
/// Get the keys map of containing the master keys.
pub fn keys(&self) -> &BTreeMap<String, String> {
&self.0.keys
}
/// Get the list of `KeyUsage` that is set for this key.
pub fn usage(&self) -> &[KeyUsage] {
&self.0.usage
}
/// Get the signatures map of this cross signing key.
pub fn signatures(&self) -> &BTreeMap<UserId, BTreeMap<String, String>> {
&self.0.signatures
}
/// Get the master key with the given key id. /// Get the master key with the given key id.
/// ///
/// # Arguments /// # Arguments
@ -167,12 +231,26 @@ impl MasterPubkey {
} }
} }
impl<'a> IntoIterator for &'a MasterPubkey {
type Item = (&'a String, &'a String);
type IntoIter = Iter<'a, String, String>;
fn into_iter(self) -> Self::IntoIter {
self.keys().iter()
}
}
impl UserSigningPubkey { impl UserSigningPubkey {
/// Get the user id of the user signing key's owner. /// Get the user id of the user signing key's owner.
pub fn user_id(&self) -> &UserId { pub fn user_id(&self) -> &UserId {
&self.0.user_id &self.0.user_id
} }
/// Get the keys map of containing the user signing keys.
pub fn keys(&self) -> &BTreeMap<String, String> {
&self.0.keys
}
/// Check if the given master key is signed by this user signing key. /// Check if the given master key is signed by this user signing key.
/// ///
/// # Arguments /// # Arguments
@ -202,12 +280,26 @@ impl UserSigningPubkey {
} }
} }
impl<'a> IntoIterator for &'a UserSigningPubkey {
type Item = (&'a String, &'a String);
type IntoIter = Iter<'a, String, String>;
fn into_iter(self) -> Self::IntoIter {
self.keys().iter()
}
}
impl SelfSigningPubkey { impl SelfSigningPubkey {
/// Get the user id of the self signing key's owner. /// Get the user id of the self signing key's owner.
pub fn user_id(&self) -> &UserId { pub fn user_id(&self) -> &UserId {
&self.0.user_id &self.0.user_id
} }
/// Get the keys map of containing the self signing keys.
pub fn keys(&self) -> &BTreeMap<String, String> {
&self.0.keys
}
/// Check if the given device is signed by this self signing key. /// Check if the given device is signed by this self signing key.
/// ///
/// # Arguments /// # Arguments
@ -236,6 +328,15 @@ impl SelfSigningPubkey {
} }
} }
impl<'a> IntoIterator for &'a SelfSigningPubkey {
type Item = (&'a String, &'a String);
type IntoIter = Iter<'a, String, String>;
fn into_iter(self) -> Self::IntoIter {
self.keys().iter()
}
}
/// Enum over the different user identity types we can have. /// Enum over the different user identity types we can have.
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub enum UserIdentities { pub enum UserIdentities {
@ -245,6 +346,18 @@ pub enum UserIdentities {
Other(UserIdentity), Other(UserIdentity),
} }
impl From<OwnUserIdentity> for UserIdentities {
fn from(identity: OwnUserIdentity) -> Self {
UserIdentities::Own(identity)
}
}
impl From<UserIdentity> for UserIdentities {
fn from(identity: UserIdentity) -> Self {
UserIdentities::Other(identity)
}
}
impl UserIdentities { impl UserIdentities {
/// The unique user id of this identity. /// The unique user id of this identity.
pub fn user_id(&self) -> &UserId { pub fn user_id(&self) -> &UserId {
@ -262,6 +375,23 @@ impl UserIdentities {
} }
} }
/// Get the self-signing key of the identity.
pub fn self_signing_key(&self) -> &SelfSigningPubkey {
match self {
UserIdentities::Own(i) => &i.self_signing_key,
UserIdentities::Other(i) => &i.self_signing_key,
}
}
/// Get the user-signing key of the identity, this is only present for our
/// own user identity..
pub fn user_signing_key(&self) -> Option<&UserSigningPubkey> {
match self {
UserIdentities::Own(i) => Some(&i.user_signing_key),
UserIdentities::Other(_) => None,
}
}
/// Destructure the enum into an `OwnUserIdentity` if it's of the correct /// Destructure the enum into an `OwnUserIdentity` if it's of the correct
/// type. /// type.
pub fn own(&self) -> Option<&OwnUserIdentity> { pub fn own(&self) -> Option<&OwnUserIdentity> {
@ -324,6 +454,11 @@ impl UserIdentity {
&self.master_key &self.master_key
} }
/// Get the public self-signing key of the identity.
pub fn self_signing_key(&self) -> &SelfSigningPubkey {
&self.self_signing_key
}
/// Update the identity with a new master key and self signing key. /// Update the identity with a new master key and self signing key.
/// ///
/// # Arguments /// # Arguments
@ -424,6 +559,16 @@ impl OwnUserIdentity {
&self.master_key &self.master_key
} }
/// Get the public self-signing key of the identity.
pub fn self_signing_key(&self) -> &SelfSigningPubkey {
&self.self_signing_key
}
/// Get the public user-signing key of the identity.
pub fn user_signing_key(&self) -> &UserSigningPubkey {
&self.user_signing_key
}
/// Check if the given identity has been signed by this identity. /// Check if the given identity has been signed by this identity.
/// ///
/// # Arguments /// # Arguments
@ -504,7 +649,7 @@ impl OwnUserIdentity {
} }
#[cfg(test)] #[cfg(test)]
mod test { pub(crate) mod test {
use serde_json::json; use serde_json::json;
use std::{convert::TryFrom, sync::Arc}; use std::{convert::TryFrom, sync::Arc};
@ -697,6 +842,20 @@ mod test {
OwnUserIdentity::new(master_key.into(), self_signing.into(), user_signing.into()).unwrap() OwnUserIdentity::new(master_key.into(), self_signing.into(), user_signing.into()).unwrap()
} }
pub(crate) fn get_own_identity() -> OwnUserIdentity {
own_identity(&own_key_query())
}
pub(crate) fn get_other_identity() -> UserIdentity {
let user_id = user_id!("@example2:localhost");
let response = other_key_query();
let master_key = response.master_keys.get(&user_id).unwrap();
let self_signing = response.self_signing_keys.get(&user_id).unwrap();
UserIdentity::new(master_key.into(), self_signing.into()).unwrap()
}
#[test] #[test]
fn own_identity_create() { fn own_identity_create() {
let user_id = user_id!("@example:localhost"); let user_id = user_id!("@example:localhost");
@ -711,19 +870,13 @@ mod test {
#[test] #[test]
fn other_identity_create() { fn other_identity_create() {
let user_id = user_id!("@example2:localhost"); get_other_identity();
let response = other_key_query();
let master_key = response.master_keys.get(&user_id).unwrap();
let self_signing = response.self_signing_keys.get(&user_id).unwrap();
UserIdentity::new(master_key.into(), self_signing.into()).unwrap();
} }
#[test] #[test]
fn own_identity_check_signatures() { fn own_identity_check_signatures() {
let response = own_key_query(); let response = own_key_query();
let identity = own_identity(&response); let identity = get_own_identity();
let (first, second) = device(&response); let (first, second) = device(&response);
assert!(identity.is_device_signed(&first).is_err()); assert!(identity.is_device_signed(&first).is_err());

View File

@ -0,0 +1,295 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::io::{Cursor, Read, Seek, SeekFrom};
use base64::{decode_config, encode_config, DecodeError, STANDARD_NO_PAD};
use byteorder::{BigEndian, ReadBytesExt};
use getrandom::getrandom;
use aes_ctr::{
stream_cipher::{NewStreamCipher, SyncStreamCipher},
Aes256Ctr,
};
use hmac::{Hmac, Mac, NewMac};
use pbkdf2::pbkdf2;
use sha2::{Sha256, Sha512};
use crate::olm::ExportedRoomKey;
const SALT_SIZE: usize = 16;
const IV_SIZE: usize = 16;
const MAC_SIZE: usize = 32;
const KEY_SIZE: usize = 32;
const VERSION: u8 = 1;
const HEADER: &str = "-----BEGIN MEGOLM SESSION DATA-----";
const FOOTER: &str = "-----END MEGOLM SESSION DATA-----";
fn decode(input: impl AsRef<[u8]>) -> Result<Vec<u8>, DecodeError> {
decode_config(input, STANDARD_NO_PAD)
}
fn encode(input: impl AsRef<[u8]>) -> String {
encode_config(input, STANDARD_NO_PAD)
}
/// Try to decrypt a reader into a list of exported room keys.
///
/// # Arguments
///
/// * `passphrase` - The passphrase that was used to encrypt the exported keys.
///
/// # Examples
/// ```no_run
/// # use std::io::Cursor;
/// # use matrix_sdk_crypto::{OlmMachine, decrypt_key_export};
/// # use matrix_sdk_common::identifiers::user_id;
/// # use futures::executor::block_on;
/// # let alice = user_id!("@alice:example.org");
/// # let machine = OlmMachine::new(&alice, "DEVICEID".into());
/// # block_on(async {
/// # let export = Cursor::new("".to_owned());
/// let exported_keys = decrypt_key_export(export, "1234").unwrap();
/// machine.import_keys(exported_keys).await.unwrap();
/// # });
/// ```
pub fn decrypt_key_export(
mut input: impl Read,
passphrase: &str,
) -> Result<Vec<ExportedRoomKey>, DecodeError> {
let mut x: String = String::new();
input.read_to_string(&mut x).expect("Can't read string");
if !(x.trim_start().starts_with(HEADER) && x.trim_end().ends_with(FOOTER)) {
panic!("Invalid header/footer");
}
let payload: String = x
.lines()
.filter(|l| !(l.starts_with(HEADER) || l.starts_with(FOOTER)))
.collect();
Ok(serde_json::from_str(&decrypt_helper(&payload, passphrase)?).unwrap())
}
/// Encrypt the list of exported room keys using the given passphrase.
///
/// # Arguments
///
/// * `keys` - A list of sessions that should be encrypted.
///
/// * `passphrase` - The passphrase that will be used to encrypt the exported
/// room keys.
///
/// * `rounds` - The number of rounds that should be used for the key
/// derivation when the passphrase gets turned into an AES key. More rounds are
/// increasingly computationally intensive and as such help against bruteforce
/// attacks. Should be at least `10000`, while values in the `100000` ranges
/// should be preferred.
///
/// # Examples
/// ```no_run
/// # use matrix_sdk_crypto::{OlmMachine, encrypt_key_export};
/// # use matrix_sdk_common::identifiers::{user_id, room_id};
/// # use futures::executor::block_on;
/// # let alice = user_id!("@alice:example.org");
/// # let machine = OlmMachine::new(&alice, "DEVICEID".into());
/// # block_on(async {
/// let room_id = room_id!("!test:localhost");
/// let exported_keys = machine.export_keys(|s| s.room_id() == &room_id).await.unwrap();
/// let encrypted_export = encrypt_key_export(&exported_keys, "1234", 1);
/// # });
/// ```
pub fn encrypt_key_export(keys: &[ExportedRoomKey], passphrase: &str, rounds: u32) -> String {
let mut plaintext = serde_json::to_string(keys).unwrap().into_bytes();
let ciphertext = encrypt_helper(&mut plaintext, passphrase, rounds);
[HEADER.to_owned(), ciphertext, FOOTER.to_owned()].join("\n")
}
fn encrypt_helper(mut plaintext: &mut [u8], passphrase: &str, rounds: u32) -> String {
let mut salt = [0u8; SALT_SIZE];
let mut iv = [0u8; IV_SIZE];
let mut derived_keys = [0u8; KEY_SIZE * 2];
getrandom(&mut salt).expect("Can't generate randomness");
getrandom(&mut iv).expect("Can't generate randomness");
let mut iv = u128::from_be_bytes(iv);
iv &= !(1 << 63);
pbkdf2::<Hmac<Sha512>>(passphrase.as_bytes(), &salt, rounds, &mut derived_keys);
let (key, hmac_key) = derived_keys.split_at(KEY_SIZE);
let mut aes = Aes256Ctr::new_var(&key, &iv.to_be_bytes()).expect("Can't create AES");
aes.apply_keystream(&mut plaintext);
let mut payload: Vec<u8> = vec![];
payload.extend(&VERSION.to_be_bytes());
payload.extend(&salt);
payload.extend(&iv.to_be_bytes());
payload.extend(&rounds.to_be_bytes());
payload.extend_from_slice(&plaintext);
let mut hmac = Hmac::<Sha256>::new_varkey(hmac_key).unwrap();
hmac.update(&payload);
let mac = hmac.finalize();
payload.extend(mac.into_bytes());
encode(payload)
}
fn decrypt_helper(ciphertext: &str, passphrase: &str) -> Result<String, DecodeError> {
let decoded = decode(ciphertext)?;
let mut decoded = Cursor::new(decoded);
let mut salt = [0u8; SALT_SIZE];
let mut iv = [0u8; IV_SIZE];
let mut mac = [0u8; MAC_SIZE];
let mut derived_keys = [0u8; KEY_SIZE * 2];
let version = decoded.read_u8().unwrap();
decoded.read_exact(&mut salt).unwrap();
decoded.read_exact(&mut iv).unwrap();
let rounds = decoded.read_u32::<BigEndian>().unwrap();
let ciphertext_start = decoded.position() as usize;
decoded.seek(SeekFrom::End(-32)).unwrap();
let ciphertext_end = decoded.position() as usize;
decoded.read_exact(&mut mac).unwrap();
let mut decoded = decoded.into_inner();
if version != VERSION {
panic!("Unsupported version")
}
pbkdf2::<Hmac<Sha512>>(passphrase.as_bytes(), &salt, rounds, &mut derived_keys);
let (key, hmac_key) = derived_keys.split_at(KEY_SIZE);
let mut hmac = Hmac::<Sha256>::new_varkey(hmac_key).unwrap();
hmac.update(&decoded[0..ciphertext_end]);
hmac.verify(&mac).expect("MAC DOESN'T MATCH");
let mut ciphertext = &mut decoded[ciphertext_start..ciphertext_end];
let mut aes = Aes256Ctr::new_var(&key, &iv).expect("Can't create AES");
aes.apply_keystream(&mut ciphertext);
Ok(String::from_utf8(ciphertext.to_owned()).expect("Invalid utf-8"))
}
#[cfg(test)]
mod test {
use indoc::indoc;
use proptest::prelude::*;
use std::io::Cursor;
use matrix_sdk_common::identifiers::room_id;
use matrix_sdk_test::async_test;
use super::{decode, decrypt_helper, decrypt_key_export, encrypt_helper, encrypt_key_export};
use crate::machine::test::get_prepared_machine;
const PASSPHRASE: &str = "1234";
const TEST_EXPORT: &str = indoc! {"
-----BEGIN MEGOLM SESSION DATA-----
Af7mGhlzQ+eGvHu93u0YXd3D/+vYMs3E7gQqOhuCtkvGAAAAASH7pEdWvFyAP1JUisAcpEo
Xke2Q7Kr9hVl/SCc6jXBNeJCZcrUbUV4D/tRQIl3E9L4fOk928YI1J+3z96qiH0uE7hpsCI
CkHKwjPU+0XTzFdIk1X8H7sZ+MD/2Sg/q3y8rtUjz7uEj4GUTnb+9SCOTVmJsRfqgUpM1CU
bDLytHf1JkohY4tWEgpsCc67xdzgodjr12qYrfg/zNm3LGpxlrffJknw4rk5QFTj4kMbqbD
ZZgDTni+HxRTDGge2J620lMOiznvXX+H09Rwruqx5aJvvaaKd86jWRpiO2oSFqHn4u5ONl9
41uzm62Sj0eIm6ZbA9NQs87jQw4LxsejhZVL+NdjIg80zVSBTWhTdo0DTnbFSNP4ReOiz0U
XosOF8A5T8Vdx2nvA0GXltfcHKVKQYh/LJAkNQ7P9UYL4ae/5TtQZkhB1KxCLTRWqADCl53
uBMGpG53EMgY6G6K2DEIOkcv7sdXQF5WpemiSWZqJRWj+cjfs9BpCTbkp/rszWFl2TniWpR
RqIbT2jORlN4rTvdtF0F4z1pqP4qWyR3sLNTkXm9CFRzWADNG0RDZKxbCoo6RPvtaCTfaHo
SwfvzBS6CjfAG+FOugpV48o7+XetaUUPZ6/tZSPhCdeV8eP9q5r0QwWeXFogzoNzWt4HYx9
MdXxzD+f0mtg5gzehrrEEARwI2bCvPpHxlt/Na9oW/GBpkjwR1LSKgg4CtpRyWngPjdEKpZ
GYW19pdjg0qdXNk/eqZsQTsNWVo6A
-----END MEGOLM SESSION DATA-----
"};
fn export_wihtout_headers() -> String {
TEST_EXPORT
.lines()
.filter(|l| !l.starts_with("-----"))
.collect()
}
#[test]
fn test_decode() {
let export = export_wihtout_headers();
assert!(decode(export).is_ok());
}
proptest! {
#[test]
fn proptest_encrypt_cycle(plaintext in prop::string::string_regex(".*").unwrap()) {
let mut plaintext_bytes = plaintext.clone().into_bytes();
let ciphertext = encrypt_helper(&mut plaintext_bytes, "test", 1);
let decrypted = decrypt_helper(&ciphertext, "test").unwrap();
prop_assert!(plaintext == decrypted);
}
}
#[test]
fn test_encrypt_decrypt() {
let data = "It's a secret to everybody";
let mut bytes = data.to_owned().into_bytes();
let encrypted = encrypt_helper(&mut bytes, PASSPHRASE, 10);
let decrypted = decrypt_helper(&encrypted, PASSPHRASE).unwrap();
assert_eq!(data, decrypted);
}
#[async_test]
async fn test_session_encrypt() {
let (machine, _) = get_prepared_machine().await;
let room_id = room_id!("!test:localhost");
machine
.create_outnbound_group_session_with_defaults(&room_id)
.await
.unwrap();
let export = machine
.export_keys(|s| s.room_id() == &room_id)
.await
.unwrap();
assert!(!export.is_empty());
let encrypted = encrypt_key_export(&export, "1234", 1);
let decrypted = decrypt_key_export(Cursor::new(encrypted), "1234").unwrap();
assert_eq!(export, decrypted);
assert_eq!(machine.import_keys(decrypted).await.unwrap(), 0);
}
#[test]
fn test_real_decrypt() {
let reader = Cursor::new(TEST_EXPORT);
let imported = decrypt_key_export(reader, PASSPHRASE).expect("Can't decrypt key export");
assert!(!imported.is_empty())
}
}

View File

@ -29,6 +29,8 @@
mod error; mod error;
mod identities; mod identities;
#[allow(dead_code)]
mod key_export;
mod machine; mod machine;
pub mod olm; pub mod olm;
mod requests; mod requests;
@ -39,6 +41,7 @@ pub use error::{MegolmError, OlmError};
pub use identities::{ pub use identities::{
Device, LocalTrust, OwnUserIdentity, ReadOnlyDevice, UserDevices, UserIdentities, UserIdentity, Device, LocalTrust, OwnUserIdentity, ReadOnlyDevice, UserDevices, UserIdentities, UserIdentity,
}; };
pub use key_export::{decrypt_key_export, encrypt_key_export};
pub use machine::OlmMachine; pub use machine::OlmMachine;
pub(crate) use olm::Account; pub(crate) use olm::Account;
pub use olm::EncryptionSettings; pub use olm::EncryptionSettings;

View File

@ -57,8 +57,8 @@ use super::{
UserIdentities, UserIdentity, UserSigningPubkey, UserIdentities, UserIdentity, UserSigningPubkey,
}, },
olm::{ olm::{
Account, EncryptionSettings, GroupSessionKey, IdentityKeys, InboundGroupSession, Account, EncryptionSettings, ExportedRoomKey, GroupSessionKey, IdentityKeys,
OlmMessage, OutboundGroupSession, InboundGroupSession, OlmMessage, OutboundGroupSession,
}, },
requests::{IncomingResponse, OutgoingRequest, ToDeviceRequest}, requests::{IncomingResponse, OutgoingRequest, ToDeviceRequest},
store::{CryptoStore, MemoryStore, Result as StoreResult}, store::{CryptoStore, MemoryStore, Result as StoreResult},
@ -984,7 +984,7 @@ impl OlmMachine {
&event.content.room_id, &event.content.room_id,
session_key, session_key,
)?; )?;
let _ = self.store.save_inbound_group_session(session).await?; let _ = self.store.save_inbound_group_sessions(&[session]).await?;
let event = Raw::from(AnyToDeviceEvent::RoomKey(event.clone())); let event = Raw::from(AnyToDeviceEvent::RoomKey(event.clone()));
Ok(Some(event)) Ok(Some(event))
@ -1014,7 +1014,7 @@ impl OlmMachine {
.await .await
.map_err(|_| EventError::UnsupportedAlgorithm)?; .map_err(|_| EventError::UnsupportedAlgorithm)?;
let _ = self.store.save_inbound_group_session(inbound).await?; let _ = self.store.save_inbound_group_sessions(&[inbound]).await?;
let _ = self let _ = self
.outbound_group_sessions .outbound_group_sessions
@ -1023,7 +1023,7 @@ impl OlmMachine {
} }
#[cfg(test)] #[cfg(test)]
async fn create_outnbound_group_session_with_defaults( pub(crate) async fn create_outnbound_group_session_with_defaults(
&self, &self,
room_id: &RoomId, room_id: &RoomId,
) -> OlmResult<()> { ) -> OlmResult<()> {
@ -1529,6 +1529,105 @@ impl OlmMachine {
device_owner_identity, device_owner_identity,
}) })
} }
/// Import the given room keys into our store.
///
/// # Arguments
///
/// * `exported_keys` - A list of previously exported keys that should be
/// imported into our store. If we already have a better version of a key
/// the key will *not* be imported.
///
/// Returns the number of sessions that were imported to the store.
///
/// # Examples
/// ```no_run
/// # use std::io::Cursor;
/// # use matrix_sdk_crypto::{OlmMachine, decrypt_key_export};
/// # use matrix_sdk_common::identifiers::user_id;
/// # use futures::executor::block_on;
/// # let alice = user_id!("@alice:example.org");
/// # let machine = OlmMachine::new(&alice, "DEVICEID".into());
/// # block_on(async {
/// # let export = Cursor::new("".to_owned());
/// let exported_keys = decrypt_key_export(export, "1234").unwrap();
/// machine.import_keys(exported_keys).await.unwrap();
/// # });
/// ```
pub async fn import_keys(&self, mut exported_keys: Vec<ExportedRoomKey>) -> StoreResult<usize> {
let mut sessions = Vec::new();
for key in exported_keys.drain(..) {
let session = InboundGroupSession::from_export(key)?;
// Only import the session if we didn't have this session or if it's
// a better version of the same session, that is the first known
// index is lower.
if let Some(existing_session) = self
.store
.get_inbound_group_session(
&session.room_id,
&session.sender_key,
session.session_id(),
)
.await?
{
let first_index = session.first_known_index().await;
let existing_index = existing_session.first_known_index().await;
if first_index < existing_index {
sessions.push(session)
}
} else {
sessions.push(session)
}
}
let num_sessions = sessions.len();
self.store.save_inbound_group_sessions(&sessions).await?;
Ok(num_sessions)
}
/// Export the keys that match the given predicate.
///
///
/// # Examples
///
/// ```no_run
/// # use matrix_sdk_crypto::{OlmMachine, encrypt_key_export};
/// # use matrix_sdk_common::identifiers::{user_id, room_id};
/// # use futures::executor::block_on;
/// # let alice = user_id!("@alice:example.org");
/// # let machine = OlmMachine::new(&alice, "DEVICEID".into());
/// # block_on(async {
/// let room_id = room_id!("!test:localhost");
/// let exported_keys = machine.export_keys(|s| s.room_id() == &room_id).await.unwrap();
/// let encrypted_export = encrypt_key_export(&exported_keys, "1234", 1);
/// # });
/// ```
pub async fn export_keys(
&self,
mut predicate: impl FnMut(&InboundGroupSession) -> bool,
) -> StoreResult<Vec<ExportedRoomKey>> {
let mut exported = Vec::new();
let mut sessions: Vec<InboundGroupSession> = self
.store
.get_inbound_group_sessions()
.await?
.drain(..)
.filter(|s| predicate(&s))
.collect();
for session in sessions.drain(..) {
let export = session.export().await;
exported.push(export);
}
Ok(exported)
}
} }
#[cfg(test)] #[cfg(test)]
@ -1623,7 +1722,7 @@ pub(crate) mod test {
content.deserialize().unwrap() content.deserialize().unwrap()
} }
async fn get_prepared_machine() -> (OlmMachine, OneTimeKeys) { pub(crate) async fn get_prepared_machine() -> (OlmMachine, OneTimeKeys) {
let machine = OlmMachine::new(&user_id(), &alice_device_id()); let machine = OlmMachine::new(&user_id(), &alice_device_id());
machine.account.update_uploaded_key_count(0); machine.account.update_uploaded_key_count(0);
let request = machine let request = machine

View File

@ -1,561 +0,0 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{
cmp::min,
convert::TryInto,
fmt,
sync::{
atomic::{AtomicBool, AtomicU64, Ordering},
Arc,
},
time::Duration,
};
use matrix_sdk_common::{
events::{
room::{encrypted::EncryptedEventContent, encryption::EncryptionEventContent},
AnyMessageEventContent, AnySyncRoomEvent, EventContent, SyncMessageEvent,
},
identifiers::{DeviceId, EventEncryptionAlgorithm, RoomId},
instant::Instant,
locks::Mutex,
Raw,
};
use olm_rs::{
errors::OlmGroupSessionError, inbound_group_session::OlmInboundGroupSession,
outbound_group_session::OlmOutboundGroupSession, PicklingMode,
};
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use zeroize::Zeroize;
pub use olm_rs::{
account::IdentityKeys,
session::{OlmMessage, PreKeyMessage},
utility::OlmUtility,
};
use crate::error::{EventError, MegolmResult};
const ROTATION_PERIOD: Duration = Duration::from_millis(604800000);
const ROTATION_MESSAGES: u64 = 100;
/// Settings for an encrypted room.
///
/// This determines the algorithm and rotation periods of a group session.
#[derive(Debug)]
pub struct EncryptionSettings {
/// The encryption algorithm that should be used in the room.
pub algorithm: EventEncryptionAlgorithm,
/// How long the session should be used before changing it.
pub rotation_period: Duration,
/// How many messages should be sent before changing the session.
pub rotation_period_msgs: u64,
}
impl Default for EncryptionSettings {
fn default() -> Self {
Self {
algorithm: EventEncryptionAlgorithm::MegolmV1AesSha2,
rotation_period: ROTATION_PERIOD,
rotation_period_msgs: ROTATION_MESSAGES,
}
}
}
impl From<&EncryptionEventContent> for EncryptionSettings {
fn from(content: &EncryptionEventContent) -> Self {
let rotation_period: Duration = content
.rotation_period_ms
.map_or(ROTATION_PERIOD, |r| Duration::from_millis(r.into()));
let rotation_period_msgs: u64 = content
.rotation_period_msgs
.map_or(ROTATION_MESSAGES, Into::into);
Self {
algorithm: content.algorithm.clone(),
rotation_period,
rotation_period_msgs,
}
}
}
/// The private session key of a group session.
/// Can be used to create a new inbound group session.
#[derive(Clone, Debug, Serialize, Zeroize)]
#[zeroize(drop)]
pub struct GroupSessionKey(pub String);
/// Inbound group session.
///
/// Inbound group sessions are used to exchange room messages between a group of
/// participants. Inbound group sessions are used to decrypt the room messages.
#[derive(Clone)]
pub struct InboundGroupSession {
inner: Arc<Mutex<OlmInboundGroupSession>>,
session_id: Arc<String>,
pub(crate) sender_key: Arc<String>,
pub(crate) signing_key: Arc<String>,
pub(crate) room_id: Arc<RoomId>,
forwarding_chains: Arc<Mutex<Option<Vec<String>>>>,
}
impl InboundGroupSession {
/// Create a new inbound group session for the given room.
///
/// These sessions are used to decrypt room messages.
///
/// # Arguments
///
/// * `sender_key` - The public curve25519 key of the account that
/// sent us the session
///
/// * `signing_key` - The public ed25519 key of the account that
/// sent us the session.
///
/// * `room_id` - The id of the room that the session is used in.
///
/// * `session_key` - The private session key that is used to decrypt
/// messages.
pub fn new(
sender_key: &str,
signing_key: &str,
room_id: &RoomId,
session_key: GroupSessionKey,
) -> Result<Self, OlmGroupSessionError> {
let session = OlmInboundGroupSession::new(&session_key.0)?;
let session_id = session.session_id();
Ok(InboundGroupSession {
inner: Arc::new(Mutex::new(session)),
session_id: Arc::new(session_id),
sender_key: Arc::new(sender_key.to_owned()),
signing_key: Arc::new(signing_key.to_owned()),
room_id: Arc::new(room_id.clone()),
forwarding_chains: Arc::new(Mutex::new(None)),
})
}
/// Store the group session as a base64 encoded string.
///
/// # Arguments
///
/// * `pickle_mode` - The mode that was used to pickle the group session,
/// either an unencrypted mode or an encrypted using passphrase.
pub async fn pickle(&self, pickle_mode: PicklingMode) -> PickledInboundGroupSession {
let pickle = self.inner.lock().await.pickle(pickle_mode);
PickledInboundGroupSession {
pickle: InboundGroupSessionPickle::from(pickle),
sender_key: self.sender_key.to_string(),
signing_key: self.signing_key.to_string(),
room_id: (&*self.room_id).clone(),
forwarding_chains: self.forwarding_chains.lock().await.clone(),
}
}
/// Restore a Session from a previously pickled string.
///
/// Returns the restored group session or a `OlmGroupSessionError` if there
/// was an error.
///
/// # Arguments
///
/// * `pickle` - The pickled version of the `InboundGroupSession`.
///
/// * `pickle_mode` - The mode that was used to pickle the session, either
/// an unencrypted mode or an encrypted using passphrase.
pub fn from_pickle(
pickle: PickledInboundGroupSession,
pickle_mode: PicklingMode,
) -> Result<Self, OlmGroupSessionError> {
let session = OlmInboundGroupSession::unpickle(pickle.pickle.0, pickle_mode)?;
let session_id = session.session_id();
Ok(InboundGroupSession {
inner: Arc::new(Mutex::new(session)),
session_id: Arc::new(session_id),
sender_key: Arc::new(pickle.sender_key),
signing_key: Arc::new(pickle.signing_key),
room_id: Arc::new(pickle.room_id),
forwarding_chains: Arc::new(Mutex::new(pickle.forwarding_chains)),
})
}
/// Returns the unique identifier for this session.
pub fn session_id(&self) -> &str {
&self.session_id
}
/// Get the first message index we know how to decrypt.
pub async fn first_known_index(&self) -> u32 {
self.inner.lock().await.first_known_index()
}
/// Decrypt the given ciphertext.
///
/// Returns the decrypted plaintext or an `OlmGroupSessionError` if
/// decryption failed.
///
/// # Arguments
///
/// * `message` - The message that should be decrypted.
pub async fn decrypt_helper(
&self,
message: String,
) -> Result<(String, u32), OlmGroupSessionError> {
self.inner.lock().await.decrypt(message)
}
/// Decrypt an event from a room timeline.
///
/// # Arguments
///
/// * `event` - The event that should be decrypted.
pub async fn decrypt(
&self,
event: &SyncMessageEvent<EncryptedEventContent>,
) -> MegolmResult<(Raw<AnySyncRoomEvent>, u32)> {
let content = match &event.content {
EncryptedEventContent::MegolmV1AesSha2(c) => c,
_ => return Err(EventError::UnsupportedAlgorithm.into()),
};
let (plaintext, message_index) = self.decrypt_helper(content.ciphertext.clone()).await?;
let mut decrypted_value = serde_json::from_str::<Value>(&plaintext)?;
let decrypted_object = decrypted_value
.as_object_mut()
.ok_or(EventError::NotAnObject)?;
// TODO better number conversion here.
let server_ts = event
.origin_server_ts
.duration_since(std::time::SystemTime::UNIX_EPOCH)
.unwrap_or_default()
.as_millis();
let server_ts: i64 = server_ts.try_into().unwrap_or_default();
decrypted_object.insert("sender".to_owned(), event.sender.to_string().into());
decrypted_object.insert("event_id".to_owned(), event.event_id.to_string().into());
decrypted_object.insert("origin_server_ts".to_owned(), server_ts.into());
decrypted_object.insert(
"unsigned".to_owned(),
serde_json::to_value(&event.unsigned).unwrap_or_default(),
);
Ok((
serde_json::from_value::<Raw<AnySyncRoomEvent>>(decrypted_value)?,
message_index,
))
}
}
#[cfg(not(tarpaulin_include))]
impl fmt::Debug for InboundGroupSession {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("InboundGroupSession")
.field("session_id", &self.session_id())
.finish()
}
}
impl PartialEq for InboundGroupSession {
fn eq(&self, other: &Self) -> bool {
self.session_id() == other.session_id()
}
}
/// A pickled version of an `InboundGroupSession`.
///
/// Holds all the information that needs to be stored in a database to restore
/// an InboundGroupSession.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PickledInboundGroupSession {
/// The pickle string holding the InboundGroupSession.
pub pickle: InboundGroupSessionPickle,
/// The public curve25519 key of the account that sent us the session
pub sender_key: String,
/// The public ed25519 key of the account that sent us the session.
pub signing_key: String,
/// The id of the room that the session is used in.
pub room_id: RoomId,
/// The list of claimed ed25519 that forwarded us this key. Will be None if
/// we dirrectly received this session.
pub forwarding_chains: Option<Vec<String>>,
}
/// The typed representation of a base64 encoded string of the GroupSession pickle.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InboundGroupSessionPickle(String);
impl From<String> for InboundGroupSessionPickle {
fn from(pickle_string: String) -> Self {
InboundGroupSessionPickle(pickle_string)
}
}
impl InboundGroupSessionPickle {
/// Get the string representation of the pickle.
pub fn as_str(&self) -> &str {
&self.0
}
}
/// Outbound group session.
///
/// Outbound group sessions are used to exchange room messages between a group
/// of participants. Outbound group sessions are used to encrypt the room
/// messages.
#[derive(Clone)]
pub struct OutboundGroupSession {
inner: Arc<Mutex<OlmOutboundGroupSession>>,
device_id: Arc<Box<DeviceId>>,
account_identity_keys: Arc<IdentityKeys>,
session_id: Arc<String>,
room_id: Arc<RoomId>,
creation_time: Arc<Instant>,
message_count: Arc<AtomicU64>,
shared: Arc<AtomicBool>,
settings: Arc<EncryptionSettings>,
}
impl OutboundGroupSession {
/// Create a new outbound group session for the given room.
///
/// Outbound group sessions are used to encrypt room messages.
///
/// # Arguments
///
/// * `device_id` - The id of the device that created this session.
///
/// * `identity_keys` - The identity keys of the account that created this
/// session.
///
/// * `room_id` - The id of the room that the session is used in.
///
/// * `settings` - Settings determining the algorithm and rotation period of
/// the outbound group session.
pub fn new(
device_id: Arc<Box<DeviceId>>,
identity_keys: Arc<IdentityKeys>,
room_id: &RoomId,
settings: EncryptionSettings,
) -> Self {
let session = OlmOutboundGroupSession::new();
let session_id = session.session_id();
OutboundGroupSession {
inner: Arc::new(Mutex::new(session)),
room_id: Arc::new(room_id.to_owned()),
device_id,
account_identity_keys: identity_keys,
session_id: Arc::new(session_id),
creation_time: Arc::new(Instant::now()),
message_count: Arc::new(AtomicU64::new(0)),
shared: Arc::new(AtomicBool::new(false)),
settings: Arc::new(settings),
}
}
/// Encrypt the given plaintext using this session.
///
/// Returns the encrypted ciphertext.
///
/// # Arguments
///
/// * `plaintext` - The plaintext that should be encrypted.
pub(crate) async fn encrypt_helper(&self, plaintext: String) -> String {
let session = self.inner.lock().await;
self.message_count.fetch_add(1, Ordering::SeqCst);
session.encrypt(plaintext)
}
/// Encrypt a room message for the given room.
///
/// Beware that a group session needs to be shared before this method can be
/// called using the `share_group_session()` method.
///
/// Since group sessions can expire or become invalid if the room membership
/// changes client authors should check with the
/// `should_share_group_session()` method if a new group session needs to
/// be shared.
///
/// # Arguments
///
/// * `content` - The plaintext content of the message that should be
/// encrypted.
///
/// # Panics
///
/// Panics if the content can't be serialized.
pub async fn encrypt(&self, content: AnyMessageEventContent) -> EncryptedEventContent {
let json_content = json!({
"content": content,
"room_id": &*self.room_id,
"type": content.event_type(),
});
let plaintext = cjson::to_string(&json_content).unwrap_or_else(|_| {
panic!(format!(
"Can't serialize {} to canonical JSON",
json_content
))
});
let ciphertext = self.encrypt_helper(plaintext).await;
EncryptedEventContent::MegolmV1AesSha2(
matrix_sdk_common::events::room::encrypted::MegolmV1AesSha2ContentInit {
ciphertext,
sender_key: self.account_identity_keys.curve25519().to_owned(),
session_id: self.session_id().to_owned(),
device_id: (&*self.device_id).to_owned(),
}
.into(),
)
}
/// Check if the session has expired and if it should be rotated.
///
/// A session will expire after some time or if enough messages have been
/// encrypted using it.
pub fn expired(&self) -> bool {
let count = self.message_count.load(Ordering::SeqCst);
count >= self.settings.rotation_period_msgs
|| self.creation_time.elapsed()
// Since the encryption settings are provided by users and not
// checked someone could set a really low rotation perdiod so
// clamp it at a minute.
>= min(self.settings.rotation_period, Duration::from_secs(3600))
}
/// Mark the session as shared.
///
/// Messages shouldn't be encrypted with the session before it has been
/// shared.
pub fn mark_as_shared(&self) {
self.shared.store(true, Ordering::Relaxed);
}
/// Check if the session has been marked as shared.
pub fn shared(&self) -> bool {
self.shared.load(Ordering::Relaxed)
}
/// Get the session key of this session.
///
/// A session key can be used to to create an `InboundGroupSession`.
pub async fn session_key(&self) -> GroupSessionKey {
let session = self.inner.lock().await;
GroupSessionKey(session.session_key())
}
/// Returns the unique identifier for this session.
pub fn session_id(&self) -> &str {
&self.session_id
}
/// Get the current message index for this session.
///
/// Each message is sent with an increasing index. This returns the
/// message index that will be used for the next encrypted message.
pub async fn message_index(&self) -> u32 {
let session = self.inner.lock().await;
session.session_message_index()
}
/// Get the outbound group session key as a json value that can be sent as a
/// m.room_key.
pub async fn as_json(&self) -> Value {
json!({
"algorithm": EventEncryptionAlgorithm::MegolmV1AesSha2,
"room_id": &*self.room_id,
"session_id": &*self.session_id,
"session_key": self.session_key().await,
"chain_index": self.message_index().await,
})
}
}
#[cfg(not(tarpaulin_include))]
impl std::fmt::Debug for OutboundGroupSession {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("OutboundGroupSession")
.field("session_id", &self.session_id)
.field("room_id", &self.room_id)
.field("creation_time", &self.creation_time)
.field("message_count", &self.message_count)
.finish()
}
}
#[cfg(test)]
mod test {
use std::{
sync::Arc,
time::{Duration, Instant},
};
use matrix_sdk_common::{
events::{
room::message::{MessageEventContent, TextMessageEventContent},
AnyMessageEventContent,
},
identifiers::{room_id, user_id},
};
use super::EncryptionSettings;
use crate::Account;
#[tokio::test]
#[cfg(not(target_os = "macos"))]
async fn expiration() {
let settings = EncryptionSettings {
rotation_period_msgs: 1,
..Default::default()
};
let account = Account::new(&user_id!("@alice:example.org"), "DEVICEID".into());
let (session, _) = account
.create_group_session_pair(&room_id!("!test_room:example.org"), settings)
.await
.unwrap();
assert!(!session.expired());
let _ = session
.encrypt(AnyMessageEventContent::RoomMessage(
MessageEventContent::Text(TextMessageEventContent::plain("Test message")),
))
.await;
assert!(session.expired());
let settings = EncryptionSettings {
rotation_period: Duration::from_millis(100),
..Default::default()
};
let (mut session, _) = account
.create_group_session_pair(&room_id!("!test_room:example.org"), settings)
.await
.unwrap();
assert!(!session.expired());
session.creation_time = Arc::new(Instant::now() - Duration::from_secs(60 * 60));
assert!(session.expired());
}
}

View File

@ -0,0 +1,343 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{
collections::BTreeMap,
convert::{TryFrom, TryInto},
fmt,
sync::Arc,
};
use matrix_sdk_common::{
events::{room::encrypted::EncryptedEventContent, AnySyncRoomEvent, SyncMessageEvent},
identifiers::{DeviceKeyAlgorithm, EventEncryptionAlgorithm, RoomId},
locks::Mutex,
Raw,
};
use olm_rs::{
errors::OlmGroupSessionError, inbound_group_session::OlmInboundGroupSession, PicklingMode,
};
use serde::{Deserialize, Serialize};
use serde_json::Value;
pub use olm_rs::{
account::IdentityKeys,
session::{OlmMessage, PreKeyMessage},
utility::OlmUtility,
};
use super::{ExportedGroupSessionKey, ExportedRoomKey, GroupSessionKey};
use crate::error::{EventError, MegolmResult};
/// Inbound group session.
///
/// Inbound group sessions are used to exchange room messages between a group of
/// participants. Inbound group sessions are used to decrypt the room messages.
#[derive(Clone)]
pub struct InboundGroupSession {
inner: Arc<Mutex<OlmInboundGroupSession>>,
session_id: Arc<String>,
pub(crate) sender_key: Arc<String>,
pub(crate) signing_key: Arc<BTreeMap<DeviceKeyAlgorithm, String>>,
pub(crate) room_id: Arc<RoomId>,
forwarding_chains: Arc<Mutex<Option<Vec<String>>>>,
imported: Arc<bool>,
}
impl InboundGroupSession {
/// Create a new inbound group session for the given room.
///
/// These sessions are used to decrypt room messages.
///
/// # Arguments
///
/// * `sender_key` - The public curve25519 key of the account that
/// sent us the session
///
/// * `signing_key` - The public ed25519 key of the account that
/// sent us the session.
///
/// * `room_id` - The id of the room that the session is used in.
///
/// * `session_key` - The private session key that is used to decrypt
/// messages.
pub(crate) fn new(
sender_key: &str,
signing_key: &str,
room_id: &RoomId,
session_key: GroupSessionKey,
) -> Result<Self, OlmGroupSessionError> {
let session = OlmInboundGroupSession::new(&session_key.0)?;
let session_id = session.session_id();
let mut keys: BTreeMap<DeviceKeyAlgorithm, String> = BTreeMap::new();
keys.insert(DeviceKeyAlgorithm::Ed25519, signing_key.to_owned());
Ok(InboundGroupSession {
inner: Arc::new(Mutex::new(session)),
session_id: Arc::new(session_id),
sender_key: Arc::new(sender_key.to_owned()),
signing_key: Arc::new(keys),
room_id: Arc::new(room_id.clone()),
forwarding_chains: Arc::new(Mutex::new(None)),
imported: Arc::new(false),
})
}
/// Create a InboundGroupSession from an exported version of the group
/// session.
///
/// Most notably this can be called with an `ExportedRoomKey` from a
/// previous [`export()`] call.
///
///
/// [`export()`]: #method.export
pub fn from_export(
exported_session: impl Into<ExportedRoomKey>,
) -> Result<Self, OlmGroupSessionError> {
Self::try_from(exported_session.into())
}
/// Store the group session as a base64 encoded string.
///
/// # Arguments
///
/// * `pickle_mode` - The mode that was used to pickle the group session,
/// either an unencrypted mode or an encrypted using passphrase.
pub async fn pickle(&self, pickle_mode: PicklingMode) -> PickledInboundGroupSession {
let pickle = self.inner.lock().await.pickle(pickle_mode);
PickledInboundGroupSession {
pickle: InboundGroupSessionPickle::from(pickle),
sender_key: self.sender_key.to_string(),
signing_key: (&*self.signing_key).clone(),
room_id: (&*self.room_id).clone(),
forwarding_chains: self.forwarding_chains.lock().await.clone(),
imported: *self.imported,
}
}
/// Export this session at the first known message index.
///
/// If only a limited part of this session should be exported use
/// [`export_at_index()`](#method.export_at_index).
pub async fn export(&self) -> ExportedRoomKey {
self.export_at_index(self.first_known_index().await)
.await
.expect("Can't export at the first known index")
}
/// Export this session at the given message index.
pub async fn export_at_index(&self, message_index: u32) -> Option<ExportedRoomKey> {
let session_key =
ExportedGroupSessionKey(self.inner.lock().await.export(message_index).ok()?);
Some(ExportedRoomKey {
algorithm: EventEncryptionAlgorithm::MegolmV1AesSha2,
room_id: (&*self.room_id).clone(),
sender_key: (&*self.sender_key).to_owned(),
session_id: self.session_id().to_owned(),
forwarding_curve25519_key_chain: self
.forwarding_chains
.lock()
.await
.as_ref()
.cloned()
.unwrap_or_default(),
sender_claimed_keys: (&*self.signing_key).clone(),
session_key,
})
}
/// Restore a Session from a previously pickled string.
///
/// Returns the restored group session or a `OlmGroupSessionError` if there
/// was an error.
///
/// # Arguments
///
/// * `pickle` - The pickled version of the `InboundGroupSession`.
///
/// * `pickle_mode` - The mode that was used to pickle the session, either
/// an unencrypted mode or an encrypted using passphrase.
pub fn from_pickle(
pickle: PickledInboundGroupSession,
pickle_mode: PicklingMode,
) -> Result<Self, OlmGroupSessionError> {
let session = OlmInboundGroupSession::unpickle(pickle.pickle.0, pickle_mode)?;
let session_id = session.session_id();
Ok(InboundGroupSession {
inner: Arc::new(Mutex::new(session)),
session_id: Arc::new(session_id),
sender_key: Arc::new(pickle.sender_key),
signing_key: Arc::new(pickle.signing_key),
room_id: Arc::new(pickle.room_id),
forwarding_chains: Arc::new(Mutex::new(pickle.forwarding_chains)),
imported: Arc::new(pickle.imported),
})
}
/// The room where this session is used in.
pub fn room_id(&self) -> &RoomId {
&self.room_id
}
/// Returns the unique identifier for this session.
pub fn session_id(&self) -> &str {
&self.session_id
}
/// Get the first message index we know how to decrypt.
pub async fn first_known_index(&self) -> u32 {
self.inner.lock().await.first_known_index()
}
/// Decrypt the given ciphertext.
///
/// Returns the decrypted plaintext or an `OlmGroupSessionError` if
/// decryption failed.
///
/// # Arguments
///
/// * `message` - The message that should be decrypted.
pub(crate) async fn decrypt_helper(
&self,
message: String,
) -> Result<(String, u32), OlmGroupSessionError> {
self.inner.lock().await.decrypt(message)
}
/// Decrypt an event from a room timeline.
///
/// # Arguments
///
/// * `event` - The event that should be decrypted.
pub(crate) async fn decrypt(
&self,
event: &SyncMessageEvent<EncryptedEventContent>,
) -> MegolmResult<(Raw<AnySyncRoomEvent>, u32)> {
let content = match &event.content {
EncryptedEventContent::MegolmV1AesSha2(c) => c,
_ => return Err(EventError::UnsupportedAlgorithm.into()),
};
let (plaintext, message_index) = self.decrypt_helper(content.ciphertext.clone()).await?;
let mut decrypted_value = serde_json::from_str::<Value>(&plaintext)?;
let decrypted_object = decrypted_value
.as_object_mut()
.ok_or(EventError::NotAnObject)?;
// TODO better number conversion here.
let server_ts = event
.origin_server_ts
.duration_since(std::time::SystemTime::UNIX_EPOCH)
.unwrap_or_default()
.as_millis();
let server_ts: i64 = server_ts.try_into().unwrap_or_default();
decrypted_object.insert("sender".to_owned(), event.sender.to_string().into());
decrypted_object.insert("event_id".to_owned(), event.event_id.to_string().into());
decrypted_object.insert("origin_server_ts".to_owned(), server_ts.into());
decrypted_object.insert(
"unsigned".to_owned(),
serde_json::to_value(&event.unsigned).unwrap_or_default(),
);
Ok((
serde_json::from_value::<Raw<AnySyncRoomEvent>>(decrypted_value)?,
message_index,
))
}
}
#[cfg(not(tarpaulin_include))]
impl fmt::Debug for InboundGroupSession {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("InboundGroupSession")
.field("session_id", &self.session_id())
.finish()
}
}
impl PartialEq for InboundGroupSession {
fn eq(&self, other: &Self) -> bool {
self.session_id() == other.session_id()
}
}
/// A pickled version of an `InboundGroupSession`.
///
/// Holds all the information that needs to be stored in a database to restore
/// an InboundGroupSession.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PickledInboundGroupSession {
/// The pickle string holding the InboundGroupSession.
pub pickle: InboundGroupSessionPickle,
/// The public curve25519 key of the account that sent us the session
pub sender_key: String,
/// The public ed25519 key of the account that sent us the session.
pub signing_key: BTreeMap<DeviceKeyAlgorithm, String>,
/// The id of the room that the session is used in.
pub room_id: RoomId,
/// The list of claimed ed25519 that forwarded us this key. Will be None if
/// we dirrectly received this session.
pub forwarding_chains: Option<Vec<String>>,
/// Flag remembering if the session was dirrectly sent to us by the sender
/// or if it was imported.
pub imported: bool,
}
/// The typed representation of a base64 encoded string of the GroupSession pickle.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InboundGroupSessionPickle(String);
impl From<String> for InboundGroupSessionPickle {
fn from(pickle_string: String) -> Self {
InboundGroupSessionPickle(pickle_string)
}
}
impl InboundGroupSessionPickle {
/// Get the string representation of the pickle.
pub fn as_str(&self) -> &str {
&self.0
}
}
impl TryFrom<ExportedRoomKey> for InboundGroupSession {
type Error = OlmGroupSessionError;
fn try_from(key: ExportedRoomKey) -> Result<Self, Self::Error> {
let session = OlmInboundGroupSession::import(&key.session_key.0)?;
let forwarding_chains = if key.forwarding_curve25519_key_chain.is_empty() {
None
} else {
Some(key.forwarding_curve25519_key_chain)
};
Ok(InboundGroupSession {
inner: Arc::new(Mutex::new(session)),
session_id: Arc::new(key.session_id),
sender_key: Arc::new(key.sender_key),
signing_key: Arc::new(key.sender_claimed_keys),
room_id: Arc::new(key.room_id),
forwarding_chains: Arc::new(Mutex::new(forwarding_chains)),
imported: Arc::new(true),
})
}
}

View File

@ -0,0 +1,176 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use matrix_sdk_common::{
events::forwarded_room_key::ForwardedRoomKeyEventContent,
identifiers::{DeviceKeyAlgorithm, EventEncryptionAlgorithm, RoomId},
};
use serde::{Deserialize, Serialize};
use std::{collections::BTreeMap, convert::TryInto};
use zeroize::Zeroize;
mod inbound;
mod outbound;
pub use inbound::{InboundGroupSession, InboundGroupSessionPickle, PickledInboundGroupSession};
pub use outbound::{EncryptionSettings, OutboundGroupSession};
/// The private session key of a group session.
/// Can be used to create a new inbound group session.
#[derive(Clone, Debug, Serialize, Deserialize, Zeroize)]
#[zeroize(drop)]
pub struct GroupSessionKey(pub String);
/// The exported version of an private session key of a group session.
/// Can be used to create a new inbound group session.
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Zeroize)]
#[zeroize(drop)]
pub struct ExportedGroupSessionKey(pub String);
/// An exported version of a `InboundGroupSession`
///
/// This can be used to share the `InboundGroupSession` in an exported file.
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
pub struct ExportedRoomKey {
/// The encryption algorithm that the session uses.
pub algorithm: EventEncryptionAlgorithm,
/// The room where the session is used.
pub room_id: RoomId,
/// The Curve25519 key of the device which initiated the session originally.
pub sender_key: String,
/// The ID of the session that the key is for.
pub session_id: String,
/// The key for the session.
pub session_key: ExportedGroupSessionKey,
/// The Ed25519 key of the device which initiated the session originally.
pub sender_claimed_keys: BTreeMap<DeviceKeyAlgorithm, String>,
/// Chain of Curve25519 keys through which this session was forwarded, via
/// m.forwarded_room_key events.
pub forwarding_curve25519_key_chain: Vec<String>,
}
impl TryInto<ForwardedRoomKeyEventContent> for ExportedRoomKey {
type Error = ();
/// Convert an exported room key into a content for a forwarded room key
/// event.
///
/// This will fail if the exported room key has multiple sender claimed keys
/// or if the algorithm of the claimed sender key isn't
/// `DeviceKeyAlgorithm::Ed25519`.
fn try_into(self) -> Result<ForwardedRoomKeyEventContent, Self::Error> {
if self.sender_claimed_keys.len() != 1 {
Err(())
} else {
let (algorithm, claimed_key) = self.sender_claimed_keys.iter().next().ok_or(())?;
if algorithm != &DeviceKeyAlgorithm::Ed25519 {
return Err(());
}
Ok(ForwardedRoomKeyEventContent {
algorithm: self.algorithm,
room_id: self.room_id,
sender_key: self.sender_key,
session_id: self.session_id,
session_key: self.session_key.0.clone(),
sender_claimed_ed25519_key: claimed_key.to_owned(),
forwarding_curve25519_key_chain: self.forwarding_curve25519_key_chain,
})
}
}
}
impl From<ForwardedRoomKeyEventContent> for ExportedRoomKey {
/// Convert the content of a forwarded room key into a exported room key.
fn from(forwarded_key: ForwardedRoomKeyEventContent) -> Self {
let mut sender_claimed_keys: BTreeMap<DeviceKeyAlgorithm, String> = BTreeMap::new();
sender_claimed_keys.insert(
DeviceKeyAlgorithm::Ed25519,
forwarded_key.sender_claimed_ed25519_key,
);
Self {
algorithm: forwarded_key.algorithm,
room_id: forwarded_key.room_id,
session_id: forwarded_key.session_id,
forwarding_curve25519_key_chain: forwarded_key.forwarding_curve25519_key_chain,
sender_claimed_keys,
sender_key: forwarded_key.sender_key,
session_key: ExportedGroupSessionKey(forwarded_key.session_key),
}
}
}
#[cfg(test)]
mod test {
use std::{
sync::Arc,
time::{Duration, Instant},
};
use matrix_sdk_common::{
events::{
room::message::{MessageEventContent, TextMessageEventContent},
AnyMessageEventContent,
},
identifiers::{room_id, user_id},
};
use super::EncryptionSettings;
use crate::Account;
#[tokio::test]
#[cfg(not(target_os = "macos"))]
async fn expiration() {
let settings = EncryptionSettings {
rotation_period_msgs: 1,
..Default::default()
};
let account = Account::new(&user_id!("@alice:example.org"), "DEVICEID".into());
let (session, _) = account
.create_group_session_pair(&room_id!("!test_room:example.org"), settings)
.await
.unwrap();
assert!(!session.expired());
let _ = session
.encrypt(AnyMessageEventContent::RoomMessage(
MessageEventContent::Text(TextMessageEventContent::plain("Test message")),
))
.await;
assert!(session.expired());
let settings = EncryptionSettings {
rotation_period: Duration::from_millis(100),
..Default::default()
};
let (mut session, _) = account
.create_group_session_pair(&room_id!("!test_room:example.org"), settings)
.await
.unwrap();
assert!(!session.expired());
session.creation_time = Arc::new(Instant::now() - Duration::from_secs(60 * 60));
assert!(session.expired());
}
}

View File

@ -0,0 +1,303 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{
cmp::min,
fmt,
sync::{
atomic::{AtomicBool, AtomicU64, Ordering},
Arc,
},
time::Duration,
};
use matrix_sdk_common::{
events::{
room::{encrypted::EncryptedEventContent, encryption::EncryptionEventContent},
AnyMessageEventContent, EventContent,
},
identifiers::{DeviceId, EventEncryptionAlgorithm, RoomId},
instant::Instant,
locks::Mutex,
};
use olm_rs::outbound_group_session::OlmOutboundGroupSession;
use serde_json::{json, Value};
pub use olm_rs::{
account::IdentityKeys,
session::{OlmMessage, PreKeyMessage},
utility::OlmUtility,
};
use super::GroupSessionKey;
const ROTATION_PERIOD: Duration = Duration::from_millis(604800000);
const ROTATION_MESSAGES: u64 = 100;
/// Settings for an encrypted room.
///
/// This determines the algorithm and rotation periods of a group session.
#[derive(Debug)]
pub struct EncryptionSettings {
/// The encryption algorithm that should be used in the room.
pub algorithm: EventEncryptionAlgorithm,
/// How long the session should be used before changing it.
pub rotation_period: Duration,
/// How many messages should be sent before changing the session.
pub rotation_period_msgs: u64,
}
impl Default for EncryptionSettings {
fn default() -> Self {
Self {
algorithm: EventEncryptionAlgorithm::MegolmV1AesSha2,
rotation_period: ROTATION_PERIOD,
rotation_period_msgs: ROTATION_MESSAGES,
}
}
}
impl From<&EncryptionEventContent> for EncryptionSettings {
fn from(content: &EncryptionEventContent) -> Self {
let rotation_period: Duration = content
.rotation_period_ms
.map_or(ROTATION_PERIOD, |r| Duration::from_millis(r.into()));
let rotation_period_msgs: u64 = content
.rotation_period_msgs
.map_or(ROTATION_MESSAGES, Into::into);
Self {
algorithm: content.algorithm.clone(),
rotation_period,
rotation_period_msgs,
}
}
}
/// Outbound group session.
///
/// Outbound group sessions are used to exchange room messages between a group
/// of participants. Outbound group sessions are used to encrypt the room
/// messages.
#[derive(Clone)]
pub struct OutboundGroupSession {
inner: Arc<Mutex<OlmOutboundGroupSession>>,
device_id: Arc<Box<DeviceId>>,
account_identity_keys: Arc<IdentityKeys>,
session_id: Arc<String>,
room_id: Arc<RoomId>,
pub(crate) creation_time: Arc<Instant>,
message_count: Arc<AtomicU64>,
shared: Arc<AtomicBool>,
settings: Arc<EncryptionSettings>,
}
impl OutboundGroupSession {
/// Create a new outbound group session for the given room.
///
/// Outbound group sessions are used to encrypt room messages.
///
/// # Arguments
///
/// * `device_id` - The id of the device that created this session.
///
/// * `identity_keys` - The identity keys of the account that created this
/// session.
///
/// * `room_id` - The id of the room that the session is used in.
///
/// * `settings` - Settings determining the algorithm and rotation period of
/// the outbound group session.
pub fn new(
device_id: Arc<Box<DeviceId>>,
identity_keys: Arc<IdentityKeys>,
room_id: &RoomId,
settings: EncryptionSettings,
) -> Self {
let session = OlmOutboundGroupSession::new();
let session_id = session.session_id();
OutboundGroupSession {
inner: Arc::new(Mutex::new(session)),
room_id: Arc::new(room_id.to_owned()),
device_id,
account_identity_keys: identity_keys,
session_id: Arc::new(session_id),
creation_time: Arc::new(Instant::now()),
message_count: Arc::new(AtomicU64::new(0)),
shared: Arc::new(AtomicBool::new(false)),
settings: Arc::new(settings),
}
}
/// Encrypt the given plaintext using this session.
///
/// Returns the encrypted ciphertext.
///
/// # Arguments
///
/// * `plaintext` - The plaintext that should be encrypted.
pub(crate) async fn encrypt_helper(&self, plaintext: String) -> String {
let session = self.inner.lock().await;
self.message_count.fetch_add(1, Ordering::SeqCst);
session.encrypt(plaintext)
}
/// Encrypt a room message for the given room.
///
/// Beware that a group session needs to be shared before this method can be
/// called using the `share_group_session()` method.
///
/// Since group sessions can expire or become invalid if the room membership
/// changes client authors should check with the
/// `should_share_group_session()` method if a new group session needs to
/// be shared.
///
/// # Arguments
///
/// * `content` - The plaintext content of the message that should be
/// encrypted.
///
/// # Panics
///
/// Panics if the content can't be serialized.
pub async fn encrypt(&self, content: AnyMessageEventContent) -> EncryptedEventContent {
let json_content = json!({
"content": content,
"room_id": &*self.room_id,
"type": content.event_type(),
});
let plaintext = cjson::to_string(&json_content).unwrap_or_else(|_| {
panic!(format!(
"Can't serialize {} to canonical JSON",
json_content
))
});
let ciphertext = self.encrypt_helper(plaintext).await;
EncryptedEventContent::MegolmV1AesSha2(
matrix_sdk_common::events::room::encrypted::MegolmV1AesSha2ContentInit {
ciphertext,
sender_key: self.account_identity_keys.curve25519().to_owned(),
session_id: self.session_id().to_owned(),
device_id: (&*self.device_id).to_owned(),
}
.into(),
)
}
/// Check if the session has expired and if it should be rotated.
///
/// A session will expire after some time or if enough messages have been
/// encrypted using it.
pub fn expired(&self) -> bool {
let count = self.message_count.load(Ordering::SeqCst);
count >= self.settings.rotation_period_msgs
|| self.creation_time.elapsed()
// Since the encryption settings are provided by users and not
// checked someone could set a really low rotation perdiod so
// clamp it at a minute.
>= min(self.settings.rotation_period, Duration::from_secs(3600))
}
/// Mark the session as shared.
///
/// Messages shouldn't be encrypted with the session before it has been
/// shared.
pub fn mark_as_shared(&self) {
self.shared.store(true, Ordering::Relaxed);
}
/// Check if the session has been marked as shared.
pub fn shared(&self) -> bool {
self.shared.load(Ordering::Relaxed)
}
/// Get the session key of this session.
///
/// A session key can be used to to create an `InboundGroupSession`.
pub async fn session_key(&self) -> GroupSessionKey {
let session = self.inner.lock().await;
GroupSessionKey(session.session_key())
}
/// Returns the unique identifier for this session.
pub fn session_id(&self) -> &str {
&self.session_id
}
/// Get the current message index for this session.
///
/// Each message is sent with an increasing index. This returns the
/// message index that will be used for the next encrypted message.
pub async fn message_index(&self) -> u32 {
let session = self.inner.lock().await;
session.session_message_index()
}
/// Get the outbound group session key as a json value that can be sent as a
/// m.room_key.
pub async fn as_json(&self) -> Value {
json!({
"algorithm": EventEncryptionAlgorithm::MegolmV1AesSha2,
"room_id": &*self.room_id,
"session_id": &*self.session_id,
"session_key": self.session_key().await,
"chain_index": self.message_index().await,
})
}
}
#[cfg(not(tarpaulin_include))]
impl std::fmt::Debug for OutboundGroupSession {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("OutboundGroupSession")
.field("session_id", &self.session_id)
.field("room_id", &self.room_id)
.field("creation_time", &self.creation_time)
.field("message_count", &self.message_count)
.finish()
}
}
#[cfg(test)]
mod test {
use std::time::Duration;
use matrix_sdk_common::{
events::room::encryption::EncryptionEventContent, identifiers::EventEncryptionAlgorithm,
js_int::uint,
};
use super::{EncryptionSettings, ROTATION_MESSAGES, ROTATION_PERIOD};
#[test]
fn encryption_settings_conversion() {
let mut content = EncryptionEventContent::new(EventEncryptionAlgorithm::MegolmV1AesSha2);
let settings = EncryptionSettings::from(&content);
assert_eq!(settings.rotation_period, ROTATION_PERIOD);
assert_eq!(settings.rotation_period_msgs, ROTATION_MESSAGES);
content.rotation_period_ms = Some(uint!(3600));
content.rotation_period_msgs = Some(uint!(500));
let settings = EncryptionSettings::from(&content);
assert_eq!(settings.rotation_period, Duration::from_millis(3600));
assert_eq!(settings.rotation_period_msgs, 500);
}
}

View File

@ -24,7 +24,8 @@ mod utility;
pub use account::{Account, AccountPickle, IdentityKeys, PickledAccount}; pub use account::{Account, AccountPickle, IdentityKeys, PickledAccount};
pub use group_sessions::{ pub use group_sessions::{
EncryptionSettings, InboundGroupSession, InboundGroupSessionPickle, PickledInboundGroupSession, EncryptionSettings, ExportedRoomKey, InboundGroupSession, InboundGroupSessionPickle,
PickledInboundGroupSession,
}; };
pub(crate) use group_sessions::{GroupSessionKey, OutboundGroupSession}; pub(crate) use group_sessions::{GroupSessionKey, OutboundGroupSession};
pub use olm_rs::PicklingMode; pub use olm_rs::PicklingMode;
@ -37,10 +38,11 @@ pub(crate) mod test {
use crate::olm::{Account, InboundGroupSession, Session}; use crate::olm::{Account, InboundGroupSession, Session};
use matrix_sdk_common::{ use matrix_sdk_common::{
api::r0::keys::SignedKey, api::r0::keys::SignedKey,
events::forwarded_room_key::ForwardedRoomKeyEventContent,
identifiers::{room_id, user_id, DeviceId, UserId}, identifiers::{room_id, user_id, DeviceId, UserId},
}; };
use olm_rs::session::OlmMessage; use olm_rs::session::OlmMessage;
use std::collections::BTreeMap; use std::{collections::BTreeMap, convert::TryInto};
fn alice_id() -> UserId { fn alice_id() -> UserId {
user_id!("@alice:example.org") user_id!("@alice:example.org")
@ -221,4 +223,22 @@ pub(crate) mod test {
inbound.decrypt_helper(ciphertext).await.unwrap().0 inbound.decrypt_helper(ciphertext).await.unwrap().0
); );
} }
#[tokio::test]
async fn group_session_export() {
let alice = Account::new(&alice_id(), &alice_device_id());
let room_id = room_id!("!test:localhost");
let (_, inbound) = alice
.create_group_session_pair(&room_id, Default::default())
.await
.unwrap();
let export = inbound.export().await;
let export: ForwardedRoomKeyEventContent = export.try_into().unwrap();
let imported = InboundGroupSession::from_export(export).unwrap();
assert_eq!(inbound.session_id(), imported.session_id());
}
} }

View File

@ -106,6 +106,19 @@ impl GroupSessionStore {
.is_none() .is_none()
} }
/// Get all the group sessions the store knows about.
pub fn get_all(&self) -> Vec<InboundGroupSession> {
self.entries
.iter()
.flat_map(|d| {
d.value()
.values()
.flat_map(|t| t.values().cloned().collect::<Vec<InboundGroupSession>>())
.collect::<Vec<InboundGroupSession>>()
})
.collect()
}
/// Get a inbound group session from our store. /// Get a inbound group session from our store.
/// ///
/// # Arguments /// # Arguments

View File

@ -80,8 +80,12 @@ impl CryptoStore for MemoryStore {
Ok(self.sessions.get(sender_key)) Ok(self.sessions.get(sender_key))
} }
async fn save_inbound_group_session(&self, session: InboundGroupSession) -> Result<bool> { async fn save_inbound_group_sessions(&self, sessions: &[InboundGroupSession]) -> Result<()> {
Ok(self.inbound_group_sessions.add(session)) for session in sessions {
self.inbound_group_sessions.add(session.clone());
}
Ok(())
} }
async fn get_inbound_group_session( async fn get_inbound_group_session(
@ -95,6 +99,10 @@ impl CryptoStore for MemoryStore {
.get(room_id, sender_key, session_id)) .get(room_id, sender_key, session_id))
} }
async fn get_inbound_group_sessions(&self) -> Result<Vec<InboundGroupSession>> {
Ok(self.inbound_group_sessions.get_all())
}
fn users_for_key_query(&self) -> HashSet<UserId> { fn users_for_key_query(&self) -> HashSet<UserId> {
#[allow(clippy::map_clone)] #[allow(clippy::map_clone)]
self.users_for_key_query.iter().map(|u| u.clone()).collect() self.users_for_key_query.iter().map(|u| u.clone()).collect()
@ -208,7 +216,7 @@ mod test {
let store = MemoryStore::new(); let store = MemoryStore::new();
let _ = store let _ = store
.save_inbound_group_session(inbound.clone()) .save_inbound_group_sessions(&[inbound.clone()])
.await .await
.unwrap(); .unwrap();

View File

@ -157,15 +157,12 @@ pub trait CryptoStore: Debug {
/// * `sender_key` - The sender key that was used to establish the sessions. /// * `sender_key` - The sender key that was used to establish the sessions.
async fn get_sessions(&self, sender_key: &str) -> Result<Option<Arc<Mutex<Vec<Session>>>>>; async fn get_sessions(&self, sender_key: &str) -> Result<Option<Arc<Mutex<Vec<Session>>>>>;
/// Save the given inbound group session in the store. /// Save the given inbound group sessions in the store.
///
/// If the session wasn't already in the store true is returned, false
/// otherwise.
/// ///
/// # Arguments /// # Arguments
/// ///
/// * `session` - The session that should be stored. /// * `sessions` - The sessions that should be stored.
async fn save_inbound_group_session(&self, session: InboundGroupSession) -> Result<bool>; async fn save_inbound_group_sessions(&self, session: &[InboundGroupSession]) -> Result<()>;
/// Get the inbound group session from our store. /// Get the inbound group session from our store.
/// ///
@ -182,6 +179,9 @@ pub trait CryptoStore: Debug {
session_id: &str, session_id: &str,
) -> Result<Option<InboundGroupSession>>; ) -> Result<Option<InboundGroupSession>>;
/// Get all the inbound group sessions we have stored.
async fn get_inbound_group_sessions(&self) -> Result<Vec<InboundGroupSession>>;
/// Is the given user already tracked. /// Is the given user already tracked.
fn is_user_tracked(&self, user_id: &UserId) -> bool; fn is_user_tracked(&self, user_id: &UserId) -> bool;

View File

@ -23,6 +23,7 @@ use std::{
use async_trait::async_trait; use async_trait::async_trait;
use dashmap::DashSet; use dashmap::DashSet;
use matrix_sdk_common::{ use matrix_sdk_common::{
api::r0::keys::{CrossSigningKey, KeyUsage},
identifiers::{ identifiers::{
DeviceId, DeviceKeyAlgorithm, DeviceKeyId, EventEncryptionAlgorithm, RoomId, UserId, DeviceId, DeviceKeyAlgorithm, DeviceKeyId, EventEncryptionAlgorithm, RoomId, UserId,
}, },
@ -38,7 +39,7 @@ use super::{
CryptoStore, CryptoStoreError, Result, CryptoStore, CryptoStoreError, Result,
}; };
use crate::{ use crate::{
identities::{LocalTrust, ReadOnlyDevice, UserIdentities}, identities::{LocalTrust, OwnUserIdentity, ReadOnlyDevice, UserIdentities, UserIdentity},
olm::{ olm::{
Account, AccountPickle, IdentityKeys, InboundGroupSession, InboundGroupSessionPickle, Account, AccountPickle, IdentityKeys, InboundGroupSession, InboundGroupSessionPickle,
PickledAccount, PickledInboundGroupSession, PickledSession, PicklingMode, Session, PickledAccount, PickledInboundGroupSession, PickledSession, PicklingMode, Session,
@ -71,6 +72,14 @@ struct AccountInfo {
identity_keys: Arc<IdentityKeys>, identity_keys: Arc<IdentityKeys>,
} }
#[derive(Debug, PartialEq, Copy, Clone, sqlx::Type)]
#[repr(i32)]
enum CrosssigningKeyType {
Master = 0,
SelfSigning = 1,
UserSigning = 2,
}
static DATABASE_NAME: &str = "matrix-sdk-crypto.db"; static DATABASE_NAME: &str = "matrix-sdk-crypto.db";
impl SqliteStore { impl SqliteStore {
@ -135,8 +144,8 @@ impl SqliteStore {
passphrase: Option<Zeroizing<String>>, passphrase: Option<Zeroizing<String>>,
) -> Result<SqliteStore> { ) -> Result<SqliteStore> {
let url = SqliteStore::path_to_url(path.as_ref())?; let url = SqliteStore::path_to_url(path.as_ref())?;
let connection = SqliteConnection::connect(url.as_ref()).await?; let connection = SqliteConnection::connect(url.as_ref()).await?;
let store = SqliteStore { let store = SqliteStore {
user_id: Arc::new(user_id.to_owned()), user_id: Arc::new(user_id.to_owned()),
device_id: Arc::new(device_id.into()), device_id: Arc::new(device_id.into()),
@ -151,6 +160,7 @@ impl SqliteStore {
users_for_key_query: Arc::new(DashSet::new()), users_for_key_query: Arc::new(DashSet::new()),
}; };
store.create_tables().await?; store.create_tables().await?;
Ok(store) Ok(store)
} }
@ -221,14 +231,16 @@ impl SqliteStore {
.execute( .execute(
r#" r#"
CREATE TABLE IF NOT EXISTS inbound_group_sessions ( CREATE TABLE IF NOT EXISTS inbound_group_sessions (
"session_id" TEXT NOT NULL PRIMARY KEY, "id" INTEGER NOT NULL PRIMARY KEY,
"session_id" TEXT NOT NULL,
"account_id" INTEGER NOT NULL, "account_id" INTEGER NOT NULL,
"sender_key" TEXT NOT NULL, "sender_key" TEXT NOT NULL,
"signing_key" TEXT NOT NULL,
"room_id" TEXT NOT NULL, "room_id" TEXT NOT NULL,
"pickle" BLOB NOT NULL, "pickle" BLOB NOT NULL,
"imported" INTEGER NOT NULL,
FOREIGN KEY ("account_id") REFERENCES "accounts" ("id") FOREIGN KEY ("account_id") REFERENCES "accounts" ("id")
ON DELETE CASCADE ON DELETE CASCADE
UNIQUE(account_id,session_id,sender_key)
); );
CREATE INDEX IF NOT EXISTS "olm_groups_sessions_account_id" ON "inbound_group_sessions" ("account_id"); CREATE INDEX IF NOT EXISTS "olm_groups_sessions_account_id" ON "inbound_group_sessions" ("account_id");
@ -236,6 +248,41 @@ impl SqliteStore {
) )
.await?; .await?;
connection
.execute(
r#"
CREATE TABLE IF NOT EXISTS group_session_claimed_keys (
"id" INTEGER NOT NULL PRIMARY KEY,
"session_id" INTEGER NOT NULL,
"algorithm" TEXT NOT NULL,
"key" TEXT NOT NULL,
FOREIGN KEY ("session_id") REFERENCES "inbound_group_sessions" ("id")
ON DELETE CASCADE
UNIQUE(session_id, algorithm)
);
CREATE INDEX IF NOT EXISTS "group_session_claimed_keys_session_id" ON "group_session_claimed_keys" ("session_id");
"#,
)
.await?;
connection
.execute(
r#"
CREATE TABLE IF NOT EXISTS group_session_chains (
"id" INTEGER NOT NULL PRIMARY KEY,
"key" TEXT NOT NULL,
"session_id" INTEGER NOT NULL,
FOREIGN KEY ("session_id") REFERENCES "inbound_group_sessions" ("id")
ON DELETE CASCADE
UNIQUE(session_id, key)
);
CREATE INDEX IF NOT EXISTS "group_session_chains_session_id" ON "group_session_chains" ("session_id");
"#,
)
.await?;
connection connection
.execute( .execute(
r#" r#"
@ -310,6 +357,91 @@ impl SqliteStore {
) )
.await?; .await?;
connection
.execute(
r#"
CREATE TABLE IF NOT EXISTS users (
"id" INTEGER NOT NULL PRIMARY KEY,
"account_id" INTEGER NOT NULL,
"user_id" TEXT NOT NULL,
FOREIGN KEY ("account_id") REFERENCES "accounts" ("id")
ON DELETE CASCADE
UNIQUE(account_id,user_id)
);
CREATE INDEX IF NOT EXISTS "users_account_id" ON "users" ("account_id");
"#,
)
.await?;
connection
.execute(
r#"
CREATE TABLE IF NOT EXISTS users_trust_state (
"id" INTEGER NOT NULL PRIMARY KEY,
"trusted" INTEGER NOT NULL,
"user_id" INTEGER NOT NULL,
FOREIGN KEY ("user_id") REFERENCES "users" ("id")
ON DELETE CASCADE
UNIQUE(user_id)
);
"#,
)
.await?;
connection
.execute(
r#"
CREATE TABLE IF NOT EXISTS cross_signing_keys (
"id" INTEGER NOT NULL PRIMARY KEY,
"key_type" INTEGER NOT NULL,
"usage" STRING NOT NULL,
"user_id" INTEGER NOT NULL,
FOREIGN KEY ("user_id") REFERENCES "users" ("id") ON DELETE CASCADE
UNIQUE(user_id, key_type)
);
CREATE INDEX IF NOT EXISTS "cross_signing_keys_users" ON "users" ("user_id");
"#,
)
.await?;
connection
.execute(
r#"
CREATE TABLE IF NOT EXISTS user_keys (
"id" INTEGER NOT NULL PRIMARY KEY,
"key" TEXT NOT NULL,
"key_id" TEXT NOT NULL,
"cross_signing_key" INTEGER NOT NULL,
FOREIGN KEY ("cross_signing_key") REFERENCES "cross_signing_keys" ("id") ON DELETE CASCADE
UNIQUE(cross_signing_key, key_id)
);
CREATE INDEX IF NOT EXISTS "cross_signing_keys_keys" ON "cross_signing_keys" ("cross_signing_key");
"#,
)
.await?;
connection
.execute(
r#"
CREATE TABLE IF NOT EXISTS user_key_signatures (
"id" INTEGER NOT NULL PRIMARY KEY,
"user_id" TEXT NOT NULL,
"key_id" INTEGER NOT NULL,
"signature" TEXT NOT NULL,
"cross_signing_key" INTEGER NOT NULL,
FOREIGN KEY ("cross_signing_key") REFERENCES "cross_signing_keys" ("id")
ON DELETE CASCADE
UNIQUE(user_id, key_id, cross_signing_key)
);
CREATE INDEX IF NOT EXISTS "cross_signing_keys_signatures" ON "cross_signing_keys" ("cross_signing_key");
"#,
)
.await?;
Ok(()) Ok(())
} }
@ -380,45 +512,67 @@ impl SqliteStore {
let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?;
let mut connection = self.connection.lock().await; let mut connection = self.connection.lock().await;
let mut rows: Vec<(String, String, String, String)> = query_as( let mut rows: Vec<(i64, String, String, String, bool)> = query_as(
"SELECT pickle, sender_key, signing_key, room_id "SELECT id, pickle, sender_key, room_id, imported
FROM inbound_group_sessions WHERE account_id = ?", FROM inbound_group_sessions WHERE account_id = ?",
) )
.bind(account_id) .bind(account_id)
.fetch_all(&mut *connection) .fetch_all(&mut *connection)
.await?; .await?;
let mut group_sessions = rows for row in rows.drain(..) {
.drain(..) let session_row_id = row.0;
.map(|row| { let pickle = row.1;
let pickle = row.0; let sender_key = row.2;
let sender_key = row.1;
let signing_key = row.2;
let room_id = row.3; let room_id = row.3;
let imported = row.4;
let key_rows: Vec<(String, String)> = query_as(
"SELECT algorithm, key FROM group_session_claimed_keys WHERE session_id = ?",
)
.bind(session_row_id)
.fetch_all(&mut *connection)
.await?;
let claimed_keys: BTreeMap<DeviceKeyAlgorithm, String> = key_rows
.into_iter()
.filter_map(|row| {
let algorithm = row.0.parse::<DeviceKeyAlgorithm>().ok()?;
let key = row.1;
Some((algorithm, key))
})
.collect();
let mut chain_rows: Vec<(String,)> =
query_as("SELECT key, key FROM group_session_chains WHERE session_id = ?")
.bind(session_row_id)
.fetch_all(&mut *connection)
.await?;
let chains: Vec<String> = chain_rows.drain(..).map(|r| r.0).collect();
let chains = if chains.is_empty() {
None
} else {
Some(chains)
};
let pickle = PickledInboundGroupSession { let pickle = PickledInboundGroupSession {
pickle: InboundGroupSessionPickle::from(pickle), pickle: InboundGroupSessionPickle::from(pickle),
sender_key, sender_key,
signing_key, signing_key: claimed_keys,
room_id: RoomId::try_from(room_id)?, room_id: RoomId::try_from(room_id)?,
// Fixme we need to store/restore these once we get support forwarding_chains: chains,
// for key requesting/forwarding. imported,
forwarding_chains: None,
}; };
Ok(InboundGroupSession::from_pickle( self.inbound_group_sessions
.add(InboundGroupSession::from_pickle(
pickle, pickle,
self.get_pickle_mode(), self.get_pickle_mode(),
)?) )?);
}) }
.collect::<Result<Vec<InboundGroupSession>>>()?;
group_sessions
.drain(..)
.map(|s| {
self.inbound_group_sessions.add(s);
})
.for_each(drop);
Ok(()) Ok(())
} }
@ -670,6 +824,325 @@ impl SqliteStore {
None => PicklingMode::Unencrypted, None => PicklingMode::Unencrypted,
} }
} }
async fn save_inbound_group_session_helper(
&self,
account_id: i64,
connection: &mut SqliteConnection,
session: &InboundGroupSession,
) -> Result<()> {
let pickle = session.pickle(self.get_pickle_mode()).await;
let session_id = session.session_id();
query(
"REPLACE INTO inbound_group_sessions (
session_id, account_id, sender_key,
room_id, pickle, imported
) VALUES (?1, ?2, ?3, ?4, ?5, ?6)
",
)
.bind(session_id)
.bind(account_id)
.bind(&pickle.sender_key)
.bind(pickle.room_id.as_str())
.bind(pickle.pickle.as_str())
.bind(pickle.imported)
.execute(&mut *connection)
.await?;
let row: (i64,) = query_as(
"SELECT id FROM inbound_group_sessions
WHERE account_id = ? and session_id = ? and sender_key = ?",
)
.bind(account_id)
.bind(session_id)
.bind(pickle.sender_key)
.fetch_one(&mut *connection)
.await?;
let session_row_id = row.0;
for (key_id, key) in pickle.signing_key {
query(
"REPLACE INTO group_session_claimed_keys (
session_id, algorithm, key
) VALUES (?1, ?2, ?3)
",
)
.bind(session_row_id)
.bind(serde_json::to_string(&key_id)?)
.bind(key)
.execute(&mut *connection)
.await?;
}
if let Some(chains) = pickle.forwarding_chains {
for key in chains {
query(
"REPLACE INTO group_session_chains (
session_id, key
) VALUES (?1, ?2)
",
)
.bind(session_row_id)
.bind(key)
.execute(&mut *connection)
.await?;
}
}
Ok(())
}
async fn load_cross_signing_key(
connection: &mut SqliteConnection,
user_id: &UserId,
user_row_id: i64,
key_type: CrosssigningKeyType,
) -> Result<CrossSigningKey> {
let row: (i64, String) =
query_as("SELECT id, usage FROM cross_signing_keys WHERE user_id =? and key_type =?")
.bind(user_row_id)
.bind(key_type)
.fetch_one(&mut *connection)
.await?;
let key_row_id = row.0;
let usage: Vec<KeyUsage> = serde_json::from_str(&row.1)?;
let key_rows: Vec<(String, String)> =
query_as("SELECT key_id, key FROM user_keys WHERE cross_signing_key = ?")
.bind(key_row_id)
.fetch_all(&mut *connection)
.await?;
let mut keys = BTreeMap::new();
let mut signatures = BTreeMap::new();
for row in key_rows {
let key_id = row.0;
let key = row.1;
keys.insert(key_id, key);
}
let mut signature_rows: Vec<(String, String, String)> = query_as(
"SELECT user_id, key_id, signature FROM user_key_signatures WHERE cross_signing_key = ?",
)
.bind(key_row_id)
.fetch_all(&mut *connection)
.await?;
for row in signature_rows.drain(..) {
let user_id = if let Ok(u) = UserId::try_from(row.0) {
u
} else {
continue;
};
let key_id = row.1;
let signature = row.2;
signatures
.entry(user_id)
.or_insert_with(BTreeMap::new)
.insert(key_id, signature);
}
Ok(CrossSigningKey {
user_id: user_id.to_owned(),
usage,
keys,
signatures,
})
}
async fn load_user(&self, user_id: &UserId) -> Result<Option<UserIdentities>> {
let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?;
let mut connection = self.connection.lock().await;
let row: Option<(i64,)> =
query_as("SELECT id FROM users WHERE account_id = ? and user_id = ?")
.bind(account_id)
.bind(user_id.as_str())
.fetch_optional(&mut *connection)
.await?;
let user_row_id = if let Some(row) = row {
row.0
} else {
return Ok(None);
};
let master = SqliteStore::load_cross_signing_key(
&mut connection,
user_id,
user_row_id,
CrosssigningKeyType::Master,
)
.await?;
let self_singing = SqliteStore::load_cross_signing_key(
&mut connection,
user_id,
user_row_id,
CrosssigningKeyType::SelfSigning,
)
.await?;
if user_id == &*self.user_id {
let user_signing = SqliteStore::load_cross_signing_key(
&mut connection,
user_id,
user_row_id,
CrosssigningKeyType::UserSigning,
)
.await?;
let verified: Option<(bool,)> =
query_as("SELECT trusted FROM users_trust_state WHERE user_id = ?")
.bind(user_row_id)
.fetch_optional(&mut *connection)
.await?;
let verified = verified.map_or(false, |r| r.0);
let identity =
OwnUserIdentity::new(master.into(), self_singing.into(), user_signing.into())
.expect("Signature check failed on stored identity");
if verified {
identity.mark_as_verified();
}
Ok(Some(UserIdentities::Own(identity)))
} else {
Ok(Some(UserIdentities::Other(
UserIdentity::new(master.into(), self_singing.into())
.expect("Signature check failed on stored identity"),
)))
}
}
async fn save_cross_signing_key(
connection: &mut SqliteConnection,
user_row_id: i64,
key_type: CrosssigningKeyType,
cross_signing_key: impl AsRef<CrossSigningKey>,
) -> Result<()> {
let cross_signing_key: &CrossSigningKey = cross_signing_key.as_ref();
query(
"REPLACE INTO cross_signing_keys (
user_id, key_type, usage
) VALUES (?1, ?2, ?3)
",
)
.bind(user_row_id)
.bind(key_type)
.bind(serde_json::to_string(&cross_signing_key.usage)?)
.execute(&mut *connection)
.await?;
let row: (i64,) = query_as(
"SELECT id FROM cross_signing_keys
WHERE user_id = ? and key_type = ?",
)
.bind(user_row_id)
.bind(key_type)
.fetch_one(&mut *connection)
.await?;
let key_row_id = row.0;
for (key_id, key) in &cross_signing_key.keys {
query(
"REPLACE INTO user_keys (
cross_signing_key, key_id, key
) VALUES (?1, ?2, ?3)
",
)
.bind(key_row_id)
.bind(key_id.as_str())
.bind(key)
.execute(&mut *connection)
.await?;
}
for (user_id, signature_map) in &cross_signing_key.signatures {
for (key_id, signature) in signature_map {
query(
"REPLACE INTO user_key_signatures (
cross_signing_key, user_id, key_id, signature
) VALUES (?1, ?2, ?3, ?4)
",
)
.bind(key_row_id)
.bind(user_id.as_str())
.bind(key_id.as_str())
.bind(signature)
.execute(&mut *connection)
.await?;
}
}
Ok(())
}
async fn save_user_helper(&self, user: &UserIdentities) -> Result<()> {
let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?;
let mut connection = self.connection.lock().await;
query("REPLACE INTO users (account_id, user_id) VALUES (?1, ?2)")
.bind(account_id)
.bind(user.user_id().as_str())
.execute(&mut *connection)
.await?;
let row: (i64,) = query_as(
"SELECT id FROM users
WHERE account_id = ? and user_id = ?",
)
.bind(account_id)
.bind(user.user_id().as_str())
.fetch_one(&mut *connection)
.await?;
let user_row_id = row.0;
SqliteStore::save_cross_signing_key(
&mut connection,
user_row_id,
CrosssigningKeyType::Master,
user.master_key(),
)
.await?;
SqliteStore::save_cross_signing_key(
&mut connection,
user_row_id,
CrosssigningKeyType::SelfSigning,
user.self_signing_key(),
)
.await?;
if let UserIdentities::Own(own_identity) = user {
SqliteStore::save_cross_signing_key(
&mut connection,
user_row_id,
CrosssigningKeyType::UserSigning,
own_identity.user_signing_key(),
)
.await?;
query("REPLACE INTO users_trust_state (user_id, trusted) VALUES (?1, ?2)")
.bind(user_row_id)
.bind(own_identity.is_verified())
.execute(&mut *connection)
.await?;
}
Ok(())
}
} }
#[async_trait] #[async_trait]
@ -790,35 +1263,19 @@ impl CryptoStore for SqliteStore {
Ok(self.get_sessions_for(sender_key).await?) Ok(self.get_sessions_for(sender_key).await?)
} }
async fn save_inbound_group_session(&self, session: InboundGroupSession) -> Result<bool> { async fn save_inbound_group_sessions(&self, sessions: &[InboundGroupSession]) -> Result<()> {
let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?;
let pickle = session.pickle(self.get_pickle_mode()).await;
let mut connection = self.connection.lock().await; let mut connection = self.connection.lock().await;
let session_id = session.session_id();
// FIXME we need to store/restore the forwarding chains. // FIXME use a transaction here once sqlx gets better support for them.
// FIXME this should be converted so it accepts an array of sessions for
// the key import feature.
query( for session in sessions {
"INSERT INTO inbound_group_sessions ( self.save_inbound_group_session_helper(account_id, &mut connection, session)
session_id, account_id, sender_key, signing_key,
room_id, pickle
) VALUES (?1, ?2, ?3, ?4, ?5, ?6)
ON CONFLICT(session_id) DO UPDATE SET
pickle = excluded.pickle
",
)
.bind(session_id)
.bind(account_id)
.bind(pickle.sender_key)
.bind(pickle.signing_key)
.bind(pickle.room_id.as_str())
.bind(pickle.pickle.as_str())
.execute(&mut *connection)
.await?; .await?;
self.inbound_group_sessions.add(session.clone());
}
Ok(self.inbound_group_sessions.add(session)) Ok(())
} }
async fn get_inbound_group_session( async fn get_inbound_group_session(
@ -832,6 +1289,10 @@ impl CryptoStore for SqliteStore {
.get(room_id, sender_key, session_id)) .get(room_id, sender_key, session_id))
} }
async fn get_inbound_group_sessions(&self) -> Result<Vec<InboundGroupSession>> {
Ok(self.inbound_group_sessions.get_all())
}
fn is_user_tracked(&self, user_id: &UserId) -> bool { fn is_user_tracked(&self, user_id: &UserId) -> bool {
self.tracked_users.contains(user_id) self.tracked_users.contains(user_id)
} }
@ -899,11 +1360,15 @@ impl CryptoStore for SqliteStore {
Ok(self.devices.user_devices(user_id)) Ok(self.devices.user_devices(user_id))
} }
async fn get_user_identity(&self, _user_id: &UserId) -> Result<Option<UserIdentities>> { async fn get_user_identity(&self, user_id: &UserId) -> Result<Option<UserIdentities>> {
Ok(None) self.load_user(user_id).await
}
async fn save_user_identities(&self, users: &[UserIdentities]) -> Result<()> {
for user in users {
self.save_user_helper(user).await?;
} }
async fn save_user_identities(&self, _users: &[UserIdentities]) -> Result<()> {
Ok(()) Ok(())
} }
} }
@ -922,7 +1387,10 @@ impl std::fmt::Debug for SqliteStore {
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use crate::{ use crate::{
identities::device::test::get_device, identities::{
device::test::get_device,
user::test::{get_other_identity, get_own_identity},
},
olm::{Account, GroupSessionKey, InboundGroupSession, Session}, olm::{Account, GroupSessionKey, InboundGroupSession, Session},
}; };
use matrix_sdk_common::{ use matrix_sdk_common::{
@ -1175,14 +1643,14 @@ mod test {
.expect("Can't create session"); .expect("Can't create session");
store store
.save_inbound_group_session(session) .save_inbound_group_sessions(&[session])
.await .await
.expect("Can't save group session"); .expect("Can't save group session");
} }
#[tokio::test] #[tokio::test]
async fn load_inbound_group_session() { async fn load_inbound_group_session() {
let (account, store, _dir) = get_loaded_store().await; let (account, store, dir) = get_loaded_store().await;
let identity_keys = account.identity_keys(); let identity_keys = account.identity_keys();
let outbound_session = OlmOutboundGroupSession::new(); let outbound_session = OlmOutboundGroupSession::new();
@ -1194,11 +1662,22 @@ mod test {
) )
.expect("Can't create session"); .expect("Can't create session");
let mut export = session.export().await;
export.forwarding_curve25519_key_chain = vec!["some_chain".to_owned()];
let session = InboundGroupSession::from_export(export).unwrap();
store store
.save_inbound_group_session(session.clone()) .save_inbound_group_sessions(&[session.clone()])
.await .await
.expect("Can't save group session"); .expect("Can't save group session");
let store = SqliteStore::open(&alice_id(), &alice_device_id(), dir.path())
.await
.expect("Can't create store");
store.load_account().await.unwrap();
store.load_inbound_group_sessions().await.unwrap(); store.load_inbound_group_sessions().await.unwrap();
let loaded_session = store let loaded_session = store
@ -1207,6 +1686,8 @@ mod test {
.unwrap() .unwrap()
.unwrap(); .unwrap();
assert_eq!(session, loaded_session); assert_eq!(session, loaded_session);
let export = loaded_session.export().await;
assert!(!export.forwarding_curve25519_key_chain.is_empty())
} }
#[tokio::test] #[tokio::test]
@ -1311,4 +1792,81 @@ mod test {
assert!(loaded_device.is_none()); assert!(loaded_device.is_none());
} }
#[tokio::test]
async fn user_saving() {
let dir = tempdir().unwrap();
let tmpdir_path = dir.path().to_str().unwrap();
let user_id = user_id!("@example:localhost");
let device_id: &DeviceId = "WSKKLTJZCL".into();
let store = SqliteStore::open(&user_id, &device_id, tmpdir_path)
.await
.expect("Can't create store");
let account = Account::new(&user_id, &device_id);
store
.save_account(account.clone())
.await
.expect("Can't save account");
let own_identity = get_own_identity();
store
.save_user_identities(&[own_identity.clone().into()])
.await
.expect("Can't save identity");
drop(store);
let store = SqliteStore::open(&user_id, &device_id, dir.path())
.await
.expect("Can't create store");
store.load_account().await.unwrap();
let loaded_user = store
.get_user_identity(own_identity.user_id())
.await
.unwrap()
.unwrap();
assert_eq!(loaded_user.master_key(), own_identity.master_key());
assert_eq!(
loaded_user.self_signing_key(),
own_identity.self_signing_key()
);
assert_eq!(loaded_user, own_identity.clone().into());
let other_identity = get_other_identity();
store
.save_user_identities(&[other_identity.clone().into()])
.await
.unwrap();
let loaded_user = store
.load_user(other_identity.user_id())
.await
.unwrap()
.unwrap();
assert_eq!(loaded_user.master_key(), other_identity.master_key());
assert_eq!(
loaded_user.self_signing_key(),
other_identity.self_signing_key()
);
assert_eq!(loaded_user, other_identity.into());
own_identity.mark_as_verified();
store
.save_user_identities(&[own_identity.into()])
.await
.unwrap();
let loaded_user = store.load_user(&user_id).await.unwrap().unwrap();
assert!(loaded_user.own().unwrap().is_verified())
}
} }