crypto: Move the users for key query map into the store.
parent
e51e89d9d5
commit
1d9fccdc9f
|
@ -80,9 +80,6 @@ pub struct OlmMachine {
|
|||
/// Persists all the encryption keys so a client can resume the session
|
||||
/// without the need to create new keys.
|
||||
store: Box<dyn CryptoStore>,
|
||||
/// Set of users that we need to query keys for. This is a subset of
|
||||
/// the tracked users in the CryptoStore.
|
||||
users_for_key_query: HashSet<UserId>,
|
||||
/// The currently active outbound group sessions.
|
||||
outbound_group_sessions: HashMap<RoomId, OutboundGroupSession>,
|
||||
}
|
||||
|
@ -122,7 +119,6 @@ impl OlmMachine {
|
|||
account: Account::new(),
|
||||
uploaded_signed_key_count: None,
|
||||
store: Box::new(MemoryStore::new()),
|
||||
users_for_key_query: HashSet::new(),
|
||||
outbound_group_sessions: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
@ -166,7 +162,6 @@ impl OlmMachine {
|
|||
account,
|
||||
uploaded_signed_key_count: None,
|
||||
store: Box::new(store),
|
||||
users_for_key_query: HashSet::new(),
|
||||
outbound_group_sessions: HashMap::new(),
|
||||
})
|
||||
}
|
||||
|
@ -461,7 +456,7 @@ impl OlmMachine {
|
|||
let mut changed_devices = Vec::new();
|
||||
|
||||
for (user_id, device_map) in &response.device_keys {
|
||||
self.users_for_key_query.remove(&user_id);
|
||||
self.store.update_tracked_user(user_id, false).await?;
|
||||
|
||||
for (device_id, device_keys) in device_map.iter() {
|
||||
// We don't need our own device in the device store.
|
||||
|
@ -1516,12 +1511,12 @@ impl OlmMachine {
|
|||
/// key query.
|
||||
///
|
||||
/// Returns true if the user was queued up for a key query, false otherwise.
|
||||
pub async fn mark_user_as_changed(&mut self, user_id: &UserId) -> bool {
|
||||
pub async fn mark_user_as_changed(&mut self, user_id: &UserId) -> StoreError<bool> {
|
||||
if self.store.tracked_users().contains(user_id) {
|
||||
self.users_for_key_query.insert(user_id.clone());
|
||||
true
|
||||
self.store.update_tracked_user(user_id, true).await?;
|
||||
Ok(true)
|
||||
} else {
|
||||
false
|
||||
Ok(false)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1544,32 +1539,26 @@ impl OlmMachine {
|
|||
I: IntoIterator<Item = &'a UserId>,
|
||||
{
|
||||
for user in users {
|
||||
let ret = self.store.add_user_for_tracking(user).await;
|
||||
if self.store.tracked_users().contains(user) {
|
||||
continue;
|
||||
}
|
||||
|
||||
match ret {
|
||||
Ok(newly_added) => {
|
||||
if newly_added {
|
||||
self.mark_user_as_changed(user).await;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
if let Err(e) = self.store.update_tracked_user(user, true).await {
|
||||
warn!("Error storing users for tracking {}", e);
|
||||
self.users_for_key_query.insert(user.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Should the client perform a key query request.
|
||||
pub fn should_query_keys(&self) -> bool {
|
||||
!self.users_for_key_query.is_empty()
|
||||
!self.store.users_for_key_query().is_empty()
|
||||
}
|
||||
|
||||
/// Get the set of users that we need to query keys for.
|
||||
///
|
||||
/// Returns a hash set of users that need to be queried for keys.
|
||||
pub fn users_for_key_query(&self) -> HashSet<UserId> {
|
||||
self.users_for_key_query.clone()
|
||||
self.store.users_for_key_query().clone()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -28,6 +28,7 @@ pub struct MemoryStore {
|
|||
sessions: SessionStore,
|
||||
inbound_group_sessions: GroupSessionStore,
|
||||
tracked_users: HashSet<UserId>,
|
||||
users_for_key_query: HashSet<UserId>,
|
||||
devices: DeviceStore,
|
||||
}
|
||||
|
||||
|
@ -37,6 +38,7 @@ impl MemoryStore {
|
|||
sessions: SessionStore::new(),
|
||||
inbound_group_sessions: GroupSessionStore::new(),
|
||||
tracked_users: HashSet::new(),
|
||||
users_for_key_query: HashSet::new(),
|
||||
devices: DeviceStore::new(),
|
||||
}
|
||||
}
|
||||
|
@ -83,7 +85,17 @@ impl CryptoStore for MemoryStore {
|
|||
&self.tracked_users
|
||||
}
|
||||
|
||||
async fn add_user_for_tracking(&mut self, user: &UserId) -> Result<bool> {
|
||||
fn users_for_key_query(&self) -> &HashSet<UserId> {
|
||||
&self.users_for_key_query
|
||||
}
|
||||
|
||||
async fn update_tracked_user(&mut self, user: &UserId, dirty: bool) -> Result<bool> {
|
||||
if dirty {
|
||||
self.users_for_key_query.insert(user.clone());
|
||||
} else {
|
||||
self.users_for_key_query.remove(user);
|
||||
}
|
||||
|
||||
Ok(self.tracked_users.insert(user.clone()))
|
||||
}
|
||||
|
||||
|
@ -207,8 +219,14 @@ mod test {
|
|||
let device = get_device();
|
||||
let mut store = MemoryStore::new();
|
||||
|
||||
assert!(store.add_user_for_tracking(device.user_id()).await.unwrap());
|
||||
assert!(!store.add_user_for_tracking(device.user_id()).await.unwrap());
|
||||
assert!(store
|
||||
.update_tracked_user(device.user_id(), false)
|
||||
.await
|
||||
.unwrap());
|
||||
assert!(!store
|
||||
.update_tracked_user(device.user_id(), false)
|
||||
.await
|
||||
.unwrap());
|
||||
|
||||
let tracked_users = store.tracked_users();
|
||||
|
||||
|
|
|
@ -138,6 +138,10 @@ pub trait CryptoStore: Debug + Send + Sync {
|
|||
/// Get the set of tracked users.
|
||||
fn tracked_users(&self) -> &HashSet<UserId>;
|
||||
|
||||
/// Set of users that we need to query keys for. This is a subset of
|
||||
/// the tracked users.
|
||||
fn users_for_key_query(&self) -> &HashSet<UserId>;
|
||||
|
||||
/// Add an user for tracking.
|
||||
///
|
||||
/// Returns true if the user wasn't already tracked, false otherwise.
|
||||
|
@ -145,7 +149,9 @@ pub trait CryptoStore: Debug + Send + Sync {
|
|||
/// # Arguments
|
||||
///
|
||||
/// * `user` - The user that should be marked as tracked.
|
||||
async fn add_user_for_tracking(&mut self, user: &UserId) -> Result<bool>;
|
||||
///
|
||||
/// * `dirty` - Should the user be also marked for a key query.
|
||||
async fn update_tracked_user(&mut self, user: &UserId, dirty: bool) -> Result<bool>;
|
||||
|
||||
/// Save the given devices in the store.
|
||||
///
|
||||
|
|
|
@ -45,6 +45,7 @@ pub struct SqliteStore {
|
|||
inbound_group_sessions: GroupSessionStore,
|
||||
devices: DeviceStore,
|
||||
tracked_users: HashSet<UserId>,
|
||||
users_for_key_query: HashSet<UserId>,
|
||||
|
||||
connection: Arc<Mutex<SqliteConnection>>,
|
||||
pickle_passphrase: Option<Zeroizing<String>>,
|
||||
|
@ -121,6 +122,7 @@ impl SqliteStore {
|
|||
connection: Arc::new(Mutex::new(connection)),
|
||||
pickle_passphrase: passphrase,
|
||||
tracked_users: HashSet::new(),
|
||||
users_for_key_query: HashSet::new(),
|
||||
};
|
||||
store.create_tables().await?;
|
||||
Ok(store)
|
||||
|
@ -169,6 +171,7 @@ impl SqliteStore {
|
|||
"id" INTEGER NOT NULL PRIMARY KEY,
|
||||
"account_id" INTEGER NOT NULL,
|
||||
"user_id" TEXT NOT NULL,
|
||||
"dirty" INTEGER NOT NULL,
|
||||
FOREIGN KEY ("account_id") REFERENCES "accounts" ("id")
|
||||
ON DELETE CASCADE
|
||||
UNIQUE(account_id,user_id)
|
||||
|
@ -347,30 +350,33 @@ impl SqliteStore {
|
|||
.collect::<Result<Vec<InboundGroupSession>>>()?)
|
||||
}
|
||||
|
||||
async fn save_tracked_user(&self, user: &UserId) -> Result<()> {
|
||||
async fn save_tracked_user(&self, user: &UserId, dirty: bool) -> Result<()> {
|
||||
let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?;
|
||||
let mut connection = self.connection.lock().await;
|
||||
|
||||
query(
|
||||
"INSERT OR IGNORE INTO tracked_users (
|
||||
account_id, user_id
|
||||
) VALUES (?1, ?2)
|
||||
"INSERT INTO tracked_users (
|
||||
account_id, user_id, dirty
|
||||
) VALUES (?1, ?2, ?3)
|
||||
ON CONFLICT(account_id, user_id) DO UPDATE SET
|
||||
dirty = excluded.dirty
|
||||
",
|
||||
)
|
||||
.bind(account_id)
|
||||
.bind(user.to_string())
|
||||
.bind(dirty)
|
||||
.execute(&mut *connection)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn load_tracked_users(&self) -> Result<HashSet<UserId>> {
|
||||
async fn load_tracked_users(&self) -> Result<(HashSet<UserId>, HashSet<UserId>)> {
|
||||
let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?;
|
||||
let mut connection = self.connection.lock().await;
|
||||
|
||||
let rows: Vec<(String,)> = query_as(
|
||||
"SELECT user_id
|
||||
let rows: Vec<(String, bool)> = query_as(
|
||||
"SELECT user_id, dirty
|
||||
FROM tracked_users WHERE account_id = ?",
|
||||
)
|
||||
.bind(account_id)
|
||||
|
@ -378,18 +384,23 @@ impl SqliteStore {
|
|||
.await?;
|
||||
|
||||
let mut users = HashSet::new();
|
||||
let mut users_for_query = HashSet::new();
|
||||
|
||||
for row in rows {
|
||||
let user_id: &str = &row.0;
|
||||
let dirty: bool = row.1;
|
||||
|
||||
if let Ok(u) = UserId::try_from(user_id) {
|
||||
users.insert(u);
|
||||
users.insert(u.clone());
|
||||
if dirty {
|
||||
users_for_query.insert(u);
|
||||
}
|
||||
} else {
|
||||
continue;
|
||||
};
|
||||
}
|
||||
|
||||
Ok(users)
|
||||
Ok((users, users_for_query))
|
||||
}
|
||||
|
||||
async fn load_devices(&self) -> Result<DeviceStore> {
|
||||
|
@ -582,8 +593,9 @@ impl CryptoStore for SqliteStore {
|
|||
let devices = self.load_devices().await?;
|
||||
mem::replace(&mut self.devices, devices);
|
||||
|
||||
let tracked_users = self.load_tracked_users().await?;
|
||||
let (tracked_users, users_for_query) = self.load_tracked_users().await?;
|
||||
mem::replace(&mut self.tracked_users, tracked_users);
|
||||
mem::replace(&mut self.users_for_key_query, users_for_query);
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
@ -699,9 +711,22 @@ impl CryptoStore for SqliteStore {
|
|||
&self.tracked_users
|
||||
}
|
||||
|
||||
async fn add_user_for_tracking(&mut self, user: &UserId) -> Result<bool> {
|
||||
self.save_tracked_user(user).await?;
|
||||
Ok(self.tracked_users.insert(user.clone()))
|
||||
fn users_for_key_query(&self) -> &HashSet<UserId> {
|
||||
&self.users_for_key_query
|
||||
}
|
||||
|
||||
async fn update_tracked_user(&mut self, user: &UserId, dirty: bool) -> Result<bool> {
|
||||
let already_added = self.tracked_users.insert(user.clone());
|
||||
|
||||
if dirty {
|
||||
self.users_for_key_query.insert(user.clone());
|
||||
} else {
|
||||
self.users_for_key_query.remove(user);
|
||||
}
|
||||
|
||||
self.save_tracked_user(user, dirty).await?;
|
||||
|
||||
Ok(already_added)
|
||||
}
|
||||
|
||||
async fn save_devices(&self, devices: &[Device]) -> Result<()> {
|
||||
|
@ -1040,12 +1065,24 @@ mod test {
|
|||
let (_account, mut store, dir) = get_loaded_store().await;
|
||||
let device = get_device();
|
||||
|
||||
assert!(store.add_user_for_tracking(device.user_id()).await.unwrap());
|
||||
assert!(!store.add_user_for_tracking(device.user_id()).await.unwrap());
|
||||
assert!(store
|
||||
.update_tracked_user(device.user_id(), false)
|
||||
.await
|
||||
.unwrap());
|
||||
assert!(!store
|
||||
.update_tracked_user(device.user_id(), false)
|
||||
.await
|
||||
.unwrap());
|
||||
|
||||
let tracked_users = store.tracked_users();
|
||||
|
||||
assert!(tracked_users.contains(device.user_id()));
|
||||
assert!(!store.users_for_key_query().contains(device.user_id()));
|
||||
assert!(!store
|
||||
.update_tracked_user(device.user_id(), true)
|
||||
.await
|
||||
.unwrap());
|
||||
assert!(store.users_for_key_query().contains(device.user_id()));
|
||||
drop(store);
|
||||
|
||||
let mut store =
|
||||
|
@ -1057,6 +1094,22 @@ mod test {
|
|||
|
||||
let tracked_users = store.tracked_users();
|
||||
assert!(tracked_users.contains(device.user_id()));
|
||||
assert!(store.users_for_key_query().contains(device.user_id()));
|
||||
|
||||
store
|
||||
.update_tracked_user(device.user_id(), false)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!store.users_for_key_query().contains(device.user_id()));
|
||||
|
||||
let mut store =
|
||||
SqliteStore::open(&UserId::try_from(USER_ID).unwrap(), DEVICE_ID, dir.path())
|
||||
.await
|
||||
.expect("Can't create store");
|
||||
|
||||
store.load_account().await.unwrap();
|
||||
|
||||
assert!(!store.users_for_key_query().contains(device.user_id()));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
|
Loading…
Reference in New Issue