Merge branch 'master' of https://github.com/matrix-org/matrix-rust-sdk into state-store

master
Devin R 2020-04-13 10:21:44 -04:00
commit a0973dec85
8 changed files with 190 additions and 154 deletions

View File

@ -12,7 +12,7 @@ version = "0.1.0"
[features] [features]
default = [] default = []
encryption = ["olm-rs", "serde/derive", "serde_json", "cjson"] encryption = ["olm-rs", "serde/derive", "serde_json", "cjson", "zeroize"]
sqlite-cryptostore = ["sqlx", "zeroize"] sqlite-cryptostore = ["sqlx", "zeroize"]
[dependencies] [dependencies]
@ -35,7 +35,7 @@ olm-rs = { git = "https://gitlab.gnome.org/poljar/olm-rs", optional = true, feat
serde = { version = "1.0.106", optional = true, features = ["derive"] } serde = { version = "1.0.106", optional = true, features = ["derive"] }
serde_json = { version = "1.0.51", optional = true } serde_json = { version = "1.0.51", optional = true }
cjson = { version = "0.1.0", optional = true } cjson = { version = "0.1.0", optional = true }
zeroize = { version = "1.1.0", optional = true } zeroize = { version = "1.1.0", optional = true, features = ["zeroize_derive"] }
# Misc dependencies # Misc dependencies
thiserror = "1.0.14" thiserror = "1.0.14"

View File

@ -1,7 +1,7 @@
[![Build Status](https://img.shields.io/travis/matrix-org/matrix-rust-sdk.svg?style=flat-square)](https://travis-ci.org/matrix-org/matrix-rust-sdk) [![Build Status](https://img.shields.io/travis/matrix-org/matrix-rust-sdk.svg?style=flat-square)](https://travis-ci.org/matrix-org/matrix-rust-sdk)
[![codecov](https://img.shields.io/codecov/c/github/matrix-org/matrix-rust-sdk/master.svg?style=flat-square)](https://codecov.io/gh/matrix-org/matrix-rust-sdk) [![codecov](https://img.shields.io/codecov/c/github/matrix-org/matrix-rust-sdk/master.svg?style=flat-square)](https://codecov.io/gh/matrix-org/matrix-rust-sdk)
[![License](https://img.shields.io/badge/License-Apache%202.0-yellowgreen.svg?style=flat-square)](https://opensource.org/licenses/Apache-2.0) [![License](https://img.shields.io/badge/License-Apache%202.0-yellowgreen.svg?style=flat-square)](https://opensource.org/licenses/Apache-2.0)
[![#matrix-rust-sdk](https://img.shields.io/badge/matrix-%23matrix--rust--sdk-blue?style=flat-square)](https://matrix.to/#/!iYnZafYUoXkeVPOSQh:matrix.org?via=matrix.org&via=matrix.ffslfl.net&via=raim.ist) [![#matrix-rust-sdk](https://img.shields.io/badge/matrix-%23matrix--rust--sdk-blue?style=flat-square)](https://matrix.to/#/#matrix-rust-sdk:matrix.org)
# matrix-rust-sdk # matrix-rust-sdk

View File

@ -14,6 +14,7 @@
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::convert::TryInto; use std::convert::TryInto;
use std::mem;
#[cfg(feature = "sqlite-cryptostore")] #[cfg(feature = "sqlite-cryptostore")]
use std::path::Path; use std::path::Path;
use std::result::Result as StdResult; use std::result::Result as StdResult;
@ -68,7 +69,7 @@ pub struct OlmMachine {
/// The unique device id of the device that holds this account. /// The unique device id of the device that holds this account.
device_id: DeviceId, device_id: DeviceId,
/// Our underlying Olm Account holding our identity keys. /// Our underlying Olm Account holding our identity keys.
account: Arc<Mutex<Account>>, account: Account,
/// The number of signed one-time keys we have uploaded to the server. If /// The number of signed one-time keys we have uploaded to the server. If
/// this is None, no action will be taken. After a sync request the client /// this is None, no action will be taken. After a sync request the client
/// needs to set this for us, depending on the count we will suggest the /// needs to set this for us, depending on the count we will suggest the
@ -98,7 +99,7 @@ impl OlmMachine {
Ok(OlmMachine { Ok(OlmMachine {
user_id: user_id.clone(), user_id: user_id.clone(),
device_id: device_id.to_owned(), device_id: device_id.to_owned(),
account: Arc::new(Mutex::new(Account::new())), account: Account::new(),
uploaded_signed_key_count: None, uploaded_signed_key_count: None,
store: Box::new(MemoryStore::new()), store: Box::new(MemoryStore::new()),
users_for_key_query: HashSet::new(), users_for_key_query: HashSet::new(),
@ -132,7 +133,7 @@ impl OlmMachine {
Ok(OlmMachine { Ok(OlmMachine {
user_id: user_id.clone(), user_id: user_id.clone(),
device_id: device_id.to_owned(), device_id: device_id.to_owned(),
account: Arc::new(Mutex::new(account)), account,
uploaded_signed_key_count: None, uploaded_signed_key_count: None,
store: Box::new(store), store: Box::new(store),
users_for_key_query: HashSet::new(), users_for_key_query: HashSet::new(),
@ -142,7 +143,7 @@ impl OlmMachine {
/// Should account or one-time keys be uploaded to the server. /// Should account or one-time keys be uploaded to the server.
pub async fn should_upload_keys(&self) -> bool { pub async fn should_upload_keys(&self) -> bool {
if !self.account.lock().await.shared() { if !self.account.shared() {
return true; return true;
} }
@ -150,7 +151,7 @@ impl OlmMachine {
// max_one_time_Keys() / 2, otherwise tell the client to upload more. // max_one_time_Keys() / 2, otherwise tell the client to upload more.
match self.uploaded_signed_key_count { match self.uploaded_signed_key_count {
Some(count) => { Some(count) => {
let max_keys = self.account.lock().await.max_one_time_keys() as u64; let max_keys = self.account.max_one_time_keys().await as u64;
let key_count = (max_keys / 2) - count; let key_count = (max_keys / 2) - count;
key_count > 0 key_count > 0
} }
@ -169,11 +170,10 @@ impl OlmMachine {
&mut self, &mut self,
response: &keys::upload_keys::Response, response: &keys::upload_keys::Response,
) -> Result<()> { ) -> Result<()> {
let mut account = self.account.lock().await; if !self.account.shared() {
if !account.shared {
debug!("Marking account as shared"); debug!("Marking account as shared");
} }
account.shared = true; self.account.mark_as_shared();
let one_time_key_count = response let one_time_key_count = response
.one_time_key_counts .one_time_key_counts
@ -187,9 +187,7 @@ impl OlmMachine {
); );
self.uploaded_signed_key_count = Some(count); self.uploaded_signed_key_count = Some(count);
account.mark_keys_as_published(); self.account.mark_keys_as_published().await;
drop(account);
self.store.save_account(self.account.clone()).await?; self.store.save_account(self.account.clone()).await?;
Ok(()) Ok(())
@ -317,9 +315,8 @@ impl OlmMachine {
let session = match self let session = match self
.account .account
.lock()
.await
.create_outbound_session(curve_key, &one_time_key) .create_outbound_session(curve_key, &one_time_key)
.await
{ {
Ok(s) => s, Ok(s) => s,
Err(e) => { Err(e) => {
@ -441,10 +438,9 @@ impl OlmMachine {
/// Returns the number of newly generated one-time keys. If no keys can be /// Returns the number of newly generated one-time keys. If no keys can be
/// generated returns an empty error. /// generated returns an empty error.
async fn generate_one_time_keys(&self) -> StdResult<u64, ()> { async fn generate_one_time_keys(&self) -> StdResult<u64, ()> {
let account = self.account.lock().await;
match self.uploaded_signed_key_count { match self.uploaded_signed_key_count {
Some(count) => { Some(count) => {
let max_keys = account.max_one_time_keys() as u64; let max_keys = self.account.max_one_time_keys().await as u64;
let max_on_server = max_keys / 2; let max_on_server = max_keys / 2;
if count >= (max_on_server) { if count >= (max_on_server) {
@ -453,11 +449,11 @@ impl OlmMachine {
let key_count = (max_on_server) - count; let key_count = (max_on_server) - count;
let key_count: usize = key_count let max_keys = self.account.max_one_time_keys().await;
.try_into()
.unwrap_or_else(|_| account.max_one_time_keys());
account.generate_one_time_keys(key_count); let key_count: usize = key_count.try_into().unwrap_or(max_keys);
self.account.generate_one_time_keys(key_count).await;
Ok(key_count as u64) Ok(key_count as u64)
} }
None => Err(()), None => Err(()),
@ -466,7 +462,7 @@ impl OlmMachine {
/// Sign the device keys and return a JSON Value to upload them. /// Sign the device keys and return a JSON Value to upload them.
async fn device_keys(&self) -> DeviceKeys { async fn device_keys(&self) -> DeviceKeys {
let identity_keys = self.account.lock().await.identity_keys(); let identity_keys = self.account.identity_keys();
let mut keys = HashMap::new(); let mut keys = HashMap::new();
@ -513,7 +509,7 @@ impl OlmMachine {
/// If no one-time keys need to be uploaded returns an empty error. /// If no one-time keys need to be uploaded returns an empty error.
async fn signed_one_time_keys(&self) -> StdResult<OneTimeKeys, ()> { async fn signed_one_time_keys(&self) -> StdResult<OneTimeKeys, ()> {
let _ = self.generate_one_time_keys().await?; let _ = self.generate_one_time_keys().await?;
let one_time_keys = self.account.lock().await.one_time_keys(); let one_time_keys = self.account.one_time_keys().await;
let mut one_time_key_map = HashMap::new(); let mut one_time_key_map = HashMap::new();
for (key_id, key) in one_time_keys.curve25519().iter() { for (key_id, key) in one_time_keys.curve25519().iter() {
@ -555,10 +551,9 @@ impl OlmMachine {
/// * `json` - The value that should be converted into a canonical JSON /// * `json` - The value that should be converted into a canonical JSON
/// string. /// string.
async fn sign_json(&self, json: &Value) -> String { async fn sign_json(&self, json: &Value) -> String {
let account = self.account.lock().await;
let canonical_json = cjson::to_string(json) let canonical_json = cjson::to_string(json)
.unwrap_or_else(|_| panic!(format!("Can't serialize {} to canonical JSON", json))); .unwrap_or_else(|_| panic!(format!("Can't serialize {} to canonical JSON", json)));
account.sign(&canonical_json) self.account.sign(&canonical_json).await
} }
/// Verify a signed JSON object. /// Verify a signed JSON object.
@ -637,7 +632,7 @@ impl OlmMachine {
return Err(()); return Err(());
} }
let shared = self.account.lock().await.shared(); let shared = self.account.shared();
let device_keys = if !shared { let device_keys = if !shared {
Some(self.device_keys().await) Some(self.device_keys().await)
@ -702,8 +697,12 @@ impl OlmMachine {
let mut session = match &message { let mut session = match &message {
OlmMessage::Message(_) => return Err(OlmError::SessionWedged), OlmMessage::Message(_) => return Err(OlmError::SessionWedged),
OlmMessage::PreKey(m) => { OlmMessage::PreKey(m) => {
let account = self.account.lock().await; let session = self
account.create_inbound_session(sender_key, m.clone())? .account
.create_inbound_session(sender_key, m.clone())
.await?;
self.store.save_account(self.account.clone()).await?;
session
} }
}; };
@ -740,7 +739,7 @@ impl OlmMachine {
return Err(OlmError::UnsupportedAlgorithm); return Err(OlmError::UnsupportedAlgorithm);
}; };
let identity_keys = self.account.lock().await.identity_keys(); let identity_keys = self.account.identity_keys();
let own_key = identity_keys.curve25519(); let own_key = identity_keys.curve25519();
let own_ciphertext = content.ciphertext.get(own_key); let own_ciphertext = content.ciphertext.get(own_key);
@ -753,11 +752,11 @@ impl OlmMachine {
OlmMessage::from_type_and_ciphertext(message_type.into(), ciphertext.body.clone()) OlmMessage::from_type_and_ciphertext(message_type.into(), ciphertext.body.clone())
.map_err(|_| OlmError::UnsupportedOlmType)?; .map_err(|_| OlmError::UnsupportedOlmType)?;
let decrypted_event = self let mut decrypted_event = self
.decrypt_olm_message(&event.sender.to_string(), &content.sender_key, message) .decrypt_olm_message(&event.sender.to_string(), &content.sender_key, message)
.await?; .await?;
debug!("Decrypted a to-device event {:?}", decrypted_event); debug!("Decrypted a to-device event {:?}", decrypted_event);
self.handle_decrypted_to_device_event(&content.sender_key, &decrypted_event) self.handle_decrypted_to_device_event(&content.sender_key, &mut decrypted_event)
.await?; .await?;
Ok(decrypted_event) Ok(decrypted_event)
@ -767,7 +766,7 @@ impl OlmMachine {
} }
} }
async fn add_room_key(&mut self, sender_key: &str, event: &ToDeviceRoomKey) -> Result<()> { async fn add_room_key(&mut self, sender_key: &str, event: &mut ToDeviceRoomKey) -> Result<()> {
match event.content.algorithm { match event.content.algorithm {
Algorithm::MegolmV1AesSha2 => { Algorithm::MegolmV1AesSha2 => {
// TODO check for all the valid fields. // TODO check for all the valid fields.
@ -776,7 +775,7 @@ impl OlmMachine {
.get("ed25519") .get("ed25519")
.ok_or(OlmError::MissingSigningKey)?; .ok_or(OlmError::MissingSigningKey)?;
let session_key = GroupSessionKey(event.content.session_key.to_owned()); let session_key = GroupSessionKey(mem::take(&mut event.content.session_key));
let session = InboundGroupSession::new( let session = InboundGroupSession::new(
sender_key, sender_key,
@ -799,8 +798,7 @@ impl OlmMachine {
async fn create_outbound_group_session(&mut self, room_id: &RoomId) -> Result<()> { async fn create_outbound_group_session(&mut self, room_id: &RoomId) -> Result<()> {
let session = OutboundGroupSession::new(room_id); let session = OutboundGroupSession::new(room_id);
let account = self.account.lock().await; let identity_keys = self.account.identity_keys();
let identity_keys = account.identity_keys();
let sender_key = identity_keys.curve25519(); let sender_key = identity_keys.curve25519();
let signing_key = identity_keys.ed25519(); let signing_key = identity_keys.ed25519();
@ -855,13 +853,7 @@ impl OlmMachine {
Ok(MegolmV1AesSha2Content { Ok(MegolmV1AesSha2Content {
algorithm: Algorithm::MegolmV1AesSha2, algorithm: Algorithm::MegolmV1AesSha2,
ciphertext, ciphertext,
sender_key: self sender_key: self.account.identity_keys().curve25519().to_owned(),
.account
.lock()
.await
.identity_keys()
.curve25519()
.to_owned(),
session_id: session.session_id().to_owned(), session_id: session.session_id().to_owned(),
device_id: self.device_id.to_owned(), device_id: self.device_id.to_owned(),
}) })
@ -874,7 +866,7 @@ impl OlmMachine {
event_type: EventType, event_type: EventType,
content: Value, content: Value,
) -> Result<OlmV1Curve25519AesSha2Content> { ) -> Result<OlmV1Curve25519AesSha2Content> {
let identity_keys = self.account.lock().await.identity_keys(); let identity_keys = self.account.identity_keys();
let recipient_signing_key = recipient_device let recipient_signing_key = recipient_device
.keys(&KeyAlgorithm::Ed25519) .keys(&KeyAlgorithm::Ed25519)
@ -1047,7 +1039,7 @@ impl OlmMachine {
async fn handle_decrypted_to_device_event( async fn handle_decrypted_to_device_event(
&mut self, &mut self,
sender_key: &str, sender_key: &str,
event: &EventResult<ToDeviceEvent>, event: &mut EventResult<ToDeviceEvent>,
) -> Result<()> { ) -> Result<()> {
let event = if let EventResult::Ok(e) = event { let event = if let EventResult::Ok(e) = event {
e e
@ -1150,7 +1142,7 @@ impl OlmMachine {
// TODO check if the olm session is wedged and re-request the key. // TODO check if the olm session is wedged and re-request the key.
let session = session.ok_or(OlmError::MissingSession)?; let session = session.ok_or(OlmError::MissingSession)?;
let (plaintext, _) = session.lock().await.decrypt(content.ciphertext.clone())?; let (plaintext, _) = session.decrypt(content.ciphertext.clone()).await?;
// TODO check the message index. // TODO check the message index.
// TODO check if this is from a verified device. // TODO check if this is from a verified device.
@ -1326,7 +1318,7 @@ mod test {
let machine = OlmMachine::new(&user_id(), DEVICE_ID).unwrap(); let machine = OlmMachine::new(&user_id(), DEVICE_ID).unwrap();
let mut device_keys = machine.device_keys().await; let mut device_keys = machine.device_keys().await;
let identity_keys = machine.account.lock().await.identity_keys(); let identity_keys = machine.account.identity_keys();
let ed25519_key = identity_keys.ed25519(); let ed25519_key = identity_keys.ed25519();
let ret = machine.verify_json( let ret = machine.verify_json(
@ -1359,7 +1351,7 @@ mod test {
machine.uploaded_signed_key_count = Some(49); machine.uploaded_signed_key_count = Some(49);
let mut one_time_keys = machine.signed_one_time_keys().await.unwrap(); let mut one_time_keys = machine.signed_one_time_keys().await.unwrap();
let identity_keys = machine.account.lock().await.identity_keys(); let identity_keys = machine.account.identity_keys();
let ed25519_key = identity_keys.ed25519(); let ed25519_key = identity_keys.ed25519();
let mut one_time_key = one_time_keys.values_mut().nth(0).unwrap(); let mut one_time_key = one_time_keys.values_mut().nth(0).unwrap();
@ -1378,7 +1370,7 @@ mod test {
let mut machine = OlmMachine::new(&user_id(), DEVICE_ID).unwrap(); let mut machine = OlmMachine::new(&user_id(), DEVICE_ID).unwrap();
machine.uploaded_signed_key_count = Some(0); machine.uploaded_signed_key_count = Some(0);
let identity_keys = machine.account.lock().await.identity_keys(); let identity_keys = machine.account.identity_keys();
let ed25519_key = identity_keys.ed25519(); let ed25519_key = identity_keys.ed25519();
let (device_keys, mut one_time_keys) = machine let (device_keys, mut one_time_keys) = machine

View File

@ -60,7 +60,7 @@ impl SessionStore {
#[derive(Debug)] #[derive(Debug)]
pub struct GroupSessionStore { pub struct GroupSessionStore {
entries: HashMap<RoomId, HashMap<String, HashMap<String, Arc<Mutex<InboundGroupSession>>>>>, entries: HashMap<RoomId, HashMap<String, HashMap<String, InboundGroupSession>>>,
} }
impl GroupSessionStore { impl GroupSessionStore {
@ -72,18 +72,19 @@ impl GroupSessionStore {
pub fn add(&mut self, session: InboundGroupSession) -> bool { pub fn add(&mut self, session: InboundGroupSession) -> bool {
if !self.entries.contains_key(&session.room_id) { if !self.entries.contains_key(&session.room_id) {
self.entries let room_id = &*session.room_id;
.insert(session.room_id.to_owned(), HashMap::new()); self.entries.insert(room_id.clone(), HashMap::new());
} }
let room_map = self.entries.get_mut(&session.room_id).unwrap(); let room_map = self.entries.get_mut(&session.room_id).unwrap();
if !room_map.contains_key(&session.sender_key) { if !room_map.contains_key(&*session.sender_key) {
room_map.insert(session.sender_key.to_owned(), HashMap::new()); let sender_key = &*session.sender_key;
room_map.insert(sender_key.to_owned(), HashMap::new());
} }
let sender_map = room_map.get_mut(&session.sender_key).unwrap(); let sender_map = room_map.get_mut(&*session.sender_key).unwrap();
let ret = sender_map.insert(session.session_id(), Arc::new(Mutex::new(session))); let ret = sender_map.insert(session.session_id().to_owned(), session);
ret.is_some() ret.is_some()
} }
@ -93,7 +94,7 @@ impl GroupSessionStore {
room_id: &RoomId, room_id: &RoomId,
sender_key: &str, sender_key: &str,
session_id: &str, session_id: &str,
) -> Option<Arc<Mutex<InboundGroupSession>>> { ) -> Option<InboundGroupSession> {
self.entries self.entries
.get(room_id) .get(room_id)
.and_then(|m| m.get(sender_key).and_then(|m| m.get(session_id).cloned())) .and_then(|m| m.get(sender_key).and_then(|m| m.get(session_id).cloned()))

View File

@ -19,6 +19,7 @@ use std::time::Instant;
use serde::Serialize; use serde::Serialize;
use tokio::sync::Mutex; use tokio::sync::Mutex;
use zeroize::Zeroize;
use olm_rs::account::{IdentityKeys, OlmAccount, OneTimeKeys}; use olm_rs::account::{IdentityKeys, OlmAccount, OneTimeKeys};
use olm_rs::errors::{OlmAccountError, OlmGroupSessionError, OlmSessionError}; use olm_rs::errors::{OlmAccountError, OlmGroupSessionError, OlmSessionError};
@ -33,9 +34,11 @@ use crate::identifiers::RoomId;
/// The Olm account. /// The Olm account.
/// An account is the central identity for encrypted communication between two /// An account is the central identity for encrypted communication between two
/// devices. It holds the two identity key pairs for a device. /// devices. It holds the two identity key pairs for a device.
#[derive(Clone)]
pub struct Account { pub struct Account {
inner: OlmAccount, inner: Arc<Mutex<OlmAccount>>,
pub(crate) shared: bool, identity_keys: Arc<IdentityKeys>,
pub(crate) shared: Arc<AtomicBool>,
} }
impl fmt::Debug for Account { impl fmt::Debug for Account {
@ -44,7 +47,7 @@ impl fmt::Debug for Account {
f, f,
"Olm Account: {:?}, shared: {}", "Olm Account: {:?}, shared: {}",
self.identity_keys(), self.identity_keys(),
self.shared self.shared()
) )
} }
} }
@ -52,49 +55,61 @@ impl fmt::Debug for Account {
impl Account { impl Account {
/// Create a new account. /// Create a new account.
pub fn new() -> Self { pub fn new() -> Self {
let account = OlmAccount::new();
let identity_keys = account.parsed_identity_keys();
Account { Account {
inner: OlmAccount::new(), inner: Arc::new(Mutex::new(account)),
shared: false, identity_keys: Arc::new(identity_keys),
shared: Arc::new(AtomicBool::new(false)),
} }
} }
/// Get the public parts of the identity keys for the account. /// Get the public parts of the identity keys for the account.
pub fn identity_keys(&self) -> IdentityKeys { pub fn identity_keys(&self) -> &IdentityKeys {
self.inner.parsed_identity_keys() &self.identity_keys
} }
/// Has the account been shared with the server. /// Has the account been shared with the server.
pub fn shared(&self) -> bool { pub fn shared(&self) -> bool {
self.shared self.shared.load(Ordering::Relaxed)
}
/// Mark the account as shared.
///
/// Messages shouldn't be encrypted with the session before it has been
/// shared.
pub fn mark_as_shared(&self) {
self.shared.store(true, Ordering::Relaxed);
} }
/// Get the one-time keys of the account. /// Get the one-time keys of the account.
/// ///
/// This can be empty, keys need to be generated first. /// This can be empty, keys need to be generated first.
pub fn one_time_keys(&self) -> OneTimeKeys { pub async fn one_time_keys(&self) -> OneTimeKeys {
self.inner.parsed_one_time_keys() self.inner.lock().await.parsed_one_time_keys()
} }
/// Generate count number of one-time keys. /// Generate count number of one-time keys.
pub fn generate_one_time_keys(&self, count: usize) { pub async fn generate_one_time_keys(&self, count: usize) {
self.inner.generate_one_time_keys(count); self.inner.lock().await.generate_one_time_keys(count);
} }
/// Get the maximum number of one-time keys the account can hold. /// Get the maximum number of one-time keys the account can hold.
pub fn max_one_time_keys(&self) -> usize { pub async fn max_one_time_keys(&self) -> usize {
self.inner.max_number_of_one_time_keys() self.inner.lock().await.max_number_of_one_time_keys()
} }
/// Mark the current set of one-time keys as being published. /// Mark the current set of one-time keys as being published.
pub fn mark_keys_as_published(&self) { pub async fn mark_keys_as_published(&self) {
self.inner.mark_keys_as_published(); self.inner.lock().await.mark_keys_as_published();
} }
/// Sign the given string using the accounts signing key. /// Sign the given string using the accounts signing key.
/// ///
/// Returns the signature as a base64 encoded string. /// Returns the signature as a base64 encoded string.
pub fn sign(&self, string: &str) -> String { pub async fn sign(&self, string: &str) -> String {
self.inner.sign(string) self.inner.lock().await.sign(string)
} }
/// Store the account as a base64 encoded string. /// Store the account as a base64 encoded string.
@ -103,8 +118,8 @@ impl Account {
/// ///
/// * `pickle_mode` - The mode that was used to pickle the account, either an /// * `pickle_mode` - The mode that was used to pickle the account, either an
/// unencrypted mode or an encrypted using passphrase. /// unencrypted mode or an encrypted using passphrase.
pub fn pickle(&self, pickle_mode: PicklingMode) -> String { pub async fn pickle(&self, pickle_mode: PicklingMode) -> String {
self.inner.pickle(pickle_mode) self.inner.lock().await.pickle(pickle_mode)
} }
/// Restore an account from a previously pickled string. /// Restore an account from a previously pickled string.
@ -123,8 +138,14 @@ impl Account {
pickle_mode: PicklingMode, pickle_mode: PicklingMode,
shared: bool, shared: bool,
) -> Result<Self, OlmAccountError> { ) -> Result<Self, OlmAccountError> {
let acc = OlmAccount::unpickle(pickle, pickle_mode)?; let account = OlmAccount::unpickle(pickle, pickle_mode)?;
Ok(Account { inner: acc, shared }) let identity_keys = account.parsed_identity_keys();
Ok(Account {
inner: Arc::new(Mutex::new(account)),
identity_keys: Arc::new(identity_keys),
shared: Arc::new(AtomicBool::from(shared)),
})
} }
/// Create a new session with another account given a one-time key. /// Create a new session with another account given a one-time key.
@ -137,13 +158,15 @@ impl Account {
/// ///
/// * `their_one_time_key` - A signed one-time key that the other account /// * `their_one_time_key` - A signed one-time key that the other account
/// created and shared with us. /// created and shared with us.
pub fn create_outbound_session( pub async fn create_outbound_session(
&self, &self,
their_identity_key: &str, their_identity_key: &str,
their_one_time_key: &SignedKey, their_one_time_key: &SignedKey,
) -> Result<Session, OlmSessionError> { ) -> Result<Session, OlmSessionError> {
let session = self let session = self
.inner .inner
.lock()
.await
.create_outbound_session(their_identity_key, &their_one_time_key.key)?; .create_outbound_session(their_identity_key, &their_one_time_key.key)?;
let now = Instant::now(); let now = Instant::now();
@ -166,15 +189,25 @@ impl Account {
/// ///
/// * `message` - A pre-key Olm message that was sent to us by the other /// * `message` - A pre-key Olm message that was sent to us by the other
/// account. /// account.
pub fn create_inbound_session( pub async fn create_inbound_session(
&self, &self,
their_identity_key: &str, their_identity_key: &str,
message: PreKeyMessage, message: PreKeyMessage,
) -> Result<Session, OlmSessionError> { ) -> Result<Session, OlmSessionError> {
let session = self let session = self
.inner .inner
.lock()
.await
.create_inbound_session_from(their_identity_key, message)?; .create_inbound_session_from(their_identity_key, message)?;
self.inner
.lock()
.await
.remove_one_time_keys(&session)
.expect(
"Session was successfully created but the account doesn't hold a matching one-time key",
);
let now = Instant::now(); let now = Instant::now();
Ok(Session { Ok(Session {
@ -188,7 +221,7 @@ impl Account {
impl PartialEq for Account { impl PartialEq for Account {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
self.identity_keys() == other.identity_keys() && self.shared == other.shared self.identity_keys() == other.identity_keys() && self.shared() == other.shared()
} }
} }
@ -312,19 +345,22 @@ impl PartialEq for Session {
/// The private session key of a group session. /// The private session key of a group session.
/// Can be used to create a new inbound group session. /// Can be used to create a new inbound group session.
#[derive(Clone, Serialize)] #[derive(Clone, Serialize, Zeroize)]
#[zeroize(drop)]
pub struct GroupSessionKey(pub String); pub struct GroupSessionKey(pub String);
/// Inbound group session. /// Inbound group session.
/// ///
/// Inbound group sessions are used to exchange room messages between a group of /// Inbound group sessions are used to exchange room messages between a group of
/// participants. Inbound group sessions are used to decrypt the room messages. /// participants. Inbound group sessions are used to decrypt the room messages.
#[derive(Clone)]
pub struct InboundGroupSession { pub struct InboundGroupSession {
inner: OlmInboundGroupSession, inner: Arc<Mutex<OlmInboundGroupSession>>,
pub(crate) sender_key: String, session_id: Arc<String>,
pub(crate) signing_key: String, pub(crate) sender_key: Arc<String>,
pub(crate) room_id: RoomId, pub(crate) signing_key: Arc<String>,
forwarding_chains: Option<Vec<String>>, pub(crate) room_id: Arc<RoomId>,
forwarding_chains: Arc<Mutex<Option<Vec<String>>>>,
} }
impl InboundGroupSession { impl InboundGroupSession {
@ -350,12 +386,16 @@ impl InboundGroupSession {
room_id: &RoomId, room_id: &RoomId,
session_key: GroupSessionKey, session_key: GroupSessionKey,
) -> Result<Self, OlmGroupSessionError> { ) -> Result<Self, OlmGroupSessionError> {
let session = OlmInboundGroupSession::new(&session_key.0)?;
let session_id = session.session_id();
Ok(InboundGroupSession { Ok(InboundGroupSession {
inner: OlmInboundGroupSession::new(&session_key.0)?, inner: Arc::new(Mutex::new(session)),
sender_key: sender_key.to_owned(), session_id: Arc::new(session_id),
signing_key: signing_key.to_owned(), sender_key: Arc::new(sender_key.to_owned()),
room_id: room_id.clone(), signing_key: Arc::new(signing_key.to_owned()),
forwarding_chains: None, room_id: Arc::new(room_id.clone()),
forwarding_chains: Arc::new(Mutex::new(None)),
}) })
} }
@ -365,8 +405,8 @@ impl InboundGroupSession {
/// ///
/// * `pickle_mode` - The mode that was used to pickle the group session, /// * `pickle_mode` - The mode that was used to pickle the group session,
/// either an unencrypted mode or an encrypted using passphrase. /// either an unencrypted mode or an encrypted using passphrase.
pub fn pickle(&self, pickle_mode: PicklingMode) -> String { pub async fn pickle(&self, pickle_mode: PicklingMode) -> String {
self.inner.pickle(pickle_mode) self.inner.lock().await.pickle(pickle_mode)
} }
/// Restore a Session from a previously pickled string. /// Restore a Session from a previously pickled string.
@ -396,23 +436,26 @@ impl InboundGroupSession {
room_id: RoomId, room_id: RoomId,
) -> Result<Self, OlmGroupSessionError> { ) -> Result<Self, OlmGroupSessionError> {
let session = OlmInboundGroupSession::unpickle(pickle, pickle_mode)?; let session = OlmInboundGroupSession::unpickle(pickle, pickle_mode)?;
let session_id = session.session_id();
Ok(InboundGroupSession { Ok(InboundGroupSession {
inner: session, inner: Arc::new(Mutex::new(session)),
sender_key, session_id: Arc::new(session_id),
signing_key, sender_key: Arc::new(sender_key),
room_id, signing_key: Arc::new(signing_key),
forwarding_chains: None, room_id: Arc::new(room_id),
forwarding_chains: Arc::new(Mutex::new(None)),
}) })
} }
/// Returns the unique identifier for this session. /// Returns the unique identifier for this session.
pub fn session_id(&self) -> String { pub fn session_id(&self) -> &str {
self.inner.session_id() &self.session_id
} }
/// Get the first message index we know how to decrypt. /// Get the first message index we know how to decrypt.
pub fn first_known_index(&self) -> u32 { pub async fn first_known_index(&self) -> u32 {
self.inner.first_known_index() self.inner.lock().await.first_known_index()
} }
/// Decrypt the given ciphertext. /// Decrypt the given ciphertext.
@ -423,8 +466,8 @@ impl InboundGroupSession {
/// # Arguments /// # Arguments
/// ///
/// * `message` - The message that should be decrypted. /// * `message` - The message that should be decrypted.
pub fn decrypt(&self, message: String) -> Result<(String, u32), OlmGroupSessionError> { pub async fn decrypt(&self, message: String) -> Result<(String, u32), OlmGroupSessionError> {
self.inner.decrypt(message) self.inner.lock().await.decrypt(message)
} }
} }
@ -566,16 +609,16 @@ mod test {
assert!(!identyty_keys.curve25519().is_empty()); assert!(!identyty_keys.curve25519().is_empty());
} }
#[test] #[tokio::test]
fn one_time_keys_creation() { async fn one_time_keys_creation() {
let account = Account::new(); let account = Account::new();
let one_time_keys = account.one_time_keys(); let one_time_keys = account.one_time_keys().await;
assert!(one_time_keys.curve25519().is_empty()); assert!(one_time_keys.curve25519().is_empty());
assert_ne!(account.max_one_time_keys(), 0); assert_ne!(account.max_one_time_keys().await, 0);
account.generate_one_time_keys(10); account.generate_one_time_keys(10).await;
let one_time_keys = account.one_time_keys(); let one_time_keys = account.one_time_keys().await;
assert!(!one_time_keys.curve25519().is_empty()); assert!(!one_time_keys.curve25519().is_empty());
assert_ne!(one_time_keys.values().len(), 0); assert_ne!(one_time_keys.values().len(), 0);
@ -588,21 +631,19 @@ mod test {
one_time_keys.get("curve25519").unwrap() one_time_keys.get("curve25519").unwrap()
); );
account.mark_keys_as_published(); account.mark_keys_as_published().await;
let one_time_keys = account.one_time_keys(); let one_time_keys = account.one_time_keys().await;
assert!(one_time_keys.curve25519().is_empty()); assert!(one_time_keys.curve25519().is_empty());
} }
#[test] #[tokio::test]
fn session_creation() { async fn session_creation() {
let alice = Account::new(); let alice = Account::new();
let bob = Account::new(); let bob = Account::new();
let alice_keys = alice.identity_keys(); let alice_keys = alice.identity_keys();
let one_time_keys = alice.one_time_keys(); alice.generate_one_time_keys(1).await;
let one_time_keys = alice.one_time_keys().await;
alice.generate_one_time_keys(1); alice.mark_keys_as_published().await;
let one_time_keys = alice.one_time_keys();
alice.mark_keys_as_published();
let one_time_key = one_time_keys let one_time_key = one_time_keys
.curve25519() .curve25519()
@ -619,6 +660,7 @@ mod test {
let mut bob_session = bob let mut bob_session = bob
.create_outbound_session(alice_keys.curve25519(), &one_time_key) .create_outbound_session(alice_keys.curve25519(), &one_time_key)
.await
.unwrap(); .unwrap();
let plaintext = "Hello world"; let plaintext = "Hello world";
@ -633,6 +675,7 @@ mod test {
let bob_keys = bob.identity_keys(); let bob_keys = bob.identity_keys();
let mut alice_session = alice let mut alice_session = alice
.create_inbound_session(bob_keys.curve25519(), prekey_message) .create_inbound_session(bob_keys.curve25519(), prekey_message)
.await
.unwrap(); .unwrap();
assert_eq!(bob_session.session_id(), alice_session.session_id()); assert_eq!(bob_session.session_id(), alice_session.session_id());

View File

@ -48,7 +48,7 @@ impl CryptoStore for MemoryStore {
Ok(None) Ok(None)
} }
async fn save_account(&mut self, _: Arc<Mutex<Account>>) -> Result<()> { async fn save_account(&mut self, _: Account) -> Result<()> {
Ok(()) Ok(())
} }
@ -77,7 +77,7 @@ impl CryptoStore for MemoryStore {
room_id: &RoomId, room_id: &RoomId,
sender_key: &str, sender_key: &str,
session_id: &str, session_id: &str,
) -> Result<Option<Arc<Mutex<InboundGroupSession>>>> { ) -> Result<Option<InboundGroupSession>> {
Ok(self Ok(self
.inbound_group_sessions .inbound_group_sessions
.get(room_id, sender_key, session_id)) .get(room_id, sender_key, session_id))

View File

@ -66,22 +66,26 @@ pub type Result<T> = std::result::Result<T, CryptoStoreError>;
#[async_trait] #[async_trait]
pub trait CryptoStore: Debug + Send + Sync { pub trait CryptoStore: Debug + Send + Sync {
async fn load_account(&mut self) -> Result<Option<Account>>; async fn load_account(&mut self) -> Result<Option<Account>>;
async fn save_account(&mut self, account: Arc<Mutex<Account>>) -> Result<()>; async fn save_account(&mut self, account: Account) -> Result<()>;
async fn save_session(&mut self, session: Arc<Mutex<Session>>) -> Result<()>; async fn save_session(&mut self, session: Arc<Mutex<Session>>) -> Result<()>;
async fn add_and_save_session(&mut self, session: Session) -> Result<()>; async fn add_and_save_session(&mut self, session: Session) -> Result<()>;
async fn get_sessions( async fn get_sessions(
&mut self, &mut self,
sender_key: &str, sender_key: &str,
) -> Result<Option<Arc<Mutex<Vec<Arc<Mutex<Session>>>>>>>; ) -> Result<Option<Arc<Mutex<Vec<Arc<Mutex<Session>>>>>>>;
async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result<bool>; async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result<bool>;
async fn get_inbound_group_session( async fn get_inbound_group_session(
&mut self, &mut self,
room_id: &RoomId, room_id: &RoomId,
sender_key: &str, sender_key: &str,
session_id: &str, session_id: &str,
) -> Result<Option<Arc<Mutex<InboundGroupSession>>>>; ) -> Result<Option<InboundGroupSession>>;
fn tracked_users(&self) -> &HashSet<UserId>; fn tracked_users(&self) -> &HashSet<UserId>;
async fn add_user_for_tracking(&mut self, user: &UserId) -> Result<bool>; async fn add_user_for_tracking(&mut self, user: &UserId) -> Result<bool>;
async fn save_device(&self, device: Device) -> Result<()>; async fn save_device(&self, device: Device) -> Result<()>;
async fn get_device(&self, user_id: &UserId, device_id: &str) -> Result<Option<Device>>; async fn get_device(&self, user_id: &UserId, device_id: &str) -> Result<Option<Device>>;
async fn get_user_devices(&self, user_id: &UserId) -> Result<UserDevices>; async fn get_user_devices(&self, user_id: &UserId) -> Result<UserDevices>;

View File

@ -288,9 +288,8 @@ impl CryptoStore for SqliteStore {
Ok(result) Ok(result)
} }
async fn save_account(&mut self, account: Arc<Mutex<Account>>) -> Result<()> { async fn save_account(&mut self, account: Account) -> Result<()> {
let acc = account.lock().await; let pickle = account.pickle(self.get_pickle_mode()).await;
let pickle = acc.pickle(self.get_pickle_mode());
let mut connection = self.connection.lock().await; let mut connection = self.connection.lock().await;
query( query(
@ -307,7 +306,7 @@ impl CryptoStore for SqliteStore {
.bind(&*self.user_id.to_string()) .bind(&*self.user_id.to_string())
.bind(&*self.device_id.to_string()) .bind(&*self.device_id.to_string())
.bind(&pickle) .bind(&pickle)
.bind(acc.shared) .bind(account.shared())
.execute(&mut *connection) .execute(&mut *connection)
.await?; .await?;
@ -367,7 +366,7 @@ impl CryptoStore for SqliteStore {
async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result<bool> { async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result<bool> {
let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?; let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?;
let pickle = session.pickle(self.get_pickle_mode()); let pickle = session.pickle(self.get_pickle_mode()).await;
let mut connection = self.connection.lock().await; let mut connection = self.connection.lock().await;
let session_id = session.session_id(); let session_id = session.session_id();
@ -383,9 +382,9 @@ impl CryptoStore for SqliteStore {
) )
.bind(session_id) .bind(session_id)
.bind(account_id) .bind(account_id)
.bind(&session.sender_key) .bind(&*session.sender_key)
.bind(&session.signing_key) .bind(&*session.signing_key)
.bind(&session.room_id.to_string()) .bind(&*session.room_id.to_string())
.bind(&pickle) .bind(&pickle)
.execute(&mut *connection) .execute(&mut *connection)
.await?; .await?;
@ -398,7 +397,7 @@ impl CryptoStore for SqliteStore {
room_id: &RoomId, room_id: &RoomId,
sender_key: &str, sender_key: &str,
session_id: &str, session_id: &str,
) -> Result<Option<Arc<Mutex<InboundGroupSession>>>> { ) -> Result<Option<InboundGroupSession>> {
Ok(self Ok(self
.inbound_group_sessions .inbound_group_sessions
.get(room_id, sender_key, session_id)) .get(room_id, sender_key, session_id))
@ -460,7 +459,7 @@ mod test {
.expect("Can't create store") .expect("Can't create store")
} }
async fn get_loaded_store() -> (Arc<Mutex<Account>>, SqliteStore) { async fn get_loaded_store() -> (Account, SqliteStore) {
let mut store = get_store().await; let mut store = get_store().await;
let account = get_account(); let account = get_account();
store store
@ -471,19 +470,19 @@ mod test {
(account, store) (account, store)
} }
fn get_account() -> Arc<Mutex<Account>> { fn get_account() -> Account {
let account = Account::new(); Account::new()
Arc::new(Mutex::new(account))
} }
fn get_account_and_session() -> (Arc<Mutex<Account>>, Session) { async fn get_account_and_session() -> (Account, Session) {
let alice = Account::new(); let alice = Account::new();
let bob = Account::new(); let bob = Account::new();
bob.generate_one_time_keys(1); bob.generate_one_time_keys(1).await;
let one_time_key = bob let one_time_key = bob
.one_time_keys() .one_time_keys()
.await
.curve25519() .curve25519()
.iter() .iter()
.nth(0) .nth(0)
@ -497,9 +496,10 @@ mod test {
let sender_key = bob.identity_keys().curve25519().to_owned(); let sender_key = bob.identity_keys().curve25519().to_owned();
let session = alice let session = alice
.create_outbound_session(&sender_key, &one_time_key) .create_outbound_session(&sender_key, &one_time_key)
.await
.unwrap(); .unwrap();
(Arc::new(Mutex::new(alice)), session) (alice, session)
} }
#[tokio::test] #[tokio::test]
@ -532,11 +532,10 @@ mod test {
.await .await
.expect("Can't save account"); .expect("Can't save account");
let acc = account.lock().await;
let loaded_account = store.load_account().await.expect("Can't load account"); let loaded_account = store.load_account().await.expect("Can't load account");
let loaded_account = loaded_account.unwrap(); let loaded_account = loaded_account.unwrap();
assert_eq!(*acc, loaded_account); assert_eq!(account, loaded_account);
} }
#[tokio::test] #[tokio::test]
@ -549,7 +548,7 @@ mod test {
.await .await
.expect("Can't save account"); .expect("Can't save account");
account.lock().await.shared = true; account.mark_as_shared();
store store
.save_account(account.clone()) .save_account(account.clone())
@ -558,15 +557,14 @@ mod test {
let loaded_account = store.load_account().await.expect("Can't load account"); let loaded_account = store.load_account().await.expect("Can't load account");
let loaded_account = loaded_account.unwrap(); let loaded_account = loaded_account.unwrap();
let acc = account.lock().await;
assert_eq!(*acc, loaded_account); assert_eq!(account, loaded_account);
} }
#[tokio::test] #[tokio::test]
async fn save_session() { async fn save_session() {
let mut store = get_store().await; let mut store = get_store().await;
let (account, session) = get_account_and_session(); let (account, session) = get_account_and_session().await;
let session = Arc::new(Mutex::new(session)); let session = Arc::new(Mutex::new(session));
assert!(store.save_session(session.clone()).await.is_err()); assert!(store.save_session(session.clone()).await.is_err());
@ -582,7 +580,7 @@ mod test {
#[tokio::test] #[tokio::test]
async fn load_sessions() { async fn load_sessions() {
let mut store = get_store().await; let mut store = get_store().await;
let (account, session) = get_account_and_session(); let (account, session) = get_account_and_session().await;
let session = Arc::new(Mutex::new(session)); let session = Arc::new(Mutex::new(session));
store store
.save_account(account.clone()) .save_account(account.clone())
@ -604,7 +602,7 @@ mod test {
#[tokio::test] #[tokio::test]
async fn add_and_save_session() { async fn add_and_save_session() {
let mut store = get_store().await; let mut store = get_store().await;
let (account, session) = get_account_and_session(); let (account, session) = get_account_and_session().await;
let sender_key = session.sender_key.to_owned(); let sender_key = session.sender_key.to_owned();
let session_id = session.session_id(); let session_id = session.session_id();
@ -625,8 +623,7 @@ mod test {
async fn save_inbound_group_session() { async fn save_inbound_group_session() {
let (account, mut store) = get_loaded_store().await; let (account, mut store) = get_loaded_store().await;
let acc = account.lock().await; let identity_keys = account.identity_keys();
let identity_keys = acc.identity_keys();
let outbound_session = OlmOutboundGroupSession::new(); let outbound_session = OlmOutboundGroupSession::new();
let session = InboundGroupSession::new( let session = InboundGroupSession::new(
identity_keys.curve25519(), identity_keys.curve25519(),
@ -646,8 +643,7 @@ mod test {
async fn load_inbound_group_session() { async fn load_inbound_group_session() {
let (account, mut store) = get_loaded_store().await; let (account, mut store) = get_loaded_store().await;
let acc = account.lock().await; let identity_keys = account.identity_keys();
let identity_keys = acc.identity_keys();
let outbound_session = OlmOutboundGroupSession::new(); let outbound_session = OlmOutboundGroupSession::new();
let session = InboundGroupSession::new( let session = InboundGroupSession::new(
identity_keys.curve25519(), identity_keys.curve25519(),
@ -657,7 +653,7 @@ mod test {
) )
.expect("Can't create session"); .expect("Can't create session");
let session_id = session.session_id(); let session_id = session.session_id().to_owned();
store store
.save_inbound_group_session(session) .save_inbound_group_session(session)