matrix-sdk: Make sure to not send out multiple group share requests at once.

master
Damir Jelić 2020-08-12 15:12:51 +02:00
parent 82c3a795ff
commit 407f9a3da8
4 changed files with 85 additions and 38 deletions

View File

@ -13,11 +13,12 @@ version = "0.1.0"
[features] [features]
default = ["encryption", "sqlite-cryptostore", "messages"] default = ["encryption", "sqlite-cryptostore", "messages"]
messages = ["matrix-sdk-base/messages"] messages = ["matrix-sdk-base/messages"]
encryption = ["matrix-sdk-base/encryption"] encryption = ["matrix-sdk-base/encryption", "dashmap"]
sqlite-cryptostore = ["matrix-sdk-base/sqlite-cryptostore"] sqlite-cryptostore = ["matrix-sdk-base/sqlite-cryptostore"]
[dependencies] [dependencies]
async-trait = "0.1.36" async-trait = "0.1.36"
dashmap = { version = "3.11.10", optional = true }
http = "0.2.1" http = "0.2.1"
# FIXME: Revert to regular dependency once 0.10.8 or 0.11.0 is released # FIXME: Revert to regular dependency once 0.10.8 or 0.11.0 is released
reqwest = { git = "https://github.com/seanmonstar/reqwest", rev = "dd8441fd23dae6ffb79b4cea2862e5bca0c59743" } reqwest = { git = "https://github.com/seanmonstar/reqwest", rev = "dd8441fd23dae6ffb79b4cea2862e5bca0c59743" }

View File

@ -25,6 +25,8 @@ use std::{
sync::Arc, sync::Arc,
}; };
#[cfg(feature = "encryption")]
use dashmap::DashMap;
use futures_timer::Delay as sleep; use futures_timer::Delay as sleep;
use reqwest::header::{HeaderValue, InvalidHeaderValue}; use reqwest::header::{HeaderValue, InvalidHeaderValue};
use url::Url; use url::Url;
@ -46,7 +48,7 @@ use matrix_sdk_common::{
identifiers::ServerName, identifiers::ServerName,
instant::{Duration, Instant}, instant::{Duration, Instant},
js_int::UInt, js_int::UInt,
locks::RwLock, locks::{Mutex, RwLock},
presence::PresenceState, presence::PresenceState,
uuid::Uuid, uuid::Uuid,
FromHttpResponseError, FromHttpResponseError,
@ -76,6 +78,10 @@ pub struct Client {
http_client: HttpClient, http_client: HttpClient,
/// User session data. /// User session data.
pub(crate) base_client: BaseClient, pub(crate) base_client: BaseClient,
/// Locks making sure we only have one group session sharing request in
/// flight per room.
#[cfg(feature = "encryption")]
group_session_locks: DashMap<RoomId, Arc<Mutex<()>>>,
} }
#[cfg(not(tarpaulin_include))] #[cfg(not(tarpaulin_include))]
@ -359,6 +365,7 @@ impl Client {
homeserver, homeserver,
http_client, http_client,
base_client, base_client,
group_session_locks: DashMap::new(),
}) })
} }
@ -1015,16 +1022,31 @@ impl Client {
} }
if self.base_client.should_share_group_session(room_id).await { if self.base_client.should_share_group_session(room_id).await {
// TODO we need to make sure that only one such request is #[allow(clippy::map_clone)]
// in flight per room at a time. if let Some(mutex) = self.group_session_locks.get(room_id).map(|m| m.clone()) {
let response = self.share_group_session(room_id).await; // If a group session share request is already going on,
// await the release of the lock.
mutex.lock().await;
} else {
// Otherwise create a new lock and share the group
// session.
let mutex = Arc::new(Mutex::new(()));
self.group_session_locks
.insert(room_id.clone(), mutex.clone());
// If one of the responses failed invalidate the group let _guard = mutex.lock().await;
// session as using it would end up in undecryptable
// messages. let response = self.share_group_session(room_id).await;
if let Err(r) = response {
self.base_client.invalidate_group_session(room_id).await; self.group_session_locks.remove(room_id);
return Err(r);
// If one of the responses failed invalidate the group
// session as using it would end up in undecryptable
// messages.
if let Err(r) = response {
self.base_client.invalidate_group_session(room_id).await;
return Err(r);
}
} }
} }
@ -1341,7 +1363,7 @@ impl Client {
#[cfg_attr(docsrs, doc(cfg(feature = "encryption")))] #[cfg_attr(docsrs, doc(cfg(feature = "encryption")))]
#[instrument] #[instrument]
async fn keys_upload(&self) -> Result<upload_keys::Response> { async fn keys_upload(&self) -> Result<upload_keys::Response> {
let (device_keys, one_time_keys) = self let request = self
.base_client .base_client
.keys_for_upload() .keys_for_upload()
.await .await
@ -1349,15 +1371,10 @@ impl Client {
debug!( debug!(
"Uploading encryption keys device keys: {}, one-time-keys: {}", "Uploading encryption keys device keys: {}, one-time-keys: {}",
device_keys.is_some(), request.device_keys.is_some(),
one_time_keys.as_ref().map_or(0, |k| k.len()) request.one_time_keys.as_ref().map_or(0, |k| k.len())
); );
let request = upload_keys::Request {
device_keys,
one_time_keys,
};
let response = self.send(request).await?; let response = self.send(request).await?;
self.base_client self.base_client
.receive_keys_upload_response(&response) .receive_keys_upload_response(&response)

View File

@ -41,8 +41,9 @@ use matrix_sdk_common::{
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
use matrix_sdk_common::{ use matrix_sdk_common::{
api::r0::keys::{ api::r0::keys::{
claim_keys::Response as KeysClaimResponse, get_keys::Response as KeysQueryResponse, claim_keys::Response as KeysClaimResponse,
upload_keys::Response as KeysUploadResponse, DeviceKeys, get_keys::Response as KeysQueryResponse,
upload_keys::{Request as KeysUploadRequest, Response as KeysUploadResponse},
}, },
api::r0::to_device::send_event_to_device::IncomingRequest as OwnedToDeviceRequest, api::r0::to_device::send_event_to_device::IncomingRequest as OwnedToDeviceRequest,
events::room::{ events::room::{
@ -51,7 +52,7 @@ use matrix_sdk_common::{
identifiers::{DeviceId, DeviceKeyAlgorithm}, identifiers::{DeviceId, DeviceKeyAlgorithm},
}; };
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
use matrix_sdk_crypto::{CryptoStore, Device, OlmError, OlmMachine, OneTimeKeys, Sas}; use matrix_sdk_crypto::{CryptoStore, Device, OlmError, OlmMachine, Sas};
use zeroize::Zeroizing; use zeroize::Zeroizing;
#[cfg(not(target_arch = "wasm32"))] #[cfg(not(target_arch = "wasm32"))]
@ -1326,9 +1327,7 @@ impl BaseClient {
/// Returns an empty error if no keys need to be uploaded. /// Returns an empty error if no keys need to be uploaded.
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
#[cfg_attr(docsrs, doc(cfg(feature = "encryption")))] #[cfg_attr(docsrs, doc(cfg(feature = "encryption")))]
pub async fn keys_for_upload( pub async fn keys_for_upload(&self) -> StdResult<KeysUploadRequest, ()> {
&self,
) -> StdResult<(Option<DeviceKeys>, Option<OneTimeKeys>), ()> {
let olm = self.olm.lock().await; let olm = self.olm.lock().await;
match &*olm { match &*olm {

View File

@ -125,7 +125,7 @@ impl OlmMachine {
} }
} }
/// Create a new OlmMachine with the given `CryptoStore`. /// Create a new OlmMachine with the given [`CryptoStore`].
/// ///
/// The created machine will keep the encryption keys only in memory and /// The created machine will keep the encryption keys only in memory and
/// once the object is dropped the keys will be lost. /// once the object is dropped the keys will be lost.
@ -142,6 +142,8 @@ impl OlmMachine {
/// ///
/// * `store` - A `Cryptostore` implementation that will be used to store /// * `store` - A `Cryptostore` implementation that will be used to store
/// the encryption keys. /// the encryption keys.
///
/// [`Cryptostore`]: trait.CryptoStore.html
pub async fn new_with_store( pub async fn new_with_store(
user_id: UserId, user_id: UserId,
device_id: Box<DeviceId>, device_id: Box<DeviceId>,
@ -171,8 +173,6 @@ impl OlmMachine {
}) })
} }
#[cfg(feature = "sqlite-cryptostore")]
#[instrument(skip(path, passphrase))]
/// Create a new machine with the default crypto store. /// Create a new machine with the default crypto store.
/// ///
/// The default store uses a SQLite database to store the encryption keys. /// The default store uses a SQLite database to store the encryption keys.
@ -182,6 +182,8 @@ impl OlmMachine {
/// * `user_id` - The unique id of the user that owns this machine. /// * `user_id` - The unique id of the user that owns this machine.
/// ///
/// * `device_id` - The unique id of the device that owns this machine. /// * `device_id` - The unique id of the device that owns this machine.
#[cfg(feature = "sqlite-cryptostore")]
#[instrument(skip(path, passphrase))]
pub async fn new_with_default_store<P: AsRef<Path>>( pub async fn new_with_default_store<P: AsRef<Path>>(
user_id: &UserId, user_id: &UserId,
device_id: &DeviceId, device_id: &DeviceId,
@ -210,6 +212,29 @@ impl OlmMachine {
} }
/// Should account or one-time keys be uploaded to the server. /// Should account or one-time keys be uploaded to the server.
///
/// This needs to be checked periodically, ideally after every sync request.
///
/// # Example
///
/// ```
/// # use std::convert::TryFrom;
/// # use matrix_sdk_crypto::OlmMachine;
/// # use matrix_sdk_common::identifiers::UserId;
/// # use futures::executor::block_on;
/// # let alice = UserId::try_from("@alice:example.org").unwrap();
/// # let machine = OlmMachine::new(&alice, "DEVICEID".into());
/// # block_on(async {
/// if machine.should_upload_keys().await {
/// let request = machine
/// .keys_for_upload()
/// .await
/// .unwrap();
///
/// // Upload the keys here.
/// }
/// # });
/// ```
pub async fn should_upload_keys(&self) -> bool { pub async fn should_upload_keys(&self) -> bool {
self.account.should_upload_keys().await self.account.should_upload_keys().await
} }
@ -478,10 +503,13 @@ impl OlmMachine {
/// ///
/// [`receive_keys_upload_response`]: #method.receive_keys_upload_response /// [`receive_keys_upload_response`]: #method.receive_keys_upload_response
/// [`OlmMachine`]: struct.OlmMachine.html /// [`OlmMachine`]: struct.OlmMachine.html
pub async fn keys_for_upload( pub async fn keys_for_upload(&self) -> StdResult<upload_keys::Request, ()> {
&self, let (device_keys, one_time_keys) = self.account.keys_for_upload().await?;
) -> StdResult<(Option<DeviceKeys>, Option<OneTimeKeys>), ()> {
self.account.keys_for_upload().await Ok(upload_keys::Request {
device_keys,
one_time_keys,
})
} }
/// Try to decrypt an Olm message. /// Try to decrypt an Olm message.
@ -1369,7 +1397,7 @@ mod test {
async fn get_prepared_machine() -> (OlmMachine, OneTimeKeys) { async fn get_prepared_machine() -> (OlmMachine, OneTimeKeys) {
let machine = OlmMachine::new(&user_id(), &alice_device_id()); let machine = OlmMachine::new(&user_id(), &alice_device_id());
machine.account.update_uploaded_key_count(0); machine.account.update_uploaded_key_count(0);
let (_, otk) = machine let request = machine
.keys_for_upload() .keys_for_upload()
.await .await
.expect("Can't prepare initial key upload"); .expect("Can't prepare initial key upload");
@ -1379,7 +1407,7 @@ mod test {
.await .await
.unwrap(); .unwrap();
(machine, otk.unwrap()) (machine, request.one_time_keys.unwrap())
} }
async fn get_machine_after_query() -> (OlmMachine, OneTimeKeys) { async fn get_machine_after_query() -> (OlmMachine, OneTimeKeys) {
@ -1598,7 +1626,7 @@ mod test {
let identity_keys = machine.account.identity_keys(); let identity_keys = machine.account.identity_keys();
let ed25519_key = identity_keys.ed25519(); let ed25519_key = identity_keys.ed25519();
let (device_keys, mut one_time_keys) = machine let mut request = machine
.keys_for_upload() .keys_for_upload()
.await .await
.expect("Can't prepare initial key upload"); .expect("Can't prepare initial key upload");
@ -1607,7 +1635,7 @@ mod test {
&machine.user_id, &machine.user_id,
machine.device_id.as_str(), machine.device_id.as_str(),
ed25519_key, ed25519_key,
&mut json!(&mut one_time_keys.as_mut().unwrap().values_mut().next()), &mut json!(&mut request.one_time_keys.as_mut().unwrap().values_mut().next()),
); );
assert!(ret.is_ok()); assert!(ret.is_ok());
@ -1615,14 +1643,16 @@ mod test {
&machine.user_id, &machine.user_id,
machine.device_id.as_str(), machine.device_id.as_str(),
ed25519_key, ed25519_key,
&mut json!(&mut device_keys.unwrap()), &mut json!(&mut request.device_keys.unwrap()),
); );
assert!(ret.is_ok()); assert!(ret.is_ok());
let mut response = keys_upload_response(); let mut response = keys_upload_response();
response.one_time_key_counts.insert( response.one_time_key_counts.insert(
DeviceKeyAlgorithm::SignedCurve25519, DeviceKeyAlgorithm::SignedCurve25519,
(one_time_keys.unwrap().len() as u64).try_into().unwrap(), (request.one_time_keys.unwrap().len() as u64)
.try_into()
.unwrap(),
); );
machine machine