crypto: Move the users for key query map into the store.
parent
e51e89d9d5
commit
1d9fccdc9f
|
@ -80,9 +80,6 @@ pub struct OlmMachine {
|
||||||
/// Persists all the encryption keys so a client can resume the session
|
/// Persists all the encryption keys so a client can resume the session
|
||||||
/// without the need to create new keys.
|
/// without the need to create new keys.
|
||||||
store: Box<dyn CryptoStore>,
|
store: Box<dyn CryptoStore>,
|
||||||
/// Set of users that we need to query keys for. This is a subset of
|
|
||||||
/// the tracked users in the CryptoStore.
|
|
||||||
users_for_key_query: HashSet<UserId>,
|
|
||||||
/// The currently active outbound group sessions.
|
/// The currently active outbound group sessions.
|
||||||
outbound_group_sessions: HashMap<RoomId, OutboundGroupSession>,
|
outbound_group_sessions: HashMap<RoomId, OutboundGroupSession>,
|
||||||
}
|
}
|
||||||
|
@ -122,7 +119,6 @@ impl OlmMachine {
|
||||||
account: Account::new(),
|
account: Account::new(),
|
||||||
uploaded_signed_key_count: None,
|
uploaded_signed_key_count: None,
|
||||||
store: Box::new(MemoryStore::new()),
|
store: Box::new(MemoryStore::new()),
|
||||||
users_for_key_query: HashSet::new(),
|
|
||||||
outbound_group_sessions: HashMap::new(),
|
outbound_group_sessions: HashMap::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -166,7 +162,6 @@ impl OlmMachine {
|
||||||
account,
|
account,
|
||||||
uploaded_signed_key_count: None,
|
uploaded_signed_key_count: None,
|
||||||
store: Box::new(store),
|
store: Box::new(store),
|
||||||
users_for_key_query: HashSet::new(),
|
|
||||||
outbound_group_sessions: HashMap::new(),
|
outbound_group_sessions: HashMap::new(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -461,7 +456,7 @@ impl OlmMachine {
|
||||||
let mut changed_devices = Vec::new();
|
let mut changed_devices = Vec::new();
|
||||||
|
|
||||||
for (user_id, device_map) in &response.device_keys {
|
for (user_id, device_map) in &response.device_keys {
|
||||||
self.users_for_key_query.remove(&user_id);
|
self.store.update_tracked_user(user_id, false).await?;
|
||||||
|
|
||||||
for (device_id, device_keys) in device_map.iter() {
|
for (device_id, device_keys) in device_map.iter() {
|
||||||
// We don't need our own device in the device store.
|
// We don't need our own device in the device store.
|
||||||
|
@ -1516,12 +1511,12 @@ impl OlmMachine {
|
||||||
/// key query.
|
/// key query.
|
||||||
///
|
///
|
||||||
/// Returns true if the user was queued up for a key query, false otherwise.
|
/// Returns true if the user was queued up for a key query, false otherwise.
|
||||||
pub async fn mark_user_as_changed(&mut self, user_id: &UserId) -> bool {
|
pub async fn mark_user_as_changed(&mut self, user_id: &UserId) -> StoreError<bool> {
|
||||||
if self.store.tracked_users().contains(user_id) {
|
if self.store.tracked_users().contains(user_id) {
|
||||||
self.users_for_key_query.insert(user_id.clone());
|
self.store.update_tracked_user(user_id, true).await?;
|
||||||
true
|
Ok(true)
|
||||||
} else {
|
} else {
|
||||||
false
|
Ok(false)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1544,32 +1539,26 @@ impl OlmMachine {
|
||||||
I: IntoIterator<Item = &'a UserId>,
|
I: IntoIterator<Item = &'a UserId>,
|
||||||
{
|
{
|
||||||
for user in users {
|
for user in users {
|
||||||
let ret = self.store.add_user_for_tracking(user).await;
|
if self.store.tracked_users().contains(user) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
match ret {
|
if let Err(e) = self.store.update_tracked_user(user, true).await {
|
||||||
Ok(newly_added) => {
|
warn!("Error storing users for tracking {}", e);
|
||||||
if newly_added {
|
|
||||||
self.mark_user_as_changed(user).await;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
warn!("Error storing users for tracking {}", e);
|
|
||||||
self.users_for_key_query.insert(user.clone());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Should the client perform a key query request.
|
/// Should the client perform a key query request.
|
||||||
pub fn should_query_keys(&self) -> bool {
|
pub fn should_query_keys(&self) -> bool {
|
||||||
!self.users_for_key_query.is_empty()
|
!self.store.users_for_key_query().is_empty()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get the set of users that we need to query keys for.
|
/// Get the set of users that we need to query keys for.
|
||||||
///
|
///
|
||||||
/// Returns a hash set of users that need to be queried for keys.
|
/// Returns a hash set of users that need to be queried for keys.
|
||||||
pub fn users_for_key_query(&self) -> HashSet<UserId> {
|
pub fn users_for_key_query(&self) -> HashSet<UserId> {
|
||||||
self.users_for_key_query.clone()
|
self.store.users_for_key_query().clone()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -28,6 +28,7 @@ pub struct MemoryStore {
|
||||||
sessions: SessionStore,
|
sessions: SessionStore,
|
||||||
inbound_group_sessions: GroupSessionStore,
|
inbound_group_sessions: GroupSessionStore,
|
||||||
tracked_users: HashSet<UserId>,
|
tracked_users: HashSet<UserId>,
|
||||||
|
users_for_key_query: HashSet<UserId>,
|
||||||
devices: DeviceStore,
|
devices: DeviceStore,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -37,6 +38,7 @@ impl MemoryStore {
|
||||||
sessions: SessionStore::new(),
|
sessions: SessionStore::new(),
|
||||||
inbound_group_sessions: GroupSessionStore::new(),
|
inbound_group_sessions: GroupSessionStore::new(),
|
||||||
tracked_users: HashSet::new(),
|
tracked_users: HashSet::new(),
|
||||||
|
users_for_key_query: HashSet::new(),
|
||||||
devices: DeviceStore::new(),
|
devices: DeviceStore::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -83,7 +85,17 @@ impl CryptoStore for MemoryStore {
|
||||||
&self.tracked_users
|
&self.tracked_users
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn add_user_for_tracking(&mut self, user: &UserId) -> Result<bool> {
|
fn users_for_key_query(&self) -> &HashSet<UserId> {
|
||||||
|
&self.users_for_key_query
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn update_tracked_user(&mut self, user: &UserId, dirty: bool) -> Result<bool> {
|
||||||
|
if dirty {
|
||||||
|
self.users_for_key_query.insert(user.clone());
|
||||||
|
} else {
|
||||||
|
self.users_for_key_query.remove(user);
|
||||||
|
}
|
||||||
|
|
||||||
Ok(self.tracked_users.insert(user.clone()))
|
Ok(self.tracked_users.insert(user.clone()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -207,8 +219,14 @@ mod test {
|
||||||
let device = get_device();
|
let device = get_device();
|
||||||
let mut store = MemoryStore::new();
|
let mut store = MemoryStore::new();
|
||||||
|
|
||||||
assert!(store.add_user_for_tracking(device.user_id()).await.unwrap());
|
assert!(store
|
||||||
assert!(!store.add_user_for_tracking(device.user_id()).await.unwrap());
|
.update_tracked_user(device.user_id(), false)
|
||||||
|
.await
|
||||||
|
.unwrap());
|
||||||
|
assert!(!store
|
||||||
|
.update_tracked_user(device.user_id(), false)
|
||||||
|
.await
|
||||||
|
.unwrap());
|
||||||
|
|
||||||
let tracked_users = store.tracked_users();
|
let tracked_users = store.tracked_users();
|
||||||
|
|
||||||
|
|
|
@ -138,6 +138,10 @@ pub trait CryptoStore: Debug + Send + Sync {
|
||||||
/// Get the set of tracked users.
|
/// Get the set of tracked users.
|
||||||
fn tracked_users(&self) -> &HashSet<UserId>;
|
fn tracked_users(&self) -> &HashSet<UserId>;
|
||||||
|
|
||||||
|
/// Set of users that we need to query keys for. This is a subset of
|
||||||
|
/// the tracked users.
|
||||||
|
fn users_for_key_query(&self) -> &HashSet<UserId>;
|
||||||
|
|
||||||
/// Add an user for tracking.
|
/// Add an user for tracking.
|
||||||
///
|
///
|
||||||
/// Returns true if the user wasn't already tracked, false otherwise.
|
/// Returns true if the user wasn't already tracked, false otherwise.
|
||||||
|
@ -145,7 +149,9 @@ pub trait CryptoStore: Debug + Send + Sync {
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
///
|
///
|
||||||
/// * `user` - The user that should be marked as tracked.
|
/// * `user` - The user that should be marked as tracked.
|
||||||
async fn add_user_for_tracking(&mut self, user: &UserId) -> Result<bool>;
|
///
|
||||||
|
/// * `dirty` - Should the user be also marked for a key query.
|
||||||
|
async fn update_tracked_user(&mut self, user: &UserId, dirty: bool) -> Result<bool>;
|
||||||
|
|
||||||
/// Save the given devices in the store.
|
/// Save the given devices in the store.
|
||||||
///
|
///
|
||||||
|
|
|
@ -45,6 +45,7 @@ pub struct SqliteStore {
|
||||||
inbound_group_sessions: GroupSessionStore,
|
inbound_group_sessions: GroupSessionStore,
|
||||||
devices: DeviceStore,
|
devices: DeviceStore,
|
||||||
tracked_users: HashSet<UserId>,
|
tracked_users: HashSet<UserId>,
|
||||||
|
users_for_key_query: HashSet<UserId>,
|
||||||
|
|
||||||
connection: Arc<Mutex<SqliteConnection>>,
|
connection: Arc<Mutex<SqliteConnection>>,
|
||||||
pickle_passphrase: Option<Zeroizing<String>>,
|
pickle_passphrase: Option<Zeroizing<String>>,
|
||||||
|
@ -121,6 +122,7 @@ impl SqliteStore {
|
||||||
connection: Arc::new(Mutex::new(connection)),
|
connection: Arc::new(Mutex::new(connection)),
|
||||||
pickle_passphrase: passphrase,
|
pickle_passphrase: passphrase,
|
||||||
tracked_users: HashSet::new(),
|
tracked_users: HashSet::new(),
|
||||||
|
users_for_key_query: HashSet::new(),
|
||||||
};
|
};
|
||||||
store.create_tables().await?;
|
store.create_tables().await?;
|
||||||
Ok(store)
|
Ok(store)
|
||||||
|
@ -169,6 +171,7 @@ impl SqliteStore {
|
||||||
"id" INTEGER NOT NULL PRIMARY KEY,
|
"id" INTEGER NOT NULL PRIMARY KEY,
|
||||||
"account_id" INTEGER NOT NULL,
|
"account_id" INTEGER NOT NULL,
|
||||||
"user_id" TEXT NOT NULL,
|
"user_id" TEXT NOT NULL,
|
||||||
|
"dirty" INTEGER NOT NULL,
|
||||||
FOREIGN KEY ("account_id") REFERENCES "accounts" ("id")
|
FOREIGN KEY ("account_id") REFERENCES "accounts" ("id")
|
||||||
ON DELETE CASCADE
|
ON DELETE CASCADE
|
||||||
UNIQUE(account_id,user_id)
|
UNIQUE(account_id,user_id)
|
||||||
|
@ -347,30 +350,33 @@ impl SqliteStore {
|
||||||
.collect::<Result<Vec<InboundGroupSession>>>()?)
|
.collect::<Result<Vec<InboundGroupSession>>>()?)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn save_tracked_user(&self, user: &UserId) -> Result<()> {
|
async fn save_tracked_user(&self, user: &UserId, dirty: bool) -> Result<()> {
|
||||||
let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?;
|
let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?;
|
||||||
let mut connection = self.connection.lock().await;
|
let mut connection = self.connection.lock().await;
|
||||||
|
|
||||||
query(
|
query(
|
||||||
"INSERT OR IGNORE INTO tracked_users (
|
"INSERT INTO tracked_users (
|
||||||
account_id, user_id
|
account_id, user_id, dirty
|
||||||
) VALUES (?1, ?2)
|
) VALUES (?1, ?2, ?3)
|
||||||
|
ON CONFLICT(account_id, user_id) DO UPDATE SET
|
||||||
|
dirty = excluded.dirty
|
||||||
",
|
",
|
||||||
)
|
)
|
||||||
.bind(account_id)
|
.bind(account_id)
|
||||||
.bind(user.to_string())
|
.bind(user.to_string())
|
||||||
|
.bind(dirty)
|
||||||
.execute(&mut *connection)
|
.execute(&mut *connection)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn load_tracked_users(&self) -> Result<HashSet<UserId>> {
|
async fn load_tracked_users(&self) -> Result<(HashSet<UserId>, HashSet<UserId>)> {
|
||||||
let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?;
|
let account_id = self.account_id.ok_or(CryptoStoreError::AccountUnset)?;
|
||||||
let mut connection = self.connection.lock().await;
|
let mut connection = self.connection.lock().await;
|
||||||
|
|
||||||
let rows: Vec<(String,)> = query_as(
|
let rows: Vec<(String, bool)> = query_as(
|
||||||
"SELECT user_id
|
"SELECT user_id, dirty
|
||||||
FROM tracked_users WHERE account_id = ?",
|
FROM tracked_users WHERE account_id = ?",
|
||||||
)
|
)
|
||||||
.bind(account_id)
|
.bind(account_id)
|
||||||
|
@ -378,18 +384,23 @@ impl SqliteStore {
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let mut users = HashSet::new();
|
let mut users = HashSet::new();
|
||||||
|
let mut users_for_query = HashSet::new();
|
||||||
|
|
||||||
for row in rows {
|
for row in rows {
|
||||||
let user_id: &str = &row.0;
|
let user_id: &str = &row.0;
|
||||||
|
let dirty: bool = row.1;
|
||||||
|
|
||||||
if let Ok(u) = UserId::try_from(user_id) {
|
if let Ok(u) = UserId::try_from(user_id) {
|
||||||
users.insert(u);
|
users.insert(u.clone());
|
||||||
|
if dirty {
|
||||||
|
users_for_query.insert(u);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
continue;
|
continue;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(users)
|
Ok((users, users_for_query))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn load_devices(&self) -> Result<DeviceStore> {
|
async fn load_devices(&self) -> Result<DeviceStore> {
|
||||||
|
@ -582,8 +593,9 @@ impl CryptoStore for SqliteStore {
|
||||||
let devices = self.load_devices().await?;
|
let devices = self.load_devices().await?;
|
||||||
mem::replace(&mut self.devices, devices);
|
mem::replace(&mut self.devices, devices);
|
||||||
|
|
||||||
let tracked_users = self.load_tracked_users().await?;
|
let (tracked_users, users_for_query) = self.load_tracked_users().await?;
|
||||||
mem::replace(&mut self.tracked_users, tracked_users);
|
mem::replace(&mut self.tracked_users, tracked_users);
|
||||||
|
mem::replace(&mut self.users_for_key_query, users_for_query);
|
||||||
|
|
||||||
Ok(result)
|
Ok(result)
|
||||||
}
|
}
|
||||||
|
@ -699,9 +711,22 @@ impl CryptoStore for SqliteStore {
|
||||||
&self.tracked_users
|
&self.tracked_users
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn add_user_for_tracking(&mut self, user: &UserId) -> Result<bool> {
|
fn users_for_key_query(&self) -> &HashSet<UserId> {
|
||||||
self.save_tracked_user(user).await?;
|
&self.users_for_key_query
|
||||||
Ok(self.tracked_users.insert(user.clone()))
|
}
|
||||||
|
|
||||||
|
async fn update_tracked_user(&mut self, user: &UserId, dirty: bool) -> Result<bool> {
|
||||||
|
let already_added = self.tracked_users.insert(user.clone());
|
||||||
|
|
||||||
|
if dirty {
|
||||||
|
self.users_for_key_query.insert(user.clone());
|
||||||
|
} else {
|
||||||
|
self.users_for_key_query.remove(user);
|
||||||
|
}
|
||||||
|
|
||||||
|
self.save_tracked_user(user, dirty).await?;
|
||||||
|
|
||||||
|
Ok(already_added)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn save_devices(&self, devices: &[Device]) -> Result<()> {
|
async fn save_devices(&self, devices: &[Device]) -> Result<()> {
|
||||||
|
@ -1040,12 +1065,24 @@ mod test {
|
||||||
let (_account, mut store, dir) = get_loaded_store().await;
|
let (_account, mut store, dir) = get_loaded_store().await;
|
||||||
let device = get_device();
|
let device = get_device();
|
||||||
|
|
||||||
assert!(store.add_user_for_tracking(device.user_id()).await.unwrap());
|
assert!(store
|
||||||
assert!(!store.add_user_for_tracking(device.user_id()).await.unwrap());
|
.update_tracked_user(device.user_id(), false)
|
||||||
|
.await
|
||||||
|
.unwrap());
|
||||||
|
assert!(!store
|
||||||
|
.update_tracked_user(device.user_id(), false)
|
||||||
|
.await
|
||||||
|
.unwrap());
|
||||||
|
|
||||||
let tracked_users = store.tracked_users();
|
let tracked_users = store.tracked_users();
|
||||||
|
|
||||||
assert!(tracked_users.contains(device.user_id()));
|
assert!(tracked_users.contains(device.user_id()));
|
||||||
|
assert!(!store.users_for_key_query().contains(device.user_id()));
|
||||||
|
assert!(!store
|
||||||
|
.update_tracked_user(device.user_id(), true)
|
||||||
|
.await
|
||||||
|
.unwrap());
|
||||||
|
assert!(store.users_for_key_query().contains(device.user_id()));
|
||||||
drop(store);
|
drop(store);
|
||||||
|
|
||||||
let mut store =
|
let mut store =
|
||||||
|
@ -1057,6 +1094,22 @@ mod test {
|
||||||
|
|
||||||
let tracked_users = store.tracked_users();
|
let tracked_users = store.tracked_users();
|
||||||
assert!(tracked_users.contains(device.user_id()));
|
assert!(tracked_users.contains(device.user_id()));
|
||||||
|
assert!(store.users_for_key_query().contains(device.user_id()));
|
||||||
|
|
||||||
|
store
|
||||||
|
.update_tracked_user(device.user_id(), false)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert!(!store.users_for_key_query().contains(device.user_id()));
|
||||||
|
|
||||||
|
let mut store =
|
||||||
|
SqliteStore::open(&UserId::try_from(USER_ID).unwrap(), DEVICE_ID, dir.path())
|
||||||
|
.await
|
||||||
|
.expect("Can't create store");
|
||||||
|
|
||||||
|
store.load_account().await.unwrap();
|
||||||
|
|
||||||
|
assert!(!store.users_for_key_query().contains(device.user_id()));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
|
|
Loading…
Reference in New Issue