crypto: Move the users for key query map into the store.

master
Damir Jelić 2020-05-15 15:33:30 +02:00
parent e51e89d9d5
commit 1d9fccdc9f
4 changed files with 108 additions and 42 deletions

View File

@ -80,9 +80,6 @@ pub struct OlmMachine {
/// Persists all the encryption keys so a client can resume the session
/// without the need to create new keys.
store: Box<dyn CryptoStore>,
/// Set of users that we need to query keys for. This is a subset of
/// the tracked users in the CryptoStore.
users_for_key_query: HashSet<UserId>,
/// The currently active outbound group sessions.
outbound_group_sessions: HashMap<RoomId, OutboundGroupSession>,
}
@ -122,7 +119,6 @@ impl OlmMachine {
account: Account::new(),
uploaded_signed_key_count: None,
store: Box::new(MemoryStore::new()),
users_for_key_query: HashSet::new(),
outbound_group_sessions: HashMap::new(),
}
}
@ -166,7 +162,6 @@ impl OlmMachine {
account,
uploaded_signed_key_count: None,
store: Box::new(store),
users_for_key_query: HashSet::new(),
outbound_group_sessions: HashMap::new(),
})
}
@ -461,7 +456,7 @@ impl OlmMachine {
let mut changed_devices = Vec::new();
for (user_id, device_map) in &response.device_keys {
self.users_for_key_query.remove(&user_id);
self.store.update_tracked_user(user_id, false).await?;
for (device_id, device_keys) in device_map.iter() {
// We don't need our own device in the device store.
@ -1516,12 +1511,12 @@ impl OlmMachine {
/// key query.
///
/// Returns true if the user was queued up for a key query, false otherwise.
pub async fn mark_user_as_changed(&mut self, user_id: &UserId) -> bool {
pub async fn mark_user_as_changed(&mut self, user_id: &UserId) -> StoreError<bool> {
if self.store.tracked_users().contains(user_id) {
self.users_for_key_query.insert(user_id.clone());
true
self.store.update_tracked_user(user_id, true).await?;
Ok(true)
} else {
false
Ok(false)
}
}
@ -1544,32 +1539,26 @@ impl OlmMachine {
I: IntoIterator<Item = &'a UserId>,
{
for user in users {
let ret = self.store.add_user_for_tracking(user).await;
if self.store.tracked_users().contains(user) {
continue;
}
match ret {
Ok(newly_added) => {
if newly_added {
self.mark_user_as_changed(user).await;
}
}
Err(e) => {
if let Err(e) = self.store.update_tracked_user(user, true).await {
warn!("Error storing users for tracking {}", e);
self.users_for_key_query.insert(user.clone());
}
}
}
}
/// Should the client perform a key query request.
pub fn should_query_keys(&self) -> bool {
!self.users_for_key_query.is_empty()
!self.store.users_for_key_query().is_empty()
}
/// Get the set of users that we need to query keys for.
///
/// Returns a hash set of users that need to be queried for keys.
pub fn users_for_key_query(&self) -> HashSet<UserId> {
self.users_for_key_query.clone()
self.store.users_for_key_query().clone()
}
}

View File

@ -28,6 +28,7 @@ pub struct MemoryStore {
sessions: SessionStore,
inbound_group_sessions: GroupSessionStore,
tracked_users: HashSet<UserId>,
users_for_key_query: HashSet<UserId>,
devices: DeviceStore,
}
@ -37,6 +38,7 @@ impl MemoryStore {
sessions: SessionStore::new(),
inbound_group_sessions: GroupSessionStore::new(),
tracked_users: HashSet::new(),
users_for_key_query: HashSet::new(),
devices: DeviceStore::new(),
}
}
@ -83,7 +85,17 @@ impl CryptoStore for MemoryStore {
&self.tracked_users
}
async fn add_user_for_tracking(&mut self, user: &UserId) -> Result<bool> {
fn users_for_key_query(&self) -> &HashSet<UserId> {
&self.users_for_key_query
}
async fn update_tracked_user(&mut self, user: &UserId, dirty: bool) -> Result<bool> {
if dirty {
self.users_for_key_query.insert(user.clone());
} else {
self.users_for_key_query.remove(user);
}
Ok(self.tracked_users.insert(user.clone()))
}
@ -207,8 +219,14 @@ mod test {
let device = get_device();
let mut store = MemoryStore::new();
assert!(store.add_user_for_tracking(device.user_id()).await.unwrap());
assert!(!store.add_user_for_tracking(device.user_id()).await.unwrap());
assert!(store
.update_tracked_user(device.user_id(), false)
.await
.unwrap());
assert!(!store
.update_tracked_user(device.user_id(), false)
.await
.unwrap());
let tracked_users = store.tracked_users();

View File

@ -138,6 +138,10 @@ pub trait CryptoStore: Debug + Send + Sync {
/// Get the set of tracked users.
fn tracked_users(&self) -> &HashSet<UserId>;
/// Set of users that we need to query keys for. This is a subset of
/// the tracked users.
fn users_for_key_query(&self) -> &HashSet<UserId>;
/// Add an user for tracking.
///
/// Returns true if the user wasn't already tracked, false otherwise.
@ -145,7 +149,9 @@ pub trait CryptoStore: Debug + Send + Sync {
/// # Arguments
///
/// * `user` - The user that should be marked as tracked.
async fn add_user_for_tracking(&mut self, user: &UserId) -> Result<bool>;
///
/// * `dirty` - Should the user be also marked for a key query.
async fn update_tracked_user(&mut self, user: &UserId, dirty: bool) -> Result<bool>;
/// Save the given devices in the store.
///

View File

@ -45,6 +45,7 @@ pub struct SqliteStore {
inbound_group_sessions: GroupSessionStore,
devices: DeviceStore,
tracked_users: HashSet<UserId>,
users_for_key_query: HashSet<UserId>,
connection: Arc<Mutex<SqliteConnection>>,
pickle_passphrase: Option<Zeroizing<String>>,
@ -121,6 +122,7 @@ impl SqliteStore {
connection: Arc::new(Mutex::new(connection)),
pickle_passphrase: passphrase,
tracked_users: HashSet::new(),
users_for_key_query: HashSet::new(),
};
store.create_tables().await?;
Ok(store)
@ -169,6 +171,7 @@ impl SqliteStore {
"id" INTEGER NOT NULL PRIMARY KEY,
"account_id" INTEGER NOT NULL,
"user_id" TEXT NOT NULL,
"dirty" INTEGER NOT NULL,
FOREIGN KEY ("account_id") REFERENCES "accounts" ("id")
ON DELETE CASCADE
UNIQUE(account_id,user_id)
@ -347,30 +350,33 @@ impl SqliteStore {
.collect::<Result<Vec<InboundGroupSession>>>()?)
}
async fn save_tracked_user(&self, user: &UserId) -> Result<()> {
async fn save_tracked_user(&self, user: &UserId, dirty: bool) -> Result<()> {
let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?;
let mut connection = self.connection.lock().await;
query(
"INSERT OR IGNORE INTO tracked_users (
account_id, user_id
) VALUES (?1, ?2)
"INSERT INTO tracked_users (
account_id, user_id, dirty
) VALUES (?1, ?2, ?3)
ON CONFLICT(account_id, user_id) DO UPDATE SET
dirty = excluded.dirty
",
)
.bind(account_id)
.bind(user.to_string())
.bind(dirty)
.execute(&mut *connection)
.await?;
Ok(())
}
async fn load_tracked_users(&self) -> Result<HashSet<UserId>> {
async fn load_tracked_users(&self) -> Result<(HashSet<UserId>, HashSet<UserId>)> {
let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?;
let mut connection = self.connection.lock().await;
let rows: Vec<(String,)> = query_as(
"SELECT user_id
let rows: Vec<(String, bool)> = query_as(
"SELECT user_id, dirty
FROM tracked_users WHERE account_id = ?",
)
.bind(account_id)
@ -378,18 +384,23 @@ impl SqliteStore {
.await?;
let mut users = HashSet::new();
let mut users_for_query = HashSet::new();
for row in rows {
let user_id: &str = &row.0;
let dirty: bool = row.1;
if let Ok(u) = UserId::try_from(user_id) {
users.insert(u);
users.insert(u.clone());
if dirty {
users_for_query.insert(u);
}
} else {
continue;
};
}
Ok(users)
Ok((users, users_for_query))
}
async fn load_devices(&self) -> Result<DeviceStore> {
@ -582,8 +593,9 @@ impl CryptoStore for SqliteStore {
let devices = self.load_devices().await?;
mem::replace(&mut self.devices, devices);
let tracked_users = self.load_tracked_users().await?;
let (tracked_users, users_for_query) = self.load_tracked_users().await?;
mem::replace(&mut self.tracked_users, tracked_users);
mem::replace(&mut self.users_for_key_query, users_for_query);
Ok(result)
}
@ -699,9 +711,22 @@ impl CryptoStore for SqliteStore {
&self.tracked_users
}
async fn add_user_for_tracking(&mut self, user: &UserId) -> Result<bool> {
self.save_tracked_user(user).await?;
Ok(self.tracked_users.insert(user.clone()))
fn users_for_key_query(&self) -> &HashSet<UserId> {
&self.users_for_key_query
}
async fn update_tracked_user(&mut self, user: &UserId, dirty: bool) -> Result<bool> {
let already_added = self.tracked_users.insert(user.clone());
if dirty {
self.users_for_key_query.insert(user.clone());
} else {
self.users_for_key_query.remove(user);
}
self.save_tracked_user(user, dirty).await?;
Ok(already_added)
}
async fn save_devices(&self, devices: &[Device]) -> Result<()> {
@ -1040,12 +1065,24 @@ mod test {
let (_account, mut store, dir) = get_loaded_store().await;
let device = get_device();
assert!(store.add_user_for_tracking(device.user_id()).await.unwrap());
assert!(!store.add_user_for_tracking(device.user_id()).await.unwrap());
assert!(store
.update_tracked_user(device.user_id(), false)
.await
.unwrap());
assert!(!store
.update_tracked_user(device.user_id(), false)
.await
.unwrap());
let tracked_users = store.tracked_users();
assert!(tracked_users.contains(device.user_id()));
assert!(!store.users_for_key_query().contains(device.user_id()));
assert!(!store
.update_tracked_user(device.user_id(), true)
.await
.unwrap());
assert!(store.users_for_key_query().contains(device.user_id()));
drop(store);
let mut store =
@ -1057,6 +1094,22 @@ mod test {
let tracked_users = store.tracked_users();
assert!(tracked_users.contains(device.user_id()));
assert!(store.users_for_key_query().contains(device.user_id()));
store
.update_tracked_user(device.user_id(), false)
.await
.unwrap();
assert!(!store.users_for_key_query().contains(device.user_id()));
let mut store =
SqliteStore::open(&UserId::try_from(USER_ID).unwrap(), DEVICE_ID, dir.path())
.await
.expect("Can't create store");
store.load_account().await.unwrap();
assert!(!store.users_for_key_query().contains(device.user_id()));
}
#[tokio::test]