crypto: Remove most mutable self borrows from the crypto-store trait.

This commit is contained in:
Damir Jelić 2020-08-11 14:34:42 +02:00
parent ac2469d270
commit 01bcbaf063
5 changed files with 62 additions and 57 deletions

View file

@ -1268,14 +1268,14 @@ impl OlmMachine {
/// Should the client perform a key query request.
pub async fn should_query_keys(&self) -> bool {
!self.store.read().await.users_for_key_query().is_empty()
self.store.read().await.has_users_for_key_query()
}
/// Get the set of users that we need to query keys for.
///
/// Returns a hash set of users that need to be queried for keys.
pub async fn users_for_key_query(&self) -> HashSet<UserId> {
self.store.read().await.users_for_key_query().clone()
self.store.read().await.users_for_key_query()
}
}

View file

@ -43,7 +43,7 @@ impl SessionStore {
///
/// Returns true if the the session was added, false if the session was
/// already in the store.
pub async fn add(&mut self, session: Session) -> bool {
pub async fn add(&self, session: Session) -> bool {
if !self.entries.contains_key(&*session.sender_key) {
self.entries.insert(
session.sender_key.to_string(),
@ -67,7 +67,7 @@ impl SessionStore {
}
/// Add a list of sessions belonging to the sender key.
pub fn set_for_sender(&mut self, sender_key: &str, sessions: Vec<Session>) {
pub fn set_for_sender(&self, sender_key: &str, sessions: Vec<Session>) {
self.entries
.insert(sender_key.to_owned(), Arc::new(Mutex::new(sessions)));
}
@ -92,7 +92,7 @@ impl GroupSessionStore {
///
/// Returns true if the the session was added, false if the session was
/// already in the store.
pub fn add(&mut self, session: InboundGroupSession) -> bool {
pub fn add(&self, session: InboundGroupSession) -> bool {
if !self.entries.contains_key(&session.room_id) {
let room_id = &*session.room_id;
self.entries.insert(room_id.clone(), HashMap::new());
@ -225,7 +225,7 @@ mod test {
async fn test_session_store() {
let (_, session) = get_account_and_session().await;
let mut store = SessionStore::new();
let store = SessionStore::new();
assert!(store.add(session.clone()).await);
assert!(!store.add(session.clone()).await);
@ -242,7 +242,7 @@ mod test {
async fn test_session_store_bulk_storing() {
let (_, session) = get_account_and_session().await;
let mut store = SessionStore::new();
let store = SessionStore::new();
store.set_for_sender(&session.sender_key, vec![session.clone()]);
let sessions = store.get(&session.sender_key).unwrap();
@ -273,7 +273,7 @@ mod test {
)
.unwrap();
let mut store = GroupSessionStore::new();
let store = GroupSessionStore::new();
store.add(inbound.clone());
let loaded_session = store

View file

@ -15,6 +15,7 @@
use std::{collections::HashSet, sync::Arc};
use async_trait::async_trait;
use dashmap::DashSet;
use matrix_sdk_common::{
identifiers::{DeviceId, RoomId, UserId},
locks::Mutex,
@ -29,8 +30,8 @@ use crate::{
pub struct MemoryStore {
sessions: SessionStore,
inbound_group_sessions: GroupSessionStore,
tracked_users: HashSet<UserId>,
users_for_key_query: HashSet<UserId>,
tracked_users: DashSet<UserId>,
users_for_key_query: DashSet<UserId>,
devices: DeviceStore,
}
@ -39,8 +40,8 @@ impl MemoryStore {
MemoryStore {
sessions: SessionStore::new(),
inbound_group_sessions: GroupSessionStore::new(),
tracked_users: HashSet::new(),
users_for_key_query: HashSet::new(),
tracked_users: DashSet::new(),
users_for_key_query: DashSet::new(),
devices: DeviceStore::new(),
}
}
@ -56,7 +57,7 @@ impl CryptoStore for MemoryStore {
Ok(())
}
async fn save_sessions(&mut self, sessions: &[Session]) -> Result<()> {
async fn save_sessions(&self, sessions: &[Session]) -> Result<()> {
for session in sessions {
let _ = self.sessions.add(session.clone()).await;
}
@ -64,16 +65,16 @@ impl CryptoStore for MemoryStore {
Ok(())
}
async fn get_sessions(&mut self, sender_key: &str) -> Result<Option<Arc<Mutex<Vec<Session>>>>> {
async fn get_sessions(&self, sender_key: &str) -> Result<Option<Arc<Mutex<Vec<Session>>>>> {
Ok(self.sessions.get(sender_key))
}
async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result<bool> {
async fn save_inbound_group_session(&self, session: InboundGroupSession) -> Result<bool> {
Ok(self.inbound_group_sessions.add(session))
}
async fn get_inbound_group_session(
&mut self,
&self,
room_id: &RoomId,
sender_key: &str,
session_id: &str,
@ -83,15 +84,19 @@ impl CryptoStore for MemoryStore {
.get(room_id, sender_key, session_id))
}
fn users_for_key_query(&self) -> &HashSet<UserId> {
&self.users_for_key_query
fn users_for_key_query(&self) -> HashSet<UserId> {
self.users_for_key_query.iter().map(|u| u.clone()).collect()
}
fn is_user_tracked(&self, user_id: &UserId) -> bool {
self.tracked_users.contains(user_id)
}
async fn update_tracked_user(&mut self, user: &UserId, dirty: bool) -> Result<bool> {
fn has_users_for_key_query(&self) -> bool {
!self.users_for_key_query.is_empty()
}
async fn update_tracked_user(&self, user: &UserId, dirty: bool) -> Result<bool> {
if dirty {
self.users_for_key_query.insert(user.clone());
} else {
@ -168,7 +173,7 @@ mod test {
)
.unwrap();
let mut store = MemoryStore::new();
let store = MemoryStore::new();
let _ = store
.save_inbound_group_session(inbound.clone())
.await
@ -217,7 +222,7 @@ mod test {
#[tokio::test]
async fn test_tracked_users() {
let device = get_device();
let mut store = MemoryStore::new();
let store = MemoryStore::new();
assert!(store
.update_tracked_user(device.user_id(), false)

View file

@ -109,14 +109,14 @@ pub trait CryptoStore: Debug {
/// # Arguments
///
/// * `session` - The sessions that should be stored.
async fn save_sessions(&mut self, session: &[Session]) -> Result<()>;
async fn save_sessions(&self, session: &[Session]) -> Result<()>;
/// Get all the sessions that belong to the given sender key.
///
/// # Arguments
///
/// * `sender_key` - The sender key that was used to establish the sessions.
async fn get_sessions(&mut self, sender_key: &str) -> Result<Option<Arc<Mutex<Vec<Session>>>>>;
async fn get_sessions(&self, sender_key: &str) -> Result<Option<Arc<Mutex<Vec<Session>>>>>;
/// Save the given inbound group session in the store.
///
@ -126,7 +126,7 @@ pub trait CryptoStore: Debug {
/// # Arguments
///
/// * `session` - The session that should be stored.
async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result<bool>;
async fn save_inbound_group_session(&self, session: InboundGroupSession) -> Result<bool>;
/// Get the inbound group session from our store.
///
@ -137,7 +137,7 @@ pub trait CryptoStore: Debug {
///
/// * `session_id` - The unique id of the session.
async fn get_inbound_group_session(
&mut self,
&self,
room_id: &RoomId,
sender_key: &str,
session_id: &str,
@ -146,9 +146,12 @@ pub trait CryptoStore: Debug {
/// Is the given user already tracked.
fn is_user_tracked(&self, user_id: &UserId) -> bool;
/// Are there any tracked users that are marked as dirty.
fn has_users_for_key_query(&self) -> bool;
/// Set of users that we need to query keys for. This is a subset of
/// the tracked users.
fn users_for_key_query(&self) -> &HashSet<UserId>;
fn users_for_key_query(&self) -> HashSet<UserId>;
/// Add an user for tracking.
///
@ -159,7 +162,7 @@ pub trait CryptoStore: Debug {
/// * `user` - The user that should be marked as tracked.
///
/// * `dirty` - Should the user be also marked for a key query.
async fn update_tracked_user(&mut self, user: &UserId, dirty: bool) -> Result<bool>;
async fn update_tracked_user(&self, user: &UserId, dirty: bool) -> Result<bool>;
/// Save the given devices in the store.
///

View file

@ -21,6 +21,7 @@ use std::{
};
use async_trait::async_trait;
use dashmap::DashSet;
use matrix_sdk_common::{
events::Algorithm,
identifiers::{DeviceId, DeviceKeyAlgorithm, DeviceKeyId, RoomId, UserId},
@ -49,8 +50,8 @@ pub struct SqliteStore {
sessions: SessionStore,
inbound_group_sessions: GroupSessionStore,
devices: DeviceStore,
tracked_users: HashSet<UserId>,
users_for_key_query: HashSet<UserId>,
tracked_users: DashSet<UserId>,
users_for_key_query: DashSet<UserId>,
connection: Arc<Mutex<SqliteConnection>>,
pickle_passphrase: Option<Zeroizing<String>>,
@ -137,8 +138,8 @@ impl SqliteStore {
path: path.as_ref().to_owned(),
connection: Arc::new(Mutex::new(connection)),
pickle_passphrase: passphrase,
tracked_users: HashSet::new(),
users_for_key_query: HashSet::new(),
tracked_users: DashSet::new(),
users_for_key_query: DashSet::new(),
};
store.create_tables().await?;
Ok(store)
@ -299,7 +300,7 @@ impl SqliteStore {
Ok(())
}
async fn lazy_load_sessions(&mut self, sender_key: &str) -> Result<()> {
async fn lazy_load_sessions(&self, sender_key: &str) -> Result<()> {
let loaded_sessions = self.sessions.get(sender_key).is_some();
if !loaded_sessions {
@ -313,15 +314,12 @@ impl SqliteStore {
Ok(())
}
async fn get_sessions_for(
&mut self,
sender_key: &str,
) -> Result<Option<Arc<Mutex<Vec<Session>>>>> {
async fn get_sessions_for(&self, sender_key: &str) -> Result<Option<Arc<Mutex<Vec<Session>>>>> {
self.lazy_load_sessions(sender_key).await?;
Ok(self.sessions.get(sender_key))
}
async fn load_sessions_for(&mut self, sender_key: &str) -> Result<Vec<Session>> {
async fn load_sessions_for(&self, sender_key: &str) -> Result<Vec<Session>> {
let account_info = self
.account_info
.as_ref()
@ -417,7 +415,7 @@ impl SqliteStore {
Ok(())
}
async fn load_tracked_users(&self) -> Result<(HashSet<UserId>, HashSet<UserId>)> {
async fn load_tracked_users(&self) -> Result<()> {
let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?;
let mut connection = self.connection.lock().await;
@ -429,24 +427,21 @@ impl SqliteStore {
.fetch_all(&mut *connection)
.await?;
let mut users = HashSet::new();
let mut users_for_query = HashSet::new();
for row in rows {
let user_id: &str = &row.0;
let dirty: bool = row.1;
if let Ok(u) = UserId::try_from(user_id) {
users.insert(u.clone());
self.tracked_users.insert(u.clone());
if dirty {
users_for_query.insert(u);
self.users_for_key_query.insert(u);
}
} else {
continue;
};
}
Ok((users, users_for_query))
Ok(())
}
async fn load_devices(&self) -> Result<DeviceStore> {
@ -700,9 +695,7 @@ impl CryptoStore for SqliteStore {
let devices = self.load_devices().await?;
self.devices = devices;
let (tracked_users, users_for_query) = self.load_tracked_users().await?;
self.tracked_users = tracked_users;
self.users_for_key_query = users_for_query;
self.load_tracked_users().await?;
Ok(result)
}
@ -743,7 +736,7 @@ impl CryptoStore for SqliteStore {
Ok(())
}
async fn save_sessions(&mut self, sessions: &[Session]) -> Result<()> {
async fn save_sessions(&self, sessions: &[Session]) -> Result<()> {
let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?;
// TODO turn this into a transaction
@ -776,11 +769,11 @@ impl CryptoStore for SqliteStore {
Ok(())
}
async fn get_sessions(&mut self, sender_key: &str) -> Result<Option<Arc<Mutex<Vec<Session>>>>> {
async fn get_sessions(&self, sender_key: &str) -> Result<Option<Arc<Mutex<Vec<Session>>>>> {
Ok(self.get_sessions_for(sender_key).await?)
}
async fn save_inbound_group_session(&mut self, session: InboundGroupSession) -> Result<bool> {
async fn save_inbound_group_session(&self, session: InboundGroupSession) -> Result<bool> {
let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?;
let pickle = session.pickle(self.get_pickle_mode()).await;
let mut connection = self.connection.lock().await;
@ -808,7 +801,7 @@ impl CryptoStore for SqliteStore {
}
async fn get_inbound_group_session(
&mut self,
&self,
room_id: &RoomId,
sender_key: &str,
session_id: &str,
@ -822,11 +815,15 @@ impl CryptoStore for SqliteStore {
self.tracked_users.contains(user_id)
}
fn users_for_key_query(&self) -> &HashSet<UserId> {
&self.users_for_key_query
fn has_users_for_key_query(&self) -> bool {
!self.users_for_key_query.is_empty()
}
async fn update_tracked_user(&mut self, user: &UserId, dirty: bool) -> Result<bool> {
fn users_for_key_query(&self) -> HashSet<UserId> {
self.users_for_key_query.iter().map(|u| u.clone()).collect()
}
async fn update_tracked_user(&self, user: &UserId, dirty: bool) -> Result<bool> {
let already_added = self.tracked_users.insert(user.clone());
if dirty {
@ -1135,7 +1132,7 @@ mod test {
#[tokio::test]
async fn save_inbound_group_session() {
let (account, mut store, _dir) = get_loaded_store().await;
let (account, store, _dir) = get_loaded_store().await;
let identity_keys = account.identity_keys();
let outbound_session = OlmOutboundGroupSession::new();
@ -1155,7 +1152,7 @@ mod test {
#[tokio::test]
async fn load_inbound_group_session() {
let (account, mut store, _dir) = get_loaded_store().await;
let (account, store, _dir) = get_loaded_store().await;
let identity_keys = account.identity_keys();
let outbound_session = OlmOutboundGroupSession::new();
@ -1188,7 +1185,7 @@ mod test {
#[tokio::test]
async fn test_tracked_users() {
let (_account, mut store, dir) = get_loaded_store().await;
let (_account, store, dir) = get_loaded_store().await;
let device = get_device();
assert!(store