sqlite: Store the tracked users in the database.

This commit is contained in:
Damir Jelić 2020-05-14 17:25:46 +02:00
parent 8c6c34e01a
commit b2e48d8eae

View file

@ -162,6 +162,23 @@ impl SqliteStore {
)
.await?;
connection
.execute(
r#"
CREATE TABLE IF NOT EXISTS tracked_users (
"id" INTEGER NOT NULL PRIMARY KEY,
"account_id" INTEGER NOT NULL,
"user_id" TEXT NOT NULL,
FOREIGN KEY ("account_id") REFERENCES "accounts" ("id")
ON DELETE CASCADE
UNIQUE(account_id,user_id)
);
CREATE INDEX IF NOT EXISTS "tracked_users_account_id" ON "tracked_users" ("account_id");
"#,
)
.await?;
connection
.execute(
r#"
@ -330,6 +347,51 @@ impl SqliteStore {
.collect::<Result<Vec<InboundGroupSession>>>()?)
}
async fn save_tracked_user(&self, user: &UserId) -> Result<()> {
let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?;
let mut connection = self.connection.lock().await;
query(
"INSERT OR IGNORE INTO tracked_users (
account_id, user_id
) VALUES (?1, ?2)
",
)
.bind(account_id)
.bind(user.to_string())
.execute(&mut *connection)
.await?;
Ok(())
}
async fn load_tracked_users(&self) -> Result<HashSet<UserId>> {
let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?;
let mut connection = self.connection.lock().await;
let rows: Vec<(String,)> = query_as(
"SELECT user_id
FROM tracked_users WHERE account_id = ?",
)
.bind(account_id)
.fetch_all(&mut *connection)
.await?;
let mut users = HashSet::new();
for row in rows {
let user_id: &str = &row.0;
if let Ok(u) = UserId::try_from(user_id) {
users.insert(u);
} else {
continue;
};
}
Ok(users)
}
async fn load_devices(&self) -> Result<DeviceStore> {
let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?;
let mut connection = self.connection.lock().await;
@ -520,7 +582,8 @@ impl CryptoStore for SqliteStore {
let devices = self.load_devices().await?;
mem::replace(&mut self.devices, devices);
// TODO load the tracked users here as well.
let tracked_users = self.load_tracked_users().await?;
mem::replace(&mut self.tracked_users, tracked_users);
Ok(result)
}
@ -637,7 +700,7 @@ impl CryptoStore for SqliteStore {
}
async fn add_user_for_tracking(&mut self, user: &UserId) -> Result<bool> {
// TODO save the tracked user to the database.
self.save_tracked_user(user).await?;
Ok(self.tracked_users.insert(user.clone()))
}
@ -974,7 +1037,7 @@ mod test {
#[tokio::test]
async fn test_tracked_users() {
let (_account, mut store, _dir) = get_loaded_store().await;
let (_account, mut store, dir) = get_loaded_store().await;
let device = get_device();
assert!(store.add_user_for_tracking(device.user_id()).await.unwrap());
@ -982,7 +1045,18 @@ mod test {
let tracked_users = store.tracked_users();
tracked_users.contains(device.user_id());
assert!(tracked_users.contains(device.user_id()));
drop(store);
let mut store =
SqliteStore::open(&UserId::try_from(USER_ID).unwrap(), DEVICE_ID, dir.path())
.await
.expect("Can't create store");
store.load_account().await.unwrap();
let tracked_users = store.tracked_users();
assert!(tracked_users.contains(device.user_id()));
}
#[tokio::test]