crypto: Move the cryptostore behind a lock.
This commit is contained in:
parent
57b65ec8c4
commit
7e95d85f17
2 changed files with 128 additions and 37 deletions
|
@ -1260,7 +1260,7 @@ impl BaseClient {
|
|||
let olm = self.olm.lock().await;
|
||||
|
||||
match &*olm {
|
||||
Some(o) => o.should_query_keys(),
|
||||
Some(o) => o.should_query_keys().await,
|
||||
None => false,
|
||||
}
|
||||
}
|
||||
|
@ -1348,7 +1348,7 @@ impl BaseClient {
|
|||
let olm = self.olm.lock().await;
|
||||
|
||||
match &*olm {
|
||||
Some(o) => Ok(o.users_for_key_query()),
|
||||
Some(o) => Ok(o.users_for_key_query().await),
|
||||
None => Err(()),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -13,12 +13,12 @@
|
|||
// limitations under the License.
|
||||
|
||||
use std::collections::{BTreeMap, HashMap, HashSet};
|
||||
use std::convert::TryFrom;
|
||||
use std::convert::TryInto;
|
||||
use std::convert::{TryFrom, TryInto};
|
||||
use std::mem;
|
||||
#[cfg(feature = "sqlite-cryptostore")]
|
||||
use std::path::Path;
|
||||
use std::result::Result as StdResult;
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::error::{EventError, MegolmError, MegolmResult, OlmError, OlmResult};
|
||||
use super::olm::{
|
||||
|
@ -36,6 +36,7 @@ use matrix_sdk_common::events::{
|
|||
EventType, SyncMessageEvent, ToDeviceEvent,
|
||||
};
|
||||
use matrix_sdk_common::identifiers::{DeviceId, RoomId, UserId};
|
||||
use matrix_sdk_common::locks::RwLock;
|
||||
use matrix_sdk_common::uuid::Uuid;
|
||||
use matrix_sdk_common::{api, Raw};
|
||||
|
||||
|
@ -66,7 +67,7 @@ pub struct OlmMachine {
|
|||
/// Store for the encryption keys.
|
||||
/// Persists all the encryption keys so a client can resume the session
|
||||
/// without the need to create new keys.
|
||||
store: Box<dyn CryptoStore>,
|
||||
store: Arc<RwLock<Box<dyn CryptoStore>>>,
|
||||
/// The currently active outbound group sessions.
|
||||
outbound_group_sessions: HashMap<RoomId, OutboundGroupSession>,
|
||||
}
|
||||
|
@ -100,7 +101,7 @@ impl OlmMachine {
|
|||
user_id: user_id.clone(),
|
||||
device_id: device_id.into(),
|
||||
account: Account::new(user_id, &device_id),
|
||||
store: Box::new(MemoryStore::new()),
|
||||
store: Arc::new(RwLock::new(Box::new(MemoryStore::new()))),
|
||||
outbound_group_sessions: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
@ -142,7 +143,7 @@ impl OlmMachine {
|
|||
user_id,
|
||||
device_id,
|
||||
account,
|
||||
store,
|
||||
store: Arc::new(RwLock::new(store)),
|
||||
outbound_group_sessions: HashMap::new(),
|
||||
})
|
||||
}
|
||||
|
@ -224,7 +225,11 @@ impl OlmMachine {
|
|||
self.update_key_count(count);
|
||||
|
||||
self.account.mark_keys_as_published().await;
|
||||
self.store.save_account(self.account.clone()).await?;
|
||||
self.store
|
||||
.write()
|
||||
.await
|
||||
.save_account(self.account.clone())
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -255,7 +260,7 @@ impl OlmMachine {
|
|||
let mut missing = BTreeMap::new();
|
||||
|
||||
for user_id in users {
|
||||
let user_devices = self.store.get_user_devices(user_id).await?;
|
||||
let user_devices = self.store.read().await.get_user_devices(user_id).await?;
|
||||
|
||||
for device in user_devices.devices() {
|
||||
let sender_key = if let Some(k) = device.get_key(KeyAlgorithm::Curve25519) {
|
||||
|
@ -264,7 +269,7 @@ impl OlmMachine {
|
|||
continue;
|
||||
};
|
||||
|
||||
let sessions = self.store.get_sessions(sender_key).await?;
|
||||
let sessions = self.store.write().await.get_sessions(sender_key).await?;
|
||||
|
||||
let is_missing = if let Some(sessions) = sessions {
|
||||
sessions.lock().await.is_empty()
|
||||
|
@ -301,7 +306,13 @@ impl OlmMachine {
|
|||
|
||||
for (user_id, user_devices) in &response.one_time_keys {
|
||||
for (device_id, key_map) in user_devices {
|
||||
let device: Device = match self.store.get_device(&user_id, device_id).await {
|
||||
let device: Device = match self
|
||||
.store
|
||||
.read()
|
||||
.await
|
||||
.get_device(&user_id, device_id)
|
||||
.await
|
||||
{
|
||||
Ok(Some(d)) => d,
|
||||
Ok(None) => {
|
||||
warn!(
|
||||
|
@ -330,7 +341,7 @@ impl OlmMachine {
|
|||
}
|
||||
};
|
||||
|
||||
if let Err(e) = self.store.save_sessions(&[session]).await {
|
||||
if let Err(e) = self.store.write().await.save_sessions(&[session]).await {
|
||||
error!("Failed to store newly created Olm session {}", e);
|
||||
continue;
|
||||
}
|
||||
|
@ -351,7 +362,11 @@ impl OlmMachine {
|
|||
let mut changed_devices = Vec::new();
|
||||
|
||||
for (user_id, device_map) in device_keys_map {
|
||||
self.store.update_tracked_user(user_id, false).await?;
|
||||
self.store
|
||||
.write()
|
||||
.await
|
||||
.update_tracked_user(user_id, false)
|
||||
.await?;
|
||||
|
||||
for (device_id, device_keys) in device_map.iter() {
|
||||
// We don't need our own device in the device store.
|
||||
|
@ -367,7 +382,12 @@ impl OlmMachine {
|
|||
continue;
|
||||
}
|
||||
|
||||
let device = self.store.get_device(&user_id, device_id).await?;
|
||||
let device = self
|
||||
.store
|
||||
.read()
|
||||
.await
|
||||
.get_device(&user_id, device_id)
|
||||
.await?;
|
||||
|
||||
let device = if let Some(mut device) = device {
|
||||
if let Err(e) = device.update_device(device_keys) {
|
||||
|
@ -398,7 +418,13 @@ impl OlmMachine {
|
|||
|
||||
let current_devices: HashSet<&DeviceId> =
|
||||
device_map.keys().map(|id| id.as_ref()).collect();
|
||||
let stored_devices = self.store.get_user_devices(&user_id).await.unwrap();
|
||||
let stored_devices = self
|
||||
.store
|
||||
.read()
|
||||
.await
|
||||
.get_user_devices(&user_id)
|
||||
.await
|
||||
.unwrap();
|
||||
let stored_devices_set: HashSet<&DeviceId> = stored_devices.keys().collect();
|
||||
|
||||
let deleted_devices = stored_devices_set.difference(¤t_devices);
|
||||
|
@ -406,7 +432,7 @@ impl OlmMachine {
|
|||
for device_id in deleted_devices {
|
||||
if let Some(device) = stored_devices.get(device_id) {
|
||||
device.mark_as_deleted();
|
||||
self.store.delete_device(device).await?;
|
||||
self.store.write().await.delete_device(device).await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -430,7 +456,11 @@ impl OlmMachine {
|
|||
let changed_devices = self
|
||||
.handle_devices_from_key_query(&response.device_keys)
|
||||
.await?;
|
||||
self.store.save_devices(&changed_devices).await?;
|
||||
self.store
|
||||
.write()
|
||||
.await
|
||||
.save_devices(&changed_devices)
|
||||
.await?;
|
||||
|
||||
Ok(changed_devices)
|
||||
}
|
||||
|
@ -454,7 +484,7 @@ impl OlmMachine {
|
|||
sender_key: &str,
|
||||
message: &OlmMessage,
|
||||
) -> OlmResult<Option<String>> {
|
||||
let s = self.store.get_sessions(sender_key).await?;
|
||||
let s = self.store.write().await.get_sessions(sender_key).await?;
|
||||
|
||||
// We don't have any existing sessions, return early.
|
||||
let sessions = if let Some(s) = s {
|
||||
|
@ -504,7 +534,7 @@ impl OlmMachine {
|
|||
// Decryption was successful, save the new ratchet state of the
|
||||
// session that was used to decrypt the message.
|
||||
trace!("Saved the new session state for {}", sender);
|
||||
self.store.save_sessions(&[session]).await?;
|
||||
self.store.write().await.save_sessions(&[session]).await?;
|
||||
}
|
||||
|
||||
Ok(plaintext)
|
||||
|
@ -559,7 +589,11 @@ impl OlmMachine {
|
|||
|
||||
// Save the account since we remove the one-time key that
|
||||
// was used to create this session.
|
||||
self.store.save_account(self.account.clone()).await?;
|
||||
self.store
|
||||
.write()
|
||||
.await
|
||||
.save_account(self.account.clone())
|
||||
.await?;
|
||||
session
|
||||
}
|
||||
};
|
||||
|
@ -569,7 +603,7 @@ impl OlmMachine {
|
|||
let plaintext = session.decrypt(message).await?;
|
||||
|
||||
// Save the new ratcheted state of the session.
|
||||
self.store.save_sessions(&[session]).await?;
|
||||
self.store.write().await.save_sessions(&[session]).await?;
|
||||
plaintext
|
||||
};
|
||||
|
||||
|
@ -720,7 +754,12 @@ impl OlmMachine {
|
|||
&event.content.room_id,
|
||||
session_key,
|
||||
)?;
|
||||
let _ = self.store.save_inbound_group_session(session).await?;
|
||||
let _ = self
|
||||
.store
|
||||
.write()
|
||||
.await
|
||||
.save_inbound_group_session(session)
|
||||
.await?;
|
||||
|
||||
let event = Raw::from(AnyToDeviceEvent::RoomKey(event.clone()));
|
||||
Ok(Some(event))
|
||||
|
@ -742,7 +781,12 @@ impl OlmMachine {
|
|||
async fn create_outbound_group_session(&mut self, room_id: &RoomId) -> OlmResult<()> {
|
||||
let (outbound, inbound) = self.account.create_group_session_pair(room_id).await;
|
||||
|
||||
let _ = self.store.save_inbound_group_session(inbound).await?;
|
||||
let _ = self
|
||||
.store
|
||||
.write()
|
||||
.await
|
||||
.save_inbound_group_session(inbound)
|
||||
.await?;
|
||||
|
||||
let _ = self
|
||||
.outbound_group_sessions
|
||||
|
@ -819,7 +863,8 @@ impl OlmMachine {
|
|||
return Err(EventError::MissingSenderKey.into());
|
||||
};
|
||||
|
||||
let mut session = if let Some(s) = self.store.get_sessions(sender_key).await? {
|
||||
let mut session = if let Some(s) = self.store.write().await.get_sessions(sender_key).await?
|
||||
{
|
||||
let session = &s.lock().await[0];
|
||||
session.clone()
|
||||
} else {
|
||||
|
@ -833,7 +878,7 @@ impl OlmMachine {
|
|||
};
|
||||
|
||||
let message = session.encrypt(recipient_device, event_type, content).await;
|
||||
self.store.save_sessions(&[session]).await?;
|
||||
self.store.write().await.save_sessions(&[session]).await?;
|
||||
|
||||
message
|
||||
}
|
||||
|
@ -897,7 +942,14 @@ impl OlmMachine {
|
|||
let mut devices = Vec::new();
|
||||
|
||||
for user_id in users {
|
||||
for device in self.store.get_user_devices(user_id).await?.devices() {
|
||||
for device in self
|
||||
.store
|
||||
.read()
|
||||
.await
|
||||
.get_user_devices(user_id)
|
||||
.await?
|
||||
.devices()
|
||||
{
|
||||
// TODO abort if the device isn't verified
|
||||
devices.push(device.clone());
|
||||
}
|
||||
|
@ -1086,6 +1138,8 @@ impl OlmMachine {
|
|||
|
||||
let session = self
|
||||
.store
|
||||
.write()
|
||||
.await
|
||||
.get_inbound_group_session(room_id, &content.sender_key, &content.session_id)
|
||||
.await?;
|
||||
// TODO check if the Olm session is wedged and re-request the key.
|
||||
|
@ -1111,8 +1165,12 @@ impl OlmMachine {
|
|||
///
|
||||
/// Returns true if the user was queued up for a key query, false otherwise.
|
||||
pub async fn mark_user_as_changed(&mut self, user_id: &UserId) -> StoreResult<bool> {
|
||||
if self.store.tracked_users().contains(user_id) {
|
||||
self.store.update_tracked_user(user_id, true).await?;
|
||||
if self.store.read().await.tracked_users().contains(user_id) {
|
||||
self.store
|
||||
.write()
|
||||
.await
|
||||
.update_tracked_user(user_id, true)
|
||||
.await?;
|
||||
Ok(true)
|
||||
} else {
|
||||
Ok(false)
|
||||
|
@ -1138,26 +1196,32 @@ impl OlmMachine {
|
|||
I: IntoIterator<Item = &'a UserId>,
|
||||
{
|
||||
for user in users {
|
||||
if self.store.tracked_users().contains(user) {
|
||||
if self.store.read().await.tracked_users().contains(user) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Err(e) = self.store.update_tracked_user(user, true).await {
|
||||
if let Err(e) = self
|
||||
.store
|
||||
.write()
|
||||
.await
|
||||
.update_tracked_user(user, true)
|
||||
.await
|
||||
{
|
||||
warn!("Error storing users for tracking {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Should the client perform a key query request.
|
||||
pub fn should_query_keys(&self) -> bool {
|
||||
!self.store.users_for_key_query().is_empty()
|
||||
pub async fn should_query_keys(&self) -> bool {
|
||||
!self.store.read().await.users_for_key_query().is_empty()
|
||||
}
|
||||
|
||||
/// Get the set of users that we need to query keys for.
|
||||
///
|
||||
/// Returns a hash set of users that need to be queried for keys.
|
||||
pub fn users_for_key_query(&self) -> HashSet<UserId> {
|
||||
self.store.users_for_key_query().clone()
|
||||
pub async fn users_for_key_query(&self) -> HashSet<UserId> {
|
||||
self.store.read().await.users_for_key_query().clone()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1277,8 +1341,19 @@ mod test {
|
|||
|
||||
let alice_deivce = Device::from_machine(&alice).await;
|
||||
let bob_device = Device::from_machine(&bob).await;
|
||||
alice.store.save_devices(&[bob_device]).await.unwrap();
|
||||
bob.store.save_devices(&[alice_deivce]).await.unwrap();
|
||||
alice
|
||||
.store
|
||||
.write()
|
||||
.await
|
||||
.save_devices(&[bob_device])
|
||||
.await
|
||||
.unwrap();
|
||||
bob.store
|
||||
.write()
|
||||
.await
|
||||
.save_devices(&[alice_deivce])
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
(alice, bob, otk)
|
||||
}
|
||||
|
@ -1311,6 +1386,8 @@ mod test {
|
|||
|
||||
let bob_device = alice
|
||||
.store
|
||||
.read()
|
||||
.await
|
||||
.get_device(&bob.user_id, &bob.device_id)
|
||||
.await
|
||||
.unwrap()
|
||||
|
@ -1515,7 +1592,13 @@ mod test {
|
|||
let alice_id = UserId::try_from("@alice:example.org").unwrap();
|
||||
let alice_device_id: &DeviceId = "JLAFKJWSCS".into();
|
||||
|
||||
let alice_devices = machine.store.get_user_devices(&alice_id).await.unwrap();
|
||||
let alice_devices = machine
|
||||
.store
|
||||
.read()
|
||||
.await
|
||||
.get_user_devices(&alice_id)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(alice_devices.devices().peekable().peek().is_none());
|
||||
|
||||
machine
|
||||
|
@ -1525,6 +1608,8 @@ mod test {
|
|||
|
||||
let device = machine
|
||||
.store
|
||||
.read()
|
||||
.await
|
||||
.get_device(&alice_id, alice_device_id)
|
||||
.await
|
||||
.unwrap()
|
||||
|
@ -1576,6 +1661,8 @@ mod test {
|
|||
|
||||
let session = alice_machine
|
||||
.store
|
||||
.write()
|
||||
.await
|
||||
.get_sessions(bob_machine.account.identity_keys().curve25519())
|
||||
.await
|
||||
.unwrap()
|
||||
|
@ -1590,6 +1677,8 @@ mod test {
|
|||
|
||||
let bob_device = alice
|
||||
.store
|
||||
.read()
|
||||
.await
|
||||
.get_device(&bob.user_id, &bob.device_id)
|
||||
.await
|
||||
.unwrap()
|
||||
|
@ -1651,6 +1740,8 @@ mod test {
|
|||
|
||||
let session = bob
|
||||
.store
|
||||
.write()
|
||||
.await
|
||||
.get_inbound_group_session(
|
||||
&room_id,
|
||||
alice.account.identity_keys().curve25519(),
|
||||
|
|
Loading…
Reference in a new issue