crypto: Move the key count field into the account.

master
Damir Jelić 2020-07-13 15:49:16 +02:00
parent b2ccb61864
commit a7a9ac24ed
3 changed files with 63 additions and 60 deletions

View File

@ -18,7 +18,6 @@ use std::mem;
#[cfg(feature = "sqlite-cryptostore")] #[cfg(feature = "sqlite-cryptostore")]
use std::path::Path; use std::path::Path;
use std::result::Result as StdResult; use std::result::Result as StdResult;
use std::sync::atomic::{AtomicU64, Ordering};
use super::error::{EventError, MegolmError, MegolmResult, OlmError, OlmResult, SignatureError}; use super::error::{EventError, MegolmError, MegolmResult, OlmError, OlmResult, SignatureError};
use super::olm::{ use super::olm::{
@ -67,11 +66,6 @@ pub struct OlmMachine {
device_id: DeviceId, device_id: DeviceId,
/// Our underlying Olm Account holding our identity keys. /// Our underlying Olm Account holding our identity keys.
account: Account, account: Account,
/// The number of signed one-time keys we have uploaded to the server. If
/// this is None, no action will be taken. After a sync request the client
/// needs to set this for us, depending on the count we will suggest the
/// client to upload new keys.
uploaded_signed_key_count: Option<AtomicU64>,
/// Store for the encryption keys. /// Store for the encryption keys.
/// Persists all the encryption keys so a client can resume the session /// Persists all the encryption keys so a client can resume the session
/// without the need to create new keys. /// without the need to create new keys.
@ -108,7 +102,6 @@ impl OlmMachine {
user_id: user_id.clone(), user_id: user_id.clone(),
device_id: device_id.to_owned(), device_id: device_id.to_owned(),
account: Account::new(user_id, &device_id), account: Account::new(user_id, &device_id),
uploaded_signed_key_count: None,
store: Box::new(MemoryStore::new()), store: Box::new(MemoryStore::new()),
outbound_group_sessions: HashMap::new(), outbound_group_sessions: HashMap::new(),
} }
@ -151,7 +144,6 @@ impl OlmMachine {
user_id, user_id,
device_id, device_id,
account, account,
uploaded_signed_key_count: None,
store, store,
outbound_group_sessions: HashMap::new(), outbound_group_sessions: HashMap::new(),
}) })
@ -201,30 +193,24 @@ impl OlmMachine {
return true; return true;
} }
let count = self.account.uploaded_key_count() as u64;
// If we have a known key count, check that we have more than // If we have a known key count, check that we have more than
// max_one_time_Keys() / 2, otherwise tell the client to upload more. // max_one_time_Keys() / 2, otherwise tell the client to upload more.
match &self.uploaded_signed_key_count { let max_keys = self.account.max_one_time_keys().await as u64;
Some(count) => { // If there are more keys already uploaded than max_key / 2
let max_keys = self.account.max_one_time_keys().await as u64; // bail out returning false, this also avoids overflow.
// If there are more keys already uploaded than max_key / 2 if count > (max_keys / 2) {
// bail out returning false, this also avoids overflow. return false;
if count.load(Ordering::Relaxed) > (max_keys / 2) {
return false;
}
let key_count = (max_keys / 2) - count.load(Ordering::Relaxed);
key_count > 0
}
None => false,
} }
let key_count = (max_keys / 2) - count;
key_count > 0
} }
/// Update the count of one-time keys that are currently on the server. /// Update the count of one-time keys that are currently on the server.
fn update_key_count(&mut self, count: u64) { fn update_key_count(&mut self, count: u64) {
match &self.uploaded_signed_key_count { self.account.update_uploaded_key_count(count);
Some(c) => c.store(count, Ordering::Relaxed),
None => self.uploaded_signed_key_count = Some(AtomicU64::new(count)),
}
} }
/// Receive a successful keys upload response. /// Receive a successful keys upload response.
@ -250,9 +236,7 @@ impl OlmMachine {
let count: u64 = one_time_key_count.map_or(0, |c| (*c).into()); let count: u64 = one_time_key_count.map_or(0, |c| (*c).into());
debug!( debug!(
"Updated uploaded one-time key count {} -> {}, marking keys as published", "Updated uploaded one-time key count {} -> {}, marking keys as published",
self.uploaded_signed_key_count self.account.uploaded_key_count(),
.as_ref()
.map_or(0, |c| c.load(Ordering::Relaxed)),
count count
); );
self.update_key_count(count); self.update_key_count(count);
@ -543,26 +527,21 @@ impl OlmMachine {
/// Returns the number of newly generated one-time keys. If no keys can be /// Returns the number of newly generated one-time keys. If no keys can be
/// generated returns an empty error. /// generated returns an empty error.
async fn generate_one_time_keys(&self) -> StdResult<u64, ()> { async fn generate_one_time_keys(&self) -> StdResult<u64, ()> {
match &self.uploaded_signed_key_count { let count = self.account.uploaded_key_count() as u64;
Some(count) => { // TODO if we store the uploaded key count with the Account all
// TODO if we store the uploaded key count with the Account all // this logic could go into the account.
// this logic could go into the account. let max_keys = self.account.max_one_time_keys().await;
let count = count.load(Ordering::Relaxed); let max_on_server = (max_keys as u64) / 2;
let max_keys = self.account.max_one_time_keys().await;
let max_on_server = (max_keys as u64) / 2;
if count >= (max_on_server) { if count >= (max_on_server) {
return Err(()); return Err(());
}
let key_count = (max_on_server) - count;
let key_count: usize = key_count.try_into().unwrap_or(max_keys);
self.account.generate_one_time_keys(key_count).await;
Ok(key_count as u64)
}
None => Err(()),
} }
let key_count = (max_on_server) - count;
let key_count: usize = key_count.try_into().unwrap_or(max_keys);
self.account.generate_one_time_keys(key_count).await;
Ok(key_count as u64)
} }
/// Generate, sign and prepare one-time keys to be uploaded. /// Generate, sign and prepare one-time keys to be uploaded.
@ -1023,8 +1002,6 @@ impl OlmMachine {
// stores the curve key of the device, if we also store the ed25519 key // stores the curve key of the device, if we also store the ed25519 key
// with the session we'll only need to pass in the account to the // with the session we'll only need to pass in the account to the
// session and all of this can live in the session. // session and all of this can live in the session.
//
// Storing a reference to the account is probably not worth the effort.
let recipient_signing_key = recipient_device let recipient_signing_key = recipient_device
.get_key(KeyAlgorithm::Ed25519) .get_key(KeyAlgorithm::Ed25519)
@ -1441,7 +1418,6 @@ mod test {
use std::collections::BTreeMap; use std::collections::BTreeMap;
use std::convert::TryFrom; use std::convert::TryFrom;
use std::convert::TryInto; use std::convert::TryInto;
use std::sync::atomic::AtomicU64;
use std::time::SystemTime; use std::time::SystemTime;
use http::Response; use http::Response;
@ -1514,7 +1490,7 @@ mod test {
async fn get_prepared_machine() -> (OlmMachine, OneTimeKeys) { async fn get_prepared_machine() -> (OlmMachine, OneTimeKeys) {
let mut machine = OlmMachine::new(&user_id(), &alice_device_id()); let mut machine = OlmMachine::new(&user_id(), &alice_device_id());
machine.uploaded_signed_key_count = Some(AtomicU64::new(0)); machine.account.update_uploaded_key_count(0);
let (_, otk) = machine let (_, otk) = machine
.keys_for_upload() .keys_for_upload()
.await .await
@ -1660,7 +1636,6 @@ mod test {
let mut response = keys_upload_response(); let mut response = keys_upload_response();
assert!(machine.should_upload_keys().await); assert!(machine.should_upload_keys().await);
assert!(machine.generate_one_time_keys().await.is_err());
machine machine
.receive_keys_upload_response(&response) .receive_keys_upload_response(&response)
@ -1729,8 +1704,8 @@ mod test {
#[tokio::test] #[tokio::test]
async fn test_one_time_key_signing() { async fn test_one_time_key_signing() {
let mut machine = OlmMachine::new(&user_id(), &alice_device_id()); let machine = OlmMachine::new(&user_id(), &alice_device_id());
machine.uploaded_signed_key_count = Some(AtomicU64::new(49)); machine.account.update_uploaded_key_count(49);
let mut one_time_keys = machine.signed_one_time_keys().await.unwrap(); let mut one_time_keys = machine.signed_one_time_keys().await.unwrap();
let identity_keys = machine.account.identity_keys(); let identity_keys = machine.account.identity_keys();
@ -1750,7 +1725,7 @@ mod test {
#[tokio::test] #[tokio::test]
async fn test_keys_for_upload() { async fn test_keys_for_upload() {
let mut machine = OlmMachine::new(&user_id(), &alice_device_id()); let mut machine = OlmMachine::new(&user_id(), &alice_device_id());
machine.uploaded_signed_key_count = Some(AtomicU64::default()); machine.account.update_uploaded_key_count(0);
let identity_keys = machine.account.identity_keys(); let identity_keys = machine.account.identity_keys();
let ed25519_key = identity_keys.ed25519(); let ed25519_key = identity_keys.ed25519();

View File

@ -13,9 +13,10 @@
// limitations under the License. // limitations under the License.
use matrix_sdk_common::instant::Instant; use matrix_sdk_common::instant::Instant;
use std::convert::TryFrom;
use std::convert::TryInto; use std::convert::TryInto;
use std::fmt; use std::fmt;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::atomic::{AtomicBool, AtomicI64, AtomicUsize, Ordering};
use std::sync::Arc; use std::sync::Arc;
use matrix_sdk_common::locks::Mutex; use matrix_sdk_common::locks::Mutex;
@ -61,6 +62,11 @@ pub struct Account {
inner: Arc<Mutex<OlmAccount>>, inner: Arc<Mutex<OlmAccount>>,
identity_keys: Arc<IdentityKeys>, identity_keys: Arc<IdentityKeys>,
shared: Arc<AtomicBool>, shared: Arc<AtomicBool>,
/// The number of signed one-time keys we have uploaded to the server. If
/// this is None, no action will be taken. After a sync request the client
/// needs to set this for us, depending on the count we will suggest the
/// client to upload new keys.
uploaded_signed_key_count: Arc<AtomicI64>,
} }
// #[cfg_attr(tarpaulin, skip)] // #[cfg_attr(tarpaulin, skip)]
@ -90,6 +96,7 @@ impl Account {
inner: Arc::new(Mutex::new(account)), inner: Arc::new(Mutex::new(account)),
identity_keys: Arc::new(identity_keys), identity_keys: Arc::new(identity_keys),
shared: Arc::new(AtomicBool::new(false)), shared: Arc::new(AtomicBool::new(false)),
uploaded_signed_key_count: Arc::new(AtomicI64::new(0)),
} }
} }
@ -98,6 +105,22 @@ impl Account {
&self.identity_keys &self.identity_keys
} }
/// Update the uploaded key count.
///
/// # Arguments
///
/// * `new_count` - The new count that was reported by the server.
pub fn update_uploaded_key_count(&self, new_count: u64) {
let key_count = i64::try_from(new_count).unwrap_or(i64::MAX);
self.uploaded_signed_key_count
.store(key_count, Ordering::Relaxed);
}
/// Get the currently known uploaded key count.
pub fn uploaded_key_count(&self) -> i64 {
self.uploaded_signed_key_count.load(Ordering::Relaxed)
}
/// Has the account been shared with the server. /// Has the account been shared with the server.
pub fn shared(&self) -> bool { pub fn shared(&self) -> bool {
self.shared.load(Ordering::Relaxed) self.shared.load(Ordering::Relaxed)
@ -165,6 +188,7 @@ impl Account {
pickle: String, pickle: String,
pickle_mode: PicklingMode, pickle_mode: PicklingMode,
shared: bool, shared: bool,
uploaded_signed_key_count: i64,
user_id: &UserId, user_id: &UserId,
device_id: &DeviceId, device_id: &DeviceId,
) -> Result<Self, OlmAccountError> { ) -> Result<Self, OlmAccountError> {
@ -177,6 +201,7 @@ impl Account {
inner: Arc::new(Mutex::new(account)), inner: Arc::new(Mutex::new(account)),
identity_keys: Arc::new(identity_keys), identity_keys: Arc::new(identity_keys),
shared: Arc::new(AtomicBool::from(shared)), shared: Arc::new(AtomicBool::from(shared)),
uploaded_signed_key_count: Arc::new(AtomicI64::new(uploaded_signed_key_count)),
}) })
} }

View File

@ -144,6 +144,7 @@ impl SqliteStore {
"device_id" TEXT NOT NULL, "device_id" TEXT NOT NULL,
"pickle" BLOB NOT NULL, "pickle" BLOB NOT NULL,
"shared" INTEGER NOT NULL, "shared" INTEGER NOT NULL,
"uploaded_key_count" INTEGER NOT NULL,
UNIQUE(user_id,device_id) UNIQUE(user_id,device_id)
); );
"#, "#,
@ -564,8 +565,8 @@ impl CryptoStore for SqliteStore {
async fn load_account(&mut self) -> Result<Option<Account>> { async fn load_account(&mut self) -> Result<Option<Account>> {
let mut connection = self.connection.lock().await; let mut connection = self.connection.lock().await;
let row: Option<(i64, String, bool)> = query_as( let row: Option<(i64, String, bool, i64)> = query_as(
"SELECT id, pickle, shared FROM accounts "SELECT id, pickle, shared, uploaded_key_count FROM accounts
WHERE user_id = ? and device_id = ?", WHERE user_id = ? and device_id = ?",
) )
.bind(self.user_id.as_str()) .bind(self.user_id.as_str())
@ -573,12 +574,13 @@ impl CryptoStore for SqliteStore {
.fetch_optional(&mut *connection) .fetch_optional(&mut *connection)
.await?; .await?;
let result = if let Some((id, pickle, shared)) = row { let result = if let Some((id, pickle, shared, uploaded_key_count)) = row {
self.account_id = Some(id); self.account_id = Some(id);
Some(Account::from_pickle( Some(Account::from_pickle(
pickle, pickle,
self.get_pickle_mode(), self.get_pickle_mode(),
shared, shared,
uploaded_key_count,
&self.user_id, &self.user_id,
&self.device_id, &self.device_id,
)?) )?)
@ -613,8 +615,8 @@ impl CryptoStore for SqliteStore {
query( query(
"INSERT INTO accounts ( "INSERT INTO accounts (
user_id, device_id, pickle, shared user_id, device_id, pickle, shared, uploaded_key_count
) VALUES (?1, ?2, ?3, ?4) ) VALUES (?1, ?2, ?3, ?4, ?5)
ON CONFLICT(user_id, device_id) DO UPDATE SET ON CONFLICT(user_id, device_id) DO UPDATE SET
pickle = excluded.pickle, pickle = excluded.pickle,
shared = excluded.shared shared = excluded.shared
@ -624,6 +626,7 @@ impl CryptoStore for SqliteStore {
.bind(&*self.device_id.to_string()) .bind(&*self.device_id.to_string())
.bind(&pickle) .bind(&pickle)
.bind(account.shared()) .bind(account.shared())
.bind(account.uploaded_key_count())
.execute(&mut *connection) .execute(&mut *connection)
.await?; .await?;