crypto: Don't require the load_account to mutably borrow self.

master
Damir Jelić 2020-08-11 15:08:07 +02:00
parent 8f4ac3da7f
commit 947fa08dae
4 changed files with 23 additions and 16 deletions

View File

@ -147,7 +147,7 @@ impl OlmMachine {
pub async fn new_with_store( pub async fn new_with_store(
user_id: UserId, user_id: UserId,
device_id: Box<DeviceId>, device_id: Box<DeviceId>,
mut store: Box<dyn CryptoStore>, store: Box<dyn CryptoStore>,
) -> StoreResult<Self> { ) -> StoreResult<Self> {
let account = match store.load_account().await? { let account = match store.load_account().await? {
Some(a) => { Some(a) => {

View File

@ -49,7 +49,7 @@ impl MemoryStore {
#[async_trait] #[async_trait]
impl CryptoStore for MemoryStore { impl CryptoStore for MemoryStore {
async fn load_account(&mut self) -> Result<Option<Account>> { async fn load_account(&self) -> Result<Option<Account>> {
Ok(None) Ok(None)
} }

View File

@ -95,7 +95,7 @@ pub type Result<T> = std::result::Result<T, CryptoStoreError>;
/// keys. /// keys.
pub trait CryptoStore: Debug { pub trait CryptoStore: Debug {
/// Load an account that was previously stored. /// Load an account that was previously stored.
async fn load_account(&mut self) -> Result<Option<Account>>; async fn load_account(&self) -> Result<Option<Account>>;
/// Save the given account in the store. /// Save the given account in the store.
/// ///

View File

@ -17,7 +17,7 @@ use std::{
convert::TryFrom, convert::TryFrom,
path::{Path, PathBuf}, path::{Path, PathBuf},
result::Result as StdResult, result::Result as StdResult,
sync::Arc, sync::{Arc, Mutex as SyncMutex},
}; };
use async_trait::async_trait; use async_trait::async_trait;
@ -44,7 +44,7 @@ use crate::{
pub struct SqliteStore { pub struct SqliteStore {
user_id: Arc<UserId>, user_id: Arc<UserId>,
device_id: Arc<Box<DeviceId>>, device_id: Arc<Box<DeviceId>>,
account_info: Option<AccountInfo>, account_info: Arc<SyncMutex<Option<AccountInfo>>>,
path: PathBuf, path: PathBuf,
sessions: SessionStore, sessions: SessionStore,
@ -57,6 +57,7 @@ pub struct SqliteStore {
pickle_passphrase: Option<Zeroizing<String>>, pickle_passphrase: Option<Zeroizing<String>>,
} }
#[derive(Clone)]
struct AccountInfo { struct AccountInfo {
account_id: i64, account_id: i64,
identity_keys: Arc<IdentityKeys>, identity_keys: Arc<IdentityKeys>,
@ -131,7 +132,7 @@ impl SqliteStore {
let store = SqliteStore { let store = SqliteStore {
user_id: Arc::new(user_id.to_owned()), user_id: Arc::new(user_id.to_owned()),
device_id: Arc::new(device_id.into()), device_id: Arc::new(device_id.into()),
account_info: None, account_info: Arc::new(SyncMutex::new(None)),
sessions: SessionStore::new(), sessions: SessionStore::new(),
inbound_group_sessions: GroupSessionStore::new(), inbound_group_sessions: GroupSessionStore::new(),
devices: DeviceStore::new(), devices: DeviceStore::new(),
@ -146,7 +147,11 @@ impl SqliteStore {
} }
fn account_id(&self) -> Option<i64> { fn account_id(&self) -> Option<i64> {
self.account_info.as_ref().map(|i| i.account_id) self.account_info
.lock()
.unwrap()
.as_ref()
.map(|i| i.account_id)
} }
async fn create_tables(&self) -> Result<()> { async fn create_tables(&self) -> Result<()> {
@ -322,7 +327,9 @@ impl SqliteStore {
async fn load_sessions_for(&self, sender_key: &str) -> Result<Vec<Session>> { async fn load_sessions_for(&self, sender_key: &str) -> Result<Vec<Session>> {
let account_info = self let account_info = self
.account_info .account_info
.as_ref() .lock()
.unwrap()
.clone()
.ok_or(CryptoStoreError::AccountUnset)?; .ok_or(CryptoStoreError::AccountUnset)?;
let mut connection = self.connection.lock().await; let mut connection = self.connection.lock().await;
@ -656,7 +663,7 @@ impl SqliteStore {
#[async_trait] #[async_trait]
impl CryptoStore for SqliteStore { impl CryptoStore for SqliteStore {
async fn load_account(&mut self) -> Result<Option<Account>> { async fn load_account(&self) -> Result<Option<Account>> {
let mut connection = self.connection.lock().await; let mut connection = self.connection.lock().await;
let row: Option<(i64, String, bool, i64)> = query_as( let row: Option<(i64, String, bool, i64)> = query_as(
@ -678,7 +685,7 @@ impl CryptoStore for SqliteStore {
&self.device_id, &self.device_id,
)?; )?;
self.account_info = Some(AccountInfo { *self.account_info.lock().unwrap() = Some(AccountInfo {
account_id: id, account_id: id,
identity_keys: account.identity_keys.clone(), identity_keys: account.identity_keys.clone(),
}); });
@ -725,7 +732,7 @@ impl CryptoStore for SqliteStore {
.fetch_one(&mut *connection) .fetch_one(&mut *connection)
.await?; .await?;
self.account_info = Some(AccountInfo { *self.account_info.lock().unwrap() = Some(AccountInfo {
account_id: account_id.0, account_id: account_id.0,
identity_keys: account.identity_keys.clone(), identity_keys: account.identity_keys.clone(),
}); });
@ -1113,7 +1120,7 @@ mod test {
drop(store); drop(store);
let mut store = SqliteStore::open(&example_user_id(), example_device_id(), dir.path()) let store = SqliteStore::open(&example_user_id(), example_device_id(), dir.path())
.await .await
.expect("Can't create store"); .expect("Can't create store");
@ -1199,7 +1206,7 @@ mod test {
assert!(store.users_for_key_query().contains(device.user_id())); assert!(store.users_for_key_query().contains(device.user_id()));
drop(store); drop(store);
let mut store = SqliteStore::open(&example_user_id(), example_device_id(), dir.path()) let store = SqliteStore::open(&example_user_id(), example_device_id(), dir.path())
.await .await
.expect("Can't create store"); .expect("Can't create store");
@ -1214,7 +1221,7 @@ mod test {
.unwrap(); .unwrap();
assert!(!store.users_for_key_query().contains(device.user_id())); assert!(!store.users_for_key_query().contains(device.user_id()));
let mut store = SqliteStore::open(&example_user_id(), example_device_id(), dir.path()) let store = SqliteStore::open(&example_user_id(), example_device_id(), dir.path())
.await .await
.expect("Can't create store"); .expect("Can't create store");
@ -1232,7 +1239,7 @@ mod test {
drop(store); drop(store);
let mut store = SqliteStore::open(&example_user_id(), example_device_id(), dir.path()) let store = SqliteStore::open(&example_user_id(), example_device_id(), dir.path())
.await .await
.expect("Can't create store"); .expect("Can't create store");
@ -1265,7 +1272,7 @@ mod test {
store.save_devices(&[device.clone()]).await.unwrap(); store.save_devices(&[device.clone()]).await.unwrap();
store.delete_device(device.clone()).await.unwrap(); store.delete_device(device.clone()).await.unwrap();
let mut store = SqliteStore::open(&example_user_id(), example_device_id(), dir.path()) let store = SqliteStore::open(&example_user_id(), example_device_id(), dir.path())
.await .await
.expect("Can't create store"); .expect("Can't create store");