diff --git a/matrix_sdk_crypto/src/machine.rs b/matrix_sdk_crypto/src/machine.rs index 7776c709..a9b0f199 100644 --- a/matrix_sdk_crypto/src/machine.rs +++ b/matrix_sdk_crypto/src/machine.rs @@ -189,23 +189,7 @@ impl OlmMachine { /// Should account or one-time keys be uploaded to the server. pub async fn should_upload_keys(&self) -> bool { - if !self.account.shared() { - return true; - } - - let count = self.account.uploaded_key_count() as u64; - - // If we have a known key count, check that we have more than - // max_one_time_Keys() / 2, otherwise tell the client to upload more. - let max_keys = self.account.max_one_time_keys().await as u64; - // If there are more keys already uploaded than max_key / 2 - // bail out returning false, this also avoids overflow. - if count > (max_keys / 2) { - return false; - } - - let key_count = (max_keys / 2) - count; - key_count > 0 + self.account.should_upload_keys().await } /// Update the count of one-time keys that are currently on the server. @@ -522,36 +506,6 @@ impl OlmMachine { Ok(changed_devices) } - /// Generate new one-time keys. - /// - /// Returns the number of newly generated one-time keys. If no keys can be - /// generated returns an empty error. - async fn generate_one_time_keys(&self) -> StdResult { - let count = self.account.uploaded_key_count() as u64; - // TODO if we store the uploaded key count with the Account all - // this logic could go into the account. - let max_keys = self.account.max_one_time_keys().await; - let max_on_server = (max_keys as u64) / 2; - - if count >= (max_on_server) { - return Err(()); - } - - let key_count = (max_on_server) - count; - let key_count: usize = key_count.try_into().unwrap_or(max_keys); - - self.account.generate_one_time_keys(key_count).await; - Ok(key_count as u64) - } - - /// Generate, sign and prepare one-time keys to be uploaded. - /// - /// If no one-time keys need to be uploaded returns an empty error. - async fn signed_one_time_keys(&self) -> StdResult { - let _ = self.generate_one_time_keys().await?; - Ok(self.account.signed_one_time_keys().await) - } - /// Verify a signed JSON object. /// /// The object must have a signatures key associated with an object of the @@ -624,21 +578,7 @@ impl OlmMachine { pub async fn keys_for_upload( &self, ) -> StdResult<(Option, Option), ()> { - if !self.should_upload_keys().await { - return Err(()); - } - - let shared = self.account.shared(); - - let device_keys = if !shared { - Some(self.account.device_keys().await) - } else { - None - }; - - let one_time_keys: Option = self.signed_one_time_keys().await.ok(); - - Ok((device_keys, one_time_keys)) + self.account.keys_for_upload().await } /// Try to decrypt an Olm message. @@ -1642,7 +1582,7 @@ mod test { .await .unwrap(); assert!(machine.should_upload_keys().await); - assert!(machine.generate_one_time_keys().await.is_ok()); + assert!(machine.account.generate_one_time_keys().await.is_ok()); response .one_time_key_counts @@ -1651,7 +1591,7 @@ mod test { .receive_keys_upload_response(&response) .await .unwrap(); - assert!(machine.generate_one_time_keys().await.is_err()); + assert!(machine.account.generate_one_time_keys().await.is_err()); } #[tokio::test] @@ -1707,7 +1647,7 @@ mod test { let machine = OlmMachine::new(&user_id(), &alice_device_id()); machine.account.update_uploaded_key_count(49); - let mut one_time_keys = machine.signed_one_time_keys().await.unwrap(); + let mut one_time_keys = machine.account.signed_one_time_keys().await.unwrap(); let identity_keys = machine.account.identity_keys(); let ed25519_key = identity_keys.ed25519(); diff --git a/matrix_sdk_crypto/src/olm.rs b/matrix_sdk_crypto/src/olm.rs index 61888ca0..e1e33a20 100644 --- a/matrix_sdk_crypto/src/olm.rs +++ b/matrix_sdk_crypto/src/olm.rs @@ -110,7 +110,7 @@ impl Account { /// # Arguments /// /// * `new_count` - The new count that was reported by the server. - pub fn update_uploaded_key_count(&self, new_count: u64) { + pub(crate) fn update_uploaded_key_count(&self, new_count: u64) { let key_count = i64::try_from(new_count).unwrap_or(i64::MAX); self.uploaded_signed_key_count .store(key_count, Ordering::Relaxed); @@ -130,29 +130,96 @@ impl Account { /// /// Messages shouldn't be encrypted with the session before it has been /// shared. - pub fn mark_as_shared(&self) { + pub(crate) fn mark_as_shared(&self) { self.shared.store(true, Ordering::Relaxed); } /// Get the one-time keys of the account. /// /// This can be empty, keys need to be generated first. - pub async fn one_time_keys(&self) -> OneTimeKeys { + pub(crate) async fn one_time_keys(&self) -> OneTimeKeys { self.inner.lock().await.parsed_one_time_keys() } /// Generate count number of one-time keys. - pub async fn generate_one_time_keys(&self, count: usize) { + pub(crate) async fn generate_one_time_keys_helper(&self, count: usize) { self.inner.lock().await.generate_one_time_keys(count); } /// Get the maximum number of one-time keys the account can hold. - pub async fn max_one_time_keys(&self) -> usize { + pub(crate) async fn max_one_time_keys(&self) -> usize { self.inner.lock().await.max_number_of_one_time_keys() } + /// Get a tuple of device and one-time keys that need to be uploaded. + /// + /// Returns an empty error if no keys need to be uploaded. + pub(crate) async fn generate_one_time_keys(&self) -> Result { + let count = self.uploaded_key_count() as u64; + let max_keys = self.max_one_time_keys().await; + let max_on_server = (max_keys as u64) / 2; + + if count >= (max_on_server) { + return Err(()); + } + + let key_count = (max_on_server) - count; + let key_count: usize = key_count.try_into().unwrap_or(max_keys); + + self.generate_one_time_keys_helper(key_count).await; + Ok(key_count as u64) + } + + /// Should account or one-time keys be uploaded to the server. + pub(crate) async fn should_upload_keys(&self) -> bool { + if !self.shared() { + return true; + } + + let count = self.uploaded_key_count() as u64; + + // If we have a known key count, check that we have more than + // max_one_time_Keys() / 2, otherwise tell the client to upload more. + let max_keys = self.max_one_time_keys().await as u64; + // If there are more keys already uploaded than max_key / 2 + // bail out returning false, this also avoids overflow. + if count > (max_keys / 2) { + return false; + } + + let key_count = (max_keys / 2) - count; + key_count > 0 + } + + /// Get a tuple of device and one-time keys that need to be uploaded. + /// + /// Returns an empty error if no keys need to be uploaded. + pub(crate) async fn keys_for_upload( + &self, + ) -> Result< + ( + Option, + Option>, + ), + (), + > { + if !self.should_upload_keys().await { + return Err(()); + } + + let device_keys = if !self.shared() { + Some(self.device_keys().await) + } else { + None + }; + + let one_time_keys = self.signed_one_time_keys().await.ok(); + + Ok((device_keys, one_time_keys)) + } + /// Mark the current set of one-time keys as being published. - pub async fn mark_keys_as_published(&self) { + pub(crate) async fn mark_keys_as_published(&self) { self.inner.lock().await.mark_keys_as_published(); } @@ -207,7 +274,7 @@ impl Account { /// Sign the device keys of the account and return them so they can be /// uploaded. - pub async fn device_keys(&self) -> DeviceKeys { + pub(crate) async fn device_keys(&self) -> DeviceKeys { let identity_keys = self.identity_keys(); let mut keys = BTreeMap::new(); @@ -257,6 +324,10 @@ impl Account { /// /// * `json` - The value that should be converted into a canonical JSON /// string. + /// + /// # Panic + /// + /// Panics if the json value can't be serialized. pub async fn sign_json(&self, json: &Value) -> String { let canonical_json = cjson::to_string(json) .unwrap_or_else(|_| panic!(format!("Can't serialize {} to canonical JSON", json))); @@ -266,7 +337,11 @@ impl Account { /// Generate, sign and prepare one-time keys to be uploaded. /// /// If no one-time keys need to be uploaded returns an empty error. - pub async fn signed_one_time_keys(&self) -> BTreeMap { + pub(crate) async fn signed_one_time_keys( + &self, + ) -> Result, ()> { + let _ = self.generate_one_time_keys().await?; + let one_time_keys = self.one_time_keys().await; let mut one_time_key_map = BTreeMap::new(); @@ -298,7 +373,7 @@ impl Account { ); } - one_time_key_map + Ok(one_time_key_map) } /// Create a new session with another account given a one-time key. @@ -311,7 +386,7 @@ impl Account { /// /// * `their_one_time_key` - A signed one-time key that the other account /// created and shared with us. - pub async fn create_outbound_session( + pub(crate) async fn create_outbound_session( &self, their_identity_key: &str, their_one_time_key: &SignedKey, @@ -344,7 +419,7 @@ impl Account { /// /// * `message` - A pre-key Olm message that was sent to us by the other /// account. - pub async fn create_inbound_session( + pub(crate) async fn create_inbound_session( &self, their_identity_key: &str, message: PreKeyMessage, @@ -386,7 +461,7 @@ impl Account { /// # Arguments /// /// * `room_id` - The ID of the room where the group session will be used. - pub async fn create_group_session_pair( + pub(crate) async fn create_group_session_pair( &self, room_id: &RoomId, ) -> (OutboundGroupSession, InboundGroupSession) { @@ -923,7 +998,7 @@ pub(crate) mod test { let alice = Account::new(&alice_id(), &alice_device_id()); let bob = Account::new(&bob_id(), &bob_device_id()); - bob.generate_one_time_keys(1).await; + bob.generate_one_time_keys_helper(1).await; let one_time_key = bob .one_time_keys() .await @@ -975,7 +1050,7 @@ pub(crate) mod test { assert!(one_time_keys.curve25519().is_empty()); assert_ne!(account.max_one_time_keys().await, 0); - account.generate_one_time_keys(10).await; + account.generate_one_time_keys_helper(10).await; let one_time_keys = account.one_time_keys().await; assert!(!one_time_keys.curve25519().is_empty()); @@ -999,7 +1074,7 @@ pub(crate) mod test { let alice = Account::new(&alice_id(), &alice_device_id()); let bob = Account::new(&bob_id(), &bob_device_id()); let alice_keys = alice.identity_keys(); - alice.generate_one_time_keys(1).await; + alice.generate_one_time_keys_helper(1).await; let one_time_keys = alice.one_time_keys().await; alice.mark_keys_as_published().await; diff --git a/matrix_sdk_crypto/src/store/sqlite.rs b/matrix_sdk_crypto/src/store/sqlite.rs index 1f207515..dd9bbb5d 100644 --- a/matrix_sdk_crypto/src/store/sqlite.rs +++ b/matrix_sdk_crypto/src/store/sqlite.rs @@ -857,7 +857,7 @@ mod test { let alice = Account::new(&alice_id(), &alice_device_id()); let bob = Account::new(&bob_id(), &bob_device_id()); - bob.generate_one_time_keys(1).await; + bob.generate_one_time_keys_helper(1).await; let one_time_key = bob .one_time_keys() .await