crypto: Move the users for key query map into the store.

master
Damir Jelić 2020-05-15 15:33:30 +02:00
parent e51e89d9d5
commit 1d9fccdc9f
4 changed files with 108 additions and 42 deletions

View File

@ -80,9 +80,6 @@ pub struct OlmMachine {
/// Persists all the encryption keys so a client can resume the session /// Persists all the encryption keys so a client can resume the session
/// without the need to create new keys. /// without the need to create new keys.
store: Box<dyn CryptoStore>, store: Box<dyn CryptoStore>,
/// Set of users that we need to query keys for. This is a subset of
/// the tracked users in the CryptoStore.
users_for_key_query: HashSet<UserId>,
/// The currently active outbound group sessions. /// The currently active outbound group sessions.
outbound_group_sessions: HashMap<RoomId, OutboundGroupSession>, outbound_group_sessions: HashMap<RoomId, OutboundGroupSession>,
} }
@ -122,7 +119,6 @@ impl OlmMachine {
account: Account::new(), account: Account::new(),
uploaded_signed_key_count: None, uploaded_signed_key_count: None,
store: Box::new(MemoryStore::new()), store: Box::new(MemoryStore::new()),
users_for_key_query: HashSet::new(),
outbound_group_sessions: HashMap::new(), outbound_group_sessions: HashMap::new(),
} }
} }
@ -166,7 +162,6 @@ impl OlmMachine {
account, account,
uploaded_signed_key_count: None, uploaded_signed_key_count: None,
store: Box::new(store), store: Box::new(store),
users_for_key_query: HashSet::new(),
outbound_group_sessions: HashMap::new(), outbound_group_sessions: HashMap::new(),
}) })
} }
@ -461,7 +456,7 @@ impl OlmMachine {
let mut changed_devices = Vec::new(); let mut changed_devices = Vec::new();
for (user_id, device_map) in &response.device_keys { for (user_id, device_map) in &response.device_keys {
self.users_for_key_query.remove(&user_id); self.store.update_tracked_user(user_id, false).await?;
for (device_id, device_keys) in device_map.iter() { for (device_id, device_keys) in device_map.iter() {
// We don't need our own device in the device store. // We don't need our own device in the device store.
@ -1516,12 +1511,12 @@ impl OlmMachine {
/// key query. /// key query.
/// ///
/// Returns true if the user was queued up for a key query, false otherwise. /// Returns true if the user was queued up for a key query, false otherwise.
pub async fn mark_user_as_changed(&mut self, user_id: &UserId) -> bool { pub async fn mark_user_as_changed(&mut self, user_id: &UserId) -> StoreError<bool> {
if self.store.tracked_users().contains(user_id) { if self.store.tracked_users().contains(user_id) {
self.users_for_key_query.insert(user_id.clone()); self.store.update_tracked_user(user_id, true).await?;
true Ok(true)
} else { } else {
false Ok(false)
} }
} }
@ -1544,32 +1539,26 @@ impl OlmMachine {
I: IntoIterator<Item = &'a UserId>, I: IntoIterator<Item = &'a UserId>,
{ {
for user in users { for user in users {
let ret = self.store.add_user_for_tracking(user).await; if self.store.tracked_users().contains(user) {
continue;
}
match ret { if let Err(e) = self.store.update_tracked_user(user, true).await {
Ok(newly_added) => { warn!("Error storing users for tracking {}", e);
if newly_added {
self.mark_user_as_changed(user).await;
}
}
Err(e) => {
warn!("Error storing users for tracking {}", e);
self.users_for_key_query.insert(user.clone());
}
} }
} }
} }
/// Should the client perform a key query request. /// Should the client perform a key query request.
pub fn should_query_keys(&self) -> bool { pub fn should_query_keys(&self) -> bool {
!self.users_for_key_query.is_empty() !self.store.users_for_key_query().is_empty()
} }
/// Get the set of users that we need to query keys for. /// Get the set of users that we need to query keys for.
/// ///
/// Returns a hash set of users that need to be queried for keys. /// Returns a hash set of users that need to be queried for keys.
pub fn users_for_key_query(&self) -> HashSet<UserId> { pub fn users_for_key_query(&self) -> HashSet<UserId> {
self.users_for_key_query.clone() self.store.users_for_key_query().clone()
} }
} }

View File

@ -28,6 +28,7 @@ pub struct MemoryStore {
sessions: SessionStore, sessions: SessionStore,
inbound_group_sessions: GroupSessionStore, inbound_group_sessions: GroupSessionStore,
tracked_users: HashSet<UserId>, tracked_users: HashSet<UserId>,
users_for_key_query: HashSet<UserId>,
devices: DeviceStore, devices: DeviceStore,
} }
@ -37,6 +38,7 @@ impl MemoryStore {
sessions: SessionStore::new(), sessions: SessionStore::new(),
inbound_group_sessions: GroupSessionStore::new(), inbound_group_sessions: GroupSessionStore::new(),
tracked_users: HashSet::new(), tracked_users: HashSet::new(),
users_for_key_query: HashSet::new(),
devices: DeviceStore::new(), devices: DeviceStore::new(),
} }
} }
@ -83,7 +85,17 @@ impl CryptoStore for MemoryStore {
&self.tracked_users &self.tracked_users
} }
async fn add_user_for_tracking(&mut self, user: &UserId) -> Result<bool> { fn users_for_key_query(&self) -> &HashSet<UserId> {
&self.users_for_key_query
}
async fn update_tracked_user(&mut self, user: &UserId, dirty: bool) -> Result<bool> {
if dirty {
self.users_for_key_query.insert(user.clone());
} else {
self.users_for_key_query.remove(user);
}
Ok(self.tracked_users.insert(user.clone())) Ok(self.tracked_users.insert(user.clone()))
} }
@ -207,8 +219,14 @@ mod test {
let device = get_device(); let device = get_device();
let mut store = MemoryStore::new(); let mut store = MemoryStore::new();
assert!(store.add_user_for_tracking(device.user_id()).await.unwrap()); assert!(store
assert!(!store.add_user_for_tracking(device.user_id()).await.unwrap()); .update_tracked_user(device.user_id(), false)
.await
.unwrap());
assert!(!store
.update_tracked_user(device.user_id(), false)
.await
.unwrap());
let tracked_users = store.tracked_users(); let tracked_users = store.tracked_users();

View File

@ -138,6 +138,10 @@ pub trait CryptoStore: Debug + Send + Sync {
/// Get the set of tracked users. /// Get the set of tracked users.
fn tracked_users(&self) -> &HashSet<UserId>; fn tracked_users(&self) -> &HashSet<UserId>;
/// Set of users that we need to query keys for. This is a subset of
/// the tracked users.
fn users_for_key_query(&self) -> &HashSet<UserId>;
/// Add an user for tracking. /// Add an user for tracking.
/// ///
/// Returns true if the user wasn't already tracked, false otherwise. /// Returns true if the user wasn't already tracked, false otherwise.
@ -145,7 +149,9 @@ pub trait CryptoStore: Debug + Send + Sync {
/// # Arguments /// # Arguments
/// ///
/// * `user` - The user that should be marked as tracked. /// * `user` - The user that should be marked as tracked.
async fn add_user_for_tracking(&mut self, user: &UserId) -> Result<bool>; ///
/// * `dirty` - Should the user be also marked for a key query.
async fn update_tracked_user(&mut self, user: &UserId, dirty: bool) -> Result<bool>;
/// Save the given devices in the store. /// Save the given devices in the store.
/// ///

View File

@ -45,6 +45,7 @@ pub struct SqliteStore {
inbound_group_sessions: GroupSessionStore, inbound_group_sessions: GroupSessionStore,
devices: DeviceStore, devices: DeviceStore,
tracked_users: HashSet<UserId>, tracked_users: HashSet<UserId>,
users_for_key_query: HashSet<UserId>,
connection: Arc<Mutex<SqliteConnection>>, connection: Arc<Mutex<SqliteConnection>>,
pickle_passphrase: Option<Zeroizing<String>>, pickle_passphrase: Option<Zeroizing<String>>,
@ -121,6 +122,7 @@ impl SqliteStore {
connection: Arc::new(Mutex::new(connection)), connection: Arc::new(Mutex::new(connection)),
pickle_passphrase: passphrase, pickle_passphrase: passphrase,
tracked_users: HashSet::new(), tracked_users: HashSet::new(),
users_for_key_query: HashSet::new(),
}; };
store.create_tables().await?; store.create_tables().await?;
Ok(store) Ok(store)
@ -169,6 +171,7 @@ impl SqliteStore {
"id" INTEGER NOT NULL PRIMARY KEY, "id" INTEGER NOT NULL PRIMARY KEY,
"account_id" INTEGER NOT NULL, "account_id" INTEGER NOT NULL,
"user_id" TEXT NOT NULL, "user_id" TEXT NOT NULL,
"dirty" INTEGER NOT NULL,
FOREIGN KEY ("account_id") REFERENCES "accounts" ("id") FOREIGN KEY ("account_id") REFERENCES "accounts" ("id")
ON DELETE CASCADE ON DELETE CASCADE
UNIQUE(account_id,user_id) UNIQUE(account_id,user_id)
@ -347,30 +350,33 @@ impl SqliteStore {
.collect::<Result<Vec<InboundGroupSession>>>()?) .collect::<Result<Vec<InboundGroupSession>>>()?)
} }
async fn save_tracked_user(&self, user: &UserId) -> Result<()> { async fn save_tracked_user(&self, user: &UserId, dirty: bool) -> Result<()> {
let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?; let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?;
let mut connection = self.connection.lock().await; let mut connection = self.connection.lock().await;
query( query(
"INSERT OR IGNORE INTO tracked_users ( "INSERT INTO tracked_users (
account_id, user_id account_id, user_id, dirty
) VALUES (?1, ?2) ) VALUES (?1, ?2, ?3)
ON CONFLICT(account_id, user_id) DO UPDATE SET
dirty = excluded.dirty
", ",
) )
.bind(account_id) .bind(account_id)
.bind(user.to_string()) .bind(user.to_string())
.bind(dirty)
.execute(&mut *connection) .execute(&mut *connection)
.await?; .await?;
Ok(()) Ok(())
} }
async fn load_tracked_users(&self) -> Result<HashSet<UserId>> { async fn load_tracked_users(&self) -> Result<(HashSet<UserId>, HashSet<UserId>)> {
let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?; let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?;
let mut connection = self.connection.lock().await; let mut connection = self.connection.lock().await;
let rows: Vec<(String,)> = query_as( let rows: Vec<(String, bool)> = query_as(
"SELECT user_id "SELECT user_id, dirty
FROM tracked_users WHERE account_id = ?", FROM tracked_users WHERE account_id = ?",
) )
.bind(account_id) .bind(account_id)
@ -378,18 +384,23 @@ impl SqliteStore {
.await?; .await?;
let mut users = HashSet::new(); let mut users = HashSet::new();
let mut users_for_query = HashSet::new();
for row in rows { for row in rows {
let user_id: &str = &row.0; let user_id: &str = &row.0;
let dirty: bool = row.1;
if let Ok(u) = UserId::try_from(user_id) { if let Ok(u) = UserId::try_from(user_id) {
users.insert(u); users.insert(u.clone());
if dirty {
users_for_query.insert(u);
}
} else { } else {
continue; continue;
}; };
} }
Ok(users) Ok((users, users_for_query))
} }
async fn load_devices(&self) -> Result<DeviceStore> { async fn load_devices(&self) -> Result<DeviceStore> {
@ -582,8 +593,9 @@ impl CryptoStore for SqliteStore {
let devices = self.load_devices().await?; let devices = self.load_devices().await?;
mem::replace(&mut self.devices, devices); mem::replace(&mut self.devices, devices);
let tracked_users = self.load_tracked_users().await?; let (tracked_users, users_for_query) = self.load_tracked_users().await?;
mem::replace(&mut self.tracked_users, tracked_users); mem::replace(&mut self.tracked_users, tracked_users);
mem::replace(&mut self.users_for_key_query, users_for_query);
Ok(result) Ok(result)
} }
@ -699,9 +711,22 @@ impl CryptoStore for SqliteStore {
&self.tracked_users &self.tracked_users
} }
async fn add_user_for_tracking(&mut self, user: &UserId) -> Result<bool> { fn users_for_key_query(&self) -> &HashSet<UserId> {
self.save_tracked_user(user).await?; &self.users_for_key_query
Ok(self.tracked_users.insert(user.clone())) }
async fn update_tracked_user(&mut self, user: &UserId, dirty: bool) -> Result<bool> {
let already_added = self.tracked_users.insert(user.clone());
if dirty {
self.users_for_key_query.insert(user.clone());
} else {
self.users_for_key_query.remove(user);
}
self.save_tracked_user(user, dirty).await?;
Ok(already_added)
} }
async fn save_devices(&self, devices: &[Device]) -> Result<()> { async fn save_devices(&self, devices: &[Device]) -> Result<()> {
@ -1040,12 +1065,24 @@ mod test {
let (_account, mut store, dir) = get_loaded_store().await; let (_account, mut store, dir) = get_loaded_store().await;
let device = get_device(); let device = get_device();
assert!(store.add_user_for_tracking(device.user_id()).await.unwrap()); assert!(store
assert!(!store.add_user_for_tracking(device.user_id()).await.unwrap()); .update_tracked_user(device.user_id(), false)
.await
.unwrap());
assert!(!store
.update_tracked_user(device.user_id(), false)
.await
.unwrap());
let tracked_users = store.tracked_users(); let tracked_users = store.tracked_users();
assert!(tracked_users.contains(device.user_id())); assert!(tracked_users.contains(device.user_id()));
assert!(!store.users_for_key_query().contains(device.user_id()));
assert!(!store
.update_tracked_user(device.user_id(), true)
.await
.unwrap());
assert!(store.users_for_key_query().contains(device.user_id()));
drop(store); drop(store);
let mut store = let mut store =
@ -1057,6 +1094,22 @@ mod test {
let tracked_users = store.tracked_users(); let tracked_users = store.tracked_users();
assert!(tracked_users.contains(device.user_id())); assert!(tracked_users.contains(device.user_id()));
assert!(store.users_for_key_query().contains(device.user_id()));
store
.update_tracked_user(device.user_id(), false)
.await
.unwrap();
assert!(!store.users_for_key_query().contains(device.user_id()));
let mut store =
SqliteStore::open(&UserId::try_from(USER_ID).unwrap(), DEVICE_ID, dir.path())
.await
.expect("Can't create store");
store.load_account().await.unwrap();
assert!(!store.users_for_key_query().contains(device.user_id()));
} }
#[tokio::test] #[tokio::test]