crypto: Use an atomic int for the key count.

This commit is contained in:
Damir Jelić 2020-04-28 15:47:49 +02:00
parent 350578739c
commit cf9ecbd0e8

View file

@ -18,6 +18,7 @@ use std::mem;
#[cfg(feature = "sqlite-cryptostore")]
use std::path::Path;
use std::result::Result as StdResult;
use std::sync::atomic::{AtomicU64, Ordering};
use uuid::Uuid;
use super::error::{OlmError, Result, SignatureError, VerificationResult};
@ -70,7 +71,7 @@ pub struct OlmMachine {
/// this is None, no action will be taken. After a sync request the client
/// needs to set this for us, depending on the count we will suggest the
/// client to upload new keys.
uploaded_signed_key_count: Option<u64>,
uploaded_signed_key_count: Option<AtomicU64>,
/// Store for the encryption keys.
/// Persists all the encrytpion keys so a client can resume the session
/// without the need to create new keys.
@ -170,16 +171,23 @@ impl OlmMachine {
// If we have a known key count, check that we have more than
// max_one_time_Keys() / 2, otherwise tell the client to upload more.
match self.uploaded_signed_key_count {
match &self.uploaded_signed_key_count {
Some(count) => {
let max_keys = self.account.max_one_time_keys().await as u64;
let key_count = (max_keys / 2) - count;
let key_count = (max_keys / 2) - count.load(Ordering::Relaxed);
key_count > 0
}
None => false,
}
}
fn update_key_count(&mut self, count: u64) {
match &self.uploaded_signed_key_count {
Some(c) => c.store(count, Ordering::Relaxed),
None => self.uploaded_signed_key_count = Some(AtomicU64::new(count)),
}
}
/// Receive a successful keys upload response.
///
/// # Arguments
@ -203,10 +211,12 @@ impl OlmMachine {
let count: u64 = one_time_key_count.map_or(0, |c| (*c).into());
debug!(
"Updated uploaded one-time key count {} -> {}, marking keys as published",
self.uploaded_signed_key_count.as_ref().map_or(0, |c| *c),
self.uploaded_signed_key_count
.as_ref()
.map_or(0, |c| c.load(Ordering::Relaxed)),
count
);
self.uploaded_signed_key_count = Some(count);
self.update_key_count(count);
self.account.mark_keys_as_published().await;
self.store.save_account(self.account.clone()).await?;
@ -470,8 +480,9 @@ impl OlmMachine {
/// Returns the number of newly generated one-time keys. If no keys can be
/// generated returns an empty error.
async fn generate_one_time_keys(&self) -> StdResult<u64, ()> {
match self.uploaded_signed_key_count {
match &self.uploaded_signed_key_count {
Some(count) => {
let count = count.load(Ordering::Relaxed);
let max_keys = self.account.max_one_time_keys().await as u64;
let max_on_server = max_keys / 2;
@ -1244,7 +1255,7 @@ impl OlmMachine {
.get(&keys::KeyAlgorithm::SignedCurve25519);
let count: u64 = one_time_key_count.map_or(0, |c| (*c).into());
self.uploaded_signed_key_count = Some(count);
self.update_key_count(count);
for event_result in &mut response.to_device.events {
let event = if let Ok(e) = event_result.deserialize() {
@ -1389,6 +1400,7 @@ mod test {
use std::convert::TryFrom;
use std::fs::File;
use std::io::prelude::*;
use std::sync::atomic::AtomicU64;
use std::time::SystemTime;
use serde_json::json;
@ -1463,7 +1475,7 @@ mod test {
async fn get_prepared_machine() -> (OlmMachine, OneTimeKeys) {
let mut machine = OlmMachine::new(&user_id(), DEVICE_ID).unwrap();
machine.uploaded_signed_key_count = Some(0);
machine.uploaded_signed_key_count = Some(AtomicU64::new(0));
let (_, otk) = machine
.keys_for_upload()
.await
@ -1666,7 +1678,7 @@ mod test {
#[tokio::test]
async fn test_one_time_key_signing() {
let mut machine = OlmMachine::new(&user_id(), DEVICE_ID).unwrap();
machine.uploaded_signed_key_count = Some(49);
machine.uploaded_signed_key_count = Some(AtomicU64::new(49));
let mut one_time_keys = machine.signed_one_time_keys().await.unwrap();
let identity_keys = machine.account.identity_keys();
@ -1686,7 +1698,7 @@ mod test {
#[tokio::test]
async fn test_keys_for_upload() {
let mut machine = OlmMachine::new(&user_id(), DEVICE_ID).unwrap();
machine.uploaded_signed_key_count = Some(0);
machine.uploaded_signed_key_count = Some(AtomicU64::default());
let identity_keys = machine.account.identity_keys();
let ed25519_key = identity_keys.ed25519();