// Copyright 2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. use std::{ collections::{BTreeMap, BTreeSet, HashMap, HashSet}, sync::Arc, }; use dashmap::DashMap; use futures::future::join_all; use matrix_sdk_common::{executor::spawn, uuid::Uuid}; use ruma::{ events::{ room::{encrypted::EncryptedEventContent, history_visibility::HistoryVisibility}, AnyMessageEventContent, AnyToDeviceEventContent, EventType, }, serde::Raw, to_device::DeviceIdOrAllDevices, DeviceId, DeviceIdBox, RoomId, UserId, }; use tracing::{debug, info, trace}; use crate::{ error::{EventError, MegolmResult, OlmResult}, olm::{Account, InboundGroupSession, OutboundGroupSession, Session, ShareInfo, ShareState}, store::{Changes, Result as StoreResult, Store}, Device, EncryptionSettings, OlmError, ToDeviceRequest, }; #[derive(Clone, Debug)] pub(crate) struct GroupSessionCache { store: Store, sessions: Arc>, /// A map from the request id to the group session that the request belongs /// to. Used to mark requests belonging to the session as shared. sessions_being_shared: Arc>, } impl GroupSessionCache { pub(crate) fn new(store: Store) -> Self { Self { store, sessions: Default::default(), sessions_being_shared: Default::default() } } pub(crate) fn insert(&self, session: OutboundGroupSession) { self.sessions.insert(session.room_id().to_owned(), session); } /// Either get a session for the given room from the cache or load it from /// the store. /// /// # Arguments /// /// * `room_id` - The id of the room this session is used for. pub async fn get_or_load(&self, room_id: &RoomId) -> StoreResult> { // Get the cached session, if there isn't one load one from the store // and put it in the cache. if let Some(s) = self.sessions.get(room_id) { Ok(Some(s.clone())) } else if let Some(s) = self.store.get_outbound_group_sessions(room_id).await? { for request_id in s.pending_request_ids() { self.sessions_being_shared.insert(request_id, s.clone()); } self.sessions.insert(room_id.clone(), s.clone()); Ok(Some(s)) } else { Ok(None) } } /// Get an outbound group session for a room, if one exists. /// /// # Arguments /// /// * `room_id` - The id of the room for which we should get the outbound /// group session. fn get(&self, room_id: &RoomId) -> Option { self.sessions.get(room_id).map(|s| s.clone()) } /// Get or load the session for the given room with the given session id. /// /// This is the same as [get_or_load()](#method.get_or_load) but it will /// filter out the session if it doesn't match the given session id. pub async fn get_with_id( &self, room_id: &RoomId, session_id: &str, ) -> StoreResult> { Ok(self.get_or_load(room_id).await?.filter(|o| session_id == o.session_id())) } } #[derive(Debug, Clone)] pub struct GroupSessionManager { account: Account, /// Store for the encryption keys. /// Persists all the encryption keys so a client can resume the session /// without the need to create new keys. store: Store, /// The currently active outbound group sessions. sessions: GroupSessionCache, } impl GroupSessionManager { const MAX_TO_DEVICE_MESSAGES: usize = 250; pub(crate) fn new(account: Account, store: Store) -> Self { Self { account, store: store.clone(), sessions: GroupSessionCache::new(store) } } pub async fn invalidate_group_session(&self, room_id: &RoomId) -> StoreResult { if let Some(s) = self.sessions.get(room_id) { s.invalidate_session(); let mut changes = Changes::default(); changes.outbound_group_sessions.push(s.clone()); self.store.save_changes(changes).await?; Ok(true) } else { Ok(false) } } pub async fn mark_request_as_sent(&self, request_id: &Uuid) -> StoreResult<()> { if let Some((_, s)) = self.sessions.sessions_being_shared.remove(request_id) { s.mark_request_as_sent(request_id); let mut changes = Changes::default(); changes.outbound_group_sessions.push(s.clone()); self.store.save_changes(changes).await?; } Ok(()) } #[cfg(test)] pub fn get_outbound_group_session(&self, room_id: &RoomId) -> Option { self.sessions.get(room_id) } pub async fn encrypt( &self, room_id: &RoomId, content: AnyMessageEventContent, ) -> MegolmResult { let session = if let Some(s) = self.sessions.get(room_id) { s } else { panic!("Session wasn't created nor shared"); }; if session.expired() { panic!("Session expired"); } let content = session.encrypt(content).await; let mut changes = Changes::default(); changes.outbound_group_sessions.push(session); self.store.save_changes(changes).await?; Ok(content) } /// Create a new outbound group session. /// /// This also creates a matching inbound group session and saves that one in /// the store. pub async fn create_outbound_group_session( &self, room_id: &RoomId, settings: EncryptionSettings, ) -> OlmResult<(OutboundGroupSession, InboundGroupSession)> { let (outbound, inbound) = self .account .create_group_session_pair(room_id, settings) .await .map_err(|_| EventError::UnsupportedAlgorithm)?; self.sessions.insert(outbound.clone()); Ok((outbound, inbound)) } pub async fn get_or_create_outbound_session( &self, room_id: &RoomId, settings: EncryptionSettings, ) -> OlmResult<(OutboundGroupSession, Option)> { let outbound_session = self.sessions.get_or_load(room_id).await?; // If there is no session or the session has expired or is invalid, // create a new one. if let Some(s) = outbound_session { if s.expired() || s.invalidated() { self.create_outbound_group_session(room_id, settings) .await .map(|(o, i)| (o, i.into())) } else { Ok((s, None)) } } else { self.create_outbound_group_session(room_id, settings).await.map(|(o, i)| (o, i.into())) } } /// Encrypt the given content for the given devices and create a to-device /// requests that sends the encrypted content to them. async fn encrypt_session_for( content: AnyToDeviceEventContent, devices: Vec, message_index: u32, ) -> OlmResult<( Uuid, ToDeviceRequest, BTreeMap>, Vec, )> { let mut messages = BTreeMap::new(); let mut changed_sessions = Vec::new(); let mut share_infos = BTreeMap::new(); let encrypt = |device: Device, content: AnyToDeviceEventContent| async move { let mut message = BTreeMap::new(); let mut share_infos = BTreeMap::new(); let encrypted = device.encrypt(content.clone()).await; let used_session = match encrypted { Ok((session, encrypted)) => { message .entry(device.user_id().to_owned()) .or_insert_with(BTreeMap::new) .insert( DeviceIdOrAllDevices::DeviceId(device.device_id().into()), Raw::from(AnyToDeviceEventContent::RoomEncrypted(encrypted)), ); share_infos .entry(device.user_id().to_owned()) .or_insert_with(BTreeMap::new) .insert( device.device_id().to_owned(), ShareInfo { sender_key: session.sender_key().to_owned(), message_index, }, ); Some(session) } // TODO we'll want to create m.room_key.withheld here. Err(OlmError::MissingSession) | Err(OlmError::EventError(EventError::MissingSenderKey)) => None, Err(e) => return Err(e), }; Ok((used_session, share_infos, message)) }; let tasks: Vec<_> = devices.iter().map(|d| spawn(encrypt(d.clone(), content.clone()))).collect(); let results = join_all(tasks).await; for result in results { let (used_session, infos, message) = result.expect("Encryption task panicked")?; if let Some(session) = used_session { changed_sessions.push(session); } for (user, device_messages) in message.into_iter() { messages.entry(user).or_insert_with(BTreeMap::new).extend(device_messages); } for (user, infos) in infos.into_iter() { share_infos.entry(user).or_insert_with(BTreeMap::new).extend(infos); } } let id = Uuid::new_v4(); let request = ToDeviceRequest { event_type: EventType::RoomEncrypted, txn_id: id, messages }; trace!( recipient_count = request.message_count(), transaction_id = ?id, "Created a to-device request carrying a room_key" ); Ok((id, request, share_infos, changed_sessions)) } /// Given a list of user and an outbound session, return the list of users /// and their devices that this session should be shared with. /// /// Returns a boolean indicating whether the session needs to be rotated and /// the list of users/devices that should receive the session. pub async fn collect_session_recipients( &self, users: impl Iterator, history_visibility: HistoryVisibility, outbound: &OutboundGroupSession, ) -> OlmResult<(bool, HashMap>)> { let users: HashSet<&UserId> = users.collect(); let mut devices: HashMap> = HashMap::new(); debug!( users = ?users, history_visibility = ?history_visibility, session_id = outbound.session_id(), "Calculating group session recipients" ); let users_shared_with: HashSet = outbound.shared_with_set.iter().map(|k| k.key().clone()).collect(); let users_shared_with: HashSet<&UserId> = users_shared_with.iter().collect(); // A user left if a user is missing from the set of users that should // get the session but is in the set of users that received the session. let user_left = !users_shared_with.difference(&users).collect::>().is_empty(); let visibility_changed = outbound.settings().history_visibility != history_visibility; // To protect the room history we need to rotate the session if either: // // 1. Any user left the room. // 2. Any of the users' devices got deleted or blacklisted. // 3. The history visibility changed. // // This is calculated in the following code and stored in this variable. let mut should_rotate = user_left || visibility_changed; for user_id in users { let user_devices = self.store.get_user_devices(user_id).await?; let non_blacklisted_devices: Vec = user_devices.devices().filter(|d| !d.is_blacklisted()).collect(); // If we haven't already concluded that the session should be // rotated for other reasons, we also need to check whether any // of the devices in the session got deleted or blacklisted in the // meantime. If so, we should also rotate the session. if !should_rotate { // Device IDs that should receive this session let non_blacklisted_device_ids: HashSet<&DeviceId> = non_blacklisted_devices.iter().map(|d| d.device_id()).collect(); if let Some(shared) = outbound.shared_with_set.get(user_id) { #[allow(clippy::map_clone)] // Devices that received this session let shared: HashSet = shared.iter().map(|d| d.key().clone()).collect(); let shared: HashSet<&DeviceId> = shared.iter().map(|d| d.as_ref()).collect(); // The set difference between // // 1. Devices that had previously received the session, and // 2. Devices that would now receive the session // // represents newly deleted or blacklisted devices. If this // set is non-empty, we must rotate. let newly_deleted_or_blacklisted = shared.difference(&non_blacklisted_device_ids).collect::>(); if !newly_deleted_or_blacklisted.is_empty() { should_rotate = true; } }; } devices.entry(user_id.clone()).or_insert_with(Vec::new).extend(non_blacklisted_devices); } debug!( should_rotate = should_rotate, session_id = outbound.session_id(), "Done calculating group session recipients" ); Ok((should_rotate, devices)) } pub async fn encrypt_request( chunk: Vec, content: AnyToDeviceEventContent, outbound: OutboundGroupSession, message_index: u32, being_shared: Arc>, ) -> OlmResult> { let (id, request, share_infos, used_sessions) = Self::encrypt_session_for(content.clone(), chunk, message_index).await?; if !request.messages.is_empty() { outbound.add_request(id, request.into(), share_infos); being_shared.insert(id, outbound.clone()); } Ok(used_sessions) } pub(crate) fn session_cache(&self) -> GroupSessionCache { self.sessions.clone() } /// Get to-device requests to share a group session with users in a room. /// /// # Arguments /// /// `room_id` - The room id of the room where the group session will be /// used. /// /// `users` - The list of users that should receive the group session. /// /// `encryption_settings` - The settings that should be used for the group /// session. pub async fn share_group_session( &self, room_id: &RoomId, users: impl Iterator, encryption_settings: impl Into, ) -> OlmResult>> { debug!(room_id = room_id.as_str(), "Checking if a room key needs to be shared",); let encryption_settings = encryption_settings.into(); let history_visibility = encryption_settings.history_visibility.clone(); let mut changes = Changes::default(); let (outbound, inbound) = self.get_or_create_outbound_session(room_id, encryption_settings.clone()).await?; if let Some(inbound) = inbound { changes.outbound_group_sessions.push(outbound.clone()); changes.inbound_group_sessions.push(inbound); } let (should_rotate, devices) = self.collect_session_recipients(users, history_visibility, &outbound).await?; let outbound = if should_rotate { let old_session_id = outbound.session_id(); let (outbound, inbound) = self.create_outbound_group_session(room_id, encryption_settings).await?; changes.outbound_group_sessions.push(outbound.clone()); changes.inbound_group_sessions.push(inbound); debug!( room_id = room_id.as_str(), old_session_id = old_session_id, session_id = outbound.session_id(), "A user/device has left the group since we last sent a message, \ rotating the outbound session.", ); outbound } else { outbound }; let devices: Vec = devices .into_iter() .map(|(_, d)| { d.into_iter() .filter(|d| matches!(outbound.is_shared_with(d), ShareState::NotShared)) }) .flatten() .collect(); let key_content = outbound.as_content().await; let message_index = outbound.message_index().await; if !devices.is_empty() { let users = devices.iter().fold(BTreeMap::new(), |mut acc, d| { acc.entry(d.user_id()).or_insert_with(BTreeSet::new).insert(d.device_id()); acc }); info!( index = message_index, users = ?users, room_id = room_id.as_str(), "Sharing an outbound group session", ); } let tasks: Vec<_> = devices .chunks(Self::MAX_TO_DEVICE_MESSAGES) .map(|chunk| { spawn(Self::encrypt_request( chunk.to_vec(), key_content.clone(), outbound.clone(), message_index, self.sessions.sessions_being_shared.clone(), )) }) .collect(); for result in join_all(tasks).await { let used_sessions: OlmResult> = result.expect("Encryption task panicked"); changes.sessions.extend(used_sessions?); } let requests = outbound.pending_requests(); debug!( room_id = room_id.as_str(), session_id = outbound.session_id(), request_count = requests.len(), "Done generating to-device requests for a room key share" ); if requests.is_empty() { debug!( room_id = room_id.as_str(), session_id = outbound.session_id(), "The outbound group session doesn't need to be shared with \ anyone, marking as shared", ); outbound.mark_as_shared(); } let session_count = changes.sessions.len(); self.store.save_changes(changes).await?; debug!( room_id = room_id.as_str(), session_id = outbound.session_id(), session_count = session_count, "Stored the changed sessions after encrypting an room key" ); Ok(requests) } } #[cfg(test)] mod test { use matrix_sdk_common::uuid::Uuid; use matrix_sdk_test::response_from_file; use ruma::{ api::{ client::r0::keys::{claim_keys, get_keys}, IncomingResponse, }, room_id, user_id, DeviceIdBox, UserId, }; use serde_json::Value; use crate::{EncryptionSettings, OlmMachine}; fn alice_id() -> UserId { user_id!("@alice:example.org") } fn alice_device_id() -> DeviceIdBox { "JLAFKJWSCS".into() } fn keys_query_response() -> get_keys::Response { let data = include_bytes!("../../benches/keys_query.json"); let data: Value = serde_json::from_slice(data).unwrap(); let data = response_from_file(&data); get_keys::Response::try_from_http_response(data) .expect("Can't parse the keys upload response") } fn keys_claim_response() -> claim_keys::Response { let data = include_bytes!("../../benches/keys_claim.json"); let data: Value = serde_json::from_slice(data).unwrap(); let data = response_from_file(&data); claim_keys::Response::try_from_http_response(data) .expect("Can't parse the keys upload response") } async fn machine() -> OlmMachine { let keys_query = keys_query_response(); let keys_claim = keys_claim_response(); let uuid = Uuid::new_v4(); let machine = OlmMachine::new(&alice_id(), &alice_device_id()); machine.mark_request_as_sent(&uuid, &keys_query).await.unwrap(); machine.mark_request_as_sent(&uuid, &keys_claim).await.unwrap(); machine } #[tokio::test] async fn test_sharing() { let machine = machine().await; let room_id = room_id!("!test:localhost"); let keys_claim = keys_claim_response(); let users: Vec<_> = keys_claim.one_time_keys.keys().collect(); let requests = machine .share_group_session(&room_id, users.clone().into_iter(), EncryptionSettings::default()) .await .unwrap(); let event_count: usize = requests.iter().map(|r| r.message_count()).sum(); // The keys claim response has a couple of one-time keys with invalid // signatures, thus only 148 sessions are actually created, we check // that all 148 valid sessions get an room key. assert_eq!(event_count, 148); } }