crypto: Move the outbound session filter logic into the group session cache

master
Damir Jelić 2021-04-20 13:35:47 +02:00
parent e15f7264dc
commit bfc7434f7e
2 changed files with 23 additions and 3 deletions

View File

@ -516,11 +516,10 @@ impl KeyRequestMachine {
) -> Result<Option<u32>, KeyshareDecision> { ) -> Result<Option<u32>, KeyshareDecision> {
let outbound_session = self let outbound_session = self
.outbound_group_sessions .outbound_group_sessions
.get_or_load(session.room_id()) .get_with_id(session.room_id(), session.session_id())
.await .await
.ok() .ok()
.flatten() .flatten();
.filter(|o| session.session_id() == o.session_id());
let own_device_check = || { let own_device_check = || {
if device.trust_state() { if device.trust_state() {

View File

@ -62,6 +62,12 @@ impl GroupSessionCache {
self.sessions.insert(session.room_id().to_owned(), session); self.sessions.insert(session.room_id().to_owned(), session);
} }
/// Either get a session for the given room from the cache or load it from
/// the store.
///
/// # Arguments
///
/// * `room_id` - The id of the room this session is used for.
pub async fn get_or_load(&self, room_id: &RoomId) -> StoreResult<Option<OutboundGroupSession>> { pub async fn get_or_load(&self, room_id: &RoomId) -> StoreResult<Option<OutboundGroupSession>> {
// Get the cached session, if there isn't one load one from the store // Get the cached session, if there isn't one load one from the store
// and put it in the cache. // and put it in the cache.
@ -89,6 +95,21 @@ impl GroupSessionCache {
fn get(&self, room_id: &RoomId) -> Option<OutboundGroupSession> { fn get(&self, room_id: &RoomId) -> Option<OutboundGroupSession> {
self.sessions.get(room_id).map(|s| s.clone()) self.sessions.get(room_id).map(|s| s.clone())
} }
/// Get or load the session for the given room with the given session id.
///
/// This is the same as [get_or_load()](#method.get_or_load) but it will
/// filter out the session if it doesn't match the given session id.
pub async fn get_with_id(
&self,
room_id: &RoomId,
session_id: &str,
) -> StoreResult<Option<OutboundGroupSession>> {
Ok(self
.get_or_load(room_id)
.await?
.filter(|o| session_id == o.session_id()))
}
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]