diff --git a/matrix_sdk_crypto/src/key_request.rs b/matrix_sdk_crypto/src/key_request.rs index 67618038..3319e5dc 100644 --- a/matrix_sdk_crypto/src/key_request.rs +++ b/matrix_sdk_crypto/src/key_request.rs @@ -20,7 +20,7 @@ // If we don't trust the device store an object that remembers the request and // let the users introspect that object. -use dashmap::DashMap; +use dashmap::{DashMap, DashSet}; use serde::{Deserialize, Serialize}; use serde_json::value::to_raw_value; use std::{collections::BTreeMap, sync::Arc}; @@ -34,13 +34,13 @@ use matrix_sdk_common::{ room_key_request::{Action, RequestedKeyInfo, RoomKeyRequestEventContent}, AnyToDeviceEvent, EventType, ToDeviceEvent, }, - identifiers::{DeviceIdBox, EventEncryptionAlgorithm, RoomId, UserId}, + identifiers::{DeviceId, DeviceIdBox, EventEncryptionAlgorithm, RoomId, UserId}, uuid::Uuid, Raw, }; use crate::{ - error::OlmResult, + error::{OlmError, OlmResult}, olm::{InboundGroupSession, OutboundGroupSession}, requests::{OutgoingRequest, ToDeviceRequest}, store::{CryptoStoreError, Store}, @@ -72,6 +72,10 @@ pub(crate) struct KeyRequestMachine { outgoing_to_device_requests: Arc>, incoming_key_requests: Arc>>, + requests_waiting_for_session: + Arc>>, + requests_ids_waiting: Arc>>, + users_for_key_claim: Arc>>, } #[derive(Debug, Serialize, Deserialize)] @@ -142,6 +146,9 @@ impl KeyRequestMachine { outbound_group_sessions, outgoing_to_device_requests: Arc::new(DashMap::new()), incoming_key_requests: Arc::new(DashMap::new()), + requests_waiting_for_session: Arc::new(DashMap::new()), + users_for_key_claim: Arc::new(DashMap::new()), + requests_ids_waiting: Arc::new(DashMap::new()), } } @@ -163,6 +170,7 @@ impl KeyRequestMachine { let sender = event.sender.clone(); let device_id = event.content.requesting_device_id.clone(); let request_id = event.content.request_id.clone(); + self.incoming_key_requests .insert((sender, device_id, request_id), event.clone()); } @@ -180,6 +188,50 @@ impl KeyRequestMachine { Ok(()) } + fn handle_key_share_without_session( + &self, + device: Device, + event: &ToDeviceEvent, + ) { + self.users_for_key_claim + .entry(device.user_id().to_owned()) + .or_insert_with(DashSet::new) + .insert(device.device_id().into()); + let key = ( + device.user_id().to_owned(), + device.device_id().into(), + event.content.request_id.to_owned(), + ); + self.requests_waiting_for_session.insert(key, event.clone()); + } + + /// Retry keyshares for a device that previously didn't have an Olm session + /// with us. + /// + /// # Arguments + /// + /// * `user_id` - The user id of the device that we created the Olm session + /// with. + /// + /// * `device_id` - The device id of the device that got the Olm session. + pub fn retry_keyshare(&self, user_id: &UserId, device_id: &DeviceId) { + if let Some((_, request_ids)) = self + .requests_ids_waiting + .remove(&(user_id.to_owned(), device_id.into())) + { + for id in request_ids { + let key = (user_id.to_owned(), device_id.into(), id.to_owned()); + let content = self.requests_waiting_for_session.remove(&key); + + if let Some((_, c)) = content { + if !self.incoming_key_requests.contains_key(&key) { + self.incoming_key_requests.insert(key, c); + } + } + } + } + } + /// Handle a single incoming key request. #[instrument] async fn handle_key_request( @@ -248,8 +300,15 @@ impl KeyRequestMachine { device.device_id() ); - // TODO the missing session error here. - self.share_session(session, device).await?; + if let Err(e) = self.share_session(&session, &device).await { + match e { + OlmError::MissingSession => { + self.handle_key_share_without_session(device, event); + return Ok(()); + } + e => return Err(e.into()), + } + } } } else { warn!( @@ -262,8 +321,8 @@ impl KeyRequestMachine { Ok(()) } - async fn share_session(&self, session: InboundGroupSession, device: Device) -> OlmResult<()> { - let content = device.encrypt_session(session).await?; + async fn share_session(&self, session: &InboundGroupSession, device: &Device) -> OlmResult<()> { + let content = device.encrypt_session(session.clone()).await?; let id = Uuid::new_v4(); let mut messages = BTreeMap::new(); diff --git a/matrix_sdk_crypto/src/machine.rs b/matrix_sdk_crypto/src/machine.rs index d2b5bf49..fca75050 100644 --- a/matrix_sdk_crypto/src/machine.rs +++ b/matrix_sdk_crypto/src/machine.rs @@ -466,8 +466,7 @@ impl OlmMachine { // TODO if this session was created because a previous one was // wedged queue up a dummy event to be sent out. - // TODO if this session was created because of a key request, - // mark the forwarding keys to be sent out + self.key_request_machine.retry_keyshare(&user_id, device_id); } } Ok(())