diff --git a/matrix_sdk_crypto/src/machine.rs b/matrix_sdk_crypto/src/machine.rs index 7cf3e449..e7332096 100644 --- a/matrix_sdk_crypto/src/machine.rs +++ b/matrix_sdk_crypto/src/machine.rs @@ -41,7 +41,6 @@ use matrix_sdk_common::{ Algorithm, AnySyncRoomEvent, AnyToDeviceEvent, EventType, SyncMessageEvent, ToDeviceEvent, }, identifiers::{DeviceId, DeviceKeyAlgorithm, DeviceKeyId, RoomId, UserId}, - locks::RwLock, uuid::Uuid, Raw, }; @@ -80,7 +79,7 @@ pub struct OlmMachine { /// Store for the encryption keys. /// Persists all the encryption keys so a client can resume the session /// without the need to create new keys. - store: Arc>>, + store: Arc>, /// The currently active outbound group sessions. outbound_group_sessions: Arc>, /// A state machine that is responsible to handle and keep track of SAS @@ -111,10 +110,9 @@ impl OlmMachine { /// * `user_id` - The unique id of the user that owns this machine. /// /// * `device_id` - The unique id of the device that owns this machine. - #[allow(clippy::ptr_arg)] pub fn new(user_id: &UserId, device_id: &DeviceId) -> Self { let store: Box = Box::new(MemoryStore::new()); - let store = Arc::new(RwLock::new(store)); + let store = Arc::new(store); let account = Account::new(user_id, device_id); OlmMachine { @@ -160,7 +158,7 @@ impl OlmMachine { } }; - let store = Arc::new(RwLock::new(store)); + let store = Arc::new(store); let verification_machine = VerificationMachine::new(account.clone(), store.clone()); Ok(OlmMachine { @@ -250,11 +248,7 @@ impl OlmMachine { self.update_key_count(count); self.account.mark_keys_as_published().await; - self.store - .write() - .await - .save_account(self.account.clone()) - .await?; + self.store.save_account(self.account.clone()).await?; Ok(()) } @@ -285,7 +279,7 @@ impl OlmMachine { let mut missing = BTreeMap::new(); for user_id in users { - let user_devices = self.store.read().await.get_user_devices(user_id).await?; + let user_devices = self.store.get_user_devices(user_id).await?; for device in user_devices.devices() { let sender_key = if let Some(k) = device.get_key(DeviceKeyAlgorithm::Curve25519) { @@ -294,7 +288,7 @@ impl OlmMachine { continue; }; - let sessions = self.store.write().await.get_sessions(sender_key).await?; + let sessions = self.store.get_sessions(sender_key).await?; let is_missing = if let Some(sessions) = sessions { sessions.lock().await.is_empty() @@ -333,13 +327,7 @@ impl OlmMachine { for (user_id, user_devices) in &response.one_time_keys { for (device_id, key_map) in user_devices { - let device: Device = match self - .store - .read() - .await - .get_device(&user_id, device_id) - .await - { + let device: Device = match self.store.get_device(&user_id, device_id).await { Ok(Some(d)) => d, Ok(None) => { warn!( @@ -368,7 +356,7 @@ impl OlmMachine { } }; - if let Err(e) = self.store.write().await.save_sessions(&[session]).await { + if let Err(e) = self.store.save_sessions(&[session]).await { error!("Failed to store newly created Olm session {}", e); continue; } @@ -389,11 +377,7 @@ impl OlmMachine { let mut changed_devices = Vec::new(); for (user_id, device_map) in device_keys_map { - self.store - .write() - .await - .update_tracked_user(user_id, false) - .await?; + self.store.update_tracked_user(user_id, false).await?; for (device_id, device_keys) in device_map.iter() { // We don't need our own device in the device store. @@ -409,12 +393,7 @@ impl OlmMachine { continue; } - let device = self - .store - .read() - .await - .get_device(&user_id, device_id) - .await?; + let device = self.store.get_device(&user_id, device_id).await?; let device = if let Some(mut device) = device { if let Err(e) = device.update_device(device_keys) { @@ -445,13 +424,7 @@ impl OlmMachine { let current_devices: HashSet<&DeviceId> = device_map.keys().map(|id| id.as_ref()).collect(); - let stored_devices = self - .store - .read() - .await - .get_user_devices(&user_id) - .await - .unwrap(); + let stored_devices = self.store.get_user_devices(&user_id).await.unwrap(); let stored_devices_set: HashSet<&DeviceId> = stored_devices.keys().collect(); let deleted_devices = stored_devices_set.difference(¤t_devices); @@ -459,7 +432,7 @@ impl OlmMachine { for device_id in deleted_devices { if let Some(device) = stored_devices.get(device_id) { device.mark_as_deleted(); - self.store.write().await.delete_device(device).await?; + self.store.delete_device(device).await?; } } } @@ -483,11 +456,7 @@ impl OlmMachine { let changed_devices = self .handle_devices_from_key_query(&response.device_keys) .await?; - self.store - .write() - .await - .save_devices(&changed_devices) - .await?; + self.store.save_devices(&changed_devices).await?; Ok(changed_devices) } @@ -511,7 +480,7 @@ impl OlmMachine { sender_key: &str, message: &OlmMessage, ) -> OlmResult> { - let s = self.store.write().await.get_sessions(sender_key).await?; + let s = self.store.get_sessions(sender_key).await?; // We don't have any existing sessions, return early. let sessions = if let Some(s) = s { @@ -561,7 +530,7 @@ impl OlmMachine { // Decryption was successful, save the new ratchet state of the // session that was used to decrypt the message. trace!("Saved the new session state for {}", sender); - self.store.write().await.save_sessions(&[session]).await?; + self.store.save_sessions(&[session]).await?; } Ok(plaintext) @@ -616,11 +585,7 @@ impl OlmMachine { // Save the account since we remove the one-time key that // was used to create this session. - self.store - .write() - .await - .save_account(self.account.clone()) - .await?; + self.store.save_account(self.account.clone()).await?; session } }; @@ -630,7 +595,7 @@ impl OlmMachine { let plaintext = session.decrypt(message).await?; // Save the new ratcheted state of the session. - self.store.write().await.save_sessions(&[session]).await?; + self.store.save_sessions(&[session]).await?; plaintext }; @@ -781,12 +746,7 @@ impl OlmMachine { &event.content.room_id, session_key, )?; - let _ = self - .store - .write() - .await - .save_inbound_group_session(session) - .await?; + let _ = self.store.save_inbound_group_session(session).await?; let event = Raw::from(AnyToDeviceEvent::RoomKey(event.clone())); Ok(Some(event)) @@ -808,12 +768,7 @@ impl OlmMachine { async fn create_outbound_group_session(&self, room_id: &RoomId) -> OlmResult<()> { let (outbound, inbound) = self.account.create_group_session_pair(room_id).await; - let _ = self - .store - .write() - .await - .save_inbound_group_session(inbound) - .await?; + let _ = self.store.save_inbound_group_session(inbound).await?; let _ = self .outbound_group_sessions @@ -899,8 +854,7 @@ impl OlmMachine { return Err(EventError::MissingSenderKey.into()); }; - let mut session = if let Some(s) = self.store.write().await.get_sessions(sender_key).await? - { + let mut session = if let Some(s) = self.store.get_sessions(sender_key).await? { let session = &s.lock().await[0]; session.clone() } else { @@ -914,7 +868,7 @@ impl OlmMachine { }; let message = session.encrypt(recipient_device, event_type, content).await; - self.store.write().await.save_sessions(&[session]).await?; + self.store.save_sessions(&[session]).await?; message } @@ -978,14 +932,7 @@ impl OlmMachine { let mut devices = Vec::new(); for user_id in users { - for device in self - .store - .read() - .await - .get_user_devices(user_id) - .await? - .devices() - { + for device in self.store.get_user_devices(user_id).await?.devices() { devices.push(device.clone()); } } @@ -1192,8 +1139,6 @@ impl OlmMachine { let session = self .store - .write() - .await .get_inbound_group_session(room_id, &content.sender_key, &content.session_id) .await?; // TODO check if the Olm session is wedged and re-request the key. @@ -1219,12 +1164,8 @@ impl OlmMachine { /// /// Returns true if the user was queued up for a key query, false otherwise. pub async fn mark_user_as_changed(&self, user_id: &UserId) -> StoreResult { - if self.store.read().await.is_user_tracked(user_id) { - self.store - .write() - .await - .update_tracked_user(user_id, true) - .await?; + if self.store.is_user_tracked(user_id) { + self.store.update_tracked_user(user_id, true).await?; Ok(true) } else { Ok(false) @@ -1250,17 +1191,11 @@ impl OlmMachine { I: IntoIterator, { for user in users { - if self.store.read().await.is_user_tracked(user) { + if self.store.is_user_tracked(user) { continue; } - if let Err(e) = self - .store - .write() - .await - .update_tracked_user(user, true) - .await - { + if let Err(e) = self.store.update_tracked_user(user, true).await { warn!("Error storing users for tracking {}", e); } } @@ -1268,14 +1203,14 @@ impl OlmMachine { /// Should the client perform a key query request. pub async fn should_query_keys(&self) -> bool { - self.store.read().await.has_users_for_key_query() + self.store.has_users_for_key_query() } /// Get the set of users that we need to query keys for. /// /// Returns a hash set of users that need to be queried for keys. pub async fn users_for_key_query(&self) -> HashSet { - self.store.read().await.users_for_key_query() + self.store.users_for_key_query() } } @@ -1398,19 +1333,8 @@ mod test { let alice_deivce = Device::from_machine(&alice).await; let bob_device = Device::from_machine(&bob).await; - alice - .store - .write() - .await - .save_devices(&[bob_device]) - .await - .unwrap(); - bob.store - .write() - .await - .save_devices(&[alice_deivce]) - .await - .unwrap(); + alice.store.save_devices(&[bob_device]).await.unwrap(); + bob.store.save_devices(&[alice_deivce]).await.unwrap(); (alice, bob, otk) } @@ -1443,8 +1367,6 @@ mod test { let bob_device = alice .store - .read() - .await .get_device(&bob.user_id, &bob.device_id) .await .unwrap() @@ -1649,13 +1571,7 @@ mod test { let alice_id = user_id!("@alice:example.org"); let alice_device_id: &DeviceId = "JLAFKJWSCS".into(); - let alice_devices = machine - .store - .read() - .await - .get_user_devices(&alice_id) - .await - .unwrap(); + let alice_devices = machine.store.get_user_devices(&alice_id).await.unwrap(); assert!(alice_devices.devices().peekable().peek().is_none()); machine @@ -1665,8 +1581,6 @@ mod test { let device = machine .store - .read() - .await .get_device(&alice_id, alice_device_id) .await .unwrap() @@ -1718,8 +1632,6 @@ mod test { let session = alice_machine .store - .write() - .await .get_sessions(bob_machine.account.identity_keys().curve25519()) .await .unwrap() @@ -1734,8 +1646,6 @@ mod test { let bob_device = alice .store - .read() - .await .get_device(&bob.user_id, &bob.device_id) .await .unwrap() @@ -1797,8 +1707,6 @@ mod test { let session = bob .store - .write() - .await .get_inbound_group_session( &room_id, alice.account.identity_keys().curve25519(), diff --git a/matrix_sdk_crypto/src/store/memorystore.rs b/matrix_sdk_crypto/src/store/memorystore.rs index f94b0a8b..4ad08cc0 100644 --- a/matrix_sdk_crypto/src/store/memorystore.rs +++ b/matrix_sdk_crypto/src/store/memorystore.rs @@ -26,7 +26,7 @@ use crate::{ device::Device, memory_stores::{DeviceStore, GroupSessionStore, SessionStore, UserDevices}, }; -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct MemoryStore { sessions: SessionStore, inbound_group_sessions: GroupSessionStore, diff --git a/matrix_sdk_crypto/src/store/sqlite.rs b/matrix_sdk_crypto/src/store/sqlite.rs index 17d67be5..1800a463 100644 --- a/matrix_sdk_crypto/src/store/sqlite.rs +++ b/matrix_sdk_crypto/src/store/sqlite.rs @@ -41,6 +41,7 @@ use crate::{ }; /// SQLite based implementation of a `CryptoStore`. +#[derive(Clone)] pub struct SqliteStore { user_id: Arc, device_id: Arc>, diff --git a/matrix_sdk_crypto/src/verification/machine.rs b/matrix_sdk_crypto/src/verification/machine.rs index ca3fc2f4..cea74359 100644 --- a/matrix_sdk_crypto/src/verification/machine.rs +++ b/matrix_sdk_crypto/src/verification/machine.rs @@ -22,7 +22,6 @@ use matrix_sdk_common::{ api::r0::to_device::send_event_to_device::IncomingRequest as OwnedToDeviceRequest, events::{AnyToDeviceEvent, AnyToDeviceEventContent}, identifiers::{DeviceId, UserId}, - locks::RwLock, }; use super::sas::{content_to_request, Sas}; @@ -31,13 +30,13 @@ use crate::{Account, CryptoStore, CryptoStoreError}; #[derive(Clone, Debug)] pub struct VerificationMachine { account: Account, - store: Arc>>, + store: Arc>, verifications: Arc>, outgoing_to_device_messages: Arc>, } impl VerificationMachine { - pub(crate) fn new(account: Account, store: Arc>>) -> Self { + pub(crate) fn new(account: Account, store: Arc>) -> Self { Self { account, store, @@ -112,8 +111,6 @@ impl VerificationMachine { if let Some(d) = self .store - .read() - .await .get_device(&e.sender, &e.content.from_device) .await? { @@ -179,7 +176,6 @@ mod test { use matrix_sdk_common::{ events::AnyToDeviceEventContent, identifiers::{DeviceId, UserId}, - locks::RwLock, }; use super::{Sas, VerificationMachine}; @@ -209,21 +205,18 @@ mod test { let alice = Account::new(&alice_id(), &alice_device_id()); let bob = Account::new(&bob_id(), &bob_device_id()); let store = MemoryStore::new(); - let bob_store: Arc>> = - Arc::new(RwLock::new(Box::new(MemoryStore::new()))); + let bob_store: Arc> = Arc::new(Box::new(MemoryStore::new())); let bob_device = Device::from_account(&bob).await; let alice_device = Device::from_account(&alice).await; store.save_devices(&[bob_device]).await.unwrap(); bob_store - .read() - .await .save_devices(&[alice_device.clone()]) .await .unwrap(); - let machine = VerificationMachine::new(alice, Arc::new(RwLock::new(Box::new(store)))); + let machine = VerificationMachine::new(alice, Arc::new(Box::new(store))); let (bob_sas, start_content) = Sas::start(bob, alice_device, bob_store); machine .receive_event(&mut wrap_any_to_device_content( @@ -240,7 +233,7 @@ mod test { fn create() { let alice = Account::new(&alice_id(), &alice_device_id()); let store = MemoryStore::new(); - let _ = VerificationMachine::new(alice, Arc::new(RwLock::new(Box::new(store)))); + let _ = VerificationMachine::new(alice, Arc::new(Box::new(store))); } #[tokio::test] diff --git a/matrix_sdk_crypto/src/verification/sas/helpers.rs b/matrix_sdk_crypto/src/verification/sas/helpers.rs index 1715db9b..3e0fb143 100644 --- a/matrix_sdk_crypto/src/verification/sas/helpers.rs +++ b/matrix_sdk_crypto/src/verification/sas/helpers.rs @@ -212,8 +212,6 @@ fn extra_mac_info_send(ids: &SasIds, flow_id: &str) -> String { /// /// * `flow_id` - The unique id that identifies this SAS verification process. /// -/// * `we_started` - Flag signaling if the SAS process was started on our side. -/// /// # Panics /// /// This will panic if the public key of the other side wasn't set. diff --git a/matrix_sdk_crypto/src/verification/sas/mod.rs b/matrix_sdk_crypto/src/verification/sas/mod.rs index 49c86cbf..a3d2d40e 100644 --- a/matrix_sdk_crypto/src/verification/sas/mod.rs +++ b/matrix_sdk_crypto/src/verification/sas/mod.rs @@ -31,7 +31,6 @@ use matrix_sdk_common::{ AnyToDeviceEvent, AnyToDeviceEventContent, ToDeviceEvent, }, identifiers::{DeviceId, UserId}, - locks::RwLock, }; use crate::{Account, CryptoStore, CryptoStoreError, Device, TrustState}; @@ -45,7 +44,7 @@ use sas_state::{ /// Short authentication string object. pub struct Sas { inner: Arc>, - store: Arc>>, + store: Arc>, account: Account, other_device: Device, flow_id: Arc, @@ -100,7 +99,7 @@ impl Sas { pub(crate) fn start( account: Account, other_device: Device, - store: Arc>>, + store: Arc>, ) -> (Sas, StartEventContent) { let (inner, content) = InnerSas::start(account.clone(), other_device.clone()); let flow_id = inner.verification_flow_id(); @@ -129,7 +128,7 @@ impl Sas { pub(crate) fn from_start_event( account: Account, other_device: Device, - store: Arc>>, + store: Arc>, event: &ToDeviceEvent, ) -> Result { let inner = InnerSas::from_start_event(account.clone(), other_device.clone(), event)?; @@ -184,8 +183,6 @@ impl Sas { pub(crate) async fn mark_device_as_verified(&self) -> Result { let device = self .store - .read() - .await .get_device(self.other_user_id(), self.other_device_id()) .await?; @@ -202,7 +199,7 @@ impl Sas { ); device.set_trust_state(TrustState::Verified); - self.store.read().await.save_devices(&[device]).await?; + self.store.save_devices(&[device]).await?; Ok(true) } else { @@ -560,7 +557,6 @@ mod test { use matrix_sdk_common::{ events::{EventContent, ToDeviceEvent}, identifiers::{DeviceId, UserId}, - locks::RwLock, }; use crate::{ @@ -685,14 +681,10 @@ mod test { let bob = Account::new(&bob_id(), &bob_device_id()); let bob_device = Device::from_account(&bob).await; - let alice_store: Arc>> = - Arc::new(RwLock::new(Box::new(MemoryStore::new()))); - let bob_store: Arc>> = - Arc::new(RwLock::new(Box::new(MemoryStore::new()))); + let alice_store: Arc> = Arc::new(Box::new(MemoryStore::new())); + let bob_store: Arc> = Arc::new(Box::new(MemoryStore::new())); bob_store - .read() - .await .save_devices(&[alice_device.clone()]) .await .unwrap();