From 407f9a3da8a54ec73bf3112ca599f506ca9275b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Damir=20Jeli=C4=87?= Date: Wed, 12 Aug 2020 15:12:51 +0200 Subject: [PATCH] matrix-sdk: Make sure to not send out multiple group share requests at once. --- matrix_sdk/Cargo.toml | 3 +- matrix_sdk/src/client.rs | 53 ++++++++++++++++++++---------- matrix_sdk_base/src/client.rs | 11 +++---- matrix_sdk_crypto/src/machine.rs | 56 ++++++++++++++++++++++++-------- 4 files changed, 85 insertions(+), 38 deletions(-) diff --git a/matrix_sdk/Cargo.toml b/matrix_sdk/Cargo.toml index faf81034..cdbb4c27 100644 --- a/matrix_sdk/Cargo.toml +++ b/matrix_sdk/Cargo.toml @@ -13,11 +13,12 @@ version = "0.1.0" [features] default = ["encryption", "sqlite-cryptostore", "messages"] messages = ["matrix-sdk-base/messages"] -encryption = ["matrix-sdk-base/encryption"] +encryption = ["matrix-sdk-base/encryption", "dashmap"] sqlite-cryptostore = ["matrix-sdk-base/sqlite-cryptostore"] [dependencies] async-trait = "0.1.36" +dashmap = { version = "3.11.10", optional = true } http = "0.2.1" # FIXME: Revert to regular dependency once 0.10.8 or 0.11.0 is released reqwest = { git = "https://github.com/seanmonstar/reqwest", rev = "dd8441fd23dae6ffb79b4cea2862e5bca0c59743" } diff --git a/matrix_sdk/src/client.rs b/matrix_sdk/src/client.rs index a0413162..26ab5126 100644 --- a/matrix_sdk/src/client.rs +++ b/matrix_sdk/src/client.rs @@ -25,6 +25,8 @@ use std::{ sync::Arc, }; +#[cfg(feature = "encryption")] +use dashmap::DashMap; use futures_timer::Delay as sleep; use reqwest::header::{HeaderValue, InvalidHeaderValue}; use url::Url; @@ -46,7 +48,7 @@ use matrix_sdk_common::{ identifiers::ServerName, instant::{Duration, Instant}, js_int::UInt, - locks::RwLock, + locks::{Mutex, RwLock}, presence::PresenceState, uuid::Uuid, FromHttpResponseError, @@ -76,6 +78,10 @@ pub struct Client { http_client: HttpClient, /// User session data. pub(crate) base_client: BaseClient, + /// Locks making sure we only have one group session sharing request in + /// flight per room. + #[cfg(feature = "encryption")] + group_session_locks: DashMap>>, } #[cfg(not(tarpaulin_include))] @@ -359,6 +365,7 @@ impl Client { homeserver, http_client, base_client, + group_session_locks: DashMap::new(), }) } @@ -1015,16 +1022,31 @@ impl Client { } if self.base_client.should_share_group_session(room_id).await { - // TODO we need to make sure that only one such request is - // in flight per room at a time. - let response = self.share_group_session(room_id).await; + #[allow(clippy::map_clone)] + if let Some(mutex) = self.group_session_locks.get(room_id).map(|m| m.clone()) { + // If a group session share request is already going on, + // await the release of the lock. + mutex.lock().await; + } else { + // Otherwise create a new lock and share the group + // session. + let mutex = Arc::new(Mutex::new(())); + self.group_session_locks + .insert(room_id.clone(), mutex.clone()); - // If one of the responses failed invalidate the group - // session as using it would end up in undecryptable - // messages. - if let Err(r) = response { - self.base_client.invalidate_group_session(room_id).await; - return Err(r); + let _guard = mutex.lock().await; + + let response = self.share_group_session(room_id).await; + + self.group_session_locks.remove(room_id); + + // If one of the responses failed invalidate the group + // session as using it would end up in undecryptable + // messages. + if let Err(r) = response { + self.base_client.invalidate_group_session(room_id).await; + return Err(r); + } } } @@ -1341,7 +1363,7 @@ impl Client { #[cfg_attr(docsrs, doc(cfg(feature = "encryption")))] #[instrument] async fn keys_upload(&self) -> Result { - let (device_keys, one_time_keys) = self + let request = self .base_client .keys_for_upload() .await @@ -1349,15 +1371,10 @@ impl Client { debug!( "Uploading encryption keys device keys: {}, one-time-keys: {}", - device_keys.is_some(), - one_time_keys.as_ref().map_or(0, |k| k.len()) + request.device_keys.is_some(), + request.one_time_keys.as_ref().map_or(0, |k| k.len()) ); - let request = upload_keys::Request { - device_keys, - one_time_keys, - }; - let response = self.send(request).await?; self.base_client .receive_keys_upload_response(&response) diff --git a/matrix_sdk_base/src/client.rs b/matrix_sdk_base/src/client.rs index c7fd3152..b243ae66 100644 --- a/matrix_sdk_base/src/client.rs +++ b/matrix_sdk_base/src/client.rs @@ -41,8 +41,9 @@ use matrix_sdk_common::{ #[cfg(feature = "encryption")] use matrix_sdk_common::{ api::r0::keys::{ - claim_keys::Response as KeysClaimResponse, get_keys::Response as KeysQueryResponse, - upload_keys::Response as KeysUploadResponse, DeviceKeys, + claim_keys::Response as KeysClaimResponse, + get_keys::Response as KeysQueryResponse, + upload_keys::{Request as KeysUploadRequest, Response as KeysUploadResponse}, }, api::r0::to_device::send_event_to_device::IncomingRequest as OwnedToDeviceRequest, events::room::{ @@ -51,7 +52,7 @@ use matrix_sdk_common::{ identifiers::{DeviceId, DeviceKeyAlgorithm}, }; #[cfg(feature = "encryption")] -use matrix_sdk_crypto::{CryptoStore, Device, OlmError, OlmMachine, OneTimeKeys, Sas}; +use matrix_sdk_crypto::{CryptoStore, Device, OlmError, OlmMachine, Sas}; use zeroize::Zeroizing; #[cfg(not(target_arch = "wasm32"))] @@ -1326,9 +1327,7 @@ impl BaseClient { /// Returns an empty error if no keys need to be uploaded. #[cfg(feature = "encryption")] #[cfg_attr(docsrs, doc(cfg(feature = "encryption")))] - pub async fn keys_for_upload( - &self, - ) -> StdResult<(Option, Option), ()> { + pub async fn keys_for_upload(&self) -> StdResult { let olm = self.olm.lock().await; match &*olm { diff --git a/matrix_sdk_crypto/src/machine.rs b/matrix_sdk_crypto/src/machine.rs index 92789ea6..2bc6ef95 100644 --- a/matrix_sdk_crypto/src/machine.rs +++ b/matrix_sdk_crypto/src/machine.rs @@ -125,7 +125,7 @@ impl OlmMachine { } } - /// Create a new OlmMachine with the given `CryptoStore`. + /// Create a new OlmMachine with the given [`CryptoStore`]. /// /// The created machine will keep the encryption keys only in memory and /// once the object is dropped the keys will be lost. @@ -142,6 +142,8 @@ impl OlmMachine { /// /// * `store` - A `Cryptostore` implementation that will be used to store /// the encryption keys. + /// + /// [`Cryptostore`]: trait.CryptoStore.html pub async fn new_with_store( user_id: UserId, device_id: Box, @@ -171,8 +173,6 @@ impl OlmMachine { }) } - #[cfg(feature = "sqlite-cryptostore")] - #[instrument(skip(path, passphrase))] /// Create a new machine with the default crypto store. /// /// The default store uses a SQLite database to store the encryption keys. @@ -182,6 +182,8 @@ impl OlmMachine { /// * `user_id` - The unique id of the user that owns this machine. /// /// * `device_id` - The unique id of the device that owns this machine. + #[cfg(feature = "sqlite-cryptostore")] + #[instrument(skip(path, passphrase))] pub async fn new_with_default_store>( user_id: &UserId, device_id: &DeviceId, @@ -210,6 +212,29 @@ impl OlmMachine { } /// Should account or one-time keys be uploaded to the server. + /// + /// This needs to be checked periodically, ideally after every sync request. + /// + /// # Example + /// + /// ``` + /// # use std::convert::TryFrom; + /// # use matrix_sdk_crypto::OlmMachine; + /// # use matrix_sdk_common::identifiers::UserId; + /// # use futures::executor::block_on; + /// # let alice = UserId::try_from("@alice:example.org").unwrap(); + /// # let machine = OlmMachine::new(&alice, "DEVICEID".into()); + /// # block_on(async { + /// if machine.should_upload_keys().await { + /// let request = machine + /// .keys_for_upload() + /// .await + /// .unwrap(); + /// + /// // Upload the keys here. + /// } + /// # }); + /// ``` pub async fn should_upload_keys(&self) -> bool { self.account.should_upload_keys().await } @@ -478,10 +503,13 @@ impl OlmMachine { /// /// [`receive_keys_upload_response`]: #method.receive_keys_upload_response /// [`OlmMachine`]: struct.OlmMachine.html - pub async fn keys_for_upload( - &self, - ) -> StdResult<(Option, Option), ()> { - self.account.keys_for_upload().await + pub async fn keys_for_upload(&self) -> StdResult { + let (device_keys, one_time_keys) = self.account.keys_for_upload().await?; + + Ok(upload_keys::Request { + device_keys, + one_time_keys, + }) } /// Try to decrypt an Olm message. @@ -1369,7 +1397,7 @@ mod test { async fn get_prepared_machine() -> (OlmMachine, OneTimeKeys) { let machine = OlmMachine::new(&user_id(), &alice_device_id()); machine.account.update_uploaded_key_count(0); - let (_, otk) = machine + let request = machine .keys_for_upload() .await .expect("Can't prepare initial key upload"); @@ -1379,7 +1407,7 @@ mod test { .await .unwrap(); - (machine, otk.unwrap()) + (machine, request.one_time_keys.unwrap()) } async fn get_machine_after_query() -> (OlmMachine, OneTimeKeys) { @@ -1598,7 +1626,7 @@ mod test { let identity_keys = machine.account.identity_keys(); let ed25519_key = identity_keys.ed25519(); - let (device_keys, mut one_time_keys) = machine + let mut request = machine .keys_for_upload() .await .expect("Can't prepare initial key upload"); @@ -1607,7 +1635,7 @@ mod test { &machine.user_id, machine.device_id.as_str(), ed25519_key, - &mut json!(&mut one_time_keys.as_mut().unwrap().values_mut().next()), + &mut json!(&mut request.one_time_keys.as_mut().unwrap().values_mut().next()), ); assert!(ret.is_ok()); @@ -1615,14 +1643,16 @@ mod test { &machine.user_id, machine.device_id.as_str(), ed25519_key, - &mut json!(&mut device_keys.unwrap()), + &mut json!(&mut request.device_keys.unwrap()), ); assert!(ret.is_ok()); let mut response = keys_upload_response(); response.one_time_key_counts.insert( DeviceKeyAlgorithm::SignedCurve25519, - (one_time_keys.unwrap().len() as u64).try_into().unwrap(), + (request.one_time_keys.unwrap().len() as u64) + .try_into() + .unwrap(), ); machine