diff --git a/matrix_sdk_crypto/src/machine.rs b/matrix_sdk_crypto/src/machine.rs index c1dff76b..7776c709 100644 --- a/matrix_sdk_crypto/src/machine.rs +++ b/matrix_sdk_crypto/src/machine.rs @@ -18,7 +18,6 @@ use std::mem; #[cfg(feature = "sqlite-cryptostore")] use std::path::Path; use std::result::Result as StdResult; -use std::sync::atomic::{AtomicU64, Ordering}; use super::error::{EventError, MegolmError, MegolmResult, OlmError, OlmResult, SignatureError}; use super::olm::{ @@ -67,11 +66,6 @@ pub struct OlmMachine { device_id: DeviceId, /// Our underlying Olm Account holding our identity keys. account: Account, - /// The number of signed one-time keys we have uploaded to the server. If - /// this is None, no action will be taken. After a sync request the client - /// needs to set this for us, depending on the count we will suggest the - /// client to upload new keys. - uploaded_signed_key_count: Option, /// Store for the encryption keys. /// Persists all the encryption keys so a client can resume the session /// without the need to create new keys. @@ -108,7 +102,6 @@ impl OlmMachine { user_id: user_id.clone(), device_id: device_id.to_owned(), account: Account::new(user_id, &device_id), - uploaded_signed_key_count: None, store: Box::new(MemoryStore::new()), outbound_group_sessions: HashMap::new(), } @@ -151,7 +144,6 @@ impl OlmMachine { user_id, device_id, account, - uploaded_signed_key_count: None, store, outbound_group_sessions: HashMap::new(), }) @@ -201,30 +193,24 @@ impl OlmMachine { return true; } + let count = self.account.uploaded_key_count() as u64; + // If we have a known key count, check that we have more than // max_one_time_Keys() / 2, otherwise tell the client to upload more. - match &self.uploaded_signed_key_count { - Some(count) => { - let max_keys = self.account.max_one_time_keys().await as u64; - // If there are more keys already uploaded than max_key / 2 - // bail out returning false, this also avoids overflow. - if count.load(Ordering::Relaxed) > (max_keys / 2) { - return false; - } - - let key_count = (max_keys / 2) - count.load(Ordering::Relaxed); - key_count > 0 - } - None => false, + let max_keys = self.account.max_one_time_keys().await as u64; + // If there are more keys already uploaded than max_key / 2 + // bail out returning false, this also avoids overflow. + if count > (max_keys / 2) { + return false; } + + let key_count = (max_keys / 2) - count; + key_count > 0 } /// Update the count of one-time keys that are currently on the server. fn update_key_count(&mut self, count: u64) { - match &self.uploaded_signed_key_count { - Some(c) => c.store(count, Ordering::Relaxed), - None => self.uploaded_signed_key_count = Some(AtomicU64::new(count)), - } + self.account.update_uploaded_key_count(count); } /// Receive a successful keys upload response. @@ -250,9 +236,7 @@ impl OlmMachine { let count: u64 = one_time_key_count.map_or(0, |c| (*c).into()); debug!( "Updated uploaded one-time key count {} -> {}, marking keys as published", - self.uploaded_signed_key_count - .as_ref() - .map_or(0, |c| c.load(Ordering::Relaxed)), + self.account.uploaded_key_count(), count ); self.update_key_count(count); @@ -543,26 +527,21 @@ impl OlmMachine { /// Returns the number of newly generated one-time keys. If no keys can be /// generated returns an empty error. async fn generate_one_time_keys(&self) -> StdResult { - match &self.uploaded_signed_key_count { - Some(count) => { - // TODO if we store the uploaded key count with the Account all - // this logic could go into the account. - let count = count.load(Ordering::Relaxed); - let max_keys = self.account.max_one_time_keys().await; - let max_on_server = (max_keys as u64) / 2; + let count = self.account.uploaded_key_count() as u64; + // TODO if we store the uploaded key count with the Account all + // this logic could go into the account. + let max_keys = self.account.max_one_time_keys().await; + let max_on_server = (max_keys as u64) / 2; - if count >= (max_on_server) { - return Err(()); - } - - let key_count = (max_on_server) - count; - let key_count: usize = key_count.try_into().unwrap_or(max_keys); - - self.account.generate_one_time_keys(key_count).await; - Ok(key_count as u64) - } - None => Err(()), + if count >= (max_on_server) { + return Err(()); } + + let key_count = (max_on_server) - count; + let key_count: usize = key_count.try_into().unwrap_or(max_keys); + + self.account.generate_one_time_keys(key_count).await; + Ok(key_count as u64) } /// Generate, sign and prepare one-time keys to be uploaded. @@ -1023,8 +1002,6 @@ impl OlmMachine { // stores the curve key of the device, if we also store the ed25519 key // with the session we'll only need to pass in the account to the // session and all of this can live in the session. - // - // Storing a reference to the account is probably not worth the effort. let recipient_signing_key = recipient_device .get_key(KeyAlgorithm::Ed25519) @@ -1441,7 +1418,6 @@ mod test { use std::collections::BTreeMap; use std::convert::TryFrom; use std::convert::TryInto; - use std::sync::atomic::AtomicU64; use std::time::SystemTime; use http::Response; @@ -1514,7 +1490,7 @@ mod test { async fn get_prepared_machine() -> (OlmMachine, OneTimeKeys) { let mut machine = OlmMachine::new(&user_id(), &alice_device_id()); - machine.uploaded_signed_key_count = Some(AtomicU64::new(0)); + machine.account.update_uploaded_key_count(0); let (_, otk) = machine .keys_for_upload() .await @@ -1660,7 +1636,6 @@ mod test { let mut response = keys_upload_response(); assert!(machine.should_upload_keys().await); - assert!(machine.generate_one_time_keys().await.is_err()); machine .receive_keys_upload_response(&response) @@ -1729,8 +1704,8 @@ mod test { #[tokio::test] async fn test_one_time_key_signing() { - let mut machine = OlmMachine::new(&user_id(), &alice_device_id()); - machine.uploaded_signed_key_count = Some(AtomicU64::new(49)); + let machine = OlmMachine::new(&user_id(), &alice_device_id()); + machine.account.update_uploaded_key_count(49); let mut one_time_keys = machine.signed_one_time_keys().await.unwrap(); let identity_keys = machine.account.identity_keys(); @@ -1750,7 +1725,7 @@ mod test { #[tokio::test] async fn test_keys_for_upload() { let mut machine = OlmMachine::new(&user_id(), &alice_device_id()); - machine.uploaded_signed_key_count = Some(AtomicU64::default()); + machine.account.update_uploaded_key_count(0); let identity_keys = machine.account.identity_keys(); let ed25519_key = identity_keys.ed25519(); diff --git a/matrix_sdk_crypto/src/olm.rs b/matrix_sdk_crypto/src/olm.rs index c8d31522..61888ca0 100644 --- a/matrix_sdk_crypto/src/olm.rs +++ b/matrix_sdk_crypto/src/olm.rs @@ -13,9 +13,10 @@ // limitations under the License. use matrix_sdk_common::instant::Instant; +use std::convert::TryFrom; use std::convert::TryInto; use std::fmt; -use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::sync::atomic::{AtomicBool, AtomicI64, AtomicUsize, Ordering}; use std::sync::Arc; use matrix_sdk_common::locks::Mutex; @@ -61,6 +62,11 @@ pub struct Account { inner: Arc>, identity_keys: Arc, shared: Arc, + /// The number of signed one-time keys we have uploaded to the server. If + /// this is None, no action will be taken. After a sync request the client + /// needs to set this for us, depending on the count we will suggest the + /// client to upload new keys. + uploaded_signed_key_count: Arc, } // #[cfg_attr(tarpaulin, skip)] @@ -90,6 +96,7 @@ impl Account { inner: Arc::new(Mutex::new(account)), identity_keys: Arc::new(identity_keys), shared: Arc::new(AtomicBool::new(false)), + uploaded_signed_key_count: Arc::new(AtomicI64::new(0)), } } @@ -98,6 +105,22 @@ impl Account { &self.identity_keys } + /// Update the uploaded key count. + /// + /// # Arguments + /// + /// * `new_count` - The new count that was reported by the server. + pub fn update_uploaded_key_count(&self, new_count: u64) { + let key_count = i64::try_from(new_count).unwrap_or(i64::MAX); + self.uploaded_signed_key_count + .store(key_count, Ordering::Relaxed); + } + + /// Get the currently known uploaded key count. + pub fn uploaded_key_count(&self) -> i64 { + self.uploaded_signed_key_count.load(Ordering::Relaxed) + } + /// Has the account been shared with the server. pub fn shared(&self) -> bool { self.shared.load(Ordering::Relaxed) @@ -165,6 +188,7 @@ impl Account { pickle: String, pickle_mode: PicklingMode, shared: bool, + uploaded_signed_key_count: i64, user_id: &UserId, device_id: &DeviceId, ) -> Result { @@ -177,6 +201,7 @@ impl Account { inner: Arc::new(Mutex::new(account)), identity_keys: Arc::new(identity_keys), shared: Arc::new(AtomicBool::from(shared)), + uploaded_signed_key_count: Arc::new(AtomicI64::new(uploaded_signed_key_count)), }) } diff --git a/matrix_sdk_crypto/src/store/sqlite.rs b/matrix_sdk_crypto/src/store/sqlite.rs index 54fb8513..1f207515 100644 --- a/matrix_sdk_crypto/src/store/sqlite.rs +++ b/matrix_sdk_crypto/src/store/sqlite.rs @@ -144,6 +144,7 @@ impl SqliteStore { "device_id" TEXT NOT NULL, "pickle" BLOB NOT NULL, "shared" INTEGER NOT NULL, + "uploaded_key_count" INTEGER NOT NULL, UNIQUE(user_id,device_id) ); "#, @@ -564,8 +565,8 @@ impl CryptoStore for SqliteStore { async fn load_account(&mut self) -> Result> { let mut connection = self.connection.lock().await; - let row: Option<(i64, String, bool)> = query_as( - "SELECT id, pickle, shared FROM accounts + let row: Option<(i64, String, bool, i64)> = query_as( + "SELECT id, pickle, shared, uploaded_key_count FROM accounts WHERE user_id = ? and device_id = ?", ) .bind(self.user_id.as_str()) @@ -573,12 +574,13 @@ impl CryptoStore for SqliteStore { .fetch_optional(&mut *connection) .await?; - let result = if let Some((id, pickle, shared)) = row { + let result = if let Some((id, pickle, shared, uploaded_key_count)) = row { self.account_id = Some(id); Some(Account::from_pickle( pickle, self.get_pickle_mode(), shared, + uploaded_key_count, &self.user_id, &self.device_id, )?) @@ -613,8 +615,8 @@ impl CryptoStore for SqliteStore { query( "INSERT INTO accounts ( - user_id, device_id, pickle, shared - ) VALUES (?1, ?2, ?3, ?4) + user_id, device_id, pickle, shared, uploaded_key_count + ) VALUES (?1, ?2, ?3, ?4, ?5) ON CONFLICT(user_id, device_id) DO UPDATE SET pickle = excluded.pickle, shared = excluded.shared @@ -624,6 +626,7 @@ impl CryptoStore for SqliteStore { .bind(&*self.device_id.to_string()) .bind(&pickle) .bind(account.shared()) + .bind(account.uploaded_key_count()) .execute(&mut *connection) .await?;