diff --git a/Cargo.toml b/Cargo.toml index c4925020..fc993d0a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,7 +8,3 @@ members = [ "matrix_sdk_common", "matrix_sdk_common_macros", ] - -[patch.crates-io] -olm-rs = { git = 'https://gitlab.gnome.org/jhaye/olm-rs/'} -olm-sys = { git = 'https://gitlab.gnome.org/BrainBlasted/olm-sys' } diff --git a/matrix_sdk/Cargo.toml b/matrix_sdk/Cargo.toml index b72669af..faf81034 100644 --- a/matrix_sdk/Cargo.toml +++ b/matrix_sdk/Cargo.toml @@ -21,9 +21,9 @@ async-trait = "0.1.36" http = "0.2.1" # FIXME: Revert to regular dependency once 0.10.8 or 0.11.0 is released reqwest = { git = "https://github.com/seanmonstar/reqwest", rev = "dd8441fd23dae6ffb79b4cea2862e5bca0c59743" } -serde_json = "1.0.56" +serde_json = "1.0.57" thiserror = "1.0.20" -tracing = "0.1.16" +tracing = "0.1.19" url = "2.1.1" matrix-sdk-common-macros = { version = "0.1.0", path = "../matrix_sdk_common_macros" } @@ -50,11 +50,11 @@ features = ["wasm-bindgen"] async-trait = "0.1.36" dirs = "3.0.1" matrix-sdk-test = { version = "0.1.0", path = "../matrix_sdk_test" } -tokio = { version = "0.2.21", features = ["rt-threaded", "macros"] } -serde_json = "1.0.56" -tracing-subscriber = "0.2.7" +tokio = { version = "0.2.22", features = ["rt-threaded", "macros"] } +serde_json = "1.0.57" +tracing-subscriber = "0.2.11" tempfile = "3.1.0" -mockito = "0.26.0" +mockito = "0.27.0" lazy_static = "1.4.0" futures = "0.3.5" diff --git a/matrix_sdk/src/client.rs b/matrix_sdk/src/client.rs index d6f43919..e7233fa5 100644 --- a/matrix_sdk/src/client.rs +++ b/matrix_sdk/src/client.rs @@ -73,7 +73,7 @@ pub struct Client { pub(crate) base_client: BaseClient, } -// #[cfg_attr(tarpaulin, skip)] +#[cfg(not(tarpaulin_include))] impl Debug for Client { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> StdResult<(), fmt::Error> { write!(fmt, "Client {{ homeserver: {} }}", self.homeserver) @@ -115,7 +115,7 @@ pub struct ClientConfig { pub(crate) client: Option>, } -// #[cfg_attr(tarpaulin, skip)] +#[cfg(not(tarpaulin_include))] impl Debug for ClientConfig { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { let mut res = fmt.debug_struct("ClientConfig"); diff --git a/matrix_sdk_base/Cargo.toml b/matrix_sdk_base/Cargo.toml index 5c16e448..b1e919d4 100644 --- a/matrix_sdk_base/Cargo.toml +++ b/matrix_sdk_base/Cargo.toml @@ -18,10 +18,10 @@ sqlite-cryptostore = ["matrix-sdk-crypto/sqlite-cryptostore"] [dependencies] async-trait = "0.1.36" -serde = "1.0.114" -serde_json = "1.0.56" +serde = "1.0.115" +serde_json = "1.0.57" zeroize = "1.1.0" -tracing = "0.1.16" +tracing = "0.1.19" matrix-sdk-common-macros = { version = "0.1.0", path = "../matrix_sdk_common_macros" } matrix-sdk-common = { version = "0.1.0", path = "../matrix_sdk_common" } @@ -31,19 +31,19 @@ matrix-sdk-crypto = { version = "0.1.0", path = "../matrix_sdk_crypto", optional thiserror = "1.0.20" [target.'cfg(not(target_arch = "wasm32"))'.dependencies.tokio] -version = "0.2.21" +version = "0.2.22" default-features = false features = ["sync", "fs"] [dev-dependencies] matrix-sdk-test = { version = "0.1.0", path = "../matrix_sdk_test" } http = "0.2.1" -tracing-subscriber = "0.2.7" +tracing-subscriber = "0.2.11" tempfile = "3.1.0" -mockito = "0.26.0" +mockito = "0.27.0" [target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies] -tokio = { version = "0.2.21", features = ["rt-threaded", "macros"] } +tokio = { version = "0.2.22", features = ["rt-threaded", "macros"] } [target.'cfg(target_arch = "wasm32")'.dev-dependencies] -wasm-bindgen-test = "0.3.15" +wasm-bindgen-test = "0.3.17" diff --git a/matrix_sdk_base/src/client.rs b/matrix_sdk_base/src/client.rs index 21168bf0..fc26e0a7 100644 --- a/matrix_sdk_base/src/client.rs +++ b/matrix_sdk_base/src/client.rs @@ -212,7 +212,7 @@ pub struct BaseClient { store_passphrase: Arc>, } -// #[cfg_attr(tarpaulin, skip)] +#[cfg(not(tarpaulin_include))] impl fmt::Debug for BaseClient { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Client") @@ -246,7 +246,7 @@ pub struct BaseClientConfig { passphrase: Option>, } -// #[cfg_attr(tarpaulin, skip)] +#[cfg(not(tarpaulin_include))] impl std::fmt::Debug for BaseClientConfig { fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> StdResult<(), std::fmt::Error> { fmt.debug_struct("BaseClientConfig").finish() diff --git a/matrix_sdk_common/Cargo.toml b/matrix_sdk_common/Cargo.toml index 71f5ac16..bb8aff84 100644 --- a/matrix_sdk_common/Cargo.toml +++ b/matrix_sdk_common/Cargo.toml @@ -12,7 +12,7 @@ version = "0.1.0" [dependencies] instant = { version = "0.1.6", features = ["wasm-bindgen", "now"] } -js_int = "0.1.8" +js_int = "0.1.9" [dependencies.ruma] version = "0.0.1" @@ -24,7 +24,7 @@ features = ["client-api"] uuid = { version = "0.8.1", features = ["v4"] } [target.'cfg(not(target_arch = "wasm32"))'.dependencies.tokio] -version = "0.2.21" +version = "0.2.22" default-features = false features = ["sync", "time", "fs"] diff --git a/matrix_sdk_crypto/Cargo.toml b/matrix_sdk_crypto/Cargo.toml index 509a3ee3..64f2d2f3 100644 --- a/matrix_sdk_crypto/Cargo.toml +++ b/matrix_sdk_crypto/Cargo.toml @@ -20,18 +20,18 @@ async-trait = "0.1.36" matrix-sdk-common-macros = { version = "0.1.0", path = "../matrix_sdk_common_macros" } matrix-sdk-common = { version = "0.1.0", path = "../matrix_sdk_common" } -olm-rs = { version = "0.5.0", features = ["serde"] } -serde = { version = "1.0.114", features = ["derive"] } -serde_json = "1.0.56" +olm-rs = { git = 'https://gitlab.gnome.org/jhaye/olm-rs/', features = ["serde"]} +serde = { version = "1.0.115", features = ["derive"] } +serde_json = "1.0.57" cjson = "0.1.1" zeroize = { version = "1.1.0", features = ["zeroize_derive"] } url = "2.1.1" # Misc dependencies thiserror = "1.0.20" -tracing = "0.1.16" -atomic = "0.4.6" -dashmap = "3.11.7" +tracing = "0.1.19" +atomic = "0.5.0" +dashmap = "3.11.10" [dependencies.tracing-futures] version = "0.2.4" diff --git a/matrix_sdk_crypto/src/device.rs b/matrix_sdk_crypto/src/device.rs index 62a54bb8..eafb193b 100644 --- a/matrix_sdk_crypto/src/device.rs +++ b/matrix_sdk_crypto/src/device.rs @@ -198,7 +198,7 @@ impl Device { #[cfg(test)] pub async fn from_machine(machine: &OlmMachine) -> Device { - Device::from_account(&machine.account).await + Device::from_account(machine.account()).await } #[cfg(test)] diff --git a/matrix_sdk_crypto/src/machine.rs b/matrix_sdk_crypto/src/machine.rs index 340af00f..73adde35 100644 --- a/matrix_sdk_crypto/src/machine.rs +++ b/matrix_sdk_crypto/src/machine.rs @@ -41,7 +41,6 @@ use matrix_sdk_common::{ Algorithm, AnySyncRoomEvent, AnyToDeviceEvent, EventType, SyncMessageEvent, ToDeviceEvent, }, identifiers::{DeviceId, DeviceKeyAlgorithm, DeviceKeyId, RoomId, UserId}, - locks::RwLock, uuid::Uuid, Raw, }; @@ -76,11 +75,11 @@ pub struct OlmMachine { /// The unique device id of the device that holds this account. device_id: Box, /// Our underlying Olm Account holding our identity keys. - pub(crate) account: Account, + account: Account, /// Store for the encryption keys. /// Persists all the encryption keys so a client can resume the session /// without the need to create new keys. - store: Arc>>, + store: Arc>, /// The currently active outbound group sessions. outbound_group_sessions: Arc>, /// A state machine that is responsible to handle and keep track of SAS @@ -88,7 +87,7 @@ pub struct OlmMachine { verification_machine: VerificationMachine, } -// #[cfg_attr(tarpaulin, skip)] +#[cfg(not(tarpaulin_include))] impl std::fmt::Debug for OlmMachine { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("OlmMachine") @@ -111,10 +110,9 @@ impl OlmMachine { /// * `user_id` - The unique id of the user that owns this machine. /// /// * `device_id` - The unique id of the device that owns this machine. - #[allow(clippy::ptr_arg)] pub fn new(user_id: &UserId, device_id: &DeviceId) -> Self { let store: Box = Box::new(MemoryStore::new()); - let store = Arc::new(RwLock::new(store)); + let store = Arc::new(store); let account = Account::new(user_id, device_id); OlmMachine { @@ -147,7 +145,7 @@ impl OlmMachine { pub async fn new_with_store( user_id: UserId, device_id: Box, - mut store: Box, + store: Box, ) -> StoreResult { let account = match store.load_account().await? { Some(a) => { @@ -160,7 +158,7 @@ impl OlmMachine { } }; - let store = Arc::new(RwLock::new(store)); + let store = Arc::new(store); let verification_machine = VerificationMachine::new(account.clone(), store.clone()); Ok(OlmMachine { @@ -216,6 +214,12 @@ impl OlmMachine { self.account.should_upload_keys().await } + /// Get the underlying Olm account of the machine. + #[cfg(test)] + pub(crate) fn account(&self) -> &Account { + &self.account + } + /// Update the count of one-time keys that are currently on the server. fn update_key_count(&self, count: u64) { self.account.update_uploaded_key_count(count); @@ -250,11 +254,7 @@ impl OlmMachine { self.update_key_count(count); self.account.mark_keys_as_published().await; - self.store - .write() - .await - .save_account(self.account.clone()) - .await?; + self.store.save_account(self.account.clone()).await?; Ok(()) } @@ -285,7 +285,7 @@ impl OlmMachine { let mut missing = BTreeMap::new(); for user_id in users { - let user_devices = self.store.read().await.get_user_devices(user_id).await?; + let user_devices = self.store.get_user_devices(user_id).await?; for device in user_devices.devices() { let sender_key = if let Some(k) = device.get_key(DeviceKeyAlgorithm::Curve25519) { @@ -294,7 +294,7 @@ impl OlmMachine { continue; }; - let sessions = self.store.write().await.get_sessions(sender_key).await?; + let sessions = self.store.get_sessions(sender_key).await?; let is_missing = if let Some(sessions) = sessions { sessions.lock().await.is_empty() @@ -333,13 +333,7 @@ impl OlmMachine { for (user_id, user_devices) in &response.one_time_keys { for (device_id, key_map) in user_devices { - let device: Device = match self - .store - .read() - .await - .get_device(&user_id, device_id) - .await - { + let device: Device = match self.store.get_device(&user_id, device_id).await { Ok(Some(d)) => d, Ok(None) => { warn!( @@ -368,7 +362,7 @@ impl OlmMachine { } }; - if let Err(e) = self.store.write().await.save_sessions(&[session]).await { + if let Err(e) = self.store.save_sessions(&[session]).await { error!("Failed to store newly created Olm session {}", e); continue; } @@ -389,11 +383,7 @@ impl OlmMachine { let mut changed_devices = Vec::new(); for (user_id, device_map) in device_keys_map { - self.store - .write() - .await - .update_tracked_user(user_id, false) - .await?; + self.store.update_tracked_user(user_id, false).await?; for (device_id, device_keys) in device_map.iter() { // We don't need our own device in the device store. @@ -409,12 +399,7 @@ impl OlmMachine { continue; } - let device = self - .store - .read() - .await - .get_device(&user_id, device_id) - .await?; + let device = self.store.get_device(&user_id, device_id).await?; let device = if let Some(mut device) = device { if let Err(e) = device.update_device(device_keys) { @@ -445,13 +430,7 @@ impl OlmMachine { let current_devices: HashSet<&DeviceId> = device_map.keys().map(|id| id.as_ref()).collect(); - let stored_devices = self - .store - .read() - .await - .get_user_devices(&user_id) - .await - .unwrap(); + let stored_devices = self.store.get_user_devices(&user_id).await.unwrap(); let stored_devices_set: HashSet<&DeviceId> = stored_devices.keys().collect(); let deleted_devices = stored_devices_set.difference(¤t_devices); @@ -459,7 +438,7 @@ impl OlmMachine { for device_id in deleted_devices { if let Some(device) = stored_devices.get(device_id) { device.mark_as_deleted(); - self.store.write().await.delete_device(device).await?; + self.store.delete_device(device).await?; } } } @@ -483,11 +462,7 @@ impl OlmMachine { let changed_devices = self .handle_devices_from_key_query(&response.device_keys) .await?; - self.store - .write() - .await - .save_devices(&changed_devices) - .await?; + self.store.save_devices(&changed_devices).await?; Ok(changed_devices) } @@ -511,7 +486,7 @@ impl OlmMachine { sender_key: &str, message: &OlmMessage, ) -> OlmResult> { - let s = self.store.write().await.get_sessions(sender_key).await?; + let s = self.store.get_sessions(sender_key).await?; // We don't have any existing sessions, return early. let sessions = if let Some(s) = s { @@ -561,7 +536,7 @@ impl OlmMachine { // Decryption was successful, save the new ratchet state of the // session that was used to decrypt the message. trace!("Saved the new session state for {}", sender); - self.store.write().await.save_sessions(&[session]).await?; + self.store.save_sessions(&[session]).await?; } Ok(plaintext) @@ -616,11 +591,7 @@ impl OlmMachine { // Save the account since we remove the one-time key that // was used to create this session. - self.store - .write() - .await - .save_account(self.account.clone()) - .await?; + self.store.save_account(self.account.clone()).await?; session } }; @@ -630,7 +601,7 @@ impl OlmMachine { let plaintext = session.decrypt(message).await?; // Save the new ratcheted state of the session. - self.store.write().await.save_sessions(&[session]).await?; + self.store.save_sessions(&[session]).await?; plaintext }; @@ -781,12 +752,7 @@ impl OlmMachine { &event.content.room_id, session_key, )?; - let _ = self - .store - .write() - .await - .save_inbound_group_session(session) - .await?; + let _ = self.store.save_inbound_group_session(session).await?; let event = Raw::from(AnyToDeviceEvent::RoomKey(event.clone())); Ok(Some(event)) @@ -808,12 +774,7 @@ impl OlmMachine { async fn create_outbound_group_session(&self, room_id: &RoomId) -> OlmResult<()> { let (outbound, inbound) = self.account.create_group_session_pair(room_id).await; - let _ = self - .store - .write() - .await - .save_inbound_group_session(inbound) - .await?; + let _ = self.store.save_inbound_group_session(inbound).await?; let _ = self .outbound_group_sessions @@ -899,8 +860,7 @@ impl OlmMachine { return Err(EventError::MissingSenderKey.into()); }; - let mut session = if let Some(s) = self.store.write().await.get_sessions(sender_key).await? - { + let mut session = if let Some(s) = self.store.get_sessions(sender_key).await? { let session = &s.lock().await[0]; session.clone() } else { @@ -914,7 +874,7 @@ impl OlmMachine { }; let message = session.encrypt(recipient_device, event_type, content).await; - self.store.write().await.save_sessions(&[session]).await?; + self.store.save_sessions(&[session]).await?; message } @@ -969,7 +929,7 @@ impl OlmMachine { panic!("Session is already shared"); } - // TODO don't mark the session as shared automatically only, when all + // TODO don't mark the session as shared automatically, only when all // the requests are done, failure to send these requests will likely end // up in wedged sessions. We'll need to store the requests and let the // caller mark them as sent using an UUID. @@ -978,15 +938,7 @@ impl OlmMachine { let mut devices = Vec::new(); for user_id in users { - for device in self - .store - .read() - .await - .get_user_devices(user_id) - .await? - .devices() - { - // TODO abort if the device isn't verified + for device in self.store.get_user_devices(user_id).await?.devices() { devices.push(device.clone()); } } @@ -1193,8 +1145,6 @@ impl OlmMachine { let session = self .store - .write() - .await .get_inbound_group_session(room_id, &content.sender_key, &content.session_id) .await?; // TODO check if the Olm session is wedged and re-request the key. @@ -1220,12 +1170,8 @@ impl OlmMachine { /// /// Returns true if the user was queued up for a key query, false otherwise. pub async fn mark_user_as_changed(&self, user_id: &UserId) -> StoreResult { - if self.store.read().await.tracked_users().contains(user_id) { - self.store - .write() - .await - .update_tracked_user(user_id, true) - .await?; + if self.store.is_user_tracked(user_id) { + self.store.update_tracked_user(user_id, true).await?; Ok(true) } else { Ok(false) @@ -1251,17 +1197,11 @@ impl OlmMachine { I: IntoIterator, { for user in users { - if self.store.read().await.tracked_users().contains(user) { + if self.store.is_user_tracked(user) { continue; } - if let Err(e) = self - .store - .write() - .await - .update_tracked_user(user, true) - .await - { + if let Err(e) = self.store.update_tracked_user(user, true).await { warn!("Error storing users for tracking {}", e); } } @@ -1269,14 +1209,14 @@ impl OlmMachine { /// Should the client perform a key query request. pub async fn should_query_keys(&self) -> bool { - !self.store.read().await.users_for_key_query().is_empty() + self.store.has_users_for_key_query() } /// Get the set of users that we need to query keys for. /// /// Returns a hash set of users that need to be queried for keys. pub async fn users_for_key_query(&self) -> HashSet { - self.store.read().await.users_for_key_query().clone() + self.store.users_for_key_query() } } @@ -1399,19 +1339,8 @@ mod test { let alice_deivce = Device::from_machine(&alice).await; let bob_device = Device::from_machine(&bob).await; - alice - .store - .write() - .await - .save_devices(&[bob_device]) - .await - .unwrap(); - bob.store - .write() - .await - .save_devices(&[alice_deivce]) - .await - .unwrap(); + alice.store.save_devices(&[bob_device]).await.unwrap(); + bob.store.save_devices(&[alice_deivce]).await.unwrap(); (alice, bob, otk) } @@ -1444,8 +1373,6 @@ mod test { let bob_device = alice .store - .read() - .await .get_device(&bob.user_id, &bob.device_id) .await .unwrap() @@ -1650,13 +1577,7 @@ mod test { let alice_id = user_id!("@alice:example.org"); let alice_device_id: &DeviceId = "JLAFKJWSCS".into(); - let alice_devices = machine - .store - .read() - .await - .get_user_devices(&alice_id) - .await - .unwrap(); + let alice_devices = machine.store.get_user_devices(&alice_id).await.unwrap(); assert!(alice_devices.devices().peekable().peek().is_none()); machine @@ -1666,8 +1587,6 @@ mod test { let device = machine .store - .read() - .await .get_device(&alice_id, alice_device_id) .await .unwrap() @@ -1719,8 +1638,6 @@ mod test { let session = alice_machine .store - .write() - .await .get_sessions(bob_machine.account.identity_keys().curve25519()) .await .unwrap() @@ -1735,8 +1652,6 @@ mod test { let bob_device = alice .store - .read() - .await .get_device(&bob.user_id, &bob.device_id) .await .unwrap() @@ -1798,8 +1713,6 @@ mod test { let session = bob .store - .write() - .await .get_inbound_group_session( &room_id, alice.account.identity_keys().curve25519(), diff --git a/matrix_sdk_crypto/src/memory_stores.rs b/matrix_sdk_crypto/src/memory_stores.rs index c6828a78..1e1f94fa 100644 --- a/matrix_sdk_crypto/src/memory_stores.rs +++ b/matrix_sdk_crypto/src/memory_stores.rs @@ -43,7 +43,7 @@ impl SessionStore { /// /// Returns true if the the session was added, false if the session was /// already in the store. - pub async fn add(&mut self, session: Session) -> bool { + pub async fn add(&self, session: Session) -> bool { if !self.entries.contains_key(&*session.sender_key) { self.entries.insert( session.sender_key.to_string(), @@ -62,11 +62,12 @@ impl SessionStore { /// Get all the sessions that belong to the given sender key. pub fn get(&self, sender_key: &str) -> Option>>> { + #[allow(clippy::map_clone)] self.entries.get(sender_key).map(|s| s.clone()) } /// Add a list of sessions belonging to the sender key. - pub fn set_for_sender(&mut self, sender_key: &str, sessions: Vec) { + pub fn set_for_sender(&self, sender_key: &str, sessions: Vec) { self.entries .insert(sender_key.to_owned(), Arc::new(Mutex::new(sessions))); } @@ -75,6 +76,7 @@ impl SessionStore { #[derive(Debug, Default, Clone)] /// In-memory store that holds inbound group sessions. pub struct GroupSessionStore { + #[allow(clippy::type_complexity)] entries: Arc>>>, } @@ -90,7 +92,7 @@ impl GroupSessionStore { /// /// Returns true if the the session was added, false if the session was /// already in the store. - pub fn add(&mut self, session: InboundGroupSession) -> bool { + pub fn add(&self, session: InboundGroupSession) -> bool { if !self.entries.contains_key(&session.room_id) { let room_id = &*session.room_id; self.entries.insert(room_id.clone(), HashMap::new()); @@ -223,7 +225,7 @@ mod test { async fn test_session_store() { let (_, session) = get_account_and_session().await; - let mut store = SessionStore::new(); + let store = SessionStore::new(); assert!(store.add(session.clone()).await); assert!(!store.add(session.clone()).await); @@ -240,7 +242,7 @@ mod test { async fn test_session_store_bulk_storing() { let (_, session) = get_account_and_session().await; - let mut store = SessionStore::new(); + let store = SessionStore::new(); store.set_for_sender(&session.sender_key, vec![session.clone()]); let sessions = store.get(&session.sender_key).unwrap(); @@ -271,7 +273,7 @@ mod test { ) .unwrap(); - let mut store = GroupSessionStore::new(); + let store = GroupSessionStore::new(); store.add(inbound.clone()); let loaded_session = store diff --git a/matrix_sdk_crypto/src/olm/account.rs b/matrix_sdk_crypto/src/olm/account.rs index a8cd18ee..5ee89216 100644 --- a/matrix_sdk_crypto/src/olm/account.rs +++ b/matrix_sdk_crypto/src/olm/account.rs @@ -63,7 +63,7 @@ pub struct Account { uploaded_signed_key_count: Arc, } -// #[cfg_attr(tarpaulin, skip)] +#[cfg(not(tarpaulin_include))] impl fmt::Debug for Account { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Account") diff --git a/matrix_sdk_crypto/src/olm/group_sessions.rs b/matrix_sdk_crypto/src/olm/group_sessions.rs index bbb36227..717877b6 100644 --- a/matrix_sdk_crypto/src/olm/group_sessions.rs +++ b/matrix_sdk_crypto/src/olm/group_sessions.rs @@ -222,7 +222,7 @@ impl InboundGroupSession { } } -// #[cfg_attr(tarpaulin, skip)] +#[cfg(not(tarpaulin_include))] impl fmt::Debug for InboundGroupSession { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("InboundGroupSession") @@ -401,7 +401,7 @@ impl OutboundGroupSession { } } -// #[cfg_attr(tarpaulin, skip)] +#[cfg(not(tarpaulin_include))] impl std::fmt::Debug for OutboundGroupSession { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("OutboundGroupSession") diff --git a/matrix_sdk_crypto/src/olm/session.rs b/matrix_sdk_crypto/src/olm/session.rs index 07a3339a..a4f31a32 100644 --- a/matrix_sdk_crypto/src/olm/session.rs +++ b/matrix_sdk_crypto/src/olm/session.rs @@ -51,7 +51,7 @@ pub struct Session { pub(crate) last_use_time: Arc, } -// #[cfg_attr(tarpaulin, skip)] +#[cfg(not(tarpaulin_include))] impl fmt::Debug for Session { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Session") diff --git a/matrix_sdk_crypto/src/store/memorystore.rs b/matrix_sdk_crypto/src/store/memorystore.rs index 1149766d..d901e0b5 100644 --- a/matrix_sdk_crypto/src/store/memorystore.rs +++ b/matrix_sdk_crypto/src/store/memorystore.rs @@ -15,6 +15,7 @@ use std::{collections::HashSet, sync::Arc}; use async_trait::async_trait; +use dashmap::DashSet; use matrix_sdk_common::{ identifiers::{DeviceId, RoomId, UserId}, locks::Mutex, @@ -25,12 +26,12 @@ use crate::{ device::Device, memory_stores::{DeviceStore, GroupSessionStore, SessionStore, UserDevices}, }; -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct MemoryStore { sessions: SessionStore, inbound_group_sessions: GroupSessionStore, - tracked_users: HashSet, - users_for_key_query: HashSet, + tracked_users: Arc>, + users_for_key_query: Arc>, devices: DeviceStore, } @@ -39,8 +40,8 @@ impl MemoryStore { MemoryStore { sessions: SessionStore::new(), inbound_group_sessions: GroupSessionStore::new(), - tracked_users: HashSet::new(), - users_for_key_query: HashSet::new(), + tracked_users: Arc::new(DashSet::new()), + users_for_key_query: Arc::new(DashSet::new()), devices: DeviceStore::new(), } } @@ -48,15 +49,15 @@ impl MemoryStore { #[async_trait] impl CryptoStore for MemoryStore { - async fn load_account(&mut self) -> Result> { + async fn load_account(&self) -> Result> { Ok(None) } - async fn save_account(&mut self, _: Account) -> Result<()> { + async fn save_account(&self, _: Account) -> Result<()> { Ok(()) } - async fn save_sessions(&mut self, sessions: &[Session]) -> Result<()> { + async fn save_sessions(&self, sessions: &[Session]) -> Result<()> { for session in sessions { let _ = self.sessions.add(session.clone()).await; } @@ -64,16 +65,16 @@ impl CryptoStore for MemoryStore { Ok(()) } - async fn get_sessions(&mut self, sender_key: &str) -> Result>>>> { + async fn get_sessions(&self, sender_key: &str) -> Result>>>> { Ok(self.sessions.get(sender_key)) } - async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result { + async fn save_inbound_group_session(&self, session: InboundGroupSession) -> Result { Ok(self.inbound_group_sessions.add(session)) } async fn get_inbound_group_session( - &mut self, + &self, room_id: &RoomId, sender_key: &str, session_id: &str, @@ -83,15 +84,20 @@ impl CryptoStore for MemoryStore { .get(room_id, sender_key, session_id)) } - fn tracked_users(&self) -> &HashSet { - &self.tracked_users + fn users_for_key_query(&self) -> HashSet { + #[allow(clippy::map_clone)] + self.users_for_key_query.iter().map(|u| u.clone()).collect() } - fn users_for_key_query(&self) -> &HashSet { - &self.users_for_key_query + fn is_user_tracked(&self, user_id: &UserId) -> bool { + self.tracked_users.contains(user_id) } - async fn update_tracked_user(&mut self, user: &UserId, dirty: bool) -> Result { + fn has_users_for_key_query(&self) -> bool { + !self.users_for_key_query.is_empty() + } + + async fn update_tracked_user(&self, user: &UserId, dirty: bool) -> Result { if dirty { self.users_for_key_query.insert(user.clone()); } else { @@ -135,7 +141,7 @@ mod test { #[tokio::test] async fn test_session_store() { let (account, session) = get_account_and_session().await; - let mut store = MemoryStore::new(); + let store = MemoryStore::new(); assert!(store.load_account().await.unwrap().is_none()); store.save_account(account).await.unwrap(); @@ -168,7 +174,7 @@ mod test { ) .unwrap(); - let mut store = MemoryStore::new(); + let store = MemoryStore::new(); let _ = store .save_inbound_group_session(inbound.clone()) .await @@ -217,7 +223,7 @@ mod test { #[tokio::test] async fn test_tracked_users() { let device = get_device(); - let mut store = MemoryStore::new(); + let store = MemoryStore::new(); assert!(store .update_tracked_user(device.user_id(), false) @@ -228,8 +234,6 @@ mod test { .await .unwrap()); - let tracked_users = store.tracked_users(); - - let _ = tracked_users.contains(device.user_id()); + assert!(store.is_user_tracked(device.user_id())); } } diff --git a/matrix_sdk_crypto/src/store/mod.rs b/matrix_sdk_crypto/src/store/mod.rs index 0f2dca0d..9b035450 100644 --- a/matrix_sdk_crypto/src/store/mod.rs +++ b/matrix_sdk_crypto/src/store/mod.rs @@ -95,28 +95,28 @@ pub type Result = std::result::Result; /// keys. pub trait CryptoStore: Debug { /// Load an account that was previously stored. - async fn load_account(&mut self) -> Result>; + async fn load_account(&self) -> Result>; /// Save the given account in the store. /// /// # Arguments /// /// * `account` - The account that should be stored. - async fn save_account(&mut self, account: Account) -> Result<()>; + async fn save_account(&self, account: Account) -> Result<()>; /// Save the given sessions in the store. /// /// # Arguments /// /// * `session` - The sessions that should be stored. - async fn save_sessions(&mut self, session: &[Session]) -> Result<()>; + async fn save_sessions(&self, session: &[Session]) -> Result<()>; /// Get all the sessions that belong to the given sender key. /// /// # Arguments /// /// * `sender_key` - The sender key that was used to establish the sessions. - async fn get_sessions(&mut self, sender_key: &str) -> Result>>>>; + async fn get_sessions(&self, sender_key: &str) -> Result>>>>; /// Save the given inbound group session in the store. /// @@ -126,7 +126,7 @@ pub trait CryptoStore: Debug { /// # Arguments /// /// * `session` - The session that should be stored. - async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result; + async fn save_inbound_group_session(&self, session: InboundGroupSession) -> Result; /// Get the inbound group session from our store. /// @@ -137,18 +137,21 @@ pub trait CryptoStore: Debug { /// /// * `session_id` - The unique id of the session. async fn get_inbound_group_session( - &mut self, + &self, room_id: &RoomId, sender_key: &str, session_id: &str, ) -> Result>; - /// Get the set of tracked users. - fn tracked_users(&self) -> &HashSet; + /// Is the given user already tracked. + fn is_user_tracked(&self, user_id: &UserId) -> bool; + + /// Are there any tracked users that are marked as dirty. + fn has_users_for_key_query(&self) -> bool; /// Set of users that we need to query keys for. This is a subset of /// the tracked users. - fn users_for_key_query(&self) -> &HashSet; + fn users_for_key_query(&self) -> HashSet; /// Add an user for tracking. /// @@ -159,7 +162,7 @@ pub trait CryptoStore: Debug { /// * `user` - The user that should be marked as tracked. /// /// * `dirty` - Should the user be also marked for a key query. - async fn update_tracked_user(&mut self, user: &UserId, dirty: bool) -> Result; + async fn update_tracked_user(&self, user: &UserId, dirty: bool) -> Result; /// Save the given devices in the store. /// diff --git a/matrix_sdk_crypto/src/store/sqlite.rs b/matrix_sdk_crypto/src/store/sqlite.rs index b0cbdd0f..ede11f5c 100644 --- a/matrix_sdk_crypto/src/store/sqlite.rs +++ b/matrix_sdk_crypto/src/store/sqlite.rs @@ -17,10 +17,11 @@ use std::{ convert::TryFrom, path::{Path, PathBuf}, result::Result as StdResult, - sync::Arc, + sync::{Arc, Mutex as SyncMutex}, }; use async_trait::async_trait; +use dashmap::DashSet; use matrix_sdk_common::{ events::Algorithm, identifiers::{DeviceId, DeviceKeyAlgorithm, DeviceKeyId, RoomId, UserId}, @@ -40,22 +41,24 @@ use crate::{ }; /// SQLite based implementation of a `CryptoStore`. +#[derive(Clone)] pub struct SqliteStore { user_id: Arc, device_id: Arc>, - account_info: Option, - path: PathBuf, + account_info: Arc>>, + path: Arc, sessions: SessionStore, inbound_group_sessions: GroupSessionStore, devices: DeviceStore, - tracked_users: HashSet, - users_for_key_query: HashSet, + tracked_users: Arc>, + users_for_key_query: Arc>, connection: Arc>, - pickle_passphrase: Option>, + pickle_passphrase: Arc>>, } +#[derive(Clone)] struct AccountInfo { account_id: i64, identity_keys: Arc, @@ -130,22 +133,26 @@ impl SqliteStore { let store = SqliteStore { user_id: Arc::new(user_id.to_owned()), device_id: Arc::new(device_id.into()), - account_info: None, + account_info: Arc::new(SyncMutex::new(None)), sessions: SessionStore::new(), inbound_group_sessions: GroupSessionStore::new(), devices: DeviceStore::new(), - path: path.as_ref().to_owned(), + path: Arc::new(path.as_ref().to_owned()), connection: Arc::new(Mutex::new(connection)), - pickle_passphrase: passphrase, - tracked_users: HashSet::new(), - users_for_key_query: HashSet::new(), + pickle_passphrase: Arc::new(passphrase), + tracked_users: Arc::new(DashSet::new()), + users_for_key_query: Arc::new(DashSet::new()), }; store.create_tables().await?; Ok(store) } fn account_id(&self) -> Option { - self.account_info.as_ref().map(|i| i.account_id) + self.account_info + .lock() + .unwrap() + .as_ref() + .map(|i| i.account_id) } async fn create_tables(&self) -> Result<()> { @@ -299,7 +306,7 @@ impl SqliteStore { Ok(()) } - async fn lazy_load_sessions(&mut self, sender_key: &str) -> Result<()> { + async fn lazy_load_sessions(&self, sender_key: &str) -> Result<()> { let loaded_sessions = self.sessions.get(sender_key).is_some(); if !loaded_sessions { @@ -313,18 +320,17 @@ impl SqliteStore { Ok(()) } - async fn get_sessions_for( - &mut self, - sender_key: &str, - ) -> Result>>>> { + async fn get_sessions_for(&self, sender_key: &str) -> Result>>>> { self.lazy_load_sessions(sender_key).await?; Ok(self.sessions.get(sender_key)) } - async fn load_sessions_for(&mut self, sender_key: &str) -> Result> { + async fn load_sessions_for(&self, sender_key: &str) -> Result> { let account_info = self .account_info - .as_ref() + .lock() + .unwrap() + .clone() .ok_or(CryptoStoreError::AccountUnset)?; let mut connection = self.connection.lock().await; @@ -365,7 +371,7 @@ impl SqliteStore { .collect::>>()?) } - async fn load_inbound_group_sessions(&self) -> Result> { + async fn load_inbound_group_sessions(&self) -> Result<()> { let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; let mut connection = self.connection.lock().await; @@ -377,7 +383,7 @@ impl SqliteStore { .fetch_all(&mut *connection) .await?; - Ok(rows + let mut group_sessions = rows .iter() .map(|row| { let pickle = &row.0; @@ -393,7 +399,16 @@ impl SqliteStore { RoomId::try_from(room_id.as_str()).unwrap(), )?) }) - .collect::>>()?) + .collect::>>()?; + + group_sessions + .drain(..) + .map(|s| { + self.inbound_group_sessions.add(s); + }) + .for_each(drop); + + Ok(()) } async fn save_tracked_user(&self, user: &UserId, dirty: bool) -> Result<()> { @@ -417,7 +432,7 @@ impl SqliteStore { Ok(()) } - async fn load_tracked_users(&self) -> Result<(HashSet, HashSet)> { + async fn load_tracked_users(&self) -> Result<()> { let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; let mut connection = self.connection.lock().await; @@ -429,27 +444,24 @@ impl SqliteStore { .fetch_all(&mut *connection) .await?; - let mut users = HashSet::new(); - let mut users_for_query = HashSet::new(); - for row in rows { let user_id: &str = &row.0; let dirty: bool = row.1; if let Ok(u) = UserId::try_from(user_id) { - users.insert(u.clone()); + self.tracked_users.insert(u.clone()); if dirty { - users_for_query.insert(u); + self.users_for_key_query.insert(u); } } else { continue; }; } - Ok((users, users_for_query)) + Ok(()) } - async fn load_devices(&self) -> Result { + async fn load_devices(&self) -> Result<()> { let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; let mut connection = self.connection.lock().await; @@ -461,8 +473,6 @@ impl SqliteStore { .fetch_all(&mut *connection) .await?; - let store = DeviceStore::new(); - for row in rows { let device_row_id = row.0; let user_id: &str = &row.1; @@ -555,10 +565,10 @@ impl SqliteStore { signatures, ); - store.add(device); + self.devices.add(device); } - Ok(store) + Ok(()) } async fn save_device_helper(&self, device: Device) -> Result<()> { @@ -643,7 +653,7 @@ impl SqliteStore { } fn get_pickle_mode(&self) -> PicklingMode { - match &self.pickle_passphrase { + match &*self.pickle_passphrase { Some(p) => PicklingMode::Encrypted { key: p.as_bytes().to_vec(), }, @@ -654,7 +664,7 @@ impl SqliteStore { #[async_trait] impl CryptoStore for SqliteStore { - async fn load_account(&mut self) -> Result> { + async fn load_account(&self) -> Result> { let mut connection = self.connection.lock().await; let row: Option<(i64, String, bool, i64)> = query_as( @@ -676,7 +686,7 @@ impl CryptoStore for SqliteStore { &self.device_id, )?; - self.account_info = Some(AccountInfo { + *self.account_info.lock().unwrap() = Some(AccountInfo { account_id: id, identity_keys: account.identity_keys.clone(), }); @@ -688,26 +698,14 @@ impl CryptoStore for SqliteStore { drop(connection); - let mut group_sessions = self.load_inbound_group_sessions().await?; - - group_sessions - .drain(..) - .map(|s| { - self.inbound_group_sessions.add(s); - }) - .for_each(drop); - - let devices = self.load_devices().await?; - self.devices = devices; - - let (tracked_users, users_for_query) = self.load_tracked_users().await?; - self.tracked_users = tracked_users; - self.users_for_key_query = users_for_query; + self.load_inbound_group_sessions().await?; + self.load_devices().await?; + self.load_tracked_users().await?; Ok(result) } - async fn save_account(&mut self, account: Account) -> Result<()> { + async fn save_account(&self, account: Account) -> Result<()> { let pickle = account.pickle(self.get_pickle_mode()).await; let mut connection = self.connection.lock().await; @@ -735,7 +733,7 @@ impl CryptoStore for SqliteStore { .fetch_one(&mut *connection) .await?; - self.account_info = Some(AccountInfo { + *self.account_info.lock().unwrap() = Some(AccountInfo { account_id: account_id.0, identity_keys: account.identity_keys.clone(), }); @@ -743,7 +741,7 @@ impl CryptoStore for SqliteStore { Ok(()) } - async fn save_sessions(&mut self, sessions: &[Session]) -> Result<()> { + async fn save_sessions(&self, sessions: &[Session]) -> Result<()> { let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; // TODO turn this into a transaction @@ -776,11 +774,11 @@ impl CryptoStore for SqliteStore { Ok(()) } - async fn get_sessions(&mut self, sender_key: &str) -> Result>>>> { + async fn get_sessions(&self, sender_key: &str) -> Result>>>> { Ok(self.get_sessions_for(sender_key).await?) } - async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result { + async fn save_inbound_group_session(&self, session: InboundGroupSession) -> Result { let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; let pickle = session.pickle(self.get_pickle_mode()).await; let mut connection = self.connection.lock().await; @@ -808,7 +806,7 @@ impl CryptoStore for SqliteStore { } async fn get_inbound_group_session( - &mut self, + &self, room_id: &RoomId, sender_key: &str, session_id: &str, @@ -818,15 +816,20 @@ impl CryptoStore for SqliteStore { .get(room_id, sender_key, session_id)) } - fn tracked_users(&self) -> &HashSet { - &self.tracked_users + fn is_user_tracked(&self, user_id: &UserId) -> bool { + self.tracked_users.contains(user_id) } - fn users_for_key_query(&self) -> &HashSet { - &self.users_for_key_query + fn has_users_for_key_query(&self) -> bool { + !self.users_for_key_query.is_empty() } - async fn update_tracked_user(&mut self, user: &UserId, dirty: bool) -> Result { + fn users_for_key_query(&self) -> HashSet { + #[allow(clippy::map_clone)] + self.users_for_key_query.iter().map(|u| u.clone()).collect() + } + + async fn update_tracked_user(&self, user: &UserId, dirty: bool) -> Result { let already_added = self.tracked_users.insert(user.clone()); if dirty { @@ -878,7 +881,7 @@ impl CryptoStore for SqliteStore { } } -// #[cfg_attr(tarpaulin, skip)] +#[cfg(not(tarpaulin_include))] impl std::fmt::Debug for SqliteStore { fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> StdResult<(), std::fmt::Error> { fmt.debug_struct("SqliteStore") @@ -931,7 +934,7 @@ mod test { } async fn get_loaded_store() -> (Account, SqliteStore, tempfile::TempDir) { - let (mut store, dir) = get_store(None).await; + let (store, dir) = get_store(None).await; let account = get_account(); store .save_account(account.clone()) @@ -999,7 +1002,7 @@ mod test { #[tokio::test] async fn save_account() { - let (mut store, _dir) = get_store(None).await; + let (store, _dir) = get_store(None).await; assert!(store.load_account().await.unwrap().is_none()); let account = get_account(); @@ -1011,7 +1014,7 @@ mod test { #[tokio::test] async fn load_account() { - let (mut store, _dir) = get_store(None).await; + let (store, _dir) = get_store(None).await; let account = get_account(); store @@ -1027,7 +1030,7 @@ mod test { #[tokio::test] async fn load_account_with_passphrase() { - let (mut store, _dir) = get_store(Some("secret_passphrase")).await; + let (store, _dir) = get_store(Some("secret_passphrase")).await; let account = get_account(); store @@ -1043,7 +1046,7 @@ mod test { #[tokio::test] async fn save_and_share_account() { - let (mut store, _dir) = get_store(None).await; + let (store, _dir) = get_store(None).await; let account = get_account(); store @@ -1066,7 +1069,7 @@ mod test { #[tokio::test] async fn save_session() { - let (mut store, _dir) = get_store(None).await; + let (store, _dir) = get_store(None).await; let (account, session) = get_account_and_session().await; assert!(store.save_sessions(&[session.clone()]).await.is_err()); @@ -1081,7 +1084,7 @@ mod test { #[tokio::test] async fn load_sessions() { - let (mut store, _dir) = get_store(None).await; + let (store, _dir) = get_store(None).await; let (account, session) = get_account_and_session().await; store .save_account(account.clone()) @@ -1100,7 +1103,7 @@ mod test { #[tokio::test] async fn add_and_save_session() { - let (mut store, dir) = get_store(None).await; + let (store, dir) = get_store(None).await; let (account, session) = get_account_and_session().await; let sender_key = session.sender_key.to_owned(); let session_id = session.session_id().to_owned(); @@ -1119,7 +1122,7 @@ mod test { drop(store); - let mut store = SqliteStore::open(&example_user_id(), example_device_id(), dir.path()) + let store = SqliteStore::open(&example_user_id(), example_device_id(), dir.path()) .await .expect("Can't create store"); @@ -1135,7 +1138,7 @@ mod test { #[tokio::test] async fn save_inbound_group_session() { - let (account, mut store, _dir) = get_loaded_store().await; + let (account, store, _dir) = get_loaded_store().await; let identity_keys = account.identity_keys(); let outbound_session = OlmOutboundGroupSession::new(); @@ -1155,7 +1158,7 @@ mod test { #[tokio::test] async fn load_inbound_group_session() { - let (account, mut store, _dir) = get_loaded_store().await; + let (account, store, _dir) = get_loaded_store().await; let identity_keys = account.identity_keys(); let outbound_session = OlmOutboundGroupSession::new(); @@ -1167,16 +1170,12 @@ mod test { ) .expect("Can't create session"); - let session_id = session.session_id().to_owned(); - store .save_inbound_group_session(session.clone()) .await .expect("Can't save group session"); - let sessions = store.load_inbound_group_sessions().await.unwrap(); - - assert_eq!(session_id, sessions[0].session_id()); + store.load_inbound_group_sessions().await.unwrap(); let loaded_session = store .get_inbound_group_session(&session.room_id, &session.sender_key, session.session_id()) @@ -1188,7 +1187,7 @@ mod test { #[tokio::test] async fn test_tracked_users() { - let (_account, mut store, dir) = get_loaded_store().await; + let (_account, store, dir) = get_loaded_store().await; let device = get_device(); assert!(store @@ -1200,9 +1199,7 @@ mod test { .await .unwrap()); - let tracked_users = store.tracked_users(); - - assert!(tracked_users.contains(device.user_id())); + assert!(store.is_user_tracked(device.user_id())); assert!(!store.users_for_key_query().contains(device.user_id())); assert!(!store .update_tracked_user(device.user_id(), true) @@ -1211,14 +1208,13 @@ mod test { assert!(store.users_for_key_query().contains(device.user_id())); drop(store); - let mut store = SqliteStore::open(&example_user_id(), example_device_id(), dir.path()) + let store = SqliteStore::open(&example_user_id(), example_device_id(), dir.path()) .await .expect("Can't create store"); store.load_account().await.unwrap(); - let tracked_users = store.tracked_users(); - assert!(tracked_users.contains(device.user_id())); + assert!(store.is_user_tracked(device.user_id())); assert!(store.users_for_key_query().contains(device.user_id())); store @@ -1227,7 +1223,7 @@ mod test { .unwrap(); assert!(!store.users_for_key_query().contains(device.user_id())); - let mut store = SqliteStore::open(&example_user_id(), example_device_id(), dir.path()) + let store = SqliteStore::open(&example_user_id(), example_device_id(), dir.path()) .await .expect("Can't create store"); @@ -1245,7 +1241,7 @@ mod test { drop(store); - let mut store = SqliteStore::open(&example_user_id(), example_device_id(), dir.path()) + let store = SqliteStore::open(&example_user_id(), example_device_id(), dir.path()) .await .expect("Can't create store"); @@ -1278,7 +1274,7 @@ mod test { store.save_devices(&[device.clone()]).await.unwrap(); store.delete_device(device.clone()).await.unwrap(); - let mut store = SqliteStore::open(&example_user_id(), example_device_id(), dir.path()) + let store = SqliteStore::open(&example_user_id(), example_device_id(), dir.path()) .await .expect("Can't create store"); diff --git a/matrix_sdk_crypto/src/verification/machine.rs b/matrix_sdk_crypto/src/verification/machine.rs index ca3fc2f4..cea74359 100644 --- a/matrix_sdk_crypto/src/verification/machine.rs +++ b/matrix_sdk_crypto/src/verification/machine.rs @@ -22,7 +22,6 @@ use matrix_sdk_common::{ api::r0::to_device::send_event_to_device::IncomingRequest as OwnedToDeviceRequest, events::{AnyToDeviceEvent, AnyToDeviceEventContent}, identifiers::{DeviceId, UserId}, - locks::RwLock, }; use super::sas::{content_to_request, Sas}; @@ -31,13 +30,13 @@ use crate::{Account, CryptoStore, CryptoStoreError}; #[derive(Clone, Debug)] pub struct VerificationMachine { account: Account, - store: Arc>>, + store: Arc>, verifications: Arc>, outgoing_to_device_messages: Arc>, } impl VerificationMachine { - pub(crate) fn new(account: Account, store: Arc>>) -> Self { + pub(crate) fn new(account: Account, store: Arc>) -> Self { Self { account, store, @@ -112,8 +111,6 @@ impl VerificationMachine { if let Some(d) = self .store - .read() - .await .get_device(&e.sender, &e.content.from_device) .await? { @@ -179,7 +176,6 @@ mod test { use matrix_sdk_common::{ events::AnyToDeviceEventContent, identifiers::{DeviceId, UserId}, - locks::RwLock, }; use super::{Sas, VerificationMachine}; @@ -209,21 +205,18 @@ mod test { let alice = Account::new(&alice_id(), &alice_device_id()); let bob = Account::new(&bob_id(), &bob_device_id()); let store = MemoryStore::new(); - let bob_store: Arc>> = - Arc::new(RwLock::new(Box::new(MemoryStore::new()))); + let bob_store: Arc> = Arc::new(Box::new(MemoryStore::new())); let bob_device = Device::from_account(&bob).await; let alice_device = Device::from_account(&alice).await; store.save_devices(&[bob_device]).await.unwrap(); bob_store - .read() - .await .save_devices(&[alice_device.clone()]) .await .unwrap(); - let machine = VerificationMachine::new(alice, Arc::new(RwLock::new(Box::new(store)))); + let machine = VerificationMachine::new(alice, Arc::new(Box::new(store))); let (bob_sas, start_content) = Sas::start(bob, alice_device, bob_store); machine .receive_event(&mut wrap_any_to_device_content( @@ -240,7 +233,7 @@ mod test { fn create() { let alice = Account::new(&alice_id(), &alice_device_id()); let store = MemoryStore::new(); - let _ = VerificationMachine::new(alice, Arc::new(RwLock::new(Box::new(store)))); + let _ = VerificationMachine::new(alice, Arc::new(Box::new(store))); } #[tokio::test] diff --git a/matrix_sdk_crypto/src/verification/mod.rs b/matrix_sdk_crypto/src/verification/mod.rs index 2f2f23a2..db955382 100644 --- a/matrix_sdk_crypto/src/verification/mod.rs +++ b/matrix_sdk_crypto/src/verification/mod.rs @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#[allow(dead_code)] mod machine; #[allow(dead_code)] mod sas; diff --git a/matrix_sdk_crypto/src/verification/sas/helpers.rs b/matrix_sdk_crypto/src/verification/sas/helpers.rs index 1715db9b..f5dfa24f 100644 --- a/matrix_sdk_crypto/src/verification/sas/helpers.rs +++ b/matrix_sdk_crypto/src/verification/sas/helpers.rs @@ -1,3 +1,17 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + use std::{collections::BTreeMap, convert::TryInto}; use tracing::trace; @@ -212,8 +226,6 @@ fn extra_mac_info_send(ids: &SasIds, flow_id: &str) -> String { /// /// * `flow_id` - The unique id that identifies this SAS verification process. /// -/// * `we_started` - Flag signaling if the SAS process was started on our side. -/// /// # Panics /// /// This will panic if the public key of the other side wasn't set. diff --git a/matrix_sdk_crypto/src/verification/sas/mod.rs b/matrix_sdk_crypto/src/verification/sas/mod.rs index 49c86cbf..a3d2d40e 100644 --- a/matrix_sdk_crypto/src/verification/sas/mod.rs +++ b/matrix_sdk_crypto/src/verification/sas/mod.rs @@ -31,7 +31,6 @@ use matrix_sdk_common::{ AnyToDeviceEvent, AnyToDeviceEventContent, ToDeviceEvent, }, identifiers::{DeviceId, UserId}, - locks::RwLock, }; use crate::{Account, CryptoStore, CryptoStoreError, Device, TrustState}; @@ -45,7 +44,7 @@ use sas_state::{ /// Short authentication string object. pub struct Sas { inner: Arc>, - store: Arc>>, + store: Arc>, account: Account, other_device: Device, flow_id: Arc, @@ -100,7 +99,7 @@ impl Sas { pub(crate) fn start( account: Account, other_device: Device, - store: Arc>>, + store: Arc>, ) -> (Sas, StartEventContent) { let (inner, content) = InnerSas::start(account.clone(), other_device.clone()); let flow_id = inner.verification_flow_id(); @@ -129,7 +128,7 @@ impl Sas { pub(crate) fn from_start_event( account: Account, other_device: Device, - store: Arc>>, + store: Arc>, event: &ToDeviceEvent, ) -> Result { let inner = InnerSas::from_start_event(account.clone(), other_device.clone(), event)?; @@ -184,8 +183,6 @@ impl Sas { pub(crate) async fn mark_device_as_verified(&self) -> Result { let device = self .store - .read() - .await .get_device(self.other_user_id(), self.other_device_id()) .await?; @@ -202,7 +199,7 @@ impl Sas { ); device.set_trust_state(TrustState::Verified); - self.store.read().await.save_devices(&[device]).await?; + self.store.save_devices(&[device]).await?; Ok(true) } else { @@ -560,7 +557,6 @@ mod test { use matrix_sdk_common::{ events::{EventContent, ToDeviceEvent}, identifiers::{DeviceId, UserId}, - locks::RwLock, }; use crate::{ @@ -685,14 +681,10 @@ mod test { let bob = Account::new(&bob_id(), &bob_device_id()); let bob_device = Device::from_account(&bob).await; - let alice_store: Arc>> = - Arc::new(RwLock::new(Box::new(MemoryStore::new()))); - let bob_store: Arc>> = - Arc::new(RwLock::new(Box::new(MemoryStore::new()))); + let alice_store: Arc> = Arc::new(Box::new(MemoryStore::new())); + let bob_store: Arc> = Arc::new(Box::new(MemoryStore::new())); bob_store - .read() - .await .save_devices(&[alice_device.clone()]) .await .unwrap(); diff --git a/matrix_sdk_crypto/src/verification/sas/sas_state.rs b/matrix_sdk_crypto/src/verification/sas/sas_state.rs index eaeb6bf3..80f3032e 100644 --- a/matrix_sdk_crypto/src/verification/sas/sas_state.rs +++ b/matrix_sdk_crypto/src/verification/sas/sas_state.rs @@ -98,6 +98,7 @@ impl TryFrom for AcceptedProtocols { } } +#[cfg(not(tarpaulin_include))] impl Default for AcceptedProtocols { fn default() -> Self { AcceptedProtocols { @@ -146,6 +147,7 @@ pub struct SasState { state: Arc, } +#[cfg(not(tarpaulin_include))] impl std::fmt::Debug for SasState { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("SasState") diff --git a/matrix_sdk_test/Cargo.toml b/matrix_sdk_test/Cargo.toml index 17ca6f84..4c58f638 100644 --- a/matrix_sdk_test/Cargo.toml +++ b/matrix_sdk_test/Cargo.toml @@ -11,9 +11,9 @@ repository = "https://github.com/matrix-org/matrix-rust-sdk" version = "0.1.0" [dependencies] -serde_json = "1.0.56" +serde_json = "1.0.57" http = "0.2.1" matrix-sdk-common = { version = "0.1.0", path = "../matrix_sdk_common" } matrix-sdk-test-macros = { version = "0.1.0", path = "../matrix_sdk_test_macros" } lazy_static = "1.4.0" -serde = "1.0.114" +serde = "1.0.115"