crypto: Don't load all the devices in the sqlite store.

master
Damir Jelić 2020-10-16 16:53:10 +02:00
parent 4262f1d3b0
commit 425a07d670
7 changed files with 216 additions and 169 deletions

View File

@ -19,7 +19,8 @@ use matrix_sdk_base::crypto::{
UserDevices as BaseUserDevices,
};
use matrix_sdk_common::{
api::r0::to_device::send_event_to_device::Request as ToDeviceRequest, identifiers::DeviceId,
api::r0::to_device::send_event_to_device::Request as ToDeviceRequest,
identifiers::{DeviceId, DeviceIdBox},
};
use crate::{error::Result, http_client::HttpClient, Sas};
@ -114,7 +115,7 @@ impl UserDevices {
}
/// Iterator over all the device ids of the user devices.
pub fn keys(&self) -> impl Iterator<Item = &DeviceId> {
pub fn keys(&self) -> impl Iterator<Item = &DeviceIdBox> {
self.inner.keys()
}

View File

@ -13,7 +13,7 @@
// limitations under the License.
use std::{
collections::BTreeMap,
collections::{BTreeMap, HashMap},
convert::{TryFrom, TryInto},
ops::Deref,
sync::{
@ -30,7 +30,9 @@ use matrix_sdk_common::{
forwarded_room_key::ForwardedRoomKeyEventContent, room::encrypted::EncryptedEventContent,
EventType,
},
identifiers::{DeviceId, DeviceKeyAlgorithm, DeviceKeyId, EventEncryptionAlgorithm, UserId},
identifiers::{
DeviceId, DeviceIdBox, DeviceKeyAlgorithm, DeviceKeyId, EventEncryptionAlgorithm, UserId,
},
locks::Mutex,
};
use serde::{Deserialize, Serialize};
@ -45,7 +47,7 @@ use crate::{
error::{EventError, OlmError, OlmResult, SignatureError},
identities::{OwnUserIdentity, UserIdentities},
olm::Utility,
store::{caches::ReadOnlyUserDevices, CryptoStore, Result as StoreResult},
store::{CryptoStore, Result as StoreResult},
verification::VerificationMachine,
Sas, ToDeviceRequest,
};
@ -168,7 +170,7 @@ impl Device {
/// A read only view over all devices belonging to a user.
#[derive(Debug)]
pub struct UserDevices {
pub(crate) inner: ReadOnlyUserDevices,
pub(crate) inner: HashMap<DeviceIdBox, ReadOnlyDevice>,
pub(crate) verification_machine: VerificationMachine,
pub(crate) own_identity: Option<OwnUserIdentity>,
pub(crate) device_owner_identity: Option<UserIdentities>,
@ -178,7 +180,7 @@ impl UserDevices {
/// Get the specific device with the given device id.
pub fn get(&self, device_id: &DeviceId) -> Option<Device> {
self.inner.get(device_id).map(|d| Device {
inner: d,
inner: d.clone(),
verification_machine: self.verification_machine.clone(),
own_identity: self.own_identity.clone(),
device_owner_identity: self.device_owner_identity.clone(),
@ -186,13 +188,13 @@ impl UserDevices {
}
/// Iterator over all the device ids of the user devices.
pub fn keys(&self) -> impl Iterator<Item = &DeviceId> {
pub fn keys(&self) -> impl Iterator<Item = &DeviceIdBox> {
self.inner.keys()
}
/// Iterator over all the devices of the user devices.
pub fn devices(&self) -> impl Iterator<Item = Device> + '_ {
self.inner.devices().map(move |d| Device {
self.inner.values().map(move |d| Device {
inner: d.clone(),
verification_machine: self.verification_machine.clone(),
own_identity: self.own_identity.clone(),

View File

@ -165,18 +165,17 @@ impl IdentityManager {
changed_devices.push(device);
}
let current_devices: HashSet<&DeviceId> =
device_map.keys().map(|id| id.as_ref()).collect();
let current_devices: HashSet<&DeviceIdBox> = device_map.keys().collect();
let stored_devices = self.store.get_readonly_devices(&user_id).await?;
let stored_devices_set: HashSet<&DeviceId> = stored_devices.keys().collect();
let stored_devices_set: HashSet<&DeviceIdBox> = stored_devices.keys().collect();
let deleted_devices = stored_devices_set.difference(&current_devices);
for device_id in deleted_devices {
users_with_new_or_deleted_devices.insert(user_id);
if let Some(device) = stored_devices.get(device_id) {
if let Some(device) = stored_devices.get(*device_id) {
device.mark_as_deleted();
self.store.delete_device(device).await?;
self.store.delete_device(device.clone()).await?;
}
}
}

View File

@ -19,9 +19,9 @@
use std::{collections::HashMap, sync::Arc};
use dashmap::{DashMap, ReadOnlyView};
use dashmap::DashMap;
use matrix_sdk_common::{
identifiers::{DeviceId, RoomId, UserId},
identifiers::{DeviceId, DeviceIdBox, RoomId, UserId},
locks::Mutex,
};
@ -145,29 +145,6 @@ pub struct DeviceStore {
entries: Arc<DashMap<UserId, DashMap<Box<DeviceId>, ReadOnlyDevice>>>,
}
/// A read only view over all devices belonging to a user.
#[derive(Debug)]
pub struct ReadOnlyUserDevices {
entries: ReadOnlyView<Box<DeviceId>, ReadOnlyDevice>,
}
impl ReadOnlyUserDevices {
/// Get the specific device with the given device id.
pub fn get(&self, device_id: &DeviceId) -> Option<ReadOnlyDevice> {
self.entries.get(device_id).cloned()
}
/// Iterator over all the device ids of the user devices.
pub fn keys(&self) -> impl Iterator<Item = &DeviceId> {
self.entries.keys().map(|id| id.as_ref())
}
/// Iterator over all the devices of the user devices.
pub fn devices(&self) -> impl Iterator<Item = &ReadOnlyDevice> {
self.entries.values()
}
}
impl DeviceStore {
/// Create a new empty device store.
pub fn new() -> Self {
@ -206,15 +183,13 @@ impl DeviceStore {
}
/// Get a read-only view over all devices of the given user.
pub fn user_devices(&self, user_id: &UserId) -> ReadOnlyUserDevices {
ReadOnlyUserDevices {
entries: self
.entries
pub fn user_devices(&self, user_id: &UserId) -> HashMap<DeviceIdBox, ReadOnlyDevice> {
self.entries
.entry(user_id.clone())
.or_insert_with(DashMap::new)
.clone()
.into_read_only(),
}
.iter()
.map(|i| (i.key().to_owned(), i.value().clone()))
.collect()
}
}
@ -305,12 +280,12 @@ mod test {
let user_devices = store.user_devices(device.user_id());
assert_eq!(user_devices.keys().next().unwrap(), device.device_id());
assert_eq!(user_devices.devices().next().unwrap(), &device);
assert_eq!(&**user_devices.keys().next().unwrap(), device.device_id());
assert_eq!(user_devices.values().next().unwrap(), &device);
let loaded_device = user_devices.get(device.device_id()).unwrap();
assert_eq!(device, loaded_device);
assert_eq!(&device, loaded_device);
store.remove(device.user_id(), device.device_id());

View File

@ -12,17 +12,20 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{collections::HashSet, sync::Arc};
use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
use dashmap::{DashMap, DashSet};
use matrix_sdk_common::{
identifiers::{DeviceId, RoomId, UserId},
identifiers::{DeviceId, DeviceIdBox, RoomId, UserId},
locks::Mutex,
};
use matrix_sdk_common_macros::async_trait;
use super::{
caches::{DeviceStore, GroupSessionStore, ReadOnlyUserDevices, SessionStore},
caches::{DeviceStore, GroupSessionStore, SessionStore},
CryptoStore, InboundGroupSession, ReadOnlyAccount, Result, Session,
};
use crate::identities::{ReadOnlyDevice, UserIdentities};
@ -153,7 +156,10 @@ impl CryptoStore for MemoryStore {
Ok(())
}
async fn get_user_devices(&self, user_id: &UserId) -> Result<ReadOnlyUserDevices> {
async fn get_user_devices(
&self,
user_id: &UserId,
) -> Result<HashMap<DeviceIdBox, ReadOnlyDevice>> {
Ok(self.devices.user_devices(user_id))
}
@ -273,12 +279,12 @@ mod test {
let user_devices = store.get_user_devices(device.user_id()).await.unwrap();
assert_eq!(user_devices.keys().next().unwrap(), device.device_id());
assert_eq!(user_devices.devices().next().unwrap(), &device);
assert_eq!(&**user_devices.keys().next().unwrap(), device.device_id());
assert_eq!(user_devices.values().next().unwrap(), &device);
let loaded_device = user_devices.get(device.device_id()).unwrap();
assert_eq!(device, loaded_device);
assert_eq!(&device, loaded_device);
store.delete_device(device.clone()).await.unwrap();
assert!(store

View File

@ -43,13 +43,19 @@ mod memorystore;
#[cfg(feature = "sqlite_cryptostore")]
pub(crate) mod sqlite;
use caches::ReadOnlyUserDevices;
use matrix_sdk_common::identifiers::DeviceIdBox;
pub use memorystore::MemoryStore;
#[cfg(not(target_arch = "wasm32"))]
#[cfg(feature = "sqlite_cryptostore")]
pub use sqlite::SqliteStore;
use std::{collections::HashSet, fmt::Debug, io::Error as IoError, ops::Deref, sync::Arc};
use std::{
collections::{HashMap, HashSet},
fmt::Debug,
io::Error as IoError,
ops::Deref,
sync::Arc,
};
use olm_rs::errors::{OlmAccountError, OlmGroupSessionError, OlmSessionError};
use serde::{Deserialize, Serialize};
@ -115,7 +121,10 @@ impl Store {
self.inner.get_device(user_id, device_id).await
}
pub async fn get_readonly_devices(&self, user_id: &UserId) -> Result<ReadOnlyUserDevices> {
pub async fn get_readonly_devices(
&self,
user_id: &UserId,
) -> Result<HashMap<DeviceIdBox, ReadOnlyDevice>> {
self.inner.get_user_devices(user_id).await
}
@ -354,7 +363,10 @@ pub trait CryptoStore: Debug {
/// # Arguments
///
/// * `user_id` - The user for which we should get all the devices.
async fn get_user_devices(&self, user_id: &UserId) -> Result<ReadOnlyUserDevices>;
async fn get_user_devices(
&self,
user_id: &UserId,
) -> Result<HashMap<DeviceIdBox, ReadOnlyDevice>>;
/// Save the given user identities in the store.
///

View File

@ -13,7 +13,7 @@
// limitations under the License.
use std::{
collections::{BTreeMap, HashSet},
collections::{BTreeMap, HashMap, HashSet},
convert::TryFrom,
path::{Path, PathBuf},
result::Result as StdResult,
@ -25,7 +25,8 @@ use dashmap::DashSet;
use matrix_sdk_common::{
api::r0::keys::{CrossSigningKey, KeyUsage},
identifiers::{
DeviceId, DeviceKeyAlgorithm, DeviceKeyId, EventEncryptionAlgorithm, RoomId, UserId,
DeviceId, DeviceIdBox, DeviceKeyAlgorithm, DeviceKeyId, EventEncryptionAlgorithm, RoomId,
UserId,
},
instant::Duration,
locks::Mutex,
@ -33,10 +34,7 @@ use matrix_sdk_common::{
use sqlx::{query, query_as, sqlite::SqliteConnectOptions, Connection, Executor, SqliteConnection};
use zeroize::Zeroizing;
use super::{
caches::{DeviceStore, ReadOnlyUserDevices, SessionStore},
CryptoStore, CryptoStoreError, Result,
};
use super::{caches::SessionStore, CryptoStore, CryptoStoreError, Result};
use crate::{
identities::{LocalTrust, OwnUserIdentity, ReadOnlyDevice, UserIdentities, UserIdentity},
olm::{
@ -56,7 +54,6 @@ pub struct SqliteStore {
path: Arc<PathBuf>,
sessions: SessionStore,
devices: DeviceStore,
tracked_users: Arc<DashSet<UserId>>,
users_for_key_query: Arc<DashSet<UserId>>,
@ -149,7 +146,6 @@ impl SqliteStore {
device_id: Arc::new(device_id.into()),
account_info: Arc::new(SyncMutex::new(None)),
sessions: SessionStore::new(),
devices: DeviceStore::new(),
path: Arc::new(path),
connection: Arc::new(Mutex::new(connection)),
pickle_passphrase: Arc::new(passphrase),
@ -717,31 +713,15 @@ impl SqliteStore {
Ok(())
}
async fn load_devices(&self) -> Result<()> {
let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?;
let mut connection = self.connection.lock().await;
let rows: Vec<(i64, String, String, Option<String>, i64)> = query_as(
"SELECT id, user_id, device_id, display_name, trust_state
FROM devices WHERE account_id = ?",
)
.bind(account_id)
.fetch_all(&mut *connection)
.await?;
for row in rows {
let device_row_id = row.0;
let user_id: &str = &row.1;
let user_id = if let Ok(u) = UserId::try_from(user_id) {
u
} else {
continue;
};
let device_id = &row.2.to_string();
let display_name = &row.3;
let trust_state = LocalTrust::from(row.4);
async fn load_device_data(
&self,
connection: &mut SqliteConnection,
device_row_id: i64,
user_id: &UserId,
device_id: DeviceIdBox,
trust_state: LocalTrust,
display_name: Option<String>,
) -> Result<ReadOnlyDevice> {
let algorithm_rows: Vec<(String,)> =
query_as("SELECT algorithm FROM algorithms WHERE device_id = ?")
.bind(device_row_id)
@ -768,10 +748,7 @@ impl SqliteStore {
let algorithm = row.0.parse::<DeviceKeyAlgorithm>().ok()?;
let key = row.1;
Some((
DeviceKeyId::from_parts(algorithm, device_id.as_str().into()),
key,
))
Some((DeviceKeyId::from_parts(algorithm, &device_id), key))
})
.collect();
@ -809,20 +786,94 @@ impl SqliteStore {
);
}
let device = ReadOnlyDevice::new(
user_id,
device_id.as_str().into(),
Ok(ReadOnlyDevice::new(
user_id.to_owned(),
device_id,
display_name.clone(),
trust_state,
algorithms,
keys,
signatures,
);
self.devices.add(device);
))
}
Ok(())
async fn get_single_device(
&self,
user_id: &UserId,
device_id: &DeviceId,
) -> Result<Option<ReadOnlyDevice>> {
let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?;
let mut connection = self.connection.lock().await;
let row: Option<(i64, Option<String>, i64)> = query_as(
"SELECT id, display_name, trust_state
FROM devices WHERE account_id = ? and user_id = ? and device_id = ?",
)
.bind(account_id)
.bind(user_id.as_str())
.bind(device_id.as_str())
.fetch_optional(&mut *connection)
.await?;
let row = if let Some(r) = row {
r
} else {
return Ok(None);
};
let device_row_id = row.0;
let display_name = row.1;
let trust_state = LocalTrust::from(row.2);
let device = self
.load_device_data(
&mut connection,
device_row_id,
user_id,
device_id.into(),
trust_state,
display_name,
)
.await?;
Ok(Some(device))
}
async fn load_devices(&self, user_id: &UserId) -> Result<HashMap<DeviceIdBox, ReadOnlyDevice>> {
let mut devices = HashMap::new();
let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?;
let mut connection = self.connection.lock().await;
let mut rows: Vec<(i64, String, Option<String>, i64)> = query_as(
"SELECT id, device_id, display_name, trust_state
FROM devices WHERE account_id = ? and user_id = ?",
)
.bind(account_id)
.bind(user_id.as_str())
.fetch_all(&mut *connection)
.await?;
for row in rows.drain(..) {
let device_row_id = row.0;
let device_id: DeviceIdBox = row.1.into();
let display_name = row.2;
let trust_state = LocalTrust::from(row.3);
let device = self
.load_device_data(
&mut connection,
device_row_id,
user_id,
device_id.clone(),
trust_state,
display_name,
)
.await?;
devices.insert(device_id, device);
}
Ok(devices)
}
async fn save_device_helper(
@ -1276,7 +1327,6 @@ impl CryptoStore for SqliteStore {
drop(connection);
self.load_devices().await?;
self.load_tracked_users().await?;
Ok(result)
@ -1424,7 +1474,6 @@ impl CryptoStore for SqliteStore {
let mut transaction = connection.begin().await?;
for device in devices {
self.devices.add(device.clone());
self.save_device_helper(&mut transaction, device.clone())
.await?
}
@ -1457,11 +1506,14 @@ impl CryptoStore for SqliteStore {
user_id: &UserId,
device_id: &DeviceId,
) -> Result<Option<ReadOnlyDevice>> {
Ok(self.devices.get(user_id, device_id))
self.get_single_device(user_id, device_id).await
}
async fn get_user_devices(&self, user_id: &UserId) -> Result<ReadOnlyUserDevices> {
Ok(self.devices.user_devices(user_id))
async fn get_user_devices(
&self,
user_id: &UserId,
) -> Result<HashMap<DeviceIdBox, ReadOnlyDevice>> {
Ok(self.load_devices(user_id).await?)
}
async fn get_user_identity(&self, user_id: &UserId) -> Result<Option<UserIdentities>> {
@ -1925,8 +1977,8 @@ mod test {
assert_eq!(device.keys(), loaded_device.keys());
let user_devices = store.get_user_devices(device.user_id()).await.unwrap();
assert_eq!(user_devices.keys().next().unwrap(), device.device_id());
assert_eq!(user_devices.devices().next().unwrap(), &device);
assert_eq!(&**user_devices.keys().next().unwrap(), device.device_id());
assert_eq!(user_devices.values().next().unwrap(), &device);
}
#[tokio::test(threaded_scheduler)]