diff --git a/src/crypto/machine.rs b/src/crypto/machine.rs index 1037bc1d..2e95daf2 100644 --- a/src/crypto/machine.rs +++ b/src/crypto/machine.rs @@ -68,7 +68,7 @@ pub struct OlmMachine { /// The unique device id of the device that holds this account. device_id: DeviceId, /// Our underlying Olm Account holding our identity keys. - account: Arc>, + account: Account, /// The number of signed one-time keys we have uploaded to the server. If /// this is None, no action will be taken. After a sync request the client /// needs to set this for us, depending on the count we will suggest the @@ -98,7 +98,7 @@ impl OlmMachine { Ok(OlmMachine { user_id: user_id.clone(), device_id: device_id.to_owned(), - account: Arc::new(Mutex::new(Account::new())), + account: Account::new(), uploaded_signed_key_count: None, store: Box::new(MemoryStore::new()), users_for_key_query: HashSet::new(), @@ -132,7 +132,7 @@ impl OlmMachine { Ok(OlmMachine { user_id: user_id.clone(), device_id: device_id.to_owned(), - account: Arc::new(Mutex::new(account)), + account, uploaded_signed_key_count: None, store: Box::new(store), users_for_key_query: HashSet::new(), @@ -142,7 +142,7 @@ impl OlmMachine { /// Should account or one-time keys be uploaded to the server. pub async fn should_upload_keys(&self) -> bool { - if !self.account.lock().await.shared() { + if !self.account.shared() { return true; } @@ -150,7 +150,7 @@ impl OlmMachine { // max_one_time_Keys() / 2, otherwise tell the client to upload more. match self.uploaded_signed_key_count { Some(count) => { - let max_keys = self.account.lock().await.max_one_time_keys() as u64; + let max_keys = self.account.max_one_time_keys().await as u64; let key_count = (max_keys / 2) - count; key_count > 0 } @@ -169,11 +169,10 @@ impl OlmMachine { &mut self, response: &keys::upload_keys::Response, ) -> Result<()> { - let mut account = self.account.lock().await; - if !account.shared { + if !self.account.shared() { debug!("Marking account as shared"); } - account.shared = true; + self.account.mark_as_shared(); let one_time_key_count = response .one_time_key_counts @@ -187,10 +186,9 @@ impl OlmMachine { ); self.uploaded_signed_key_count = Some(count); - account.mark_keys_as_published(); - drop(account); + self.account.mark_keys_as_published().await; - self.store.save_account(self.account.clone()).await?; + // self.store.save_account(self.account.clone()).await?; Ok(()) } @@ -317,9 +315,8 @@ impl OlmMachine { let session = match self .account - .lock() - .await .create_outbound_session(curve_key, &one_time_key) + .await { Ok(s) => s, Err(e) => { @@ -441,10 +438,9 @@ impl OlmMachine { /// Returns the number of newly generated one-time keys. If no keys can be /// generated returns an empty error. async fn generate_one_time_keys(&self) -> StdResult { - let account = self.account.lock().await; match self.uploaded_signed_key_count { Some(count) => { - let max_keys = account.max_one_time_keys() as u64; + let max_keys = self.account.max_one_time_keys().await as u64; let max_on_server = max_keys / 2; if count >= (max_on_server) { @@ -453,11 +449,11 @@ impl OlmMachine { let key_count = (max_on_server) - count; - let key_count: usize = key_count - .try_into() - .unwrap_or_else(|_| account.max_one_time_keys()); + let max_keys = self.account.max_one_time_keys().await; - account.generate_one_time_keys(key_count); + let key_count: usize = key_count.try_into().unwrap_or(max_keys); + + self.account.generate_one_time_keys(key_count).await; Ok(key_count as u64) } None => Err(()), @@ -466,7 +462,7 @@ impl OlmMachine { /// Sign the device keys and return a JSON Value to upload them. async fn device_keys(&self) -> DeviceKeys { - let identity_keys = self.account.lock().await.identity_keys(); + let identity_keys = self.account.identity_keys(); let mut keys = HashMap::new(); @@ -513,7 +509,7 @@ impl OlmMachine { /// If no one-time keys need to be uploaded returns an empty error. async fn signed_one_time_keys(&self) -> StdResult { let _ = self.generate_one_time_keys().await?; - let one_time_keys = self.account.lock().await.one_time_keys(); + let one_time_keys = self.account.one_time_keys().await; let mut one_time_key_map = HashMap::new(); for (key_id, key) in one_time_keys.curve25519().iter() { @@ -555,10 +551,9 @@ impl OlmMachine { /// * `json` - The value that should be converted into a canonical JSON /// string. async fn sign_json(&self, json: &Value) -> String { - let account = self.account.lock().await; let canonical_json = cjson::to_string(json) .unwrap_or_else(|_| panic!(format!("Can't serialize {} to canonical JSON", json))); - account.sign(&canonical_json) + self.account.sign(&canonical_json).await } /// Verify a signed JSON object. @@ -637,7 +632,7 @@ impl OlmMachine { return Err(()); } - let shared = self.account.lock().await.shared(); + let shared = self.account.shared(); let device_keys = if !shared { Some(self.device_keys().await) @@ -702,8 +697,9 @@ impl OlmMachine { let mut session = match &message { OlmMessage::Message(_) => return Err(OlmError::SessionWedged), OlmMessage::PreKey(m) => { - let account = self.account.lock().await; - account.create_inbound_session(sender_key, m.clone())? + self.account + .create_inbound_session(sender_key, m.clone()) + .await? } }; @@ -740,7 +736,7 @@ impl OlmMachine { return Err(OlmError::UnsupportedAlgorithm); }; - let identity_keys = self.account.lock().await.identity_keys(); + let identity_keys = self.account.identity_keys(); let own_key = identity_keys.curve25519(); let own_ciphertext = content.ciphertext.get(own_key); @@ -799,8 +795,7 @@ impl OlmMachine { async fn create_outbound_group_session(&mut self, room_id: &RoomId) -> Result<()> { let session = OutboundGroupSession::new(room_id); - let account = self.account.lock().await; - let identity_keys = account.identity_keys(); + let identity_keys = self.account.identity_keys(); let sender_key = identity_keys.curve25519(); let signing_key = identity_keys.ed25519(); @@ -855,13 +850,7 @@ impl OlmMachine { Ok(MegolmV1AesSha2Content { algorithm: Algorithm::MegolmV1AesSha2, ciphertext, - sender_key: self - .account - .lock() - .await - .identity_keys() - .curve25519() - .to_owned(), + sender_key: self.account.identity_keys().curve25519().to_owned(), session_id: session.session_id().to_owned(), device_id: self.device_id.to_owned(), }) @@ -874,7 +863,7 @@ impl OlmMachine { event_type: EventType, content: Value, ) -> Result { - let identity_keys = self.account.lock().await.identity_keys(); + let identity_keys = self.account.identity_keys(); let recipient_signing_key = recipient_device .keys(&KeyAlgorithm::Ed25519) @@ -1326,7 +1315,7 @@ mod test { let machine = OlmMachine::new(&user_id(), DEVICE_ID).unwrap(); let mut device_keys = machine.device_keys().await; - let identity_keys = machine.account.lock().await.identity_keys(); + let identity_keys = machine.account.identity_keys(); let ed25519_key = identity_keys.ed25519(); let ret = machine.verify_json( @@ -1359,7 +1348,7 @@ mod test { machine.uploaded_signed_key_count = Some(49); let mut one_time_keys = machine.signed_one_time_keys().await.unwrap(); - let identity_keys = machine.account.lock().await.identity_keys(); + let identity_keys = machine.account.identity_keys(); let ed25519_key = identity_keys.ed25519(); let mut one_time_key = one_time_keys.values_mut().nth(0).unwrap(); @@ -1378,7 +1367,7 @@ mod test { let mut machine = OlmMachine::new(&user_id(), DEVICE_ID).unwrap(); machine.uploaded_signed_key_count = Some(0); - let identity_keys = machine.account.lock().await.identity_keys(); + let identity_keys = machine.account.identity_keys(); let ed25519_key = identity_keys.ed25519(); let (device_keys, mut one_time_keys) = machine diff --git a/src/crypto/olm.rs b/src/crypto/olm.rs index 8520b4f1..f9fc99bc 100644 --- a/src/crypto/olm.rs +++ b/src/crypto/olm.rs @@ -33,9 +33,11 @@ use crate::identifiers::RoomId; /// The Olm account. /// An account is the central identity for encrypted communication between two /// devices. It holds the two identity key pairs for a device. +#[derive(Clone)] pub struct Account { - inner: OlmAccount, - pub(crate) shared: bool, + inner: Arc>, + identity_keys: Arc, + pub(crate) shared: Arc, } impl fmt::Debug for Account { @@ -44,7 +46,7 @@ impl fmt::Debug for Account { f, "Olm Account: {:?}, shared: {}", self.identity_keys(), - self.shared + self.shared() ) } } @@ -52,49 +54,61 @@ impl fmt::Debug for Account { impl Account { /// Create a new account. pub fn new() -> Self { + let account = OlmAccount::new(); + let identity_keys = account.parsed_identity_keys(); + Account { - inner: OlmAccount::new(), - shared: false, + inner: Arc::new(Mutex::new(account)), + identity_keys: Arc::new(identity_keys), + shared: Arc::new(AtomicBool::new(false)), } } /// Get the public parts of the identity keys for the account. - pub fn identity_keys(&self) -> IdentityKeys { - self.inner.parsed_identity_keys() + pub fn identity_keys(&self) -> &IdentityKeys { + &self.identity_keys } /// Has the account been shared with the server. pub fn shared(&self) -> bool { - self.shared + self.shared.load(Ordering::Relaxed) + } + + /// Mark the account as shared. + /// + /// Messages shouldn't be encrypted with the session before it has been + /// shared. + pub fn mark_as_shared(&self) { + self.shared.store(true, Ordering::Relaxed); } /// Get the one-time keys of the account. /// /// This can be empty, keys need to be generated first. - pub fn one_time_keys(&self) -> OneTimeKeys { - self.inner.parsed_one_time_keys() + pub async fn one_time_keys(&self) -> OneTimeKeys { + self.inner.lock().await.parsed_one_time_keys() } /// Generate count number of one-time keys. - pub fn generate_one_time_keys(&self, count: usize) { - self.inner.generate_one_time_keys(count); + pub async fn generate_one_time_keys(&self, count: usize) { + self.inner.lock().await.generate_one_time_keys(count); } /// Get the maximum number of one-time keys the account can hold. - pub fn max_one_time_keys(&self) -> usize { - self.inner.max_number_of_one_time_keys() + pub async fn max_one_time_keys(&self) -> usize { + self.inner.lock().await.max_number_of_one_time_keys() } /// Mark the current set of one-time keys as being published. - pub fn mark_keys_as_published(&self) { - self.inner.mark_keys_as_published(); + pub async fn mark_keys_as_published(&self) { + self.inner.lock().await.mark_keys_as_published(); } /// Sign the given string using the accounts signing key. /// /// Returns the signature as a base64 encoded string. - pub fn sign(&self, string: &str) -> String { - self.inner.sign(string) + pub async fn sign(&self, string: &str) -> String { + self.inner.lock().await.sign(string) } /// Store the account as a base64 encoded string. @@ -103,8 +117,8 @@ impl Account { /// /// * `pickle_mode` - The mode that was used to pickle the account, either an /// unencrypted mode or an encrypted using passphrase. - pub fn pickle(&self, pickle_mode: PicklingMode) -> String { - self.inner.pickle(pickle_mode) + pub async fn pickle(&self, pickle_mode: PicklingMode) -> String { + self.inner.lock().await.pickle(pickle_mode) } /// Restore an account from a previously pickled string. @@ -123,8 +137,14 @@ impl Account { pickle_mode: PicklingMode, shared: bool, ) -> Result { - let acc = OlmAccount::unpickle(pickle, pickle_mode)?; - Ok(Account { inner: acc, shared }) + let account = OlmAccount::unpickle(pickle, pickle_mode)?; + let identity_keys = account.parsed_identity_keys(); + + Ok(Account { + inner: Arc::new(Mutex::new(account)), + identity_keys: Arc::new(identity_keys), + shared: Arc::new(AtomicBool::from(shared)), + }) } /// Create a new session with another account given a one-time key. @@ -137,13 +157,15 @@ impl Account { /// /// * `their_one_time_key` - A signed one-time key that the other account /// created and shared with us. - pub fn create_outbound_session( + pub async fn create_outbound_session( &self, their_identity_key: &str, their_one_time_key: &SignedKey, ) -> Result { let session = self .inner + .lock() + .await .create_outbound_session(their_identity_key, &their_one_time_key.key)?; let now = Instant::now(); @@ -166,13 +188,15 @@ impl Account { /// /// * `message` - A pre-key Olm message that was sent to us by the other /// account. - pub fn create_inbound_session( + pub async fn create_inbound_session( &self, their_identity_key: &str, message: PreKeyMessage, ) -> Result { let session = self .inner + .lock() + .await .create_inbound_session_from(their_identity_key, message)?; let now = Instant::now(); @@ -188,7 +212,7 @@ impl Account { impl PartialEq for Account { fn eq(&self, other: &Self) -> bool { - self.identity_keys() == other.identity_keys() && self.shared == other.shared + self.identity_keys() == other.identity_keys() && self.shared() == other.shared() } } @@ -566,16 +590,16 @@ mod test { assert!(!identyty_keys.curve25519().is_empty()); } - #[test] - fn one_time_keys_creation() { + #[tokio::test] + async fn one_time_keys_creation() { let account = Account::new(); - let one_time_keys = account.one_time_keys(); + let one_time_keys = account.one_time_keys().await; assert!(one_time_keys.curve25519().is_empty()); - assert_ne!(account.max_one_time_keys(), 0); + assert_ne!(account.max_one_time_keys().await, 0); - account.generate_one_time_keys(10); - let one_time_keys = account.one_time_keys(); + account.generate_one_time_keys(10).await; + let one_time_keys = account.one_time_keys().await; assert!(!one_time_keys.curve25519().is_empty()); assert_ne!(one_time_keys.values().len(), 0); @@ -588,21 +612,19 @@ mod test { one_time_keys.get("curve25519").unwrap() ); - account.mark_keys_as_published(); - let one_time_keys = account.one_time_keys(); + account.mark_keys_as_published().await; + let one_time_keys = account.one_time_keys().await; assert!(one_time_keys.curve25519().is_empty()); } - #[test] - fn session_creation() { + #[tokio::test] + async fn session_creation() { let alice = Account::new(); let bob = Account::new(); let alice_keys = alice.identity_keys(); - let one_time_keys = alice.one_time_keys(); - - alice.generate_one_time_keys(1); - let one_time_keys = alice.one_time_keys(); - alice.mark_keys_as_published(); + alice.generate_one_time_keys(1).await; + let one_time_keys = alice.one_time_keys().await; + alice.mark_keys_as_published().await; let one_time_key = one_time_keys .curve25519() @@ -619,6 +641,7 @@ mod test { let mut bob_session = bob .create_outbound_session(alice_keys.curve25519(), &one_time_key) + .await .unwrap(); let plaintext = "Hello world"; @@ -633,6 +656,7 @@ mod test { let bob_keys = bob.identity_keys(); let mut alice_session = alice .create_inbound_session(bob_keys.curve25519(), prekey_message) + .await .unwrap(); assert_eq!(bob_session.session_id(), alice_session.session_id()); diff --git a/src/crypto/store/memorystore.rs b/src/crypto/store/memorystore.rs index 2ab3bc97..23116b67 100644 --- a/src/crypto/store/memorystore.rs +++ b/src/crypto/store/memorystore.rs @@ -48,7 +48,7 @@ impl CryptoStore for MemoryStore { Ok(None) } - async fn save_account(&mut self, _: Arc>) -> Result<()> { + async fn save_account(&mut self, _: Account) -> Result<()> { Ok(()) } diff --git a/src/crypto/store/mod.rs b/src/crypto/store/mod.rs index aa485f82..1b41f63b 100644 --- a/src/crypto/store/mod.rs +++ b/src/crypto/store/mod.rs @@ -66,7 +66,7 @@ pub type Result = std::result::Result; #[async_trait] pub trait CryptoStore: Debug + Send + Sync { async fn load_account(&mut self) -> Result>; - async fn save_account(&mut self, account: Arc>) -> Result<()>; + async fn save_account(&mut self, account: Account) -> Result<()>; async fn save_session(&mut self, session: Arc>) -> Result<()>; async fn add_and_save_session(&mut self, session: Session) -> Result<()>; async fn get_sessions( diff --git a/src/crypto/store/sqlite.rs b/src/crypto/store/sqlite.rs index 32742326..a629214c 100644 --- a/src/crypto/store/sqlite.rs +++ b/src/crypto/store/sqlite.rs @@ -288,9 +288,8 @@ impl CryptoStore for SqliteStore { Ok(result) } - async fn save_account(&mut self, account: Arc>) -> Result<()> { - let acc = account.lock().await; - let pickle = acc.pickle(self.get_pickle_mode()); + async fn save_account(&mut self, account: Account) -> Result<()> { + let pickle = account.pickle(self.get_pickle_mode()).await; let mut connection = self.connection.lock().await; query( @@ -307,7 +306,7 @@ impl CryptoStore for SqliteStore { .bind(&*self.user_id.to_string()) .bind(&*self.device_id.to_string()) .bind(&pickle) - .bind(acc.shared) + .bind(account.shared()) .execute(&mut *connection) .await?; @@ -460,7 +459,7 @@ mod test { .expect("Can't create store") } - async fn get_loaded_store() -> (Arc>, SqliteStore) { + async fn get_loaded_store() -> (Account, SqliteStore) { let mut store = get_store().await; let account = get_account(); store @@ -471,19 +470,19 @@ mod test { (account, store) } - fn get_account() -> Arc> { - let account = Account::new(); - Arc::new(Mutex::new(account)) + fn get_account() -> Account { + Account::new() } - fn get_account_and_session() -> (Arc>, Session) { + async fn get_account_and_session() -> (Account, Session) { let alice = Account::new(); let bob = Account::new(); - bob.generate_one_time_keys(1); + bob.generate_one_time_keys(1).await; let one_time_key = bob .one_time_keys() + .await .curve25519() .iter() .nth(0) @@ -497,9 +496,10 @@ mod test { let sender_key = bob.identity_keys().curve25519().to_owned(); let session = alice .create_outbound_session(&sender_key, &one_time_key) + .await .unwrap(); - (Arc::new(Mutex::new(alice)), session) + (alice, session) } #[tokio::test] @@ -532,11 +532,10 @@ mod test { .await .expect("Can't save account"); - let acc = account.lock().await; let loaded_account = store.load_account().await.expect("Can't load account"); let loaded_account = loaded_account.unwrap(); - assert_eq!(*acc, loaded_account); + assert_eq!(account, loaded_account); } #[tokio::test] @@ -549,7 +548,7 @@ mod test { .await .expect("Can't save account"); - account.lock().await.shared = true; + account.mark_as_shared(); store .save_account(account.clone()) @@ -558,15 +557,14 @@ mod test { let loaded_account = store.load_account().await.expect("Can't load account"); let loaded_account = loaded_account.unwrap(); - let acc = account.lock().await; - assert_eq!(*acc, loaded_account); + assert_eq!(account, loaded_account); } #[tokio::test] async fn save_session() { let mut store = get_store().await; - let (account, session) = get_account_and_session(); + let (account, session) = get_account_and_session().await; let session = Arc::new(Mutex::new(session)); assert!(store.save_session(session.clone()).await.is_err()); @@ -582,7 +580,7 @@ mod test { #[tokio::test] async fn load_sessions() { let mut store = get_store().await; - let (account, session) = get_account_and_session(); + let (account, session) = get_account_and_session().await; let session = Arc::new(Mutex::new(session)); store .save_account(account.clone()) @@ -604,7 +602,7 @@ mod test { #[tokio::test] async fn add_and_save_session() { let mut store = get_store().await; - let (account, session) = get_account_and_session(); + let (account, session) = get_account_and_session().await; let sender_key = session.sender_key.to_owned(); let session_id = session.session_id(); @@ -625,8 +623,7 @@ mod test { async fn save_inbound_group_session() { let (account, mut store) = get_loaded_store().await; - let acc = account.lock().await; - let identity_keys = acc.identity_keys(); + let identity_keys = account.identity_keys(); let outbound_session = OlmOutboundGroupSession::new(); let session = InboundGroupSession::new( identity_keys.curve25519(), @@ -646,8 +643,7 @@ mod test { async fn load_inbound_group_session() { let (account, mut store) = get_loaded_store().await; - let acc = account.lock().await; - let identity_keys = acc.identity_keys(); + let identity_keys = account.identity_keys(); let outbound_session = OlmOutboundGroupSession::new(); let session = InboundGroupSession::new( identity_keys.curve25519(),