diff --git a/matrix_sdk/src/client.rs b/matrix_sdk/src/client.rs index 8d76ef22..403954f0 100644 --- a/matrix_sdk/src/client.rs +++ b/matrix_sdk/src/client.rs @@ -1057,7 +1057,15 @@ impl Client { if self.base_client.should_share_group_session(room_id).await { // TODO we need to make sure that only one such request is // in flight per room at a time. - self.share_group_session(room_id).await?; + let response = self.share_group_session(room_id).await; + + // If one of the responses failed invalidate the group + // session as using it would end up in undecryptable + // messages. + if let Err(r) = response { + self.base_client.invalidate_group_session(room_id).await; + return Err(r); + } } raw_content = serde_json::value::to_raw_value( diff --git a/matrix_sdk_base/src/client.rs b/matrix_sdk_base/src/client.rs index 0f8826c5..ab807239 100644 --- a/matrix_sdk_base/src/client.rs +++ b/matrix_sdk_base/src/client.rs @@ -1087,6 +1087,22 @@ impl BaseClient { Ok(()) } + /// Invalidate the currently active outbound group session for the given + /// room. + /// + /// Returns true if a session was invalidated, false if there was no session + /// to invalidate. + #[cfg(feature = "encryption")] + #[cfg_attr(docsrs, doc(cfg(feature = "encryption")))] + pub async fn invalidate_group_session(&self, room_id: &RoomId) -> bool { + let mut olm = self.olm.lock().await; + + match &mut *olm { + Some(o) => o.invalidate_group_session(room_id), + None => false, + } + } + pub(crate) async fn emit_timeline_event( &self, room_id: &RoomId, @@ -1586,4 +1602,24 @@ mod test { let room = client.get_joined_room(&room_id).await; assert!(room.is_some()); } + + #[async_test] + async fn test_group_session_invalidation() { + let client = get_client(); + let room_id = get_room_id(); + + let mut sync_response = EventBuilder::default() + .add_room_event(EventsFile::Member, RoomEvent::RoomMember) + .build_sync_response(); + + client + .receive_sync_response(&mut sync_response) + .await + .unwrap(); + + assert!(client.should_share_group_session(&room_id).await); + let _ = client.share_group_session(&room_id).await.unwrap(); + assert!(!client.should_share_group_session(&room_id).await); + client.invalidate_group_session(&room_id).await; + } } diff --git a/matrix_sdk_crypto/src/machine.rs b/matrix_sdk_crypto/src/machine.rs index 96a3b897..c55e1bc3 100644 --- a/matrix_sdk_crypto/src/machine.rs +++ b/matrix_sdk_crypto/src/machine.rs @@ -1219,6 +1219,15 @@ impl OlmMachine { } } + /// Invalidate the currently active outbound group session for the given + /// room. + /// + /// Returns true if a session was invalidated, false if there was no session + /// to invalidate. + pub fn invalidate_group_session(&mut self, room_id: &RoomId) -> bool { + self.outbound_group_sessions.remove(room_id).is_some() + } + // TODO accept an algorithm here /// Get to-device requests to share a group session with users in a room. /// @@ -1816,6 +1825,22 @@ mod test { assert!(ret.is_ok()); } + #[tokio::test] + async fn tests_session_invalidation() { + let mut machine = OlmMachine::new(&user_id(), DEVICE_ID); + let room_id = RoomId::try_from("!test:example.org").unwrap(); + + machine + .create_outbound_group_session(&room_id) + .await + .unwrap(); + assert!(machine.outbound_group_sessions.get(&room_id).is_some()); + + machine.invalidate_group_session(&room_id); + + assert!(machine.outbound_group_sessions.get(&room_id).is_none()); + } + #[tokio::test] async fn test_invalid_signature() { let machine = OlmMachine::new(&user_id(), DEVICE_ID);