diff --git a/matrix_sdk_crypto/src/store/sqlite.rs b/matrix_sdk_crypto/src/store/sqlite.rs index 922bebb1..f908acf3 100644 --- a/matrix_sdk_crypto/src/store/sqlite.rs +++ b/matrix_sdk_crypto/src/store/sqlite.rs @@ -162,6 +162,23 @@ impl SqliteStore { ) .await?; + connection + .execute( + r#" + CREATE TABLE IF NOT EXISTS tracked_users ( + "id" INTEGER NOT NULL PRIMARY KEY, + "account_id" INTEGER NOT NULL, + "user_id" TEXT NOT NULL, + FOREIGN KEY ("account_id") REFERENCES "accounts" ("id") + ON DELETE CASCADE + UNIQUE(account_id,user_id) + ); + + CREATE INDEX IF NOT EXISTS "tracked_users_account_id" ON "tracked_users" ("account_id"); + "#, + ) + .await?; + connection .execute( r#" @@ -330,6 +347,51 @@ impl SqliteStore { .collect::>>()?) } + async fn save_tracked_user(&self, user: &UserId) -> Result<()> { + let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?; + let mut connection = self.connection.lock().await; + + query( + "INSERT OR IGNORE INTO tracked_users ( + account_id, user_id + ) VALUES (?1, ?2) + ", + ) + .bind(account_id) + .bind(user.to_string()) + .execute(&mut *connection) + .await?; + + Ok(()) + } + + async fn load_tracked_users(&self) -> Result> { + let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?; + let mut connection = self.connection.lock().await; + + let rows: Vec<(String,)> = query_as( + "SELECT user_id + FROM tracked_users WHERE account_id = ?", + ) + .bind(account_id) + .fetch_all(&mut *connection) + .await?; + + let mut users = HashSet::new(); + + for row in rows { + let user_id: &str = &row.0; + + if let Ok(u) = UserId::try_from(user_id) { + users.insert(u); + } else { + continue; + }; + } + + Ok(users) + } + async fn load_devices(&self) -> Result { let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?; let mut connection = self.connection.lock().await; @@ -520,7 +582,8 @@ impl CryptoStore for SqliteStore { let devices = self.load_devices().await?; mem::replace(&mut self.devices, devices); - // TODO load the tracked users here as well. + let tracked_users = self.load_tracked_users().await?; + mem::replace(&mut self.tracked_users, tracked_users); Ok(result) } @@ -637,7 +700,7 @@ impl CryptoStore for SqliteStore { } async fn add_user_for_tracking(&mut self, user: &UserId) -> Result { - // TODO save the tracked user to the database. + self.save_tracked_user(user).await?; Ok(self.tracked_users.insert(user.clone())) } @@ -974,7 +1037,7 @@ mod test { #[tokio::test] async fn test_tracked_users() { - let (_account, mut store, _dir) = get_loaded_store().await; + let (_account, mut store, dir) = get_loaded_store().await; let device = get_device(); assert!(store.add_user_for_tracking(device.user_id()).await.unwrap()); @@ -982,7 +1045,18 @@ mod test { let tracked_users = store.tracked_users(); - tracked_users.contains(device.user_id()); + assert!(tracked_users.contains(device.user_id())); + drop(store); + + let mut store = + SqliteStore::open(&UserId::try_from(USER_ID).unwrap(), DEVICE_ID, dir.path()) + .await + .expect("Can't create store"); + + store.load_account().await.unwrap(); + + let tracked_users = store.tracked_users(); + assert!(tracked_users.contains(device.user_id())); } #[tokio::test]