diff --git a/matrix_sdk/src/client.rs b/matrix_sdk/src/client.rs index 3cee36e2..ad1fc71a 100644 --- a/matrix_sdk/src/client.rs +++ b/matrix_sdk/src/client.rs @@ -105,7 +105,7 @@ use matrix_sdk_common::{ #[cfg(feature = "encryption")] use matrix_sdk_common::{ api::r0::{ - keys::{claim_keys, get_keys, upload_keys}, + keys::{get_keys, upload_keys}, to_device::send_event_to_device::{ Request as RumaToDeviceRequest, Response as ToDeviceResponse, }, @@ -143,6 +143,9 @@ pub struct Client { /// flight per room. #[cfg(feature = "encryption")] group_session_locks: DashMap>>, + #[cfg(feature = "encryption")] + /// Lock making sure we're only doing one key claim request at a time. + key_claim_lock: Arc>, } #[cfg(not(tarpaulin_include))] @@ -404,6 +407,8 @@ impl Client { base_client, #[cfg(feature = "encryption")] group_session_locks: DashMap::new(), + #[cfg(feature = "encryption")] + key_claim_lock: Arc::new(Mutex::new(())), }) } @@ -1017,19 +1022,16 @@ impl Client { let _guard = mutex.lock().await; - let missing_sessions = { + { let room = self.base_client.get_joined_room(room_id).await; let room = room.as_ref().unwrap().read().await; - let members = room + let mut members = room .joined_members .keys() .chain(room.invited_members.keys()); - self.base_client.get_missing_sessions(members).await? + self.claim_one_time_keys(&mut members).await?; }; - if let Some((request_id, request)) = missing_sessions { - self.claim_one_time_keys(&request_id, request).await?; - } let response = self.share_group_session(room_id).await; self.group_session_locks.remove(room_id); @@ -1520,6 +1522,10 @@ impl Client { #[cfg(feature = "encryption")] { + if let Err(e) = self.claim_one_time_keys(&mut [].iter()).await { + warn!("Error while claiming one-time keys {:?}", e); + } + for r in self.base_client.outgoing_requests().await { match r.request() { OutgoingRequests::KeysQuery(request) => { @@ -1582,23 +1588,20 @@ impl Client { /// /// * `users` - The list of user/device pairs that we should claim keys for. /// - /// # Panics - /// - /// Panics if the client isn't logged in, or if no encryption keys need to - /// be uploaded. #[cfg(feature = "encryption")] #[cfg_attr(feature = "docs", doc(cfg(encryption)))] - #[instrument] - async fn claim_one_time_keys( - &self, - request_id: &Uuid, - request: claim_keys::Request, - ) -> Result { - let response = self.send(request).await?; - self.base_client - .mark_request_as_sent(request_id, &response) - .await?; - Ok(response) + #[instrument(skip(users))] + async fn claim_one_time_keys(&self, users: &mut impl Iterator) -> Result<()> { + let _lock = self.key_claim_lock.lock().await; + + if let Some((request_id, request)) = self.base_client.get_missing_sessions(users).await? { + let response = self.send(request).await?; + self.base_client + .mark_request_as_sent(&request_id, &response) + .await?; + } + + Ok(()) } /// Share a group session for a room. diff --git a/matrix_sdk_base/src/client.rs b/matrix_sdk_base/src/client.rs index 9161e376..eb0ad8e1 100644 --- a/matrix_sdk_base/src/client.rs +++ b/matrix_sdk_base/src/client.rs @@ -1291,7 +1291,7 @@ impl BaseClient { #[cfg_attr(feature = "docs", doc(cfg(encryption)))] pub async fn get_missing_sessions( &self, - users: impl Iterator, + users: &mut impl Iterator, ) -> Result> { let olm = self.olm.lock().await; diff --git a/matrix_sdk_crypto/src/machine.rs b/matrix_sdk_crypto/src/machine.rs index 90f1bb4d..931739ad 100644 --- a/matrix_sdk_crypto/src/machine.rs +++ b/matrix_sdk_crypto/src/machine.rs @@ -379,7 +379,7 @@ impl OlmMachine { /// [`mark_request_as_sent`]: #method.mark_request_as_sent pub async fn get_missing_sessions( &self, - users: impl Iterator, + users: &mut impl Iterator, ) -> OlmResult> { let mut missing = BTreeMap::new(); @@ -1456,7 +1456,7 @@ pub(crate) mod test { let alice_device = alice_device_id(); let (_, missing_sessions) = machine - .get_missing_sessions([alice.clone()].iter()) + .get_missing_sessions(&mut [alice.clone()].iter()) .await .unwrap() .unwrap();