crypto: Remove most mutable self borrows from the crypto-store trait.

master
Damir Jelić 2020-08-11 14:34:42 +02:00
parent ac2469d270
commit 01bcbaf063
5 changed files with 62 additions and 57 deletions

View File

@ -1268,14 +1268,14 @@ impl OlmMachine {
/// Should the client perform a key query request. /// Should the client perform a key query request.
pub async fn should_query_keys(&self) -> bool { pub async fn should_query_keys(&self) -> bool {
!self.store.read().await.users_for_key_query().is_empty() self.store.read().await.has_users_for_key_query()
} }
/// Get the set of users that we need to query keys for. /// Get the set of users that we need to query keys for.
/// ///
/// Returns a hash set of users that need to be queried for keys. /// Returns a hash set of users that need to be queried for keys.
pub async fn users_for_key_query(&self) -> HashSet<UserId> { pub async fn users_for_key_query(&self) -> HashSet<UserId> {
self.store.read().await.users_for_key_query().clone() self.store.read().await.users_for_key_query()
} }
} }

View File

@ -43,7 +43,7 @@ impl SessionStore {
/// ///
/// Returns true if the the session was added, false if the session was /// Returns true if the the session was added, false if the session was
/// already in the store. /// already in the store.
pub async fn add(&mut self, session: Session) -> bool { pub async fn add(&self, session: Session) -> bool {
if !self.entries.contains_key(&*session.sender_key) { if !self.entries.contains_key(&*session.sender_key) {
self.entries.insert( self.entries.insert(
session.sender_key.to_string(), session.sender_key.to_string(),
@ -67,7 +67,7 @@ impl SessionStore {
} }
/// Add a list of sessions belonging to the sender key. /// Add a list of sessions belonging to the sender key.
pub fn set_for_sender(&mut self, sender_key: &str, sessions: Vec<Session>) { pub fn set_for_sender(&self, sender_key: &str, sessions: Vec<Session>) {
self.entries self.entries
.insert(sender_key.to_owned(), Arc::new(Mutex::new(sessions))); .insert(sender_key.to_owned(), Arc::new(Mutex::new(sessions)));
} }
@ -92,7 +92,7 @@ impl GroupSessionStore {
/// ///
/// Returns true if the the session was added, false if the session was /// Returns true if the the session was added, false if the session was
/// already in the store. /// already in the store.
pub fn add(&mut self, session: InboundGroupSession) -> bool { pub fn add(&self, session: InboundGroupSession) -> bool {
if !self.entries.contains_key(&session.room_id) { if !self.entries.contains_key(&session.room_id) {
let room_id = &*session.room_id; let room_id = &*session.room_id;
self.entries.insert(room_id.clone(), HashMap::new()); self.entries.insert(room_id.clone(), HashMap::new());
@ -225,7 +225,7 @@ mod test {
async fn test_session_store() { async fn test_session_store() {
let (_, session) = get_account_and_session().await; let (_, session) = get_account_and_session().await;
let mut store = SessionStore::new(); let store = SessionStore::new();
assert!(store.add(session.clone()).await); assert!(store.add(session.clone()).await);
assert!(!store.add(session.clone()).await); assert!(!store.add(session.clone()).await);
@ -242,7 +242,7 @@ mod test {
async fn test_session_store_bulk_storing() { async fn test_session_store_bulk_storing() {
let (_, session) = get_account_and_session().await; let (_, session) = get_account_and_session().await;
let mut store = SessionStore::new(); let store = SessionStore::new();
store.set_for_sender(&session.sender_key, vec![session.clone()]); store.set_for_sender(&session.sender_key, vec![session.clone()]);
let sessions = store.get(&session.sender_key).unwrap(); let sessions = store.get(&session.sender_key).unwrap();
@ -273,7 +273,7 @@ mod test {
) )
.unwrap(); .unwrap();
let mut store = GroupSessionStore::new(); let store = GroupSessionStore::new();
store.add(inbound.clone()); store.add(inbound.clone());
let loaded_session = store let loaded_session = store

View File

@ -15,6 +15,7 @@
use std::{collections::HashSet, sync::Arc}; use std::{collections::HashSet, sync::Arc};
use async_trait::async_trait; use async_trait::async_trait;
use dashmap::DashSet;
use matrix_sdk_common::{ use matrix_sdk_common::{
identifiers::{DeviceId, RoomId, UserId}, identifiers::{DeviceId, RoomId, UserId},
locks::Mutex, locks::Mutex,
@ -29,8 +30,8 @@ use crate::{
pub struct MemoryStore { pub struct MemoryStore {
sessions: SessionStore, sessions: SessionStore,
inbound_group_sessions: GroupSessionStore, inbound_group_sessions: GroupSessionStore,
tracked_users: HashSet<UserId>, tracked_users: DashSet<UserId>,
users_for_key_query: HashSet<UserId>, users_for_key_query: DashSet<UserId>,
devices: DeviceStore, devices: DeviceStore,
} }
@ -39,8 +40,8 @@ impl MemoryStore {
MemoryStore { MemoryStore {
sessions: SessionStore::new(), sessions: SessionStore::new(),
inbound_group_sessions: GroupSessionStore::new(), inbound_group_sessions: GroupSessionStore::new(),
tracked_users: HashSet::new(), tracked_users: DashSet::new(),
users_for_key_query: HashSet::new(), users_for_key_query: DashSet::new(),
devices: DeviceStore::new(), devices: DeviceStore::new(),
} }
} }
@ -56,7 +57,7 @@ impl CryptoStore for MemoryStore {
Ok(()) Ok(())
} }
async fn save_sessions(&mut self, sessions: &[Session]) -> Result<()> { async fn save_sessions(&self, sessions: &[Session]) -> Result<()> {
for session in sessions { for session in sessions {
let _ = self.sessions.add(session.clone()).await; let _ = self.sessions.add(session.clone()).await;
} }
@ -64,16 +65,16 @@ impl CryptoStore for MemoryStore {
Ok(()) Ok(())
} }
async fn get_sessions(&mut self, sender_key: &str) -> Result<Option<Arc<Mutex<Vec<Session>>>>> { async fn get_sessions(&self, sender_key: &str) -> Result<Option<Arc<Mutex<Vec<Session>>>>> {
Ok(self.sessions.get(sender_key)) Ok(self.sessions.get(sender_key))
} }
async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result<bool> { async fn save_inbound_group_session(&self, session: InboundGroupSession) -> Result<bool> {
Ok(self.inbound_group_sessions.add(session)) Ok(self.inbound_group_sessions.add(session))
} }
async fn get_inbound_group_session( async fn get_inbound_group_session(
&mut self, &self,
room_id: &RoomId, room_id: &RoomId,
sender_key: &str, sender_key: &str,
session_id: &str, session_id: &str,
@ -83,15 +84,19 @@ impl CryptoStore for MemoryStore {
.get(room_id, sender_key, session_id)) .get(room_id, sender_key, session_id))
} }
fn users_for_key_query(&self) -> &HashSet<UserId> { fn users_for_key_query(&self) -> HashSet<UserId> {
&self.users_for_key_query self.users_for_key_query.iter().map(|u| u.clone()).collect()
} }
fn is_user_tracked(&self, user_id: &UserId) -> bool { fn is_user_tracked(&self, user_id: &UserId) -> bool {
self.tracked_users.contains(user_id) self.tracked_users.contains(user_id)
} }
async fn update_tracked_user(&mut self, user: &UserId, dirty: bool) -> Result<bool> { fn has_users_for_key_query(&self) -> bool {
!self.users_for_key_query.is_empty()
}
async fn update_tracked_user(&self, user: &UserId, dirty: bool) -> Result<bool> {
if dirty { if dirty {
self.users_for_key_query.insert(user.clone()); self.users_for_key_query.insert(user.clone());
} else { } else {
@ -168,7 +173,7 @@ mod test {
) )
.unwrap(); .unwrap();
let mut store = MemoryStore::new(); let store = MemoryStore::new();
let _ = store let _ = store
.save_inbound_group_session(inbound.clone()) .save_inbound_group_session(inbound.clone())
.await .await
@ -217,7 +222,7 @@ mod test {
#[tokio::test] #[tokio::test]
async fn test_tracked_users() { async fn test_tracked_users() {
let device = get_device(); let device = get_device();
let mut store = MemoryStore::new(); let store = MemoryStore::new();
assert!(store assert!(store
.update_tracked_user(device.user_id(), false) .update_tracked_user(device.user_id(), false)

View File

@ -109,14 +109,14 @@ pub trait CryptoStore: Debug {
/// # Arguments /// # Arguments
/// ///
/// * `session` - The sessions that should be stored. /// * `session` - The sessions that should be stored.
async fn save_sessions(&mut self, session: &[Session]) -> Result<()>; async fn save_sessions(&self, session: &[Session]) -> Result<()>;
/// Get all the sessions that belong to the given sender key. /// Get all the sessions that belong to the given sender key.
/// ///
/// # Arguments /// # Arguments
/// ///
/// * `sender_key` - The sender key that was used to establish the sessions. /// * `sender_key` - The sender key that was used to establish the sessions.
async fn get_sessions(&mut self, sender_key: &str) -> Result<Option<Arc<Mutex<Vec<Session>>>>>; async fn get_sessions(&self, sender_key: &str) -> Result<Option<Arc<Mutex<Vec<Session>>>>>;
/// Save the given inbound group session in the store. /// Save the given inbound group session in the store.
/// ///
@ -126,7 +126,7 @@ pub trait CryptoStore: Debug {
/// # Arguments /// # Arguments
/// ///
/// * `session` - The session that should be stored. /// * `session` - The session that should be stored.
async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result<bool>; async fn save_inbound_group_session(&self, session: InboundGroupSession) -> Result<bool>;
/// Get the inbound group session from our store. /// Get the inbound group session from our store.
/// ///
@ -137,7 +137,7 @@ pub trait CryptoStore: Debug {
/// ///
/// * `session_id` - The unique id of the session. /// * `session_id` - The unique id of the session.
async fn get_inbound_group_session( async fn get_inbound_group_session(
&mut self, &self,
room_id: &RoomId, room_id: &RoomId,
sender_key: &str, sender_key: &str,
session_id: &str, session_id: &str,
@ -146,9 +146,12 @@ pub trait CryptoStore: Debug {
/// Is the given user already tracked. /// Is the given user already tracked.
fn is_user_tracked(&self, user_id: &UserId) -> bool; fn is_user_tracked(&self, user_id: &UserId) -> bool;
/// Are there any tracked users that are marked as dirty.
fn has_users_for_key_query(&self) -> bool;
/// Set of users that we need to query keys for. This is a subset of /// Set of users that we need to query keys for. This is a subset of
/// the tracked users. /// the tracked users.
fn users_for_key_query(&self) -> &HashSet<UserId>; fn users_for_key_query(&self) -> HashSet<UserId>;
/// Add an user for tracking. /// Add an user for tracking.
/// ///
@ -159,7 +162,7 @@ pub trait CryptoStore: Debug {
/// * `user` - The user that should be marked as tracked. /// * `user` - The user that should be marked as tracked.
/// ///
/// * `dirty` - Should the user be also marked for a key query. /// * `dirty` - Should the user be also marked for a key query.
async fn update_tracked_user(&mut self, user: &UserId, dirty: bool) -> Result<bool>; async fn update_tracked_user(&self, user: &UserId, dirty: bool) -> Result<bool>;
/// Save the given devices in the store. /// Save the given devices in the store.
/// ///

View File

@ -21,6 +21,7 @@ use std::{
}; };
use async_trait::async_trait; use async_trait::async_trait;
use dashmap::DashSet;
use matrix_sdk_common::{ use matrix_sdk_common::{
events::Algorithm, events::Algorithm,
identifiers::{DeviceId, DeviceKeyAlgorithm, DeviceKeyId, RoomId, UserId}, identifiers::{DeviceId, DeviceKeyAlgorithm, DeviceKeyId, RoomId, UserId},
@ -49,8 +50,8 @@ pub struct SqliteStore {
sessions: SessionStore, sessions: SessionStore,
inbound_group_sessions: GroupSessionStore, inbound_group_sessions: GroupSessionStore,
devices: DeviceStore, devices: DeviceStore,
tracked_users: HashSet<UserId>, tracked_users: DashSet<UserId>,
users_for_key_query: HashSet<UserId>, users_for_key_query: DashSet<UserId>,
connection: Arc<Mutex<SqliteConnection>>, connection: Arc<Mutex<SqliteConnection>>,
pickle_passphrase: Option<Zeroizing<String>>, pickle_passphrase: Option<Zeroizing<String>>,
@ -137,8 +138,8 @@ impl SqliteStore {
path: path.as_ref().to_owned(), path: path.as_ref().to_owned(),
connection: Arc::new(Mutex::new(connection)), connection: Arc::new(Mutex::new(connection)),
pickle_passphrase: passphrase, pickle_passphrase: passphrase,
tracked_users: HashSet::new(), tracked_users: DashSet::new(),
users_for_key_query: HashSet::new(), users_for_key_query: DashSet::new(),
}; };
store.create_tables().await?; store.create_tables().await?;
Ok(store) Ok(store)
@ -299,7 +300,7 @@ impl SqliteStore {
Ok(()) Ok(())
} }
async fn lazy_load_sessions(&mut self, sender_key: &str) -> Result<()> { async fn lazy_load_sessions(&self, sender_key: &str) -> Result<()> {
let loaded_sessions = self.sessions.get(sender_key).is_some(); let loaded_sessions = self.sessions.get(sender_key).is_some();
if !loaded_sessions { if !loaded_sessions {
@ -313,15 +314,12 @@ impl SqliteStore {
Ok(()) Ok(())
} }
async fn get_sessions_for( async fn get_sessions_for(&self, sender_key: &str) -> Result<Option<Arc<Mutex<Vec<Session>>>>> {
&mut self,
sender_key: &str,
) -> Result<Option<Arc<Mutex<Vec<Session>>>>> {
self.lazy_load_sessions(sender_key).await?; self.lazy_load_sessions(sender_key).await?;
Ok(self.sessions.get(sender_key)) Ok(self.sessions.get(sender_key))
} }
async fn load_sessions_for(&mut self, sender_key: &str) -> Result<Vec<Session>> { async fn load_sessions_for(&self, sender_key: &str) -> Result<Vec<Session>> {
let account_info = self let account_info = self
.account_info .account_info
.as_ref() .as_ref()
@ -417,7 +415,7 @@ impl SqliteStore {
Ok(()) Ok(())
} }
async fn load_tracked_users(&self) -> Result<(HashSet<UserId>, HashSet<UserId>)> { async fn load_tracked_users(&self) -> Result<()> {
let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?;
let mut connection = self.connection.lock().await; let mut connection = self.connection.lock().await;
@ -429,24 +427,21 @@ impl SqliteStore {
.fetch_all(&mut *connection) .fetch_all(&mut *connection)
.await?; .await?;
let mut users = HashSet::new();
let mut users_for_query = HashSet::new();
for row in rows { for row in rows {
let user_id: &str = &row.0; let user_id: &str = &row.0;
let dirty: bool = row.1; let dirty: bool = row.1;
if let Ok(u) = UserId::try_from(user_id) { if let Ok(u) = UserId::try_from(user_id) {
users.insert(u.clone()); self.tracked_users.insert(u.clone());
if dirty { if dirty {
users_for_query.insert(u); self.users_for_key_query.insert(u);
} }
} else { } else {
continue; continue;
}; };
} }
Ok((users, users_for_query)) Ok(())
} }
async fn load_devices(&self) -> Result<DeviceStore> { async fn load_devices(&self) -> Result<DeviceStore> {
@ -700,9 +695,7 @@ impl CryptoStore for SqliteStore {
let devices = self.load_devices().await?; let devices = self.load_devices().await?;
self.devices = devices; self.devices = devices;
let (tracked_users, users_for_query) = self.load_tracked_users().await?; self.load_tracked_users().await?;
self.tracked_users = tracked_users;
self.users_for_key_query = users_for_query;
Ok(result) Ok(result)
} }
@ -743,7 +736,7 @@ impl CryptoStore for SqliteStore {
Ok(()) Ok(())
} }
async fn save_sessions(&mut self, sessions: &[Session]) -> Result<()> { async fn save_sessions(&self, sessions: &[Session]) -> Result<()> {
let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?;
// TODO turn this into a transaction // TODO turn this into a transaction
@ -776,11 +769,11 @@ impl CryptoStore for SqliteStore {
Ok(()) Ok(())
} }
async fn get_sessions(&mut self, sender_key: &str) -> Result<Option<Arc<Mutex<Vec<Session>>>>> { async fn get_sessions(&self, sender_key: &str) -> Result<Option<Arc<Mutex<Vec<Session>>>>> {
Ok(self.get_sessions_for(sender_key).await?) Ok(self.get_sessions_for(sender_key).await?)
} }
async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result<bool> { async fn save_inbound_group_session(&self, session: InboundGroupSession) -> Result<bool> {
let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?;
let pickle = session.pickle(self.get_pickle_mode()).await; let pickle = session.pickle(self.get_pickle_mode()).await;
let mut connection = self.connection.lock().await; let mut connection = self.connection.lock().await;
@ -808,7 +801,7 @@ impl CryptoStore for SqliteStore {
} }
async fn get_inbound_group_session( async fn get_inbound_group_session(
&mut self, &self,
room_id: &RoomId, room_id: &RoomId,
sender_key: &str, sender_key: &str,
session_id: &str, session_id: &str,
@ -822,11 +815,15 @@ impl CryptoStore for SqliteStore {
self.tracked_users.contains(user_id) self.tracked_users.contains(user_id)
} }
fn users_for_key_query(&self) -> &HashSet<UserId> { fn has_users_for_key_query(&self) -> bool {
&self.users_for_key_query !self.users_for_key_query.is_empty()
} }
async fn update_tracked_user(&mut self, user: &UserId, dirty: bool) -> Result<bool> { fn users_for_key_query(&self) -> HashSet<UserId> {
self.users_for_key_query.iter().map(|u| u.clone()).collect()
}
async fn update_tracked_user(&self, user: &UserId, dirty: bool) -> Result<bool> {
let already_added = self.tracked_users.insert(user.clone()); let already_added = self.tracked_users.insert(user.clone());
if dirty { if dirty {
@ -1135,7 +1132,7 @@ mod test {
#[tokio::test] #[tokio::test]
async fn save_inbound_group_session() { async fn save_inbound_group_session() {
let (account, mut store, _dir) = get_loaded_store().await; let (account, store, _dir) = get_loaded_store().await;
let identity_keys = account.identity_keys(); let identity_keys = account.identity_keys();
let outbound_session = OlmOutboundGroupSession::new(); let outbound_session = OlmOutboundGroupSession::new();
@ -1155,7 +1152,7 @@ mod test {
#[tokio::test] #[tokio::test]
async fn load_inbound_group_session() { async fn load_inbound_group_session() {
let (account, mut store, _dir) = get_loaded_store().await; let (account, store, _dir) = get_loaded_store().await;
let identity_keys = account.identity_keys(); let identity_keys = account.identity_keys();
let outbound_session = OlmOutboundGroupSession::new(); let outbound_session = OlmOutboundGroupSession::new();
@ -1188,7 +1185,7 @@ mod test {
#[tokio::test] #[tokio::test]
async fn test_tracked_users() { async fn test_tracked_users() {
let (_account, mut store, dir) = get_loaded_store().await; let (_account, store, dir) = get_loaded_store().await;
let device = get_device(); let device = get_device();
assert!(store assert!(store