crypto: Remove the last mutable self borrows in the Olm machine methods.

master
Damir Jelić 2020-08-11 12:22:14 +02:00
parent 72168ce084
commit 528483ef0e
2 changed files with 29 additions and 17 deletions

View File

@ -1288,9 +1288,9 @@ impl BaseClient {
#[cfg_attr(docsrs, doc(cfg(feature = "encryption")))] #[cfg_attr(docsrs, doc(cfg(feature = "encryption")))]
pub async fn share_group_session(&self, room_id: &RoomId) -> Result<Vec<OwnedToDeviceRequest>> { pub async fn share_group_session(&self, room_id: &RoomId) -> Result<Vec<OwnedToDeviceRequest>> {
let room = self.get_joined_room(room_id).await.expect("No room found"); let room = self.get_joined_room(room_id).await.expect("No room found");
let mut olm = self.olm.lock().await; let olm = self.olm.lock().await;
match &mut *olm { match &*olm {
Some(o) => { Some(o) => {
let room = room.write().await; let room = room.write().await;
@ -1417,9 +1417,9 @@ impl BaseClient {
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
#[cfg_attr(docsrs, doc(cfg(feature = "encryption")))] #[cfg_attr(docsrs, doc(cfg(feature = "encryption")))]
pub async fn invalidate_group_session(&self, room_id: &RoomId) -> bool { pub async fn invalidate_group_session(&self, room_id: &RoomId) -> bool {
let mut olm = self.olm.lock().await; let olm = self.olm.lock().await;
match &mut *olm { match &*olm {
Some(o) => o.invalidate_group_session(room_id), Some(o) => o.invalidate_group_session(room_id),
None => false, None => false,
} }

View File

@ -15,13 +15,15 @@
#[cfg(feature = "sqlite-cryptostore")] #[cfg(feature = "sqlite-cryptostore")]
use std::path::Path; use std::path::Path;
use std::{ use std::{
collections::{BTreeMap, HashMap, HashSet}, collections::{BTreeMap, HashSet},
convert::{TryFrom, TryInto}, convert::{TryFrom, TryInto},
mem, mem,
result::Result as StdResult, result::Result as StdResult,
sync::Arc, sync::Arc,
}; };
use dashmap::DashMap;
use api::r0::{ use api::r0::{
keys::{claim_keys, get_keys, upload_keys, DeviceKeys, OneTimeKey}, keys::{claim_keys, get_keys, upload_keys, DeviceKeys, OneTimeKey},
sync::sync_events::Response as SyncResponse, sync::sync_events::Response as SyncResponse,
@ -67,6 +69,7 @@ pub type OneTimeKeys = BTreeMap<DeviceKeyId, OneTimeKey>;
/// State machine implementation of the Olm/Megolm encryption protocol used for /// State machine implementation of the Olm/Megolm encryption protocol used for
/// Matrix end to end encryption. /// Matrix end to end encryption.
#[derive(Clone)]
pub struct OlmMachine { pub struct OlmMachine {
/// The unique user id that owns this account. /// The unique user id that owns this account.
user_id: UserId, user_id: UserId,
@ -79,7 +82,7 @@ pub struct OlmMachine {
/// without the need to create new keys. /// without the need to create new keys.
store: Arc<RwLock<Box<dyn CryptoStore>>>, store: Arc<RwLock<Box<dyn CryptoStore>>>,
/// The currently active outbound group sessions. /// The currently active outbound group sessions.
outbound_group_sessions: HashMap<RoomId, OutboundGroupSession>, outbound_group_sessions: Arc<DashMap<RoomId, OutboundGroupSession>>,
/// A state machine that is responsible to handle and keep track of SAS /// A state machine that is responsible to handle and keep track of SAS
/// verification flows. /// verification flows.
verification_machine: VerificationMachine, verification_machine: VerificationMachine,
@ -119,7 +122,7 @@ impl OlmMachine {
device_id: device_id.into(), device_id: device_id.into(),
account: account.clone(), account: account.clone(),
store: store.clone(), store: store.clone(),
outbound_group_sessions: HashMap::new(), outbound_group_sessions: Arc::new(DashMap::new()),
verification_machine: VerificationMachine::new(account, store), verification_machine: VerificationMachine::new(account, store),
} }
} }
@ -165,7 +168,7 @@ impl OlmMachine {
device_id, device_id,
account, account,
store, store,
outbound_group_sessions: HashMap::new(), outbound_group_sessions: Arc::new(DashMap::new()),
verification_machine, verification_machine,
}) })
} }
@ -802,7 +805,7 @@ impl OlmMachine {
/// ///
/// This also creates a matching inbound group session and saves that one in /// This also creates a matching inbound group session and saves that one in
/// the store. /// the store.
async fn create_outbound_group_session(&mut self, room_id: &RoomId) -> OlmResult<()> { async fn create_outbound_group_session(&self, room_id: &RoomId) -> OlmResult<()> {
let (outbound, inbound) = self.account.create_group_session_pair(room_id).await; let (outbound, inbound) = self.account.create_group_session_pair(room_id).await;
let _ = self let _ = self
@ -818,6 +821,17 @@ impl OlmMachine {
Ok(()) Ok(())
} }
/// Get an outbound group session for a room, if one exists.
///
/// # Arguments
///
/// * `room_id` - The id of the room for which we should get the outbound
/// group session.
fn get_outbound_group_session(&self, room_id: &RoomId) -> Option<OutboundGroupSession> {
#[allow(clippy::map_clone)]
self.outbound_group_sessions.get(room_id).map(|s| s.clone())
}
/// Encrypt a room message for the given room. /// Encrypt a room message for the given room.
/// ///
/// Beware that a group session needs to be shared before this method can be /// Beware that a group session needs to be shared before this method can be
@ -844,9 +858,7 @@ impl OlmMachine {
room_id: &RoomId, room_id: &RoomId,
content: MessageEventContent, content: MessageEventContent,
) -> MegolmResult<EncryptedEventContent> { ) -> MegolmResult<EncryptedEventContent> {
let session = self.outbound_group_sessions.get(room_id); let session = if let Some(s) = self.get_outbound_group_session(room_id) {
let session = if let Some(s) = session {
s s
} else { } else {
panic!("Session wasn't created nor shared"); panic!("Session wasn't created nor shared");
@ -929,7 +941,7 @@ impl OlmMachine {
/// ///
/// Returns true if a session was invalidated, false if there was no session /// Returns true if a session was invalidated, false if there was no session
/// to invalidate. /// to invalidate.
pub fn invalidate_group_session(&mut self, room_id: &RoomId) -> bool { pub fn invalidate_group_session(&self, room_id: &RoomId) -> bool {
self.outbound_group_sessions.remove(room_id).is_some() self.outbound_group_sessions.remove(room_id).is_some()
} }
@ -943,7 +955,7 @@ impl OlmMachine {
/// ///
/// `users` - The list of users that should receive the group session. /// `users` - The list of users that should receive the group session.
pub async fn share_group_session<'a, I>( pub async fn share_group_session<'a, I>(
&mut self, &self,
room_id: &RoomId, room_id: &RoomId,
users: I, users: I,
) -> OlmResult<Vec<OwnedToDeviceRequest>> ) -> OlmResult<Vec<OwnedToDeviceRequest>>
@ -1538,7 +1550,7 @@ mod test {
#[tokio::test] #[tokio::test]
async fn tests_session_invalidation() { async fn tests_session_invalidation() {
let mut machine = OlmMachine::new(&user_id(), &alice_device_id()); let machine = OlmMachine::new(&user_id(), &alice_device_id());
let room_id = room_id!("!test:example.org"); let room_id = room_id!("!test:example.org");
machine machine
@ -1754,7 +1766,7 @@ mod test {
#[tokio::test] #[tokio::test]
async fn test_room_key_sharing() { async fn test_room_key_sharing() {
let (mut alice, bob) = get_machine_pair_with_session().await; let (alice, bob) = get_machine_pair_with_session().await;
let room_id = room_id!("!test:example.org"); let room_id = room_id!("!test:example.org");
@ -1800,7 +1812,7 @@ mod test {
#[tokio::test] #[tokio::test]
async fn test_megolm_encryption() { async fn test_megolm_encryption() {
let (mut alice, bob) = get_machine_pair_with_setup_sessions().await; let (alice, bob) = get_machine_pair_with_setup_sessions().await;
let room_id = room_id!("!test:example.org"); let room_id = room_id!("!test:example.org");
let to_device_requests = alice let to_device_requests = alice