diff --git a/matrix_sdk_crypto/src/key_request.rs b/matrix_sdk_crypto/src/key_request.rs index 3319e5dc..b40797c5 100644 --- a/matrix_sdk_crypto/src/key_request.rs +++ b/matrix_sdk_crypto/src/key_request.rs @@ -14,8 +14,8 @@ // TODO // -// If we don't have a session, queue up a key claim request, once we get a -// session send out the key if we trust the device. +// handle the case where we can't create a session with a device. clearing our +// stale key share requests that we'll never be able to handle. // // If we don't trust the device store an object that remembers the request and // let the users introspect that object. @@ -72,6 +72,7 @@ pub(crate) struct KeyRequestMachine { outgoing_to_device_requests: Arc>, incoming_key_requests: Arc>>, + // TODO group these hashmaps into a logical unit. requests_waiting_for_session: Arc>>, requests_ids_waiting: Arc>>, @@ -203,11 +204,20 @@ impl KeyRequestMachine { event.content.request_id.to_owned(), ); self.requests_waiting_for_session.insert(key, event.clone()); + + let key = (device.user_id().to_owned(), device.device_id().into()); + self.requests_ids_waiting + .entry(key) + .or_insert_with(DashSet::new) + .insert(event.content.request_id.clone()); } /// Retry keyshares for a device that previously didn't have an Olm session /// with us. /// + /// This should be only called if the given user/device got a new Olm + /// session. + /// /// # Arguments /// /// * `user_id` - The user id of the device that we created the Olm session @@ -1089,4 +1099,187 @@ mod test { assert_eq!(session.session_id(), group_session.session_id()) } + + #[async_test] + async fn key_share_cycle_without_session() { + let alice_machine = get_machine().await; + let alice_account = Account { + inner: account(), + store: alice_machine.store.clone(), + }; + + let bob_machine = bob_machine(); + let bob_account = bob_account(); + + // Create Olm sessions for our two accounts. + let (alice_session, bob_session) = alice_account.create_session_for(&bob_account).await; + + let alice_device = ReadOnlyDevice::from_account(&alice_account).await; + let bob_device = ReadOnlyDevice::from_account(&bob_account).await; + + // Populate our stores with Olm sessions and a Megolm session. + + alice_machine + .store + .save_devices(&[bob_device]) + .await + .unwrap(); + bob_machine + .store + .save_devices(&[alice_device]) + .await + .unwrap(); + + let (group_session, inbound_group_session) = bob_account + .create_group_session_pair_with_defaults(&room_id()) + .await + .unwrap(); + + bob_machine + .store + .save_inbound_group_sessions(&[inbound_group_session]) + .await + .unwrap(); + + // Alice wants to request the outbound group session from bob. + alice_machine + .create_outgoing_key_request( + &room_id(), + bob_account.identity_keys.curve25519(), + group_session.session_id(), + ) + .await + .unwrap(); + group_session.mark_shared_with(&alice_id(), &alice_device_id()); + + // Put the outbound session into bobs store. + bob_machine + .outbound_group_sessions + .insert(room_id(), group_session.clone()); + + // Get the request and convert it into a event. + let request = alice_machine + .outgoing_to_device_requests + .iter() + .next() + .unwrap(); + let id = request.request_id; + let content = request + .request + .to_device() + .unwrap() + .messages + .get(&alice_id()) + .unwrap() + .get(&DeviceIdOrAllDevices::AllDevices) + .unwrap(); + let content: RoomKeyRequestEventContent = serde_json::from_str(content.get()).unwrap(); + + drop(request); + alice_machine + .mark_outgoing_request_as_sent(&id) + .await + .unwrap(); + + let event = ToDeviceEvent { + sender: alice_id(), + content, + }; + + // Bob doesn't have any outgoing requests. + assert!(bob_machine.outgoing_to_device_requests.is_empty()); + assert!(bob_machine.requests_ids_waiting.is_empty()); + assert!(bob_machine.requests_waiting_for_session.is_empty()); + + // Receive the room key request from alice. + bob_machine.receive_incoming_key_request(&event); + bob_machine.collect_incoming_key_requests().await.unwrap(); + // Bob doens't have an outgoing requests since we're lacking a session. + assert!(bob_machine.outgoing_to_device_requests.is_empty()); + assert!(!bob_machine.requests_ids_waiting.is_empty()); + assert!(!bob_machine.requests_waiting_for_session.is_empty()); + + // We create a session now. + alice_machine + .store + .save_sessions(&[alice_session]) + .await + .unwrap(); + bob_machine + .store + .save_sessions(&[bob_session]) + .await + .unwrap(); + + bob_machine.retry_keyshare(&alice_id(), &alice_device_id()); + bob_machine.collect_incoming_key_requests().await.unwrap(); + // Bob now has an outgoing requests. + assert!(!bob_machine.outgoing_to_device_requests.is_empty()); + assert!(bob_machine.requests_ids_waiting.is_empty()); + assert!(bob_machine.requests_waiting_for_session.is_empty()); + + // Get the request and convert it to a encrypted to-device event. + let request = bob_machine + .outgoing_to_device_requests + .iter() + .next() + .unwrap(); + + let id = request.request_id; + let content = request + .request + .to_device() + .unwrap() + .messages + .get(&alice_id()) + .unwrap() + .get(&DeviceIdOrAllDevices::DeviceId(alice_device_id())) + .unwrap(); + let content: EncryptedEventContent = serde_json::from_str(content.get()).unwrap(); + + drop(request); + bob_machine + .mark_outgoing_request_as_sent(&id) + .await + .unwrap(); + + let event = ToDeviceEvent { + sender: bob_id(), + content, + }; + + // Check that alice doesn't have the session. + assert!(alice_machine + .store + .get_inbound_group_session( + &room_id(), + &bob_account.identity_keys().curve25519(), + group_session.session_id() + ) + .await + .unwrap() + .is_none()); + + let (decrypted, sender_key, _) = + alice_account.decrypt_to_device_event(&event).await.unwrap(); + + if let AnyToDeviceEvent::ForwardedRoomKey(mut e) = decrypted.deserialize().unwrap() { + alice_machine + .receive_forwarded_room_key(&sender_key, &mut e) + .await + .unwrap(); + } else { + panic!("Invalid decrypted event type"); + } + + // Check that alice now does have the session. + let session = alice_machine + .store + .get_inbound_group_session(&room_id(), &sender_key, group_session.session_id()) + .await + .unwrap() + .unwrap(); + + assert_eq!(session.session_id(), group_session.session_id()) + } }