diff --git a/src/crypto/machine.rs b/src/crypto/machine.rs index f5db37a5..ab97dd8a 100644 --- a/src/crypto/machine.rs +++ b/src/crypto/machine.rs @@ -32,7 +32,7 @@ use ruma_client_api::r0::keys::{ use ruma_events::Algorithm; use ruma_identifiers::{DeviceId, UserId}; -pub type SignedOneTimeKeys = HashMap; +pub type OneTimeKeys = HashMap; struct OlmMachine { /// The unique user id that owns this account. @@ -177,7 +177,7 @@ impl OlmMachine { /// Generate, sign and prepare one-time keys to be uploaded. /// /// If no one-time keys need to be uploaded returns an empty error. - fn signed_one_time_keys(&self) -> Result { + fn signed_one_time_keys(&self) -> Result { let _ = self.generate_one_time_keys()?; let one_time_keys = self.account.one_time_keys(); @@ -292,6 +292,25 @@ impl OlmMachine { ret } + + /// Get a tuple of device and one-time keys that need to be uploaded. + /// + /// Returns an empty error if no keys need to be uploaded. + pub fn keys_for_upload(&self) -> Result<(Option, Option), ()> { + if !self.should_upload_keys() { + return Err(()); + } + + let device_keys = if !self.account.shared() { + Some(self.device_keys()) + } else { + None + }; + + let one_time_keys: Option = self.signed_one_time_keys().ok(); + + Ok((device_keys, one_time_keys)) + } } #[cfg(test)] @@ -437,4 +456,44 @@ mod test { ); assert!(ret.is_ok()); } + + #[async_std::test] + async fn test_keys_for_upload() { + let mut machine = OlmMachine::new(user_id(), DEVICE_ID); + machine.uploaded_signed_key_count = Some(0); + + let identity_keys = machine.account.identity_keys(); + let ed25519_key = identity_keys.ed25519(); + + let (device_keys, mut one_time_keys) = machine + .keys_for_upload() + .expect("Can't prepare initial key upload"); + + let ret = machine.verify_json( + &machine.user_id, + &machine.device_id, + ed25519_key, + &mut json!(&mut one_time_keys.as_mut().unwrap().values_mut().nth(0)), + ); + assert!(ret.is_ok()); + + let ret = machine.verify_json( + &machine.user_id, + &machine.device_id, + ed25519_key, + &mut json!(&mut device_keys.unwrap()), + ); + assert!(ret.is_ok()); + + let mut response = keys_upload_response(); + response.one_time_key_counts.insert( + keys::KeyAlgorithm::SignedCurve25519, + UInt::new_wrapping(one_time_keys.unwrap().len() as u64), + ); + + machine.receive_keys_upload_response(&response).await; + + let ret = machine.keys_for_upload(); + assert!(ret.is_err()); + } }