crypto: Remove most mutable self borrows from the crypto-store trait.
This commit is contained in:
parent
ac2469d270
commit
01bcbaf063
5 changed files with 62 additions and 57 deletions
|
@ -1268,14 +1268,14 @@ impl OlmMachine {
|
|||
|
||||
/// Should the client perform a key query request.
|
||||
pub async fn should_query_keys(&self) -> bool {
|
||||
!self.store.read().await.users_for_key_query().is_empty()
|
||||
self.store.read().await.has_users_for_key_query()
|
||||
}
|
||||
|
||||
/// Get the set of users that we need to query keys for.
|
||||
///
|
||||
/// Returns a hash set of users that need to be queried for keys.
|
||||
pub async fn users_for_key_query(&self) -> HashSet<UserId> {
|
||||
self.store.read().await.users_for_key_query().clone()
|
||||
self.store.read().await.users_for_key_query()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -43,7 +43,7 @@ impl SessionStore {
|
|||
///
|
||||
/// Returns true if the the session was added, false if the session was
|
||||
/// already in the store.
|
||||
pub async fn add(&mut self, session: Session) -> bool {
|
||||
pub async fn add(&self, session: Session) -> bool {
|
||||
if !self.entries.contains_key(&*session.sender_key) {
|
||||
self.entries.insert(
|
||||
session.sender_key.to_string(),
|
||||
|
@ -67,7 +67,7 @@ impl SessionStore {
|
|||
}
|
||||
|
||||
/// Add a list of sessions belonging to the sender key.
|
||||
pub fn set_for_sender(&mut self, sender_key: &str, sessions: Vec<Session>) {
|
||||
pub fn set_for_sender(&self, sender_key: &str, sessions: Vec<Session>) {
|
||||
self.entries
|
||||
.insert(sender_key.to_owned(), Arc::new(Mutex::new(sessions)));
|
||||
}
|
||||
|
@ -92,7 +92,7 @@ impl GroupSessionStore {
|
|||
///
|
||||
/// Returns true if the the session was added, false if the session was
|
||||
/// already in the store.
|
||||
pub fn add(&mut self, session: InboundGroupSession) -> bool {
|
||||
pub fn add(&self, session: InboundGroupSession) -> bool {
|
||||
if !self.entries.contains_key(&session.room_id) {
|
||||
let room_id = &*session.room_id;
|
||||
self.entries.insert(room_id.clone(), HashMap::new());
|
||||
|
@ -225,7 +225,7 @@ mod test {
|
|||
async fn test_session_store() {
|
||||
let (_, session) = get_account_and_session().await;
|
||||
|
||||
let mut store = SessionStore::new();
|
||||
let store = SessionStore::new();
|
||||
|
||||
assert!(store.add(session.clone()).await);
|
||||
assert!(!store.add(session.clone()).await);
|
||||
|
@ -242,7 +242,7 @@ mod test {
|
|||
async fn test_session_store_bulk_storing() {
|
||||
let (_, session) = get_account_and_session().await;
|
||||
|
||||
let mut store = SessionStore::new();
|
||||
let store = SessionStore::new();
|
||||
store.set_for_sender(&session.sender_key, vec![session.clone()]);
|
||||
|
||||
let sessions = store.get(&session.sender_key).unwrap();
|
||||
|
@ -273,7 +273,7 @@ mod test {
|
|||
)
|
||||
.unwrap();
|
||||
|
||||
let mut store = GroupSessionStore::new();
|
||||
let store = GroupSessionStore::new();
|
||||
store.add(inbound.clone());
|
||||
|
||||
let loaded_session = store
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
use std::{collections::HashSet, sync::Arc};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use dashmap::DashSet;
|
||||
use matrix_sdk_common::{
|
||||
identifiers::{DeviceId, RoomId, UserId},
|
||||
locks::Mutex,
|
||||
|
@ -29,8 +30,8 @@ use crate::{
|
|||
pub struct MemoryStore {
|
||||
sessions: SessionStore,
|
||||
inbound_group_sessions: GroupSessionStore,
|
||||
tracked_users: HashSet<UserId>,
|
||||
users_for_key_query: HashSet<UserId>,
|
||||
tracked_users: DashSet<UserId>,
|
||||
users_for_key_query: DashSet<UserId>,
|
||||
devices: DeviceStore,
|
||||
}
|
||||
|
||||
|
@ -39,8 +40,8 @@ impl MemoryStore {
|
|||
MemoryStore {
|
||||
sessions: SessionStore::new(),
|
||||
inbound_group_sessions: GroupSessionStore::new(),
|
||||
tracked_users: HashSet::new(),
|
||||
users_for_key_query: HashSet::new(),
|
||||
tracked_users: DashSet::new(),
|
||||
users_for_key_query: DashSet::new(),
|
||||
devices: DeviceStore::new(),
|
||||
}
|
||||
}
|
||||
|
@ -56,7 +57,7 @@ impl CryptoStore for MemoryStore {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
async fn save_sessions(&mut self, sessions: &[Session]) -> Result<()> {
|
||||
async fn save_sessions(&self, sessions: &[Session]) -> Result<()> {
|
||||
for session in sessions {
|
||||
let _ = self.sessions.add(session.clone()).await;
|
||||
}
|
||||
|
@ -64,16 +65,16 @@ impl CryptoStore for MemoryStore {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_sessions(&mut self, sender_key: &str) -> Result<Option<Arc<Mutex<Vec<Session>>>>> {
|
||||
async fn get_sessions(&self, sender_key: &str) -> Result<Option<Arc<Mutex<Vec<Session>>>>> {
|
||||
Ok(self.sessions.get(sender_key))
|
||||
}
|
||||
|
||||
async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result<bool> {
|
||||
async fn save_inbound_group_session(&self, session: InboundGroupSession) -> Result<bool> {
|
||||
Ok(self.inbound_group_sessions.add(session))
|
||||
}
|
||||
|
||||
async fn get_inbound_group_session(
|
||||
&mut self,
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
sender_key: &str,
|
||||
session_id: &str,
|
||||
|
@ -83,15 +84,19 @@ impl CryptoStore for MemoryStore {
|
|||
.get(room_id, sender_key, session_id))
|
||||
}
|
||||
|
||||
fn users_for_key_query(&self) -> &HashSet<UserId> {
|
||||
&self.users_for_key_query
|
||||
fn users_for_key_query(&self) -> HashSet<UserId> {
|
||||
self.users_for_key_query.iter().map(|u| u.clone()).collect()
|
||||
}
|
||||
|
||||
fn is_user_tracked(&self, user_id: &UserId) -> bool {
|
||||
self.tracked_users.contains(user_id)
|
||||
}
|
||||
|
||||
async fn update_tracked_user(&mut self, user: &UserId, dirty: bool) -> Result<bool> {
|
||||
fn has_users_for_key_query(&self) -> bool {
|
||||
!self.users_for_key_query.is_empty()
|
||||
}
|
||||
|
||||
async fn update_tracked_user(&self, user: &UserId, dirty: bool) -> Result<bool> {
|
||||
if dirty {
|
||||
self.users_for_key_query.insert(user.clone());
|
||||
} else {
|
||||
|
@ -168,7 +173,7 @@ mod test {
|
|||
)
|
||||
.unwrap();
|
||||
|
||||
let mut store = MemoryStore::new();
|
||||
let store = MemoryStore::new();
|
||||
let _ = store
|
||||
.save_inbound_group_session(inbound.clone())
|
||||
.await
|
||||
|
@ -217,7 +222,7 @@ mod test {
|
|||
#[tokio::test]
|
||||
async fn test_tracked_users() {
|
||||
let device = get_device();
|
||||
let mut store = MemoryStore::new();
|
||||
let store = MemoryStore::new();
|
||||
|
||||
assert!(store
|
||||
.update_tracked_user(device.user_id(), false)
|
||||
|
|
|
@ -109,14 +109,14 @@ pub trait CryptoStore: Debug {
|
|||
/// # Arguments
|
||||
///
|
||||
/// * `session` - The sessions that should be stored.
|
||||
async fn save_sessions(&mut self, session: &[Session]) -> Result<()>;
|
||||
async fn save_sessions(&self, session: &[Session]) -> Result<()>;
|
||||
|
||||
/// Get all the sessions that belong to the given sender key.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `sender_key` - The sender key that was used to establish the sessions.
|
||||
async fn get_sessions(&mut self, sender_key: &str) -> Result<Option<Arc<Mutex<Vec<Session>>>>>;
|
||||
async fn get_sessions(&self, sender_key: &str) -> Result<Option<Arc<Mutex<Vec<Session>>>>>;
|
||||
|
||||
/// Save the given inbound group session in the store.
|
||||
///
|
||||
|
@ -126,7 +126,7 @@ pub trait CryptoStore: Debug {
|
|||
/// # Arguments
|
||||
///
|
||||
/// * `session` - The session that should be stored.
|
||||
async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result<bool>;
|
||||
async fn save_inbound_group_session(&self, session: InboundGroupSession) -> Result<bool>;
|
||||
|
||||
/// Get the inbound group session from our store.
|
||||
///
|
||||
|
@ -137,7 +137,7 @@ pub trait CryptoStore: Debug {
|
|||
///
|
||||
/// * `session_id` - The unique id of the session.
|
||||
async fn get_inbound_group_session(
|
||||
&mut self,
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
sender_key: &str,
|
||||
session_id: &str,
|
||||
|
@ -146,9 +146,12 @@ pub trait CryptoStore: Debug {
|
|||
/// Is the given user already tracked.
|
||||
fn is_user_tracked(&self, user_id: &UserId) -> bool;
|
||||
|
||||
/// Are there any tracked users that are marked as dirty.
|
||||
fn has_users_for_key_query(&self) -> bool;
|
||||
|
||||
/// Set of users that we need to query keys for. This is a subset of
|
||||
/// the tracked users.
|
||||
fn users_for_key_query(&self) -> &HashSet<UserId>;
|
||||
fn users_for_key_query(&self) -> HashSet<UserId>;
|
||||
|
||||
/// Add an user for tracking.
|
||||
///
|
||||
|
@ -159,7 +162,7 @@ pub trait CryptoStore: Debug {
|
|||
/// * `user` - The user that should be marked as tracked.
|
||||
///
|
||||
/// * `dirty` - Should the user be also marked for a key query.
|
||||
async fn update_tracked_user(&mut self, user: &UserId, dirty: bool) -> Result<bool>;
|
||||
async fn update_tracked_user(&self, user: &UserId, dirty: bool) -> Result<bool>;
|
||||
|
||||
/// Save the given devices in the store.
|
||||
///
|
||||
|
|
|
@ -21,6 +21,7 @@ use std::{
|
|||
};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use dashmap::DashSet;
|
||||
use matrix_sdk_common::{
|
||||
events::Algorithm,
|
||||
identifiers::{DeviceId, DeviceKeyAlgorithm, DeviceKeyId, RoomId, UserId},
|
||||
|
@ -49,8 +50,8 @@ pub struct SqliteStore {
|
|||
sessions: SessionStore,
|
||||
inbound_group_sessions: GroupSessionStore,
|
||||
devices: DeviceStore,
|
||||
tracked_users: HashSet<UserId>,
|
||||
users_for_key_query: HashSet<UserId>,
|
||||
tracked_users: DashSet<UserId>,
|
||||
users_for_key_query: DashSet<UserId>,
|
||||
|
||||
connection: Arc<Mutex<SqliteConnection>>,
|
||||
pickle_passphrase: Option<Zeroizing<String>>,
|
||||
|
@ -137,8 +138,8 @@ impl SqliteStore {
|
|||
path: path.as_ref().to_owned(),
|
||||
connection: Arc::new(Mutex::new(connection)),
|
||||
pickle_passphrase: passphrase,
|
||||
tracked_users: HashSet::new(),
|
||||
users_for_key_query: HashSet::new(),
|
||||
tracked_users: DashSet::new(),
|
||||
users_for_key_query: DashSet::new(),
|
||||
};
|
||||
store.create_tables().await?;
|
||||
Ok(store)
|
||||
|
@ -299,7 +300,7 @@ impl SqliteStore {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
async fn lazy_load_sessions(&mut self, sender_key: &str) -> Result<()> {
|
||||
async fn lazy_load_sessions(&self, sender_key: &str) -> Result<()> {
|
||||
let loaded_sessions = self.sessions.get(sender_key).is_some();
|
||||
|
||||
if !loaded_sessions {
|
||||
|
@ -313,15 +314,12 @@ impl SqliteStore {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_sessions_for(
|
||||
&mut self,
|
||||
sender_key: &str,
|
||||
) -> Result<Option<Arc<Mutex<Vec<Session>>>>> {
|
||||
async fn get_sessions_for(&self, sender_key: &str) -> Result<Option<Arc<Mutex<Vec<Session>>>>> {
|
||||
self.lazy_load_sessions(sender_key).await?;
|
||||
Ok(self.sessions.get(sender_key))
|
||||
}
|
||||
|
||||
async fn load_sessions_for(&mut self, sender_key: &str) -> Result<Vec<Session>> {
|
||||
async fn load_sessions_for(&self, sender_key: &str) -> Result<Vec<Session>> {
|
||||
let account_info = self
|
||||
.account_info
|
||||
.as_ref()
|
||||
|
@ -417,7 +415,7 @@ impl SqliteStore {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
async fn load_tracked_users(&self) -> Result<(HashSet<UserId>, HashSet<UserId>)> {
|
||||
async fn load_tracked_users(&self) -> Result<()> {
|
||||
let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?;
|
||||
let mut connection = self.connection.lock().await;
|
||||
|
||||
|
@ -429,24 +427,21 @@ impl SqliteStore {
|
|||
.fetch_all(&mut *connection)
|
||||
.await?;
|
||||
|
||||
let mut users = HashSet::new();
|
||||
let mut users_for_query = HashSet::new();
|
||||
|
||||
for row in rows {
|
||||
let user_id: &str = &row.0;
|
||||
let dirty: bool = row.1;
|
||||
|
||||
if let Ok(u) = UserId::try_from(user_id) {
|
||||
users.insert(u.clone());
|
||||
self.tracked_users.insert(u.clone());
|
||||
if dirty {
|
||||
users_for_query.insert(u);
|
||||
self.users_for_key_query.insert(u);
|
||||
}
|
||||
} else {
|
||||
continue;
|
||||
};
|
||||
}
|
||||
|
||||
Ok((users, users_for_query))
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn load_devices(&self) -> Result<DeviceStore> {
|
||||
|
@ -700,9 +695,7 @@ impl CryptoStore for SqliteStore {
|
|||
let devices = self.load_devices().await?;
|
||||
self.devices = devices;
|
||||
|
||||
let (tracked_users, users_for_query) = self.load_tracked_users().await?;
|
||||
self.tracked_users = tracked_users;
|
||||
self.users_for_key_query = users_for_query;
|
||||
self.load_tracked_users().await?;
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
@ -743,7 +736,7 @@ impl CryptoStore for SqliteStore {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
async fn save_sessions(&mut self, sessions: &[Session]) -> Result<()> {
|
||||
async fn save_sessions(&self, sessions: &[Session]) -> Result<()> {
|
||||
let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?;
|
||||
|
||||
// TODO turn this into a transaction
|
||||
|
@ -776,11 +769,11 @@ impl CryptoStore for SqliteStore {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_sessions(&mut self, sender_key: &str) -> Result<Option<Arc<Mutex<Vec<Session>>>>> {
|
||||
async fn get_sessions(&self, sender_key: &str) -> Result<Option<Arc<Mutex<Vec<Session>>>>> {
|
||||
Ok(self.get_sessions_for(sender_key).await?)
|
||||
}
|
||||
|
||||
async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result<bool> {
|
||||
async fn save_inbound_group_session(&self, session: InboundGroupSession) -> Result<bool> {
|
||||
let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?;
|
||||
let pickle = session.pickle(self.get_pickle_mode()).await;
|
||||
let mut connection = self.connection.lock().await;
|
||||
|
@ -808,7 +801,7 @@ impl CryptoStore for SqliteStore {
|
|||
}
|
||||
|
||||
async fn get_inbound_group_session(
|
||||
&mut self,
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
sender_key: &str,
|
||||
session_id: &str,
|
||||
|
@ -822,11 +815,15 @@ impl CryptoStore for SqliteStore {
|
|||
self.tracked_users.contains(user_id)
|
||||
}
|
||||
|
||||
fn users_for_key_query(&self) -> &HashSet<UserId> {
|
||||
&self.users_for_key_query
|
||||
fn has_users_for_key_query(&self) -> bool {
|
||||
!self.users_for_key_query.is_empty()
|
||||
}
|
||||
|
||||
async fn update_tracked_user(&mut self, user: &UserId, dirty: bool) -> Result<bool> {
|
||||
fn users_for_key_query(&self) -> HashSet<UserId> {
|
||||
self.users_for_key_query.iter().map(|u| u.clone()).collect()
|
||||
}
|
||||
|
||||
async fn update_tracked_user(&self, user: &UserId, dirty: bool) -> Result<bool> {
|
||||
let already_added = self.tracked_users.insert(user.clone());
|
||||
|
||||
if dirty {
|
||||
|
@ -1135,7 +1132,7 @@ mod test {
|
|||
|
||||
#[tokio::test]
|
||||
async fn save_inbound_group_session() {
|
||||
let (account, mut store, _dir) = get_loaded_store().await;
|
||||
let (account, store, _dir) = get_loaded_store().await;
|
||||
|
||||
let identity_keys = account.identity_keys();
|
||||
let outbound_session = OlmOutboundGroupSession::new();
|
||||
|
@ -1155,7 +1152,7 @@ mod test {
|
|||
|
||||
#[tokio::test]
|
||||
async fn load_inbound_group_session() {
|
||||
let (account, mut store, _dir) = get_loaded_store().await;
|
||||
let (account, store, _dir) = get_loaded_store().await;
|
||||
|
||||
let identity_keys = account.identity_keys();
|
||||
let outbound_session = OlmOutboundGroupSession::new();
|
||||
|
@ -1188,7 +1185,7 @@ mod test {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_tracked_users() {
|
||||
let (_account, mut store, dir) = get_loaded_store().await;
|
||||
let (_account, store, dir) = get_loaded_store().await;
|
||||
let device = get_device();
|
||||
|
||||
assert!(store
|
||||
|
|
Loading…
Reference in a new issue