From 1d9fccdc9fcac93c45eaa483251c63ac5d6d98af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Damir=20Jeli=C4=87?= Date: Fri, 15 May 2020 15:33:30 +0200 Subject: [PATCH] crypto: Move the users for key query map into the store. --- matrix_sdk_crypto/src/machine.rs | 35 ++++----- matrix_sdk_crypto/src/store/memorystore.rs | 24 ++++++- matrix_sdk_crypto/src/store/mod.rs | 8 ++- matrix_sdk_crypto/src/store/sqlite.rs | 83 ++++++++++++++++++---- 4 files changed, 108 insertions(+), 42 deletions(-) diff --git a/matrix_sdk_crypto/src/machine.rs b/matrix_sdk_crypto/src/machine.rs index 2a27f01f..b46ac957 100644 --- a/matrix_sdk_crypto/src/machine.rs +++ b/matrix_sdk_crypto/src/machine.rs @@ -80,9 +80,6 @@ pub struct OlmMachine { /// Persists all the encryption keys so a client can resume the session /// without the need to create new keys. store: Box, - /// Set of users that we need to query keys for. This is a subset of - /// the tracked users in the CryptoStore. - users_for_key_query: HashSet, /// The currently active outbound group sessions. outbound_group_sessions: HashMap, } @@ -122,7 +119,6 @@ impl OlmMachine { account: Account::new(), uploaded_signed_key_count: None, store: Box::new(MemoryStore::new()), - users_for_key_query: HashSet::new(), outbound_group_sessions: HashMap::new(), } } @@ -166,7 +162,6 @@ impl OlmMachine { account, uploaded_signed_key_count: None, store: Box::new(store), - users_for_key_query: HashSet::new(), outbound_group_sessions: HashMap::new(), }) } @@ -461,7 +456,7 @@ impl OlmMachine { let mut changed_devices = Vec::new(); for (user_id, device_map) in &response.device_keys { - self.users_for_key_query.remove(&user_id); + self.store.update_tracked_user(user_id, false).await?; for (device_id, device_keys) in device_map.iter() { // We don't need our own device in the device store. @@ -1516,12 +1511,12 @@ impl OlmMachine { /// key query. /// /// Returns true if the user was queued up for a key query, false otherwise. - pub async fn mark_user_as_changed(&mut self, user_id: &UserId) -> bool { + pub async fn mark_user_as_changed(&mut self, user_id: &UserId) -> StoreError { if self.store.tracked_users().contains(user_id) { - self.users_for_key_query.insert(user_id.clone()); - true + self.store.update_tracked_user(user_id, true).await?; + Ok(true) } else { - false + Ok(false) } } @@ -1544,32 +1539,26 @@ impl OlmMachine { I: IntoIterator, { for user in users { - let ret = self.store.add_user_for_tracking(user).await; + if self.store.tracked_users().contains(user) { + continue; + } - match ret { - Ok(newly_added) => { - if newly_added { - self.mark_user_as_changed(user).await; - } - } - Err(e) => { - warn!("Error storing users for tracking {}", e); - self.users_for_key_query.insert(user.clone()); - } + if let Err(e) = self.store.update_tracked_user(user, true).await { + warn!("Error storing users for tracking {}", e); } } } /// Should the client perform a key query request. pub fn should_query_keys(&self) -> bool { - !self.users_for_key_query.is_empty() + !self.store.users_for_key_query().is_empty() } /// Get the set of users that we need to query keys for. /// /// Returns a hash set of users that need to be queried for keys. pub fn users_for_key_query(&self) -> HashSet { - self.users_for_key_query.clone() + self.store.users_for_key_query().clone() } } diff --git a/matrix_sdk_crypto/src/store/memorystore.rs b/matrix_sdk_crypto/src/store/memorystore.rs index 08c1f83f..71a303ba 100644 --- a/matrix_sdk_crypto/src/store/memorystore.rs +++ b/matrix_sdk_crypto/src/store/memorystore.rs @@ -28,6 +28,7 @@ pub struct MemoryStore { sessions: SessionStore, inbound_group_sessions: GroupSessionStore, tracked_users: HashSet, + users_for_key_query: HashSet, devices: DeviceStore, } @@ -37,6 +38,7 @@ impl MemoryStore { sessions: SessionStore::new(), inbound_group_sessions: GroupSessionStore::new(), tracked_users: HashSet::new(), + users_for_key_query: HashSet::new(), devices: DeviceStore::new(), } } @@ -83,7 +85,17 @@ impl CryptoStore for MemoryStore { &self.tracked_users } - async fn add_user_for_tracking(&mut self, user: &UserId) -> Result { + fn users_for_key_query(&self) -> &HashSet { + &self.users_for_key_query + } + + async fn update_tracked_user(&mut self, user: &UserId, dirty: bool) -> Result { + if dirty { + self.users_for_key_query.insert(user.clone()); + } else { + self.users_for_key_query.remove(user); + } + Ok(self.tracked_users.insert(user.clone())) } @@ -207,8 +219,14 @@ mod test { let device = get_device(); let mut store = MemoryStore::new(); - assert!(store.add_user_for_tracking(device.user_id()).await.unwrap()); - assert!(!store.add_user_for_tracking(device.user_id()).await.unwrap()); + assert!(store + .update_tracked_user(device.user_id(), false) + .await + .unwrap()); + assert!(!store + .update_tracked_user(device.user_id(), false) + .await + .unwrap()); let tracked_users = store.tracked_users(); diff --git a/matrix_sdk_crypto/src/store/mod.rs b/matrix_sdk_crypto/src/store/mod.rs index 916a5a29..80839b74 100644 --- a/matrix_sdk_crypto/src/store/mod.rs +++ b/matrix_sdk_crypto/src/store/mod.rs @@ -138,6 +138,10 @@ pub trait CryptoStore: Debug + Send + Sync { /// Get the set of tracked users. fn tracked_users(&self) -> &HashSet; + /// Set of users that we need to query keys for. This is a subset of + /// the tracked users. + fn users_for_key_query(&self) -> &HashSet; + /// Add an user for tracking. /// /// Returns true if the user wasn't already tracked, false otherwise. @@ -145,7 +149,9 @@ pub trait CryptoStore: Debug + Send + Sync { /// # Arguments /// /// * `user` - The user that should be marked as tracked. - async fn add_user_for_tracking(&mut self, user: &UserId) -> Result; + /// + /// * `dirty` - Should the user be also marked for a key query. + async fn update_tracked_user(&mut self, user: &UserId, dirty: bool) -> Result; /// Save the given devices in the store. /// diff --git a/matrix_sdk_crypto/src/store/sqlite.rs b/matrix_sdk_crypto/src/store/sqlite.rs index f908acf3..bd6d1ff7 100644 --- a/matrix_sdk_crypto/src/store/sqlite.rs +++ b/matrix_sdk_crypto/src/store/sqlite.rs @@ -45,6 +45,7 @@ pub struct SqliteStore { inbound_group_sessions: GroupSessionStore, devices: DeviceStore, tracked_users: HashSet, + users_for_key_query: HashSet, connection: Arc>, pickle_passphrase: Option>, @@ -121,6 +122,7 @@ impl SqliteStore { connection: Arc::new(Mutex::new(connection)), pickle_passphrase: passphrase, tracked_users: HashSet::new(), + users_for_key_query: HashSet::new(), }; store.create_tables().await?; Ok(store) @@ -169,6 +171,7 @@ impl SqliteStore { "id" INTEGER NOT NULL PRIMARY KEY, "account_id" INTEGER NOT NULL, "user_id" TEXT NOT NULL, + "dirty" INTEGER NOT NULL, FOREIGN KEY ("account_id") REFERENCES "accounts" ("id") ON DELETE CASCADE UNIQUE(account_id,user_id) @@ -347,30 +350,33 @@ impl SqliteStore { .collect::>>()?) } - async fn save_tracked_user(&self, user: &UserId) -> Result<()> { + async fn save_tracked_user(&self, user: &UserId, dirty: bool) -> Result<()> { let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?; let mut connection = self.connection.lock().await; query( - "INSERT OR IGNORE INTO tracked_users ( - account_id, user_id - ) VALUES (?1, ?2) + "INSERT INTO tracked_users ( + account_id, user_id, dirty + ) VALUES (?1, ?2, ?3) + ON CONFLICT(account_id, user_id) DO UPDATE SET + dirty = excluded.dirty ", ) .bind(account_id) .bind(user.to_string()) + .bind(dirty) .execute(&mut *connection) .await?; Ok(()) } - async fn load_tracked_users(&self) -> Result> { + async fn load_tracked_users(&self) -> Result<(HashSet, HashSet)> { let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?; let mut connection = self.connection.lock().await; - let rows: Vec<(String,)> = query_as( - "SELECT user_id + let rows: Vec<(String, bool)> = query_as( + "SELECT user_id, dirty FROM tracked_users WHERE account_id = ?", ) .bind(account_id) @@ -378,18 +384,23 @@ impl SqliteStore { .await?; let mut users = HashSet::new(); + let mut users_for_query = HashSet::new(); for row in rows { let user_id: &str = &row.0; + let dirty: bool = row.1; if let Ok(u) = UserId::try_from(user_id) { - users.insert(u); + users.insert(u.clone()); + if dirty { + users_for_query.insert(u); + } } else { continue; }; } - Ok(users) + Ok((users, users_for_query)) } async fn load_devices(&self) -> Result { @@ -582,8 +593,9 @@ impl CryptoStore for SqliteStore { let devices = self.load_devices().await?; mem::replace(&mut self.devices, devices); - let tracked_users = self.load_tracked_users().await?; + let (tracked_users, users_for_query) = self.load_tracked_users().await?; mem::replace(&mut self.tracked_users, tracked_users); + mem::replace(&mut self.users_for_key_query, users_for_query); Ok(result) } @@ -699,9 +711,22 @@ impl CryptoStore for SqliteStore { &self.tracked_users } - async fn add_user_for_tracking(&mut self, user: &UserId) -> Result { - self.save_tracked_user(user).await?; - Ok(self.tracked_users.insert(user.clone())) + fn users_for_key_query(&self) -> &HashSet { + &self.users_for_key_query + } + + async fn update_tracked_user(&mut self, user: &UserId, dirty: bool) -> Result { + let already_added = self.tracked_users.insert(user.clone()); + + if dirty { + self.users_for_key_query.insert(user.clone()); + } else { + self.users_for_key_query.remove(user); + } + + self.save_tracked_user(user, dirty).await?; + + Ok(already_added) } async fn save_devices(&self, devices: &[Device]) -> Result<()> { @@ -1040,12 +1065,24 @@ mod test { let (_account, mut store, dir) = get_loaded_store().await; let device = get_device(); - assert!(store.add_user_for_tracking(device.user_id()).await.unwrap()); - assert!(!store.add_user_for_tracking(device.user_id()).await.unwrap()); + assert!(store + .update_tracked_user(device.user_id(), false) + .await + .unwrap()); + assert!(!store + .update_tracked_user(device.user_id(), false) + .await + .unwrap()); let tracked_users = store.tracked_users(); assert!(tracked_users.contains(device.user_id())); + assert!(!store.users_for_key_query().contains(device.user_id())); + assert!(!store + .update_tracked_user(device.user_id(), true) + .await + .unwrap()); + assert!(store.users_for_key_query().contains(device.user_id())); drop(store); let mut store = @@ -1057,6 +1094,22 @@ mod test { let tracked_users = store.tracked_users(); assert!(tracked_users.contains(device.user_id())); + assert!(store.users_for_key_query().contains(device.user_id())); + + store + .update_tracked_user(device.user_id(), false) + .await + .unwrap(); + assert!(!store.users_for_key_query().contains(device.user_id())); + + let mut store = + SqliteStore::open(&UserId::try_from(USER_ID).unwrap(), DEVICE_ID, dir.path()) + .await + .expect("Can't create store"); + + store.load_account().await.unwrap(); + + assert!(!store.users_for_key_query().contains(device.user_id())); } #[tokio::test]