From 7e95d85f175b0d451af5a7a4b1b475863447eff4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Damir=20Jeli=C4=87?= Date: Tue, 28 Jul 2020 15:03:44 +0200 Subject: [PATCH] crypto: Move the cryptostore behind a lock. --- matrix_sdk_base/src/client.rs | 4 +- matrix_sdk_crypto/src/machine.rs | 161 ++++++++++++++++++++++++------- 2 files changed, 128 insertions(+), 37 deletions(-) diff --git a/matrix_sdk_base/src/client.rs b/matrix_sdk_base/src/client.rs index 375a6195..d26918c5 100644 --- a/matrix_sdk_base/src/client.rs +++ b/matrix_sdk_base/src/client.rs @@ -1260,7 +1260,7 @@ impl BaseClient { let olm = self.olm.lock().await; match &*olm { - Some(o) => o.should_query_keys(), + Some(o) => o.should_query_keys().await, None => false, } } @@ -1348,7 +1348,7 @@ impl BaseClient { let olm = self.olm.lock().await; match &*olm { - Some(o) => Ok(o.users_for_key_query()), + Some(o) => Ok(o.users_for_key_query().await), None => Err(()), } } diff --git a/matrix_sdk_crypto/src/machine.rs b/matrix_sdk_crypto/src/machine.rs index b569e429..6597499a 100644 --- a/matrix_sdk_crypto/src/machine.rs +++ b/matrix_sdk_crypto/src/machine.rs @@ -13,12 +13,12 @@ // limitations under the License. use std::collections::{BTreeMap, HashMap, HashSet}; -use std::convert::TryFrom; -use std::convert::TryInto; +use std::convert::{TryFrom, TryInto}; use std::mem; #[cfg(feature = "sqlite-cryptostore")] use std::path::Path; use std::result::Result as StdResult; +use std::sync::Arc; use super::error::{EventError, MegolmError, MegolmResult, OlmError, OlmResult}; use super::olm::{ @@ -36,6 +36,7 @@ use matrix_sdk_common::events::{ EventType, SyncMessageEvent, ToDeviceEvent, }; use matrix_sdk_common::identifiers::{DeviceId, RoomId, UserId}; +use matrix_sdk_common::locks::RwLock; use matrix_sdk_common::uuid::Uuid; use matrix_sdk_common::{api, Raw}; @@ -66,7 +67,7 @@ pub struct OlmMachine { /// Store for the encryption keys. /// Persists all the encryption keys so a client can resume the session /// without the need to create new keys. - store: Box, + store: Arc>>, /// The currently active outbound group sessions. outbound_group_sessions: HashMap, } @@ -100,7 +101,7 @@ impl OlmMachine { user_id: user_id.clone(), device_id: device_id.into(), account: Account::new(user_id, &device_id), - store: Box::new(MemoryStore::new()), + store: Arc::new(RwLock::new(Box::new(MemoryStore::new()))), outbound_group_sessions: HashMap::new(), } } @@ -142,7 +143,7 @@ impl OlmMachine { user_id, device_id, account, - store, + store: Arc::new(RwLock::new(store)), outbound_group_sessions: HashMap::new(), }) } @@ -224,7 +225,11 @@ impl OlmMachine { self.update_key_count(count); self.account.mark_keys_as_published().await; - self.store.save_account(self.account.clone()).await?; + self.store + .write() + .await + .save_account(self.account.clone()) + .await?; Ok(()) } @@ -255,7 +260,7 @@ impl OlmMachine { let mut missing = BTreeMap::new(); for user_id in users { - let user_devices = self.store.get_user_devices(user_id).await?; + let user_devices = self.store.read().await.get_user_devices(user_id).await?; for device in user_devices.devices() { let sender_key = if let Some(k) = device.get_key(KeyAlgorithm::Curve25519) { @@ -264,7 +269,7 @@ impl OlmMachine { continue; }; - let sessions = self.store.get_sessions(sender_key).await?; + let sessions = self.store.write().await.get_sessions(sender_key).await?; let is_missing = if let Some(sessions) = sessions { sessions.lock().await.is_empty() @@ -301,7 +306,13 @@ impl OlmMachine { for (user_id, user_devices) in &response.one_time_keys { for (device_id, key_map) in user_devices { - let device: Device = match self.store.get_device(&user_id, device_id).await { + let device: Device = match self + .store + .read() + .await + .get_device(&user_id, device_id) + .await + { Ok(Some(d)) => d, Ok(None) => { warn!( @@ -330,7 +341,7 @@ impl OlmMachine { } }; - if let Err(e) = self.store.save_sessions(&[session]).await { + if let Err(e) = self.store.write().await.save_sessions(&[session]).await { error!("Failed to store newly created Olm session {}", e); continue; } @@ -351,7 +362,11 @@ impl OlmMachine { let mut changed_devices = Vec::new(); for (user_id, device_map) in device_keys_map { - self.store.update_tracked_user(user_id, false).await?; + self.store + .write() + .await + .update_tracked_user(user_id, false) + .await?; for (device_id, device_keys) in device_map.iter() { // We don't need our own device in the device store. @@ -367,7 +382,12 @@ impl OlmMachine { continue; } - let device = self.store.get_device(&user_id, device_id).await?; + let device = self + .store + .read() + .await + .get_device(&user_id, device_id) + .await?; let device = if let Some(mut device) = device { if let Err(e) = device.update_device(device_keys) { @@ -398,7 +418,13 @@ impl OlmMachine { let current_devices: HashSet<&DeviceId> = device_map.keys().map(|id| id.as_ref()).collect(); - let stored_devices = self.store.get_user_devices(&user_id).await.unwrap(); + let stored_devices = self + .store + .read() + .await + .get_user_devices(&user_id) + .await + .unwrap(); let stored_devices_set: HashSet<&DeviceId> = stored_devices.keys().collect(); let deleted_devices = stored_devices_set.difference(¤t_devices); @@ -406,7 +432,7 @@ impl OlmMachine { for device_id in deleted_devices { if let Some(device) = stored_devices.get(device_id) { device.mark_as_deleted(); - self.store.delete_device(device).await?; + self.store.write().await.delete_device(device).await?; } } } @@ -430,7 +456,11 @@ impl OlmMachine { let changed_devices = self .handle_devices_from_key_query(&response.device_keys) .await?; - self.store.save_devices(&changed_devices).await?; + self.store + .write() + .await + .save_devices(&changed_devices) + .await?; Ok(changed_devices) } @@ -454,7 +484,7 @@ impl OlmMachine { sender_key: &str, message: &OlmMessage, ) -> OlmResult> { - let s = self.store.get_sessions(sender_key).await?; + let s = self.store.write().await.get_sessions(sender_key).await?; // We don't have any existing sessions, return early. let sessions = if let Some(s) = s { @@ -504,7 +534,7 @@ impl OlmMachine { // Decryption was successful, save the new ratchet state of the // session that was used to decrypt the message. trace!("Saved the new session state for {}", sender); - self.store.save_sessions(&[session]).await?; + self.store.write().await.save_sessions(&[session]).await?; } Ok(plaintext) @@ -559,7 +589,11 @@ impl OlmMachine { // Save the account since we remove the one-time key that // was used to create this session. - self.store.save_account(self.account.clone()).await?; + self.store + .write() + .await + .save_account(self.account.clone()) + .await?; session } }; @@ -569,7 +603,7 @@ impl OlmMachine { let plaintext = session.decrypt(message).await?; // Save the new ratcheted state of the session. - self.store.save_sessions(&[session]).await?; + self.store.write().await.save_sessions(&[session]).await?; plaintext }; @@ -720,7 +754,12 @@ impl OlmMachine { &event.content.room_id, session_key, )?; - let _ = self.store.save_inbound_group_session(session).await?; + let _ = self + .store + .write() + .await + .save_inbound_group_session(session) + .await?; let event = Raw::from(AnyToDeviceEvent::RoomKey(event.clone())); Ok(Some(event)) @@ -742,7 +781,12 @@ impl OlmMachine { async fn create_outbound_group_session(&mut self, room_id: &RoomId) -> OlmResult<()> { let (outbound, inbound) = self.account.create_group_session_pair(room_id).await; - let _ = self.store.save_inbound_group_session(inbound).await?; + let _ = self + .store + .write() + .await + .save_inbound_group_session(inbound) + .await?; let _ = self .outbound_group_sessions @@ -819,7 +863,8 @@ impl OlmMachine { return Err(EventError::MissingSenderKey.into()); }; - let mut session = if let Some(s) = self.store.get_sessions(sender_key).await? { + let mut session = if let Some(s) = self.store.write().await.get_sessions(sender_key).await? + { let session = &s.lock().await[0]; session.clone() } else { @@ -833,7 +878,7 @@ impl OlmMachine { }; let message = session.encrypt(recipient_device, event_type, content).await; - self.store.save_sessions(&[session]).await?; + self.store.write().await.save_sessions(&[session]).await?; message } @@ -897,7 +942,14 @@ impl OlmMachine { let mut devices = Vec::new(); for user_id in users { - for device in self.store.get_user_devices(user_id).await?.devices() { + for device in self + .store + .read() + .await + .get_user_devices(user_id) + .await? + .devices() + { // TODO abort if the device isn't verified devices.push(device.clone()); } @@ -1086,6 +1138,8 @@ impl OlmMachine { let session = self .store + .write() + .await .get_inbound_group_session(room_id, &content.sender_key, &content.session_id) .await?; // TODO check if the Olm session is wedged and re-request the key. @@ -1111,8 +1165,12 @@ impl OlmMachine { /// /// Returns true if the user was queued up for a key query, false otherwise. pub async fn mark_user_as_changed(&mut self, user_id: &UserId) -> StoreResult { - if self.store.tracked_users().contains(user_id) { - self.store.update_tracked_user(user_id, true).await?; + if self.store.read().await.tracked_users().contains(user_id) { + self.store + .write() + .await + .update_tracked_user(user_id, true) + .await?; Ok(true) } else { Ok(false) @@ -1138,26 +1196,32 @@ impl OlmMachine { I: IntoIterator, { for user in users { - if self.store.tracked_users().contains(user) { + if self.store.read().await.tracked_users().contains(user) { continue; } - if let Err(e) = self.store.update_tracked_user(user, true).await { + if let Err(e) = self + .store + .write() + .await + .update_tracked_user(user, true) + .await + { warn!("Error storing users for tracking {}", e); } } } /// Should the client perform a key query request. - pub fn should_query_keys(&self) -> bool { - !self.store.users_for_key_query().is_empty() + pub async fn should_query_keys(&self) -> bool { + !self.store.read().await.users_for_key_query().is_empty() } /// Get the set of users that we need to query keys for. /// /// Returns a hash set of users that need to be queried for keys. - pub fn users_for_key_query(&self) -> HashSet { - self.store.users_for_key_query().clone() + pub async fn users_for_key_query(&self) -> HashSet { + self.store.read().await.users_for_key_query().clone() } } @@ -1277,8 +1341,19 @@ mod test { let alice_deivce = Device::from_machine(&alice).await; let bob_device = Device::from_machine(&bob).await; - alice.store.save_devices(&[bob_device]).await.unwrap(); - bob.store.save_devices(&[alice_deivce]).await.unwrap(); + alice + .store + .write() + .await + .save_devices(&[bob_device]) + .await + .unwrap(); + bob.store + .write() + .await + .save_devices(&[alice_deivce]) + .await + .unwrap(); (alice, bob, otk) } @@ -1311,6 +1386,8 @@ mod test { let bob_device = alice .store + .read() + .await .get_device(&bob.user_id, &bob.device_id) .await .unwrap() @@ -1515,7 +1592,13 @@ mod test { let alice_id = UserId::try_from("@alice:example.org").unwrap(); let alice_device_id: &DeviceId = "JLAFKJWSCS".into(); - let alice_devices = machine.store.get_user_devices(&alice_id).await.unwrap(); + let alice_devices = machine + .store + .read() + .await + .get_user_devices(&alice_id) + .await + .unwrap(); assert!(alice_devices.devices().peekable().peek().is_none()); machine @@ -1525,6 +1608,8 @@ mod test { let device = machine .store + .read() + .await .get_device(&alice_id, alice_device_id) .await .unwrap() @@ -1576,6 +1661,8 @@ mod test { let session = alice_machine .store + .write() + .await .get_sessions(bob_machine.account.identity_keys().curve25519()) .await .unwrap() @@ -1590,6 +1677,8 @@ mod test { let bob_device = alice .store + .read() + .await .get_device(&bob.user_id, &bob.device_id) .await .unwrap() @@ -1651,6 +1740,8 @@ mod test { let session = bob .store + .write() + .await .get_inbound_group_session( &room_id, alice.account.identity_keys().curve25519(),