diff --git a/matrix_sdk_crypto/src/session_manager/sessions.rs b/matrix_sdk_crypto/src/session_manager/sessions.rs index 8742fb10..9417c3ea 100644 --- a/matrix_sdk_crypto/src/session_manager/sessions.rs +++ b/matrix_sdk_crypto/src/session_manager/sessions.rs @@ -195,9 +195,9 @@ impl SessionManager { // Add the list of devices that the user wishes to establish sessions // right now. for user_id in users { - let user_devices = self.store.get_user_devices(user_id).await?; + let user_devices = self.store.get_readonly_devices(user_id).await?; - for device in user_devices.devices() { + for (device_id, device) in user_devices.into_iter() { let sender_key = if let Some(k) = device.get_key(DeviceKeyAlgorithm::Curve25519) { k } else { @@ -216,10 +216,7 @@ impl SessionManager { missing .entry(user_id.to_owned()) .or_insert_with(BTreeMap::new) - .insert( - device.device_id().into(), - DeviceKeyAlgorithm::SignedCurve25519, - ); + .insert(device_id, DeviceKeyAlgorithm::SignedCurve25519); } } } diff --git a/matrix_sdk_crypto/src/store/sled.rs b/matrix_sdk_crypto/src/store/sled.rs index f833b770..121b8843 100644 --- a/matrix_sdk_crypto/src/store/sled.rs +++ b/matrix_sdk_crypto/src/store/sled.rs @@ -16,11 +16,11 @@ use std::{ collections::{HashMap, HashSet}, convert::TryFrom, path::{Path, PathBuf}, - sync::Arc, + sync::{Arc, RwLock}, }; use dashmap::DashSet; -use olm_rs::PicklingMode; +use olm_rs::{account::IdentityKeys, PicklingMode}; pub use sled::Error; use sled::{ transaction::{ConflictableTransactionError, TransactionError}, @@ -95,9 +95,17 @@ impl EncodeKey for (&str, &str, &str) { } } +#[derive(Clone, Debug)] +pub struct AccountInfo { + user_id: Arc, + device_id: Arc, + identity_keys: Arc, +} + /// An in-memory only store that will forget all the E2EE key once it's dropped. #[derive(Clone)] pub struct SledStore { + account_info: Arc>>, path: Option, inner: Db, pickle_key: Arc, @@ -159,6 +167,10 @@ impl SledStore { SledStore::open_helper(db, None, passphrase) } + fn get_account_info(&self) -> Option { + self.account_info.read().unwrap().clone() + } + fn open_helper(db: Db, path: Option, passphrase: Option<&str>) -> Result { let account = db.open_tree("account")?; let private_identity = db.open_tree("private_identity")?; @@ -184,6 +196,7 @@ impl SledStore { }; Ok(Self { + account_info: RwLock::new(None).into(), path, inner: db, pickle_key: pickle_key.into(), @@ -249,13 +262,12 @@ impl SledStore { &self, room_id: &RoomId, ) -> Result> { - let account = self - .load_account() - .await? + let account_info = self + .get_account_info() .ok_or(CryptoStoreError::AccountUnset)?; - let device_id: Arc = account.device_id().to_owned().into(); - let identity_keys = account.identity_keys; + let device_id: Arc = account_info.device_id.clone(); + let identity_keys = account_info.identity_keys.clone(); self.outbound_group_sessions .get(room_id.encode())? @@ -430,16 +442,31 @@ impl CryptoStore for SledStore { self.load_tracked_users().await?; - Ok(Some(ReadOnlyAccount::from_pickle( - pickle, - self.get_pickle_mode(), - )?)) + let account = ReadOnlyAccount::from_pickle(pickle, self.get_pickle_mode())?; + + let account_info = AccountInfo { + user_id: account.user_id.clone(), + device_id: account.device_id.clone(), + identity_keys: account.identity_keys.clone(), + }; + + *self.account_info.write().unwrap() = Some(account_info); + + Ok(Some(account)) } else { Ok(None) } } async fn save_account(&self, account: ReadOnlyAccount) -> Result<()> { + let account_info = AccountInfo { + user_id: account.user_id.clone(), + device_id: account.device_id.clone(), + identity_keys: account.identity_keys.clone(), + }; + + *self.account_info.write().unwrap() = Some(account_info); + let changes = Changes { account: Some(account), ..Default::default() @@ -453,11 +480,14 @@ impl CryptoStore for SledStore { } async fn get_sessions(&self, sender_key: &str) -> Result>>>> { - let account = self - .load_account() - .await? + let account_info = self + .get_account_info() .ok_or(CryptoStoreError::AccountUnset)?; + let user_id: Arc = account_info.user_id.clone(); + let device_id: Arc = account_info.device_id.clone(); + let identity_keys = account_info.identity_keys.clone(); + if self.session_cache.get(sender_key).is_none() { let sessions: Result> = self .sessions @@ -465,9 +495,9 @@ impl CryptoStore for SledStore { .map(|s| serde_json::from_slice(&s?.1).map_err(CryptoStoreError::Serialization)) .map(|p| { Session::from_pickle( - account.user_id.clone(), - account.device_id.clone(), - account.identity_keys.clone(), + user_id.clone(), + device_id.clone(), + identity_keys.clone(), p?, self.get_pickle_mode(), )