crypto: Remove our lock around the cryptostore.
parent
707b4c1185
commit
d0a5b86ff3
|
@ -41,7 +41,6 @@ use matrix_sdk_common::{
|
||||||
Algorithm, AnySyncRoomEvent, AnyToDeviceEvent, EventType, SyncMessageEvent, ToDeviceEvent,
|
Algorithm, AnySyncRoomEvent, AnyToDeviceEvent, EventType, SyncMessageEvent, ToDeviceEvent,
|
||||||
},
|
},
|
||||||
identifiers::{DeviceId, DeviceKeyAlgorithm, DeviceKeyId, RoomId, UserId},
|
identifiers::{DeviceId, DeviceKeyAlgorithm, DeviceKeyId, RoomId, UserId},
|
||||||
locks::RwLock,
|
|
||||||
uuid::Uuid,
|
uuid::Uuid,
|
||||||
Raw,
|
Raw,
|
||||||
};
|
};
|
||||||
|
@ -80,7 +79,7 @@ pub struct OlmMachine {
|
||||||
/// Store for the encryption keys.
|
/// Store for the encryption keys.
|
||||||
/// Persists all the encryption keys so a client can resume the session
|
/// Persists all the encryption keys so a client can resume the session
|
||||||
/// without the need to create new keys.
|
/// without the need to create new keys.
|
||||||
store: Arc<RwLock<Box<dyn CryptoStore>>>,
|
store: Arc<Box<dyn CryptoStore>>,
|
||||||
/// The currently active outbound group sessions.
|
/// The currently active outbound group sessions.
|
||||||
outbound_group_sessions: Arc<DashMap<RoomId, OutboundGroupSession>>,
|
outbound_group_sessions: Arc<DashMap<RoomId, OutboundGroupSession>>,
|
||||||
/// A state machine that is responsible to handle and keep track of SAS
|
/// A state machine that is responsible to handle and keep track of SAS
|
||||||
|
@ -111,10 +110,9 @@ impl OlmMachine {
|
||||||
/// * `user_id` - The unique id of the user that owns this machine.
|
/// * `user_id` - The unique id of the user that owns this machine.
|
||||||
///
|
///
|
||||||
/// * `device_id` - The unique id of the device that owns this machine.
|
/// * `device_id` - The unique id of the device that owns this machine.
|
||||||
#[allow(clippy::ptr_arg)]
|
|
||||||
pub fn new(user_id: &UserId, device_id: &DeviceId) -> Self {
|
pub fn new(user_id: &UserId, device_id: &DeviceId) -> Self {
|
||||||
let store: Box<dyn CryptoStore> = Box::new(MemoryStore::new());
|
let store: Box<dyn CryptoStore> = Box::new(MemoryStore::new());
|
||||||
let store = Arc::new(RwLock::new(store));
|
let store = Arc::new(store);
|
||||||
let account = Account::new(user_id, device_id);
|
let account = Account::new(user_id, device_id);
|
||||||
|
|
||||||
OlmMachine {
|
OlmMachine {
|
||||||
|
@ -160,7 +158,7 @@ impl OlmMachine {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let store = Arc::new(RwLock::new(store));
|
let store = Arc::new(store);
|
||||||
let verification_machine = VerificationMachine::new(account.clone(), store.clone());
|
let verification_machine = VerificationMachine::new(account.clone(), store.clone());
|
||||||
|
|
||||||
Ok(OlmMachine {
|
Ok(OlmMachine {
|
||||||
|
@ -250,11 +248,7 @@ impl OlmMachine {
|
||||||
self.update_key_count(count);
|
self.update_key_count(count);
|
||||||
|
|
||||||
self.account.mark_keys_as_published().await;
|
self.account.mark_keys_as_published().await;
|
||||||
self.store
|
self.store.save_account(self.account.clone()).await?;
|
||||||
.write()
|
|
||||||
.await
|
|
||||||
.save_account(self.account.clone())
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -285,7 +279,7 @@ impl OlmMachine {
|
||||||
let mut missing = BTreeMap::new();
|
let mut missing = BTreeMap::new();
|
||||||
|
|
||||||
for user_id in users {
|
for user_id in users {
|
||||||
let user_devices = self.store.read().await.get_user_devices(user_id).await?;
|
let user_devices = self.store.get_user_devices(user_id).await?;
|
||||||
|
|
||||||
for device in user_devices.devices() {
|
for device in user_devices.devices() {
|
||||||
let sender_key = if let Some(k) = device.get_key(DeviceKeyAlgorithm::Curve25519) {
|
let sender_key = if let Some(k) = device.get_key(DeviceKeyAlgorithm::Curve25519) {
|
||||||
|
@ -294,7 +288,7 @@ impl OlmMachine {
|
||||||
continue;
|
continue;
|
||||||
};
|
};
|
||||||
|
|
||||||
let sessions = self.store.write().await.get_sessions(sender_key).await?;
|
let sessions = self.store.get_sessions(sender_key).await?;
|
||||||
|
|
||||||
let is_missing = if let Some(sessions) = sessions {
|
let is_missing = if let Some(sessions) = sessions {
|
||||||
sessions.lock().await.is_empty()
|
sessions.lock().await.is_empty()
|
||||||
|
@ -333,13 +327,7 @@ impl OlmMachine {
|
||||||
|
|
||||||
for (user_id, user_devices) in &response.one_time_keys {
|
for (user_id, user_devices) in &response.one_time_keys {
|
||||||
for (device_id, key_map) in user_devices {
|
for (device_id, key_map) in user_devices {
|
||||||
let device: Device = match self
|
let device: Device = match self.store.get_device(&user_id, device_id).await {
|
||||||
.store
|
|
||||||
.read()
|
|
||||||
.await
|
|
||||||
.get_device(&user_id, device_id)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
Ok(Some(d)) => d,
|
Ok(Some(d)) => d,
|
||||||
Ok(None) => {
|
Ok(None) => {
|
||||||
warn!(
|
warn!(
|
||||||
|
@ -368,7 +356,7 @@ impl OlmMachine {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
if let Err(e) = self.store.write().await.save_sessions(&[session]).await {
|
if let Err(e) = self.store.save_sessions(&[session]).await {
|
||||||
error!("Failed to store newly created Olm session {}", e);
|
error!("Failed to store newly created Olm session {}", e);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
@ -389,11 +377,7 @@ impl OlmMachine {
|
||||||
let mut changed_devices = Vec::new();
|
let mut changed_devices = Vec::new();
|
||||||
|
|
||||||
for (user_id, device_map) in device_keys_map {
|
for (user_id, device_map) in device_keys_map {
|
||||||
self.store
|
self.store.update_tracked_user(user_id, false).await?;
|
||||||
.write()
|
|
||||||
.await
|
|
||||||
.update_tracked_user(user_id, false)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
for (device_id, device_keys) in device_map.iter() {
|
for (device_id, device_keys) in device_map.iter() {
|
||||||
// We don't need our own device in the device store.
|
// We don't need our own device in the device store.
|
||||||
|
@ -409,12 +393,7 @@ impl OlmMachine {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
let device = self
|
let device = self.store.get_device(&user_id, device_id).await?;
|
||||||
.store
|
|
||||||
.read()
|
|
||||||
.await
|
|
||||||
.get_device(&user_id, device_id)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
let device = if let Some(mut device) = device {
|
let device = if let Some(mut device) = device {
|
||||||
if let Err(e) = device.update_device(device_keys) {
|
if let Err(e) = device.update_device(device_keys) {
|
||||||
|
@ -445,13 +424,7 @@ impl OlmMachine {
|
||||||
|
|
||||||
let current_devices: HashSet<&DeviceId> =
|
let current_devices: HashSet<&DeviceId> =
|
||||||
device_map.keys().map(|id| id.as_ref()).collect();
|
device_map.keys().map(|id| id.as_ref()).collect();
|
||||||
let stored_devices = self
|
let stored_devices = self.store.get_user_devices(&user_id).await.unwrap();
|
||||||
.store
|
|
||||||
.read()
|
|
||||||
.await
|
|
||||||
.get_user_devices(&user_id)
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
let stored_devices_set: HashSet<&DeviceId> = stored_devices.keys().collect();
|
let stored_devices_set: HashSet<&DeviceId> = stored_devices.keys().collect();
|
||||||
|
|
||||||
let deleted_devices = stored_devices_set.difference(¤t_devices);
|
let deleted_devices = stored_devices_set.difference(¤t_devices);
|
||||||
|
@ -459,7 +432,7 @@ impl OlmMachine {
|
||||||
for device_id in deleted_devices {
|
for device_id in deleted_devices {
|
||||||
if let Some(device) = stored_devices.get(device_id) {
|
if let Some(device) = stored_devices.get(device_id) {
|
||||||
device.mark_as_deleted();
|
device.mark_as_deleted();
|
||||||
self.store.write().await.delete_device(device).await?;
|
self.store.delete_device(device).await?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -483,11 +456,7 @@ impl OlmMachine {
|
||||||
let changed_devices = self
|
let changed_devices = self
|
||||||
.handle_devices_from_key_query(&response.device_keys)
|
.handle_devices_from_key_query(&response.device_keys)
|
||||||
.await?;
|
.await?;
|
||||||
self.store
|
self.store.save_devices(&changed_devices).await?;
|
||||||
.write()
|
|
||||||
.await
|
|
||||||
.save_devices(&changed_devices)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
Ok(changed_devices)
|
Ok(changed_devices)
|
||||||
}
|
}
|
||||||
|
@ -511,7 +480,7 @@ impl OlmMachine {
|
||||||
sender_key: &str,
|
sender_key: &str,
|
||||||
message: &OlmMessage,
|
message: &OlmMessage,
|
||||||
) -> OlmResult<Option<String>> {
|
) -> OlmResult<Option<String>> {
|
||||||
let s = self.store.write().await.get_sessions(sender_key).await?;
|
let s = self.store.get_sessions(sender_key).await?;
|
||||||
|
|
||||||
// We don't have any existing sessions, return early.
|
// We don't have any existing sessions, return early.
|
||||||
let sessions = if let Some(s) = s {
|
let sessions = if let Some(s) = s {
|
||||||
|
@ -561,7 +530,7 @@ impl OlmMachine {
|
||||||
// Decryption was successful, save the new ratchet state of the
|
// Decryption was successful, save the new ratchet state of the
|
||||||
// session that was used to decrypt the message.
|
// session that was used to decrypt the message.
|
||||||
trace!("Saved the new session state for {}", sender);
|
trace!("Saved the new session state for {}", sender);
|
||||||
self.store.write().await.save_sessions(&[session]).await?;
|
self.store.save_sessions(&[session]).await?;
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(plaintext)
|
Ok(plaintext)
|
||||||
|
@ -616,11 +585,7 @@ impl OlmMachine {
|
||||||
|
|
||||||
// Save the account since we remove the one-time key that
|
// Save the account since we remove the one-time key that
|
||||||
// was used to create this session.
|
// was used to create this session.
|
||||||
self.store
|
self.store.save_account(self.account.clone()).await?;
|
||||||
.write()
|
|
||||||
.await
|
|
||||||
.save_account(self.account.clone())
|
|
||||||
.await?;
|
|
||||||
session
|
session
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -630,7 +595,7 @@ impl OlmMachine {
|
||||||
let plaintext = session.decrypt(message).await?;
|
let plaintext = session.decrypt(message).await?;
|
||||||
|
|
||||||
// Save the new ratcheted state of the session.
|
// Save the new ratcheted state of the session.
|
||||||
self.store.write().await.save_sessions(&[session]).await?;
|
self.store.save_sessions(&[session]).await?;
|
||||||
plaintext
|
plaintext
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -781,12 +746,7 @@ impl OlmMachine {
|
||||||
&event.content.room_id,
|
&event.content.room_id,
|
||||||
session_key,
|
session_key,
|
||||||
)?;
|
)?;
|
||||||
let _ = self
|
let _ = self.store.save_inbound_group_session(session).await?;
|
||||||
.store
|
|
||||||
.write()
|
|
||||||
.await
|
|
||||||
.save_inbound_group_session(session)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
let event = Raw::from(AnyToDeviceEvent::RoomKey(event.clone()));
|
let event = Raw::from(AnyToDeviceEvent::RoomKey(event.clone()));
|
||||||
Ok(Some(event))
|
Ok(Some(event))
|
||||||
|
@ -808,12 +768,7 @@ impl OlmMachine {
|
||||||
async fn create_outbound_group_session(&self, room_id: &RoomId) -> OlmResult<()> {
|
async fn create_outbound_group_session(&self, room_id: &RoomId) -> OlmResult<()> {
|
||||||
let (outbound, inbound) = self.account.create_group_session_pair(room_id).await;
|
let (outbound, inbound) = self.account.create_group_session_pair(room_id).await;
|
||||||
|
|
||||||
let _ = self
|
let _ = self.store.save_inbound_group_session(inbound).await?;
|
||||||
.store
|
|
||||||
.write()
|
|
||||||
.await
|
|
||||||
.save_inbound_group_session(inbound)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
let _ = self
|
let _ = self
|
||||||
.outbound_group_sessions
|
.outbound_group_sessions
|
||||||
|
@ -899,8 +854,7 @@ impl OlmMachine {
|
||||||
return Err(EventError::MissingSenderKey.into());
|
return Err(EventError::MissingSenderKey.into());
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut session = if let Some(s) = self.store.write().await.get_sessions(sender_key).await?
|
let mut session = if let Some(s) = self.store.get_sessions(sender_key).await? {
|
||||||
{
|
|
||||||
let session = &s.lock().await[0];
|
let session = &s.lock().await[0];
|
||||||
session.clone()
|
session.clone()
|
||||||
} else {
|
} else {
|
||||||
|
@ -914,7 +868,7 @@ impl OlmMachine {
|
||||||
};
|
};
|
||||||
|
|
||||||
let message = session.encrypt(recipient_device, event_type, content).await;
|
let message = session.encrypt(recipient_device, event_type, content).await;
|
||||||
self.store.write().await.save_sessions(&[session]).await?;
|
self.store.save_sessions(&[session]).await?;
|
||||||
|
|
||||||
message
|
message
|
||||||
}
|
}
|
||||||
|
@ -978,14 +932,7 @@ impl OlmMachine {
|
||||||
let mut devices = Vec::new();
|
let mut devices = Vec::new();
|
||||||
|
|
||||||
for user_id in users {
|
for user_id in users {
|
||||||
for device in self
|
for device in self.store.get_user_devices(user_id).await?.devices() {
|
||||||
.store
|
|
||||||
.read()
|
|
||||||
.await
|
|
||||||
.get_user_devices(user_id)
|
|
||||||
.await?
|
|
||||||
.devices()
|
|
||||||
{
|
|
||||||
devices.push(device.clone());
|
devices.push(device.clone());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1192,8 +1139,6 @@ impl OlmMachine {
|
||||||
|
|
||||||
let session = self
|
let session = self
|
||||||
.store
|
.store
|
||||||
.write()
|
|
||||||
.await
|
|
||||||
.get_inbound_group_session(room_id, &content.sender_key, &content.session_id)
|
.get_inbound_group_session(room_id, &content.sender_key, &content.session_id)
|
||||||
.await?;
|
.await?;
|
||||||
// TODO check if the Olm session is wedged and re-request the key.
|
// TODO check if the Olm session is wedged and re-request the key.
|
||||||
|
@ -1219,12 +1164,8 @@ impl OlmMachine {
|
||||||
///
|
///
|
||||||
/// Returns true if the user was queued up for a key query, false otherwise.
|
/// Returns true if the user was queued up for a key query, false otherwise.
|
||||||
pub async fn mark_user_as_changed(&self, user_id: &UserId) -> StoreResult<bool> {
|
pub async fn mark_user_as_changed(&self, user_id: &UserId) -> StoreResult<bool> {
|
||||||
if self.store.read().await.is_user_tracked(user_id) {
|
if self.store.is_user_tracked(user_id) {
|
||||||
self.store
|
self.store.update_tracked_user(user_id, true).await?;
|
||||||
.write()
|
|
||||||
.await
|
|
||||||
.update_tracked_user(user_id, true)
|
|
||||||
.await?;
|
|
||||||
Ok(true)
|
Ok(true)
|
||||||
} else {
|
} else {
|
||||||
Ok(false)
|
Ok(false)
|
||||||
|
@ -1250,17 +1191,11 @@ impl OlmMachine {
|
||||||
I: IntoIterator<Item = &'a UserId>,
|
I: IntoIterator<Item = &'a UserId>,
|
||||||
{
|
{
|
||||||
for user in users {
|
for user in users {
|
||||||
if self.store.read().await.is_user_tracked(user) {
|
if self.store.is_user_tracked(user) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Err(e) = self
|
if let Err(e) = self.store.update_tracked_user(user, true).await {
|
||||||
.store
|
|
||||||
.write()
|
|
||||||
.await
|
|
||||||
.update_tracked_user(user, true)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
warn!("Error storing users for tracking {}", e);
|
warn!("Error storing users for tracking {}", e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1268,14 +1203,14 @@ impl OlmMachine {
|
||||||
|
|
||||||
/// Should the client perform a key query request.
|
/// Should the client perform a key query request.
|
||||||
pub async fn should_query_keys(&self) -> bool {
|
pub async fn should_query_keys(&self) -> bool {
|
||||||
self.store.read().await.has_users_for_key_query()
|
self.store.has_users_for_key_query()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get the set of users that we need to query keys for.
|
/// Get the set of users that we need to query keys for.
|
||||||
///
|
///
|
||||||
/// Returns a hash set of users that need to be queried for keys.
|
/// Returns a hash set of users that need to be queried for keys.
|
||||||
pub async fn users_for_key_query(&self) -> HashSet<UserId> {
|
pub async fn users_for_key_query(&self) -> HashSet<UserId> {
|
||||||
self.store.read().await.users_for_key_query()
|
self.store.users_for_key_query()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1398,19 +1333,8 @@ mod test {
|
||||||
|
|
||||||
let alice_deivce = Device::from_machine(&alice).await;
|
let alice_deivce = Device::from_machine(&alice).await;
|
||||||
let bob_device = Device::from_machine(&bob).await;
|
let bob_device = Device::from_machine(&bob).await;
|
||||||
alice
|
alice.store.save_devices(&[bob_device]).await.unwrap();
|
||||||
.store
|
bob.store.save_devices(&[alice_deivce]).await.unwrap();
|
||||||
.write()
|
|
||||||
.await
|
|
||||||
.save_devices(&[bob_device])
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
bob.store
|
|
||||||
.write()
|
|
||||||
.await
|
|
||||||
.save_devices(&[alice_deivce])
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
(alice, bob, otk)
|
(alice, bob, otk)
|
||||||
}
|
}
|
||||||
|
@ -1443,8 +1367,6 @@ mod test {
|
||||||
|
|
||||||
let bob_device = alice
|
let bob_device = alice
|
||||||
.store
|
.store
|
||||||
.read()
|
|
||||||
.await
|
|
||||||
.get_device(&bob.user_id, &bob.device_id)
|
.get_device(&bob.user_id, &bob.device_id)
|
||||||
.await
|
.await
|
||||||
.unwrap()
|
.unwrap()
|
||||||
|
@ -1649,13 +1571,7 @@ mod test {
|
||||||
let alice_id = user_id!("@alice:example.org");
|
let alice_id = user_id!("@alice:example.org");
|
||||||
let alice_device_id: &DeviceId = "JLAFKJWSCS".into();
|
let alice_device_id: &DeviceId = "JLAFKJWSCS".into();
|
||||||
|
|
||||||
let alice_devices = machine
|
let alice_devices = machine.store.get_user_devices(&alice_id).await.unwrap();
|
||||||
.store
|
|
||||||
.read()
|
|
||||||
.await
|
|
||||||
.get_user_devices(&alice_id)
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert!(alice_devices.devices().peekable().peek().is_none());
|
assert!(alice_devices.devices().peekable().peek().is_none());
|
||||||
|
|
||||||
machine
|
machine
|
||||||
|
@ -1665,8 +1581,6 @@ mod test {
|
||||||
|
|
||||||
let device = machine
|
let device = machine
|
||||||
.store
|
.store
|
||||||
.read()
|
|
||||||
.await
|
|
||||||
.get_device(&alice_id, alice_device_id)
|
.get_device(&alice_id, alice_device_id)
|
||||||
.await
|
.await
|
||||||
.unwrap()
|
.unwrap()
|
||||||
|
@ -1718,8 +1632,6 @@ mod test {
|
||||||
|
|
||||||
let session = alice_machine
|
let session = alice_machine
|
||||||
.store
|
.store
|
||||||
.write()
|
|
||||||
.await
|
|
||||||
.get_sessions(bob_machine.account.identity_keys().curve25519())
|
.get_sessions(bob_machine.account.identity_keys().curve25519())
|
||||||
.await
|
.await
|
||||||
.unwrap()
|
.unwrap()
|
||||||
|
@ -1734,8 +1646,6 @@ mod test {
|
||||||
|
|
||||||
let bob_device = alice
|
let bob_device = alice
|
||||||
.store
|
.store
|
||||||
.read()
|
|
||||||
.await
|
|
||||||
.get_device(&bob.user_id, &bob.device_id)
|
.get_device(&bob.user_id, &bob.device_id)
|
||||||
.await
|
.await
|
||||||
.unwrap()
|
.unwrap()
|
||||||
|
@ -1797,8 +1707,6 @@ mod test {
|
||||||
|
|
||||||
let session = bob
|
let session = bob
|
||||||
.store
|
.store
|
||||||
.write()
|
|
||||||
.await
|
|
||||||
.get_inbound_group_session(
|
.get_inbound_group_session(
|
||||||
&room_id,
|
&room_id,
|
||||||
alice.account.identity_keys().curve25519(),
|
alice.account.identity_keys().curve25519(),
|
||||||
|
|
|
@ -26,7 +26,7 @@ use crate::{
|
||||||
device::Device,
|
device::Device,
|
||||||
memory_stores::{DeviceStore, GroupSessionStore, SessionStore, UserDevices},
|
memory_stores::{DeviceStore, GroupSessionStore, SessionStore, UserDevices},
|
||||||
};
|
};
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct MemoryStore {
|
pub struct MemoryStore {
|
||||||
sessions: SessionStore,
|
sessions: SessionStore,
|
||||||
inbound_group_sessions: GroupSessionStore,
|
inbound_group_sessions: GroupSessionStore,
|
||||||
|
|
|
@ -41,6 +41,7 @@ use crate::{
|
||||||
};
|
};
|
||||||
|
|
||||||
/// SQLite based implementation of a `CryptoStore`.
|
/// SQLite based implementation of a `CryptoStore`.
|
||||||
|
#[derive(Clone)]
|
||||||
pub struct SqliteStore {
|
pub struct SqliteStore {
|
||||||
user_id: Arc<UserId>,
|
user_id: Arc<UserId>,
|
||||||
device_id: Arc<Box<DeviceId>>,
|
device_id: Arc<Box<DeviceId>>,
|
||||||
|
|
|
@ -22,7 +22,6 @@ use matrix_sdk_common::{
|
||||||
api::r0::to_device::send_event_to_device::IncomingRequest as OwnedToDeviceRequest,
|
api::r0::to_device::send_event_to_device::IncomingRequest as OwnedToDeviceRequest,
|
||||||
events::{AnyToDeviceEvent, AnyToDeviceEventContent},
|
events::{AnyToDeviceEvent, AnyToDeviceEventContent},
|
||||||
identifiers::{DeviceId, UserId},
|
identifiers::{DeviceId, UserId},
|
||||||
locks::RwLock,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::sas::{content_to_request, Sas};
|
use super::sas::{content_to_request, Sas};
|
||||||
|
@ -31,13 +30,13 @@ use crate::{Account, CryptoStore, CryptoStoreError};
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct VerificationMachine {
|
pub struct VerificationMachine {
|
||||||
account: Account,
|
account: Account,
|
||||||
store: Arc<RwLock<Box<dyn CryptoStore>>>,
|
store: Arc<Box<dyn CryptoStore>>,
|
||||||
verifications: Arc<DashMap<String, Sas>>,
|
verifications: Arc<DashMap<String, Sas>>,
|
||||||
outgoing_to_device_messages: Arc<DashMap<String, OwnedToDeviceRequest>>,
|
outgoing_to_device_messages: Arc<DashMap<String, OwnedToDeviceRequest>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl VerificationMachine {
|
impl VerificationMachine {
|
||||||
pub(crate) fn new(account: Account, store: Arc<RwLock<Box<dyn CryptoStore>>>) -> Self {
|
pub(crate) fn new(account: Account, store: Arc<Box<dyn CryptoStore>>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
account,
|
account,
|
||||||
store,
|
store,
|
||||||
|
@ -112,8 +111,6 @@ impl VerificationMachine {
|
||||||
|
|
||||||
if let Some(d) = self
|
if let Some(d) = self
|
||||||
.store
|
.store
|
||||||
.read()
|
|
||||||
.await
|
|
||||||
.get_device(&e.sender, &e.content.from_device)
|
.get_device(&e.sender, &e.content.from_device)
|
||||||
.await?
|
.await?
|
||||||
{
|
{
|
||||||
|
@ -179,7 +176,6 @@ mod test {
|
||||||
use matrix_sdk_common::{
|
use matrix_sdk_common::{
|
||||||
events::AnyToDeviceEventContent,
|
events::AnyToDeviceEventContent,
|
||||||
identifiers::{DeviceId, UserId},
|
identifiers::{DeviceId, UserId},
|
||||||
locks::RwLock,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::{Sas, VerificationMachine};
|
use super::{Sas, VerificationMachine};
|
||||||
|
@ -209,21 +205,18 @@ mod test {
|
||||||
let alice = Account::new(&alice_id(), &alice_device_id());
|
let alice = Account::new(&alice_id(), &alice_device_id());
|
||||||
let bob = Account::new(&bob_id(), &bob_device_id());
|
let bob = Account::new(&bob_id(), &bob_device_id());
|
||||||
let store = MemoryStore::new();
|
let store = MemoryStore::new();
|
||||||
let bob_store: Arc<RwLock<Box<dyn CryptoStore>>> =
|
let bob_store: Arc<Box<dyn CryptoStore>> = Arc::new(Box::new(MemoryStore::new()));
|
||||||
Arc::new(RwLock::new(Box::new(MemoryStore::new())));
|
|
||||||
|
|
||||||
let bob_device = Device::from_account(&bob).await;
|
let bob_device = Device::from_account(&bob).await;
|
||||||
let alice_device = Device::from_account(&alice).await;
|
let alice_device = Device::from_account(&alice).await;
|
||||||
|
|
||||||
store.save_devices(&[bob_device]).await.unwrap();
|
store.save_devices(&[bob_device]).await.unwrap();
|
||||||
bob_store
|
bob_store
|
||||||
.read()
|
|
||||||
.await
|
|
||||||
.save_devices(&[alice_device.clone()])
|
.save_devices(&[alice_device.clone()])
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let machine = VerificationMachine::new(alice, Arc::new(RwLock::new(Box::new(store))));
|
let machine = VerificationMachine::new(alice, Arc::new(Box::new(store)));
|
||||||
let (bob_sas, start_content) = Sas::start(bob, alice_device, bob_store);
|
let (bob_sas, start_content) = Sas::start(bob, alice_device, bob_store);
|
||||||
machine
|
machine
|
||||||
.receive_event(&mut wrap_any_to_device_content(
|
.receive_event(&mut wrap_any_to_device_content(
|
||||||
|
@ -240,7 +233,7 @@ mod test {
|
||||||
fn create() {
|
fn create() {
|
||||||
let alice = Account::new(&alice_id(), &alice_device_id());
|
let alice = Account::new(&alice_id(), &alice_device_id());
|
||||||
let store = MemoryStore::new();
|
let store = MemoryStore::new();
|
||||||
let _ = VerificationMachine::new(alice, Arc::new(RwLock::new(Box::new(store))));
|
let _ = VerificationMachine::new(alice, Arc::new(Box::new(store)));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
|
|
|
@ -212,8 +212,6 @@ fn extra_mac_info_send(ids: &SasIds, flow_id: &str) -> String {
|
||||||
///
|
///
|
||||||
/// * `flow_id` - The unique id that identifies this SAS verification process.
|
/// * `flow_id` - The unique id that identifies this SAS verification process.
|
||||||
///
|
///
|
||||||
/// * `we_started` - Flag signaling if the SAS process was started on our side.
|
|
||||||
///
|
|
||||||
/// # Panics
|
/// # Panics
|
||||||
///
|
///
|
||||||
/// This will panic if the public key of the other side wasn't set.
|
/// This will panic if the public key of the other side wasn't set.
|
||||||
|
|
|
@ -31,7 +31,6 @@ use matrix_sdk_common::{
|
||||||
AnyToDeviceEvent, AnyToDeviceEventContent, ToDeviceEvent,
|
AnyToDeviceEvent, AnyToDeviceEventContent, ToDeviceEvent,
|
||||||
},
|
},
|
||||||
identifiers::{DeviceId, UserId},
|
identifiers::{DeviceId, UserId},
|
||||||
locks::RwLock,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::{Account, CryptoStore, CryptoStoreError, Device, TrustState};
|
use crate::{Account, CryptoStore, CryptoStoreError, Device, TrustState};
|
||||||
|
@ -45,7 +44,7 @@ use sas_state::{
|
||||||
/// Short authentication string object.
|
/// Short authentication string object.
|
||||||
pub struct Sas {
|
pub struct Sas {
|
||||||
inner: Arc<Mutex<InnerSas>>,
|
inner: Arc<Mutex<InnerSas>>,
|
||||||
store: Arc<RwLock<Box<dyn CryptoStore>>>,
|
store: Arc<Box<dyn CryptoStore>>,
|
||||||
account: Account,
|
account: Account,
|
||||||
other_device: Device,
|
other_device: Device,
|
||||||
flow_id: Arc<String>,
|
flow_id: Arc<String>,
|
||||||
|
@ -100,7 +99,7 @@ impl Sas {
|
||||||
pub(crate) fn start(
|
pub(crate) fn start(
|
||||||
account: Account,
|
account: Account,
|
||||||
other_device: Device,
|
other_device: Device,
|
||||||
store: Arc<RwLock<Box<dyn CryptoStore>>>,
|
store: Arc<Box<dyn CryptoStore>>,
|
||||||
) -> (Sas, StartEventContent) {
|
) -> (Sas, StartEventContent) {
|
||||||
let (inner, content) = InnerSas::start(account.clone(), other_device.clone());
|
let (inner, content) = InnerSas::start(account.clone(), other_device.clone());
|
||||||
let flow_id = inner.verification_flow_id();
|
let flow_id = inner.verification_flow_id();
|
||||||
|
@ -129,7 +128,7 @@ impl Sas {
|
||||||
pub(crate) fn from_start_event(
|
pub(crate) fn from_start_event(
|
||||||
account: Account,
|
account: Account,
|
||||||
other_device: Device,
|
other_device: Device,
|
||||||
store: Arc<RwLock<Box<dyn CryptoStore>>>,
|
store: Arc<Box<dyn CryptoStore>>,
|
||||||
event: &ToDeviceEvent<StartEventContent>,
|
event: &ToDeviceEvent<StartEventContent>,
|
||||||
) -> Result<Sas, AnyToDeviceEventContent> {
|
) -> Result<Sas, AnyToDeviceEventContent> {
|
||||||
let inner = InnerSas::from_start_event(account.clone(), other_device.clone(), event)?;
|
let inner = InnerSas::from_start_event(account.clone(), other_device.clone(), event)?;
|
||||||
|
@ -184,8 +183,6 @@ impl Sas {
|
||||||
pub(crate) async fn mark_device_as_verified(&self) -> Result<bool, CryptoStoreError> {
|
pub(crate) async fn mark_device_as_verified(&self) -> Result<bool, CryptoStoreError> {
|
||||||
let device = self
|
let device = self
|
||||||
.store
|
.store
|
||||||
.read()
|
|
||||||
.await
|
|
||||||
.get_device(self.other_user_id(), self.other_device_id())
|
.get_device(self.other_user_id(), self.other_device_id())
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
@ -202,7 +199,7 @@ impl Sas {
|
||||||
);
|
);
|
||||||
|
|
||||||
device.set_trust_state(TrustState::Verified);
|
device.set_trust_state(TrustState::Verified);
|
||||||
self.store.read().await.save_devices(&[device]).await?;
|
self.store.save_devices(&[device]).await?;
|
||||||
|
|
||||||
Ok(true)
|
Ok(true)
|
||||||
} else {
|
} else {
|
||||||
|
@ -560,7 +557,6 @@ mod test {
|
||||||
use matrix_sdk_common::{
|
use matrix_sdk_common::{
|
||||||
events::{EventContent, ToDeviceEvent},
|
events::{EventContent, ToDeviceEvent},
|
||||||
identifiers::{DeviceId, UserId},
|
identifiers::{DeviceId, UserId},
|
||||||
locks::RwLock,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
|
@ -685,14 +681,10 @@ mod test {
|
||||||
let bob = Account::new(&bob_id(), &bob_device_id());
|
let bob = Account::new(&bob_id(), &bob_device_id());
|
||||||
let bob_device = Device::from_account(&bob).await;
|
let bob_device = Device::from_account(&bob).await;
|
||||||
|
|
||||||
let alice_store: Arc<RwLock<Box<dyn CryptoStore>>> =
|
let alice_store: Arc<Box<dyn CryptoStore>> = Arc::new(Box::new(MemoryStore::new()));
|
||||||
Arc::new(RwLock::new(Box::new(MemoryStore::new())));
|
let bob_store: Arc<Box<dyn CryptoStore>> = Arc::new(Box::new(MemoryStore::new()));
|
||||||
let bob_store: Arc<RwLock<Box<dyn CryptoStore>>> =
|
|
||||||
Arc::new(RwLock::new(Box::new(MemoryStore::new())));
|
|
||||||
|
|
||||||
bob_store
|
bob_store
|
||||||
.read()
|
|
||||||
.await
|
|
||||||
.save_devices(&[alice_device.clone()])
|
.save_devices(&[alice_device.clone()])
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
Loading…
Reference in New Issue