sqlite: Store the tracked users in the database.
parent
8c6c34e01a
commit
b2e48d8eae
|
@ -162,6 +162,23 @@ impl SqliteStore {
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
connection
|
||||||
|
.execute(
|
||||||
|
r#"
|
||||||
|
CREATE TABLE IF NOT EXISTS tracked_users (
|
||||||
|
"id" INTEGER NOT NULL PRIMARY KEY,
|
||||||
|
"account_id" INTEGER NOT NULL,
|
||||||
|
"user_id" TEXT NOT NULL,
|
||||||
|
FOREIGN KEY ("account_id") REFERENCES "accounts" ("id")
|
||||||
|
ON DELETE CASCADE
|
||||||
|
UNIQUE(account_id,user_id)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS "tracked_users_account_id" ON "tracked_users" ("account_id");
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
connection
|
connection
|
||||||
.execute(
|
.execute(
|
||||||
r#"
|
r#"
|
||||||
|
@ -330,6 +347,51 @@ impl SqliteStore {
|
||||||
.collect::<Result<Vec<InboundGroupSession>>>()?)
|
.collect::<Result<Vec<InboundGroupSession>>>()?)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn save_tracked_user(&self, user: &UserId) -> Result<()> {
|
||||||
|
let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?;
|
||||||
|
let mut connection = self.connection.lock().await;
|
||||||
|
|
||||||
|
query(
|
||||||
|
"INSERT OR IGNORE INTO tracked_users (
|
||||||
|
account_id, user_id
|
||||||
|
) VALUES (?1, ?2)
|
||||||
|
",
|
||||||
|
)
|
||||||
|
.bind(account_id)
|
||||||
|
.bind(user.to_string())
|
||||||
|
.execute(&mut *connection)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn load_tracked_users(&self) -> Result<HashSet<UserId>> {
|
||||||
|
let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?;
|
||||||
|
let mut connection = self.connection.lock().await;
|
||||||
|
|
||||||
|
let rows: Vec<(String,)> = query_as(
|
||||||
|
"SELECT user_id
|
||||||
|
FROM tracked_users WHERE account_id = ?",
|
||||||
|
)
|
||||||
|
.bind(account_id)
|
||||||
|
.fetch_all(&mut *connection)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let mut users = HashSet::new();
|
||||||
|
|
||||||
|
for row in rows {
|
||||||
|
let user_id: &str = &row.0;
|
||||||
|
|
||||||
|
if let Ok(u) = UserId::try_from(user_id) {
|
||||||
|
users.insert(u);
|
||||||
|
} else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(users)
|
||||||
|
}
|
||||||
|
|
||||||
async fn load_devices(&self) -> Result<DeviceStore> {
|
async fn load_devices(&self) -> Result<DeviceStore> {
|
||||||
let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?;
|
let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?;
|
||||||
let mut connection = self.connection.lock().await;
|
let mut connection = self.connection.lock().await;
|
||||||
|
@ -520,7 +582,8 @@ impl CryptoStore for SqliteStore {
|
||||||
let devices = self.load_devices().await?;
|
let devices = self.load_devices().await?;
|
||||||
mem::replace(&mut self.devices, devices);
|
mem::replace(&mut self.devices, devices);
|
||||||
|
|
||||||
// TODO load the tracked users here as well.
|
let tracked_users = self.load_tracked_users().await?;
|
||||||
|
mem::replace(&mut self.tracked_users, tracked_users);
|
||||||
|
|
||||||
Ok(result)
|
Ok(result)
|
||||||
}
|
}
|
||||||
|
@ -637,7 +700,7 @@ impl CryptoStore for SqliteStore {
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn add_user_for_tracking(&mut self, user: &UserId) -> Result<bool> {
|
async fn add_user_for_tracking(&mut self, user: &UserId) -> Result<bool> {
|
||||||
// TODO save the tracked user to the database.
|
self.save_tracked_user(user).await?;
|
||||||
Ok(self.tracked_users.insert(user.clone()))
|
Ok(self.tracked_users.insert(user.clone()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -974,7 +1037,7 @@ mod test {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_tracked_users() {
|
async fn test_tracked_users() {
|
||||||
let (_account, mut store, _dir) = get_loaded_store().await;
|
let (_account, mut store, dir) = get_loaded_store().await;
|
||||||
let device = get_device();
|
let device = get_device();
|
||||||
|
|
||||||
assert!(store.add_user_for_tracking(device.user_id()).await.unwrap());
|
assert!(store.add_user_for_tracking(device.user_id()).await.unwrap());
|
||||||
|
@ -982,7 +1045,18 @@ mod test {
|
||||||
|
|
||||||
let tracked_users = store.tracked_users();
|
let tracked_users = store.tracked_users();
|
||||||
|
|
||||||
tracked_users.contains(device.user_id());
|
assert!(tracked_users.contains(device.user_id()));
|
||||||
|
drop(store);
|
||||||
|
|
||||||
|
let mut store =
|
||||||
|
SqliteStore::open(&UserId::try_from(USER_ID).unwrap(), DEVICE_ID, dir.path())
|
||||||
|
.await
|
||||||
|
.expect("Can't create store");
|
||||||
|
|
||||||
|
store.load_account().await.unwrap();
|
||||||
|
|
||||||
|
let tracked_users = store.tracked_users();
|
||||||
|
assert!(tracked_users.contains(device.user_id()));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
|
|
Loading…
Reference in New Issue