From 425a07d6708c8933b06304ffe272cbb047cc9cde Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Damir=20Jeli=C4=87?= Date: Fri, 16 Oct 2020 16:53:10 +0200 Subject: [PATCH] crypto: Don't load all the devices in the sqlite store. --- matrix_sdk/src/device.rs | 5 +- matrix_sdk_crypto/src/identities/device.rs | 16 +- matrix_sdk_crypto/src/identities/manager.rs | 9 +- matrix_sdk_crypto/src/store/caches.rs | 49 +--- matrix_sdk_crypto/src/store/memorystore.rs | 20 +- matrix_sdk_crypto/src/store/mod.rs | 20 +- matrix_sdk_crypto/src/store/sqlite.rs | 266 ++++++++++++-------- 7 files changed, 216 insertions(+), 169 deletions(-) diff --git a/matrix_sdk/src/device.rs b/matrix_sdk/src/device.rs index 5c7b55fb..776b9f0e 100644 --- a/matrix_sdk/src/device.rs +++ b/matrix_sdk/src/device.rs @@ -19,7 +19,8 @@ use matrix_sdk_base::crypto::{ UserDevices as BaseUserDevices, }; use matrix_sdk_common::{ - api::r0::to_device::send_event_to_device::Request as ToDeviceRequest, identifiers::DeviceId, + api::r0::to_device::send_event_to_device::Request as ToDeviceRequest, + identifiers::{DeviceId, DeviceIdBox}, }; use crate::{error::Result, http_client::HttpClient, Sas}; @@ -114,7 +115,7 @@ impl UserDevices { } /// Iterator over all the device ids of the user devices. - pub fn keys(&self) -> impl Iterator { + pub fn keys(&self) -> impl Iterator { self.inner.keys() } diff --git a/matrix_sdk_crypto/src/identities/device.rs b/matrix_sdk_crypto/src/identities/device.rs index 578e7718..a0281217 100644 --- a/matrix_sdk_crypto/src/identities/device.rs +++ b/matrix_sdk_crypto/src/identities/device.rs @@ -13,7 +13,7 @@ // limitations under the License. use std::{ - collections::BTreeMap, + collections::{BTreeMap, HashMap}, convert::{TryFrom, TryInto}, ops::Deref, sync::{ @@ -30,7 +30,9 @@ use matrix_sdk_common::{ forwarded_room_key::ForwardedRoomKeyEventContent, room::encrypted::EncryptedEventContent, EventType, }, - identifiers::{DeviceId, DeviceKeyAlgorithm, DeviceKeyId, EventEncryptionAlgorithm, UserId}, + identifiers::{ + DeviceId, DeviceIdBox, DeviceKeyAlgorithm, DeviceKeyId, EventEncryptionAlgorithm, UserId, + }, locks::Mutex, }; use serde::{Deserialize, Serialize}; @@ -45,7 +47,7 @@ use crate::{ error::{EventError, OlmError, OlmResult, SignatureError}, identities::{OwnUserIdentity, UserIdentities}, olm::Utility, - store::{caches::ReadOnlyUserDevices, CryptoStore, Result as StoreResult}, + store::{CryptoStore, Result as StoreResult}, verification::VerificationMachine, Sas, ToDeviceRequest, }; @@ -168,7 +170,7 @@ impl Device { /// A read only view over all devices belonging to a user. #[derive(Debug)] pub struct UserDevices { - pub(crate) inner: ReadOnlyUserDevices, + pub(crate) inner: HashMap, pub(crate) verification_machine: VerificationMachine, pub(crate) own_identity: Option, pub(crate) device_owner_identity: Option, @@ -178,7 +180,7 @@ impl UserDevices { /// Get the specific device with the given device id. pub fn get(&self, device_id: &DeviceId) -> Option { self.inner.get(device_id).map(|d| Device { - inner: d, + inner: d.clone(), verification_machine: self.verification_machine.clone(), own_identity: self.own_identity.clone(), device_owner_identity: self.device_owner_identity.clone(), @@ -186,13 +188,13 @@ impl UserDevices { } /// Iterator over all the device ids of the user devices. - pub fn keys(&self) -> impl Iterator { + pub fn keys(&self) -> impl Iterator { self.inner.keys() } /// Iterator over all the devices of the user devices. pub fn devices(&self) -> impl Iterator + '_ { - self.inner.devices().map(move |d| Device { + self.inner.values().map(move |d| Device { inner: d.clone(), verification_machine: self.verification_machine.clone(), own_identity: self.own_identity.clone(), diff --git a/matrix_sdk_crypto/src/identities/manager.rs b/matrix_sdk_crypto/src/identities/manager.rs index ef9d60db..ba6fb749 100644 --- a/matrix_sdk_crypto/src/identities/manager.rs +++ b/matrix_sdk_crypto/src/identities/manager.rs @@ -165,18 +165,17 @@ impl IdentityManager { changed_devices.push(device); } - let current_devices: HashSet<&DeviceId> = - device_map.keys().map(|id| id.as_ref()).collect(); + let current_devices: HashSet<&DeviceIdBox> = device_map.keys().collect(); let stored_devices = self.store.get_readonly_devices(&user_id).await?; - let stored_devices_set: HashSet<&DeviceId> = stored_devices.keys().collect(); + let stored_devices_set: HashSet<&DeviceIdBox> = stored_devices.keys().collect(); let deleted_devices = stored_devices_set.difference(¤t_devices); for device_id in deleted_devices { users_with_new_or_deleted_devices.insert(user_id); - if let Some(device) = stored_devices.get(device_id) { + if let Some(device) = stored_devices.get(*device_id) { device.mark_as_deleted(); - self.store.delete_device(device).await?; + self.store.delete_device(device.clone()).await?; } } } diff --git a/matrix_sdk_crypto/src/store/caches.rs b/matrix_sdk_crypto/src/store/caches.rs index eb66198d..3b306d6b 100644 --- a/matrix_sdk_crypto/src/store/caches.rs +++ b/matrix_sdk_crypto/src/store/caches.rs @@ -19,9 +19,9 @@ use std::{collections::HashMap, sync::Arc}; -use dashmap::{DashMap, ReadOnlyView}; +use dashmap::DashMap; use matrix_sdk_common::{ - identifiers::{DeviceId, RoomId, UserId}, + identifiers::{DeviceId, DeviceIdBox, RoomId, UserId}, locks::Mutex, }; @@ -145,29 +145,6 @@ pub struct DeviceStore { entries: Arc, ReadOnlyDevice>>>, } -/// A read only view over all devices belonging to a user. -#[derive(Debug)] -pub struct ReadOnlyUserDevices { - entries: ReadOnlyView, ReadOnlyDevice>, -} - -impl ReadOnlyUserDevices { - /// Get the specific device with the given device id. - pub fn get(&self, device_id: &DeviceId) -> Option { - self.entries.get(device_id).cloned() - } - - /// Iterator over all the device ids of the user devices. - pub fn keys(&self) -> impl Iterator { - self.entries.keys().map(|id| id.as_ref()) - } - - /// Iterator over all the devices of the user devices. - pub fn devices(&self) -> impl Iterator { - self.entries.values() - } -} - impl DeviceStore { /// Create a new empty device store. pub fn new() -> Self { @@ -206,15 +183,13 @@ impl DeviceStore { } /// Get a read-only view over all devices of the given user. - pub fn user_devices(&self, user_id: &UserId) -> ReadOnlyUserDevices { - ReadOnlyUserDevices { - entries: self - .entries - .entry(user_id.clone()) - .or_insert_with(DashMap::new) - .clone() - .into_read_only(), - } + pub fn user_devices(&self, user_id: &UserId) -> HashMap { + self.entries + .entry(user_id.clone()) + .or_insert_with(DashMap::new) + .iter() + .map(|i| (i.key().to_owned(), i.value().clone())) + .collect() } } @@ -305,12 +280,12 @@ mod test { let user_devices = store.user_devices(device.user_id()); - assert_eq!(user_devices.keys().next().unwrap(), device.device_id()); - assert_eq!(user_devices.devices().next().unwrap(), &device); + assert_eq!(&**user_devices.keys().next().unwrap(), device.device_id()); + assert_eq!(user_devices.values().next().unwrap(), &device); let loaded_device = user_devices.get(device.device_id()).unwrap(); - assert_eq!(device, loaded_device); + assert_eq!(&device, loaded_device); store.remove(device.user_id(), device.device_id()); diff --git a/matrix_sdk_crypto/src/store/memorystore.rs b/matrix_sdk_crypto/src/store/memorystore.rs index d7f6a4da..2e5a8762 100644 --- a/matrix_sdk_crypto/src/store/memorystore.rs +++ b/matrix_sdk_crypto/src/store/memorystore.rs @@ -12,17 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{collections::HashSet, sync::Arc}; +use std::{ + collections::{HashMap, HashSet}, + sync::Arc, +}; use dashmap::{DashMap, DashSet}; use matrix_sdk_common::{ - identifiers::{DeviceId, RoomId, UserId}, + identifiers::{DeviceId, DeviceIdBox, RoomId, UserId}, locks::Mutex, }; use matrix_sdk_common_macros::async_trait; use super::{ - caches::{DeviceStore, GroupSessionStore, ReadOnlyUserDevices, SessionStore}, + caches::{DeviceStore, GroupSessionStore, SessionStore}, CryptoStore, InboundGroupSession, ReadOnlyAccount, Result, Session, }; use crate::identities::{ReadOnlyDevice, UserIdentities}; @@ -153,7 +156,10 @@ impl CryptoStore for MemoryStore { Ok(()) } - async fn get_user_devices(&self, user_id: &UserId) -> Result { + async fn get_user_devices( + &self, + user_id: &UserId, + ) -> Result> { Ok(self.devices.user_devices(user_id)) } @@ -273,12 +279,12 @@ mod test { let user_devices = store.get_user_devices(device.user_id()).await.unwrap(); - assert_eq!(user_devices.keys().next().unwrap(), device.device_id()); - assert_eq!(user_devices.devices().next().unwrap(), &device); + assert_eq!(&**user_devices.keys().next().unwrap(), device.device_id()); + assert_eq!(user_devices.values().next().unwrap(), &device); let loaded_device = user_devices.get(device.device_id()).unwrap(); - assert_eq!(device, loaded_device); + assert_eq!(&device, loaded_device); store.delete_device(device.clone()).await.unwrap(); assert!(store diff --git a/matrix_sdk_crypto/src/store/mod.rs b/matrix_sdk_crypto/src/store/mod.rs index 7dad4d07..b8b661ef 100644 --- a/matrix_sdk_crypto/src/store/mod.rs +++ b/matrix_sdk_crypto/src/store/mod.rs @@ -43,13 +43,19 @@ mod memorystore; #[cfg(feature = "sqlite_cryptostore")] pub(crate) mod sqlite; -use caches::ReadOnlyUserDevices; +use matrix_sdk_common::identifiers::DeviceIdBox; pub use memorystore::MemoryStore; #[cfg(not(target_arch = "wasm32"))] #[cfg(feature = "sqlite_cryptostore")] pub use sqlite::SqliteStore; -use std::{collections::HashSet, fmt::Debug, io::Error as IoError, ops::Deref, sync::Arc}; +use std::{ + collections::{HashMap, HashSet}, + fmt::Debug, + io::Error as IoError, + ops::Deref, + sync::Arc, +}; use olm_rs::errors::{OlmAccountError, OlmGroupSessionError, OlmSessionError}; use serde::{Deserialize, Serialize}; @@ -115,7 +121,10 @@ impl Store { self.inner.get_device(user_id, device_id).await } - pub async fn get_readonly_devices(&self, user_id: &UserId) -> Result { + pub async fn get_readonly_devices( + &self, + user_id: &UserId, + ) -> Result> { self.inner.get_user_devices(user_id).await } @@ -354,7 +363,10 @@ pub trait CryptoStore: Debug { /// # Arguments /// /// * `user_id` - The user for which we should get all the devices. - async fn get_user_devices(&self, user_id: &UserId) -> Result; + async fn get_user_devices( + &self, + user_id: &UserId, + ) -> Result>; /// Save the given user identities in the store. /// diff --git a/matrix_sdk_crypto/src/store/sqlite.rs b/matrix_sdk_crypto/src/store/sqlite.rs index 6bc710dc..f591af79 100644 --- a/matrix_sdk_crypto/src/store/sqlite.rs +++ b/matrix_sdk_crypto/src/store/sqlite.rs @@ -13,7 +13,7 @@ // limitations under the License. use std::{ - collections::{BTreeMap, HashSet}, + collections::{BTreeMap, HashMap, HashSet}, convert::TryFrom, path::{Path, PathBuf}, result::Result as StdResult, @@ -25,7 +25,8 @@ use dashmap::DashSet; use matrix_sdk_common::{ api::r0::keys::{CrossSigningKey, KeyUsage}, identifiers::{ - DeviceId, DeviceKeyAlgorithm, DeviceKeyId, EventEncryptionAlgorithm, RoomId, UserId, + DeviceId, DeviceIdBox, DeviceKeyAlgorithm, DeviceKeyId, EventEncryptionAlgorithm, RoomId, + UserId, }, instant::Duration, locks::Mutex, @@ -33,10 +34,7 @@ use matrix_sdk_common::{ use sqlx::{query, query_as, sqlite::SqliteConnectOptions, Connection, Executor, SqliteConnection}; use zeroize::Zeroizing; -use super::{ - caches::{DeviceStore, ReadOnlyUserDevices, SessionStore}, - CryptoStore, CryptoStoreError, Result, -}; +use super::{caches::SessionStore, CryptoStore, CryptoStoreError, Result}; use crate::{ identities::{LocalTrust, OwnUserIdentity, ReadOnlyDevice, UserIdentities, UserIdentity}, olm::{ @@ -56,7 +54,6 @@ pub struct SqliteStore { path: Arc, sessions: SessionStore, - devices: DeviceStore, tracked_users: Arc>, users_for_key_query: Arc>, @@ -149,7 +146,6 @@ impl SqliteStore { device_id: Arc::new(device_id.into()), account_info: Arc::new(SyncMutex::new(None)), sessions: SessionStore::new(), - devices: DeviceStore::new(), path: Arc::new(path), connection: Arc::new(Mutex::new(connection)), pickle_passphrase: Arc::new(passphrase), @@ -717,112 +713,167 @@ impl SqliteStore { Ok(()) } - async fn load_devices(&self) -> Result<()> { - let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; - let mut connection = self.connection.lock().await; + async fn load_device_data( + &self, + connection: &mut SqliteConnection, + device_row_id: i64, + user_id: &UserId, + device_id: DeviceIdBox, + trust_state: LocalTrust, + display_name: Option, + ) -> Result { + let algorithm_rows: Vec<(String,)> = + query_as("SELECT algorithm FROM algorithms WHERE device_id = ?") + .bind(device_row_id) + .fetch_all(&mut *connection) + .await?; - let rows: Vec<(i64, String, String, Option, i64)> = query_as( - "SELECT id, user_id, device_id, display_name, trust_state - FROM devices WHERE account_id = ?", + let algorithms = algorithm_rows + .iter() + .map(|row| { + let algorithm: &str = &row.0; + EventEncryptionAlgorithm::from(algorithm) + }) + .collect::>(); + + let key_rows: Vec<(String, String)> = + query_as("SELECT algorithm, key FROM device_keys WHERE device_id = ?") + .bind(device_row_id) + .fetch_all(&mut *connection) + .await?; + + let keys: BTreeMap = key_rows + .into_iter() + .filter_map(|row| { + let algorithm = row.0.parse::().ok()?; + let key = row.1; + + Some((DeviceKeyId::from_parts(algorithm, &device_id), key)) + }) + .collect(); + + let signature_rows: Vec<(String, String, String)> = query_as( + "SELECT user_id, key_algorithm, signature + FROM device_signatures WHERE device_id = ?", ) - .bind(account_id) + .bind(device_row_id) .fetch_all(&mut *connection) .await?; - for row in rows { - let device_row_id = row.0; - let user_id: &str = &row.1; - let user_id = if let Ok(u) = UserId::try_from(user_id) { + let mut signatures: BTreeMap> = BTreeMap::new(); + + for row in signature_rows { + let user_id = if let Ok(u) = UserId::try_from(&*row.0) { u } else { continue; }; - let device_id = &row.2.to_string(); - let display_name = &row.3; - let trust_state = LocalTrust::from(row.4); + let key_algorithm = if let Ok(k) = row.1.parse::() { + k + } else { + continue; + }; - let algorithm_rows: Vec<(String,)> = - query_as("SELECT algorithm FROM algorithms WHERE device_id = ?") - .bind(device_row_id) - .fetch_all(&mut *connection) - .await?; + let signature = row.2; - let algorithms = algorithm_rows - .iter() - .map(|row| { - let algorithm: &str = &row.0; - EventEncryptionAlgorithm::from(algorithm) - }) - .collect::>(); - - let key_rows: Vec<(String, String)> = - query_as("SELECT algorithm, key FROM device_keys WHERE device_id = ?") - .bind(device_row_id) - .fetch_all(&mut *connection) - .await?; - - let keys: BTreeMap = key_rows - .into_iter() - .filter_map(|row| { - let algorithm = row.0.parse::().ok()?; - let key = row.1; - - Some(( - DeviceKeyId::from_parts(algorithm, device_id.as_str().into()), - key, - )) - }) - .collect(); - - let signature_rows: Vec<(String, String, String)> = query_as( - "SELECT user_id, key_algorithm, signature - FROM device_signatures WHERE device_id = ?", - ) - .bind(device_row_id) - .fetch_all(&mut *connection) - .await?; - - let mut signatures: BTreeMap> = BTreeMap::new(); - - for row in signature_rows { - let user_id = if let Ok(u) = UserId::try_from(&*row.0) { - u - } else { - continue; - }; - - let key_algorithm = if let Ok(k) = row.1.parse::() { - k - } else { - continue; - }; - - let signature = row.2; - - signatures - .entry(user_id) - .or_insert_with(BTreeMap::new) - .insert( - DeviceKeyId::from_parts(key_algorithm, device_id.as_str().into()), - signature.to_owned(), - ); - } - - let device = ReadOnlyDevice::new( - user_id, - device_id.as_str().into(), - display_name.clone(), - trust_state, - algorithms, - keys, - signatures, - ); - - self.devices.add(device); + signatures + .entry(user_id) + .or_insert_with(BTreeMap::new) + .insert( + DeviceKeyId::from_parts(key_algorithm, device_id.as_str().into()), + signature.to_owned(), + ); } - Ok(()) + Ok(ReadOnlyDevice::new( + user_id.to_owned(), + device_id, + display_name.clone(), + trust_state, + algorithms, + keys, + signatures, + )) + } + + async fn get_single_device( + &self, + user_id: &UserId, + device_id: &DeviceId, + ) -> Result> { + let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; + let mut connection = self.connection.lock().await; + + let row: Option<(i64, Option, i64)> = query_as( + "SELECT id, display_name, trust_state + FROM devices WHERE account_id = ? and user_id = ? and device_id = ?", + ) + .bind(account_id) + .bind(user_id.as_str()) + .bind(device_id.as_str()) + .fetch_optional(&mut *connection) + .await?; + + let row = if let Some(r) = row { + r + } else { + return Ok(None); + }; + + let device_row_id = row.0; + let display_name = row.1; + let trust_state = LocalTrust::from(row.2); + let device = self + .load_device_data( + &mut connection, + device_row_id, + user_id, + device_id.into(), + trust_state, + display_name, + ) + .await?; + + Ok(Some(device)) + } + + async fn load_devices(&self, user_id: &UserId) -> Result> { + let mut devices = HashMap::new(); + + let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; + let mut connection = self.connection.lock().await; + + let mut rows: Vec<(i64, String, Option, i64)> = query_as( + "SELECT id, device_id, display_name, trust_state + FROM devices WHERE account_id = ? and user_id = ?", + ) + .bind(account_id) + .bind(user_id.as_str()) + .fetch_all(&mut *connection) + .await?; + + for row in rows.drain(..) { + let device_row_id = row.0; + let device_id: DeviceIdBox = row.1.into(); + let display_name = row.2; + let trust_state = LocalTrust::from(row.3); + + let device = self + .load_device_data( + &mut connection, + device_row_id, + user_id, + device_id.clone(), + trust_state, + display_name, + ) + .await?; + + devices.insert(device_id, device); + } + + Ok(devices) } async fn save_device_helper( @@ -1276,7 +1327,6 @@ impl CryptoStore for SqliteStore { drop(connection); - self.load_devices().await?; self.load_tracked_users().await?; Ok(result) @@ -1424,7 +1474,6 @@ impl CryptoStore for SqliteStore { let mut transaction = connection.begin().await?; for device in devices { - self.devices.add(device.clone()); self.save_device_helper(&mut transaction, device.clone()) .await? } @@ -1457,11 +1506,14 @@ impl CryptoStore for SqliteStore { user_id: &UserId, device_id: &DeviceId, ) -> Result> { - Ok(self.devices.get(user_id, device_id)) + self.get_single_device(user_id, device_id).await } - async fn get_user_devices(&self, user_id: &UserId) -> Result { - Ok(self.devices.user_devices(user_id)) + async fn get_user_devices( + &self, + user_id: &UserId, + ) -> Result> { + Ok(self.load_devices(user_id).await?) } async fn get_user_identity(&self, user_id: &UserId) -> Result> { @@ -1925,8 +1977,8 @@ mod test { assert_eq!(device.keys(), loaded_device.keys()); let user_devices = store.get_user_devices(device.user_id()).await.unwrap(); - assert_eq!(user_devices.keys().next().unwrap(), device.device_id()); - assert_eq!(user_devices.devices().next().unwrap(), &device); + assert_eq!(&**user_devices.keys().next().unwrap(), device.device_id()); + assert_eq!(user_devices.values().next().unwrap(), &device); } #[tokio::test(threaded_scheduler)]