crypto: Don't load all the devices in the sqlite store.

master
Damir Jelić 2020-10-16 16:53:10 +02:00
parent 4262f1d3b0
commit 425a07d670
7 changed files with 216 additions and 169 deletions

View File

@ -19,7 +19,8 @@ use matrix_sdk_base::crypto::{
UserDevices as BaseUserDevices, UserDevices as BaseUserDevices,
}; };
use matrix_sdk_common::{ use matrix_sdk_common::{
api::r0::to_device::send_event_to_device::Request as ToDeviceRequest, identifiers::DeviceId, api::r0::to_device::send_event_to_device::Request as ToDeviceRequest,
identifiers::{DeviceId, DeviceIdBox},
}; };
use crate::{error::Result, http_client::HttpClient, Sas}; use crate::{error::Result, http_client::HttpClient, Sas};
@ -114,7 +115,7 @@ impl UserDevices {
} }
/// Iterator over all the device ids of the user devices. /// Iterator over all the device ids of the user devices.
pub fn keys(&self) -> impl Iterator<Item = &DeviceId> { pub fn keys(&self) -> impl Iterator<Item = &DeviceIdBox> {
self.inner.keys() self.inner.keys()
} }

View File

@ -13,7 +13,7 @@
// limitations under the License. // limitations under the License.
use std::{ use std::{
collections::BTreeMap, collections::{BTreeMap, HashMap},
convert::{TryFrom, TryInto}, convert::{TryFrom, TryInto},
ops::Deref, ops::Deref,
sync::{ sync::{
@ -30,7 +30,9 @@ use matrix_sdk_common::{
forwarded_room_key::ForwardedRoomKeyEventContent, room::encrypted::EncryptedEventContent, forwarded_room_key::ForwardedRoomKeyEventContent, room::encrypted::EncryptedEventContent,
EventType, EventType,
}, },
identifiers::{DeviceId, DeviceKeyAlgorithm, DeviceKeyId, EventEncryptionAlgorithm, UserId}, identifiers::{
DeviceId, DeviceIdBox, DeviceKeyAlgorithm, DeviceKeyId, EventEncryptionAlgorithm, UserId,
},
locks::Mutex, locks::Mutex,
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -45,7 +47,7 @@ use crate::{
error::{EventError, OlmError, OlmResult, SignatureError}, error::{EventError, OlmError, OlmResult, SignatureError},
identities::{OwnUserIdentity, UserIdentities}, identities::{OwnUserIdentity, UserIdentities},
olm::Utility, olm::Utility,
store::{caches::ReadOnlyUserDevices, CryptoStore, Result as StoreResult}, store::{CryptoStore, Result as StoreResult},
verification::VerificationMachine, verification::VerificationMachine,
Sas, ToDeviceRequest, Sas, ToDeviceRequest,
}; };
@ -168,7 +170,7 @@ impl Device {
/// A read only view over all devices belonging to a user. /// A read only view over all devices belonging to a user.
#[derive(Debug)] #[derive(Debug)]
pub struct UserDevices { pub struct UserDevices {
pub(crate) inner: ReadOnlyUserDevices, pub(crate) inner: HashMap<DeviceIdBox, ReadOnlyDevice>,
pub(crate) verification_machine: VerificationMachine, pub(crate) verification_machine: VerificationMachine,
pub(crate) own_identity: Option<OwnUserIdentity>, pub(crate) own_identity: Option<OwnUserIdentity>,
pub(crate) device_owner_identity: Option<UserIdentities>, pub(crate) device_owner_identity: Option<UserIdentities>,
@ -178,7 +180,7 @@ impl UserDevices {
/// Get the specific device with the given device id. /// Get the specific device with the given device id.
pub fn get(&self, device_id: &DeviceId) -> Option<Device> { pub fn get(&self, device_id: &DeviceId) -> Option<Device> {
self.inner.get(device_id).map(|d| Device { self.inner.get(device_id).map(|d| Device {
inner: d, inner: d.clone(),
verification_machine: self.verification_machine.clone(), verification_machine: self.verification_machine.clone(),
own_identity: self.own_identity.clone(), own_identity: self.own_identity.clone(),
device_owner_identity: self.device_owner_identity.clone(), device_owner_identity: self.device_owner_identity.clone(),
@ -186,13 +188,13 @@ impl UserDevices {
} }
/// Iterator over all the device ids of the user devices. /// Iterator over all the device ids of the user devices.
pub fn keys(&self) -> impl Iterator<Item = &DeviceId> { pub fn keys(&self) -> impl Iterator<Item = &DeviceIdBox> {
self.inner.keys() self.inner.keys()
} }
/// Iterator over all the devices of the user devices. /// Iterator over all the devices of the user devices.
pub fn devices(&self) -> impl Iterator<Item = Device> + '_ { pub fn devices(&self) -> impl Iterator<Item = Device> + '_ {
self.inner.devices().map(move |d| Device { self.inner.values().map(move |d| Device {
inner: d.clone(), inner: d.clone(),
verification_machine: self.verification_machine.clone(), verification_machine: self.verification_machine.clone(),
own_identity: self.own_identity.clone(), own_identity: self.own_identity.clone(),

View File

@ -165,18 +165,17 @@ impl IdentityManager {
changed_devices.push(device); changed_devices.push(device);
} }
let current_devices: HashSet<&DeviceId> = let current_devices: HashSet<&DeviceIdBox> = device_map.keys().collect();
device_map.keys().map(|id| id.as_ref()).collect();
let stored_devices = self.store.get_readonly_devices(&user_id).await?; let stored_devices = self.store.get_readonly_devices(&user_id).await?;
let stored_devices_set: HashSet<&DeviceId> = stored_devices.keys().collect(); let stored_devices_set: HashSet<&DeviceIdBox> = stored_devices.keys().collect();
let deleted_devices = stored_devices_set.difference(&current_devices); let deleted_devices = stored_devices_set.difference(&current_devices);
for device_id in deleted_devices { for device_id in deleted_devices {
users_with_new_or_deleted_devices.insert(user_id); users_with_new_or_deleted_devices.insert(user_id);
if let Some(device) = stored_devices.get(device_id) { if let Some(device) = stored_devices.get(*device_id) {
device.mark_as_deleted(); device.mark_as_deleted();
self.store.delete_device(device).await?; self.store.delete_device(device.clone()).await?;
} }
} }
} }

View File

@ -19,9 +19,9 @@
use std::{collections::HashMap, sync::Arc}; use std::{collections::HashMap, sync::Arc};
use dashmap::{DashMap, ReadOnlyView}; use dashmap::DashMap;
use matrix_sdk_common::{ use matrix_sdk_common::{
identifiers::{DeviceId, RoomId, UserId}, identifiers::{DeviceId, DeviceIdBox, RoomId, UserId},
locks::Mutex, locks::Mutex,
}; };
@ -145,29 +145,6 @@ pub struct DeviceStore {
entries: Arc<DashMap<UserId, DashMap<Box<DeviceId>, ReadOnlyDevice>>>, entries: Arc<DashMap<UserId, DashMap<Box<DeviceId>, ReadOnlyDevice>>>,
} }
/// A read only view over all devices belonging to a user.
#[derive(Debug)]
pub struct ReadOnlyUserDevices {
entries: ReadOnlyView<Box<DeviceId>, ReadOnlyDevice>,
}
impl ReadOnlyUserDevices {
/// Get the specific device with the given device id.
pub fn get(&self, device_id: &DeviceId) -> Option<ReadOnlyDevice> {
self.entries.get(device_id).cloned()
}
/// Iterator over all the device ids of the user devices.
pub fn keys(&self) -> impl Iterator<Item = &DeviceId> {
self.entries.keys().map(|id| id.as_ref())
}
/// Iterator over all the devices of the user devices.
pub fn devices(&self) -> impl Iterator<Item = &ReadOnlyDevice> {
self.entries.values()
}
}
impl DeviceStore { impl DeviceStore {
/// Create a new empty device store. /// Create a new empty device store.
pub fn new() -> Self { pub fn new() -> Self {
@ -206,15 +183,13 @@ impl DeviceStore {
} }
/// Get a read-only view over all devices of the given user. /// Get a read-only view over all devices of the given user.
pub fn user_devices(&self, user_id: &UserId) -> ReadOnlyUserDevices { pub fn user_devices(&self, user_id: &UserId) -> HashMap<DeviceIdBox, ReadOnlyDevice> {
ReadOnlyUserDevices { self.entries
entries: self
.entries
.entry(user_id.clone()) .entry(user_id.clone())
.or_insert_with(DashMap::new) .or_insert_with(DashMap::new)
.clone() .iter()
.into_read_only(), .map(|i| (i.key().to_owned(), i.value().clone()))
} .collect()
} }
} }
@ -305,12 +280,12 @@ mod test {
let user_devices = store.user_devices(device.user_id()); let user_devices = store.user_devices(device.user_id());
assert_eq!(user_devices.keys().next().unwrap(), device.device_id()); assert_eq!(&**user_devices.keys().next().unwrap(), device.device_id());
assert_eq!(user_devices.devices().next().unwrap(), &device); assert_eq!(user_devices.values().next().unwrap(), &device);
let loaded_device = user_devices.get(device.device_id()).unwrap(); let loaded_device = user_devices.get(device.device_id()).unwrap();
assert_eq!(device, loaded_device); assert_eq!(&device, loaded_device);
store.remove(device.user_id(), device.device_id()); store.remove(device.user_id(), device.device_id());

View File

@ -12,17 +12,20 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::{collections::HashSet, sync::Arc}; use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
use dashmap::{DashMap, DashSet}; use dashmap::{DashMap, DashSet};
use matrix_sdk_common::{ use matrix_sdk_common::{
identifiers::{DeviceId, RoomId, UserId}, identifiers::{DeviceId, DeviceIdBox, RoomId, UserId},
locks::Mutex, locks::Mutex,
}; };
use matrix_sdk_common_macros::async_trait; use matrix_sdk_common_macros::async_trait;
use super::{ use super::{
caches::{DeviceStore, GroupSessionStore, ReadOnlyUserDevices, SessionStore}, caches::{DeviceStore, GroupSessionStore, SessionStore},
CryptoStore, InboundGroupSession, ReadOnlyAccount, Result, Session, CryptoStore, InboundGroupSession, ReadOnlyAccount, Result, Session,
}; };
use crate::identities::{ReadOnlyDevice, UserIdentities}; use crate::identities::{ReadOnlyDevice, UserIdentities};
@ -153,7 +156,10 @@ impl CryptoStore for MemoryStore {
Ok(()) Ok(())
} }
async fn get_user_devices(&self, user_id: &UserId) -> Result<ReadOnlyUserDevices> { async fn get_user_devices(
&self,
user_id: &UserId,
) -> Result<HashMap<DeviceIdBox, ReadOnlyDevice>> {
Ok(self.devices.user_devices(user_id)) Ok(self.devices.user_devices(user_id))
} }
@ -273,12 +279,12 @@ mod test {
let user_devices = store.get_user_devices(device.user_id()).await.unwrap(); let user_devices = store.get_user_devices(device.user_id()).await.unwrap();
assert_eq!(user_devices.keys().next().unwrap(), device.device_id()); assert_eq!(&**user_devices.keys().next().unwrap(), device.device_id());
assert_eq!(user_devices.devices().next().unwrap(), &device); assert_eq!(user_devices.values().next().unwrap(), &device);
let loaded_device = user_devices.get(device.device_id()).unwrap(); let loaded_device = user_devices.get(device.device_id()).unwrap();
assert_eq!(device, loaded_device); assert_eq!(&device, loaded_device);
store.delete_device(device.clone()).await.unwrap(); store.delete_device(device.clone()).await.unwrap();
assert!(store assert!(store

View File

@ -43,13 +43,19 @@ mod memorystore;
#[cfg(feature = "sqlite_cryptostore")] #[cfg(feature = "sqlite_cryptostore")]
pub(crate) mod sqlite; pub(crate) mod sqlite;
use caches::ReadOnlyUserDevices; use matrix_sdk_common::identifiers::DeviceIdBox;
pub use memorystore::MemoryStore; pub use memorystore::MemoryStore;
#[cfg(not(target_arch = "wasm32"))] #[cfg(not(target_arch = "wasm32"))]
#[cfg(feature = "sqlite_cryptostore")] #[cfg(feature = "sqlite_cryptostore")]
pub use sqlite::SqliteStore; pub use sqlite::SqliteStore;
use std::{collections::HashSet, fmt::Debug, io::Error as IoError, ops::Deref, sync::Arc}; use std::{
collections::{HashMap, HashSet},
fmt::Debug,
io::Error as IoError,
ops::Deref,
sync::Arc,
};
use olm_rs::errors::{OlmAccountError, OlmGroupSessionError, OlmSessionError}; use olm_rs::errors::{OlmAccountError, OlmGroupSessionError, OlmSessionError};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -115,7 +121,10 @@ impl Store {
self.inner.get_device(user_id, device_id).await self.inner.get_device(user_id, device_id).await
} }
pub async fn get_readonly_devices(&self, user_id: &UserId) -> Result<ReadOnlyUserDevices> { pub async fn get_readonly_devices(
&self,
user_id: &UserId,
) -> Result<HashMap<DeviceIdBox, ReadOnlyDevice>> {
self.inner.get_user_devices(user_id).await self.inner.get_user_devices(user_id).await
} }
@ -354,7 +363,10 @@ pub trait CryptoStore: Debug {
/// # Arguments /// # Arguments
/// ///
/// * `user_id` - The user for which we should get all the devices. /// * `user_id` - The user for which we should get all the devices.
async fn get_user_devices(&self, user_id: &UserId) -> Result<ReadOnlyUserDevices>; async fn get_user_devices(
&self,
user_id: &UserId,
) -> Result<HashMap<DeviceIdBox, ReadOnlyDevice>>;
/// Save the given user identities in the store. /// Save the given user identities in the store.
/// ///

View File

@ -13,7 +13,7 @@
// limitations under the License. // limitations under the License.
use std::{ use std::{
collections::{BTreeMap, HashSet}, collections::{BTreeMap, HashMap, HashSet},
convert::TryFrom, convert::TryFrom,
path::{Path, PathBuf}, path::{Path, PathBuf},
result::Result as StdResult, result::Result as StdResult,
@ -25,7 +25,8 @@ use dashmap::DashSet;
use matrix_sdk_common::{ use matrix_sdk_common::{
api::r0::keys::{CrossSigningKey, KeyUsage}, api::r0::keys::{CrossSigningKey, KeyUsage},
identifiers::{ identifiers::{
DeviceId, DeviceKeyAlgorithm, DeviceKeyId, EventEncryptionAlgorithm, RoomId, UserId, DeviceId, DeviceIdBox, DeviceKeyAlgorithm, DeviceKeyId, EventEncryptionAlgorithm, RoomId,
UserId,
}, },
instant::Duration, instant::Duration,
locks::Mutex, locks::Mutex,
@ -33,10 +34,7 @@ use matrix_sdk_common::{
use sqlx::{query, query_as, sqlite::SqliteConnectOptions, Connection, Executor, SqliteConnection}; use sqlx::{query, query_as, sqlite::SqliteConnectOptions, Connection, Executor, SqliteConnection};
use zeroize::Zeroizing; use zeroize::Zeroizing;
use super::{ use super::{caches::SessionStore, CryptoStore, CryptoStoreError, Result};
caches::{DeviceStore, ReadOnlyUserDevices, SessionStore},
CryptoStore, CryptoStoreError, Result,
};
use crate::{ use crate::{
identities::{LocalTrust, OwnUserIdentity, ReadOnlyDevice, UserIdentities, UserIdentity}, identities::{LocalTrust, OwnUserIdentity, ReadOnlyDevice, UserIdentities, UserIdentity},
olm::{ olm::{
@ -56,7 +54,6 @@ pub struct SqliteStore {
path: Arc<PathBuf>, path: Arc<PathBuf>,
sessions: SessionStore, sessions: SessionStore,
devices: DeviceStore,
tracked_users: Arc<DashSet<UserId>>, tracked_users: Arc<DashSet<UserId>>,
users_for_key_query: Arc<DashSet<UserId>>, users_for_key_query: Arc<DashSet<UserId>>,
@ -149,7 +146,6 @@ impl SqliteStore {
device_id: Arc::new(device_id.into()), device_id: Arc::new(device_id.into()),
account_info: Arc::new(SyncMutex::new(None)), account_info: Arc::new(SyncMutex::new(None)),
sessions: SessionStore::new(), sessions: SessionStore::new(),
devices: DeviceStore::new(),
path: Arc::new(path), path: Arc::new(path),
connection: Arc::new(Mutex::new(connection)), connection: Arc::new(Mutex::new(connection)),
pickle_passphrase: Arc::new(passphrase), pickle_passphrase: Arc::new(passphrase),
@ -717,31 +713,15 @@ impl SqliteStore {
Ok(()) Ok(())
} }
async fn load_devices(&self) -> Result<()> { async fn load_device_data(
let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; &self,
let mut connection = self.connection.lock().await; connection: &mut SqliteConnection,
device_row_id: i64,
let rows: Vec<(i64, String, String, Option<String>, i64)> = query_as( user_id: &UserId,
"SELECT id, user_id, device_id, display_name, trust_state device_id: DeviceIdBox,
FROM devices WHERE account_id = ?", trust_state: LocalTrust,
) display_name: Option<String>,
.bind(account_id) ) -> Result<ReadOnlyDevice> {
.fetch_all(&mut *connection)
.await?;
for row in rows {
let device_row_id = row.0;
let user_id: &str = &row.1;
let user_id = if let Ok(u) = UserId::try_from(user_id) {
u
} else {
continue;
};
let device_id = &row.2.to_string();
let display_name = &row.3;
let trust_state = LocalTrust::from(row.4);
let algorithm_rows: Vec<(String,)> = let algorithm_rows: Vec<(String,)> =
query_as("SELECT algorithm FROM algorithms WHERE device_id = ?") query_as("SELECT algorithm FROM algorithms WHERE device_id = ?")
.bind(device_row_id) .bind(device_row_id)
@ -768,10 +748,7 @@ impl SqliteStore {
let algorithm = row.0.parse::<DeviceKeyAlgorithm>().ok()?; let algorithm = row.0.parse::<DeviceKeyAlgorithm>().ok()?;
let key = row.1; let key = row.1;
Some(( Some((DeviceKeyId::from_parts(algorithm, &device_id), key))
DeviceKeyId::from_parts(algorithm, device_id.as_str().into()),
key,
))
}) })
.collect(); .collect();
@ -809,20 +786,94 @@ impl SqliteStore {
); );
} }
let device = ReadOnlyDevice::new( Ok(ReadOnlyDevice::new(
user_id, user_id.to_owned(),
device_id.as_str().into(), device_id,
display_name.clone(), display_name.clone(),
trust_state, trust_state,
algorithms, algorithms,
keys, keys,
signatures, signatures,
); ))
self.devices.add(device);
} }
Ok(()) async fn get_single_device(
&self,
user_id: &UserId,
device_id: &DeviceId,
) -> Result<Option<ReadOnlyDevice>> {
let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?;
let mut connection = self.connection.lock().await;
let row: Option<(i64, Option<String>, i64)> = query_as(
"SELECT id, display_name, trust_state
FROM devices WHERE account_id = ? and user_id = ? and device_id = ?",
)
.bind(account_id)
.bind(user_id.as_str())
.bind(device_id.as_str())
.fetch_optional(&mut *connection)
.await?;
let row = if let Some(r) = row {
r
} else {
return Ok(None);
};
let device_row_id = row.0;
let display_name = row.1;
let trust_state = LocalTrust::from(row.2);
let device = self
.load_device_data(
&mut connection,
device_row_id,
user_id,
device_id.into(),
trust_state,
display_name,
)
.await?;
Ok(Some(device))
}
async fn load_devices(&self, user_id: &UserId) -> Result<HashMap<DeviceIdBox, ReadOnlyDevice>> {
let mut devices = HashMap::new();
let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?;
let mut connection = self.connection.lock().await;
let mut rows: Vec<(i64, String, Option<String>, i64)> = query_as(
"SELECT id, device_id, display_name, trust_state
FROM devices WHERE account_id = ? and user_id = ?",
)
.bind(account_id)
.bind(user_id.as_str())
.fetch_all(&mut *connection)
.await?;
for row in rows.drain(..) {
let device_row_id = row.0;
let device_id: DeviceIdBox = row.1.into();
let display_name = row.2;
let trust_state = LocalTrust::from(row.3);
let device = self
.load_device_data(
&mut connection,
device_row_id,
user_id,
device_id.clone(),
trust_state,
display_name,
)
.await?;
devices.insert(device_id, device);
}
Ok(devices)
} }
async fn save_device_helper( async fn save_device_helper(
@ -1276,7 +1327,6 @@ impl CryptoStore for SqliteStore {
drop(connection); drop(connection);
self.load_devices().await?;
self.load_tracked_users().await?; self.load_tracked_users().await?;
Ok(result) Ok(result)
@ -1424,7 +1474,6 @@ impl CryptoStore for SqliteStore {
let mut transaction = connection.begin().await?; let mut transaction = connection.begin().await?;
for device in devices { for device in devices {
self.devices.add(device.clone());
self.save_device_helper(&mut transaction, device.clone()) self.save_device_helper(&mut transaction, device.clone())
.await? .await?
} }
@ -1457,11 +1506,14 @@ impl CryptoStore for SqliteStore {
user_id: &UserId, user_id: &UserId,
device_id: &DeviceId, device_id: &DeviceId,
) -> Result<Option<ReadOnlyDevice>> { ) -> Result<Option<ReadOnlyDevice>> {
Ok(self.devices.get(user_id, device_id)) self.get_single_device(user_id, device_id).await
} }
async fn get_user_devices(&self, user_id: &UserId) -> Result<ReadOnlyUserDevices> { async fn get_user_devices(
Ok(self.devices.user_devices(user_id)) &self,
user_id: &UserId,
) -> Result<HashMap<DeviceIdBox, ReadOnlyDevice>> {
Ok(self.load_devices(user_id).await?)
} }
async fn get_user_identity(&self, user_id: &UserId) -> Result<Option<UserIdentities>> { async fn get_user_identity(&self, user_id: &UserId) -> Result<Option<UserIdentities>> {
@ -1925,8 +1977,8 @@ mod test {
assert_eq!(device.keys(), loaded_device.keys()); assert_eq!(device.keys(), loaded_device.keys());
let user_devices = store.get_user_devices(device.user_id()).await.unwrap(); let user_devices = store.get_user_devices(device.user_id()).await.unwrap();
assert_eq!(user_devices.keys().next().unwrap(), device.device_id()); assert_eq!(&**user_devices.keys().next().unwrap(), device.device_id());
assert_eq!(user_devices.devices().next().unwrap(), &device); assert_eq!(user_devices.values().next().unwrap(), &device);
} }
#[tokio::test(threaded_scheduler)] #[tokio::test(threaded_scheduler)]