crypto: Rework the cryptostore.

This modifies the cryptostore and storage logic in two ways:
    * The cryptostore trait has only one main save method.
    * The receive_sync method tries to save all the objects in one
    `save_changes()` call.

This means that all the changes a sync makes get commited to the store
in one transaction, leaving us in a consistent state.

This also means that we can pass the Changes struct the receive sync
method collects to our caller if the caller wishes to store the room
state and crypto state changes in a single transaction.
master
Damir Jelić 2020-10-20 17:19:37 +02:00
parent 728d80ed06
commit 7cab7cadc9
13 changed files with 711 additions and 477 deletions

View File

@ -971,8 +971,6 @@ impl BaseClient {
return Ok(()); return Ok(());
} }
*self.sync_token.write().await = Some(response.next_batch.clone());
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
{ {
let olm = self.olm.lock().await; let olm = self.olm.lock().await;
@ -982,10 +980,12 @@ impl BaseClient {
// decryptes to-device events, but leaves room events alone. // decryptes to-device events, but leaves room events alone.
// This makes sure that we have the deryption keys for the room // This makes sure that we have the deryption keys for the room
// events at hand. // events at hand.
o.receive_sync_response(response).await; o.receive_sync_response(response).await?;
} }
} }
*self.sync_token.write().await = Some(response.next_batch.clone());
// when events change state, updated_* signals to StateStore to update database // when events change state, updated_* signals to StateStore to update database
self.iter_joined_rooms(response).await?; self.iter_joined_rooms(response).await?;
self.iter_invited_rooms(response).await?; self.iter_invited_rooms(response).await?;

View File

@ -39,7 +39,10 @@ use serde::{Deserialize, Serialize};
use serde_json::{json, Value}; use serde_json::{json, Value};
use tracing::warn; use tracing::warn;
use crate::olm::{InboundGroupSession, Session}; use crate::{
olm::{InboundGroupSession, Session},
store::{Changes, DeviceChanges},
};
#[cfg(test)] #[cfg(test)]
use crate::{OlmMachine, ReadOnlyAccount}; use crate::{OlmMachine, ReadOnlyAccount};
@ -118,10 +121,15 @@ impl Device {
pub async fn set_local_trust(&self, trust_state: LocalTrust) -> StoreResult<()> { pub async fn set_local_trust(&self, trust_state: LocalTrust) -> StoreResult<()> {
self.inner.set_trust_state(trust_state); self.inner.set_trust_state(trust_state);
self.verification_machine let changes = Changes {
.store devices: DeviceChanges {
.save_devices(&[self.inner.clone()]) changed: vec![self.inner.clone()],
.await ..Default::default()
},
..Default::default()
};
self.verification_machine.store.save_changes(changes).await
} }
/// Encrypt the given content for this `Device`. /// Encrypt the given content for this `Device`.
@ -135,7 +143,7 @@ impl Device {
&self, &self,
event_type: EventType, event_type: EventType,
content: Value, content: Value,
) -> OlmResult<EncryptedEventContent> { ) -> OlmResult<(Session, EncryptedEventContent)> {
self.inner self.inner
.encrypt(&**self.verification_machine.store, event_type, content) .encrypt(&**self.verification_machine.store, event_type, content)
.await .await
@ -146,7 +154,7 @@ impl Device {
pub async fn encrypt_session( pub async fn encrypt_session(
&self, &self,
session: InboundGroupSession, session: InboundGroupSession,
) -> OlmResult<EncryptedEventContent> { ) -> OlmResult<(Session, EncryptedEventContent)> {
let export = session.export().await; let export = session.export().await;
let content: ForwardedRoomKeyEventContent = if let Ok(c) = export.try_into() { let content: ForwardedRoomKeyEventContent = if let Ok(c) = export.try_into() {
@ -364,7 +372,7 @@ impl ReadOnlyDevice {
store: &dyn CryptoStore, store: &dyn CryptoStore,
event_type: EventType, event_type: EventType,
content: Value, content: Value,
) -> OlmResult<EncryptedEventContent> { ) -> OlmResult<(Session, EncryptedEventContent)> {
let sender_key = if let Some(k) = self.get_key(DeviceKeyAlgorithm::Curve25519) { let sender_key = if let Some(k) = self.get_key(DeviceKeyAlgorithm::Curve25519) {
k k
} else { } else {
@ -396,10 +404,9 @@ impl ReadOnlyDevice {
return Err(OlmError::MissingSession); return Err(OlmError::MissingSession);
}; };
let message = session.encrypt(&self, event_type, content).await; let message = session.encrypt(&self, event_type, content).await?;
store.save_sessions(&[session]).await?;
message Ok((session, message))
} }
/// Update a device with a new device keys struct. /// Update a device with a new device keys struct.

View File

@ -33,7 +33,7 @@ use crate::{
}, },
requests::KeysQueryRequest, requests::KeysQueryRequest,
session_manager::GroupSessionManager, session_manager::GroupSessionManager,
store::{Result as StoreResult, Store}, store::{Changes, DeviceChanges, IdentityChanges, Result as StoreResult, Store},
}; };
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -79,7 +79,7 @@ impl IdentityManager {
pub async fn receive_keys_query_response( pub async fn receive_keys_query_response(
&self, &self,
response: &KeysQueryResponse, response: &KeysQueryResponse,
) -> OlmResult<(Vec<ReadOnlyDevice>, Vec<UserIdentities>)> { ) -> OlmResult<(DeviceChanges, IdentityChanges)> {
// TODO create a enum that tells us how the device/identity changed, // TODO create a enum that tells us how the device/identity changed,
// e.g. new/deleted/display name change. // e.g. new/deleted/display name change.
// //
@ -92,9 +92,15 @@ impl IdentityManager {
let changed_devices = self let changed_devices = self
.handle_devices_from_key_query(&response.device_keys) .handle_devices_from_key_query(&response.device_keys)
.await?; .await?;
self.store.save_devices(&changed_devices).await?;
let changed_identities = self.handle_cross_singing_keys(response).await?; let changed_identities = self.handle_cross_singing_keys(response).await?;
self.store.save_user_identities(&changed_identities).await?;
let changes = Changes {
identities: changed_identities.clone(),
devices: changed_devices.clone(),
..Default::default()
};
self.store.save_changes(changes).await?;
Ok((changed_devices, changed_identities)) Ok((changed_devices, changed_identities))
} }
@ -111,9 +117,10 @@ impl IdentityManager {
async fn handle_devices_from_key_query( async fn handle_devices_from_key_query(
&self, &self,
device_keys_map: &BTreeMap<UserId, BTreeMap<DeviceIdBox, DeviceKeys>>, device_keys_map: &BTreeMap<UserId, BTreeMap<DeviceIdBox, DeviceKeys>>,
) -> StoreResult<Vec<ReadOnlyDevice>> { ) -> StoreResult<DeviceChanges> {
let mut users_with_new_or_deleted_devices = HashSet::new(); let mut users_with_new_or_deleted_devices = HashSet::new();
let mut changed_devices = Vec::new();
let mut changes = DeviceChanges::default();
for (user_id, device_map) in device_keys_map { for (user_id, device_map) in device_keys_map {
// TODO move this out into the handle keys query response method // TODO move this out into the handle keys query response method
@ -137,7 +144,7 @@ impl IdentityManager {
let device = self.store.get_readonly_device(&user_id, device_id).await?; let device = self.store.get_readonly_device(&user_id, device_id).await?;
let device = if let Some(mut device) = device { if let Some(mut device) = device {
if let Err(e) = device.update_device(device_keys) { if let Err(e) = device.update_device(device_keys) {
warn!( warn!(
"Failed to update the device keys for {} {}: {:?}", "Failed to update the device keys for {} {}: {:?}",
@ -145,7 +152,7 @@ impl IdentityManager {
); );
continue; continue;
} }
device changes.changed.push(device);
} else { } else {
let device = match ReadOnlyDevice::try_from(device_keys) { let device = match ReadOnlyDevice::try_from(device_keys) {
Ok(d) => d, Ok(d) => d,
@ -159,23 +166,21 @@ impl IdentityManager {
}; };
info!("Adding a new device to the device store {:?}", device); info!("Adding a new device to the device store {:?}", device);
users_with_new_or_deleted_devices.insert(user_id); users_with_new_or_deleted_devices.insert(user_id);
device changes.new.push(device);
}; }
changed_devices.push(device);
} }
let current_devices: HashSet<&DeviceIdBox> = device_map.keys().collect(); let current_devices: HashSet<&DeviceIdBox> = device_map.keys().collect();
let stored_devices = self.store.get_readonly_devices(&user_id).await?; let stored_devices = self.store.get_readonly_devices(&user_id).await?;
let stored_devices_set: HashSet<&DeviceIdBox> = stored_devices.keys().collect(); let stored_devices_set: HashSet<&DeviceIdBox> = stored_devices.keys().collect();
let deleted_devices = stored_devices_set.difference(&current_devices); let deleted_devices_set = stored_devices_set.difference(&current_devices);
for device_id in deleted_devices { for device_id in deleted_devices_set {
users_with_new_or_deleted_devices.insert(user_id); users_with_new_or_deleted_devices.insert(user_id);
if let Some(device) = stored_devices.get(*device_id) { if let Some(device) = stored_devices.get(*device_id) {
device.mark_as_deleted(); device.mark_as_deleted();
self.store.delete_device(device.clone()).await?; changes.deleted.push(device.clone());
} }
} }
} }
@ -183,7 +188,7 @@ impl IdentityManager {
self.group_manager self.group_manager
.invalidate_sessions_new_devices(&users_with_new_or_deleted_devices); .invalidate_sessions_new_devices(&users_with_new_or_deleted_devices);
Ok(changed_devices) Ok(changes)
} }
/// Handle the device keys part of a key query response. /// Handle the device keys part of a key query response.
@ -197,8 +202,8 @@ impl IdentityManager {
async fn handle_cross_singing_keys( async fn handle_cross_singing_keys(
&self, &self,
response: &KeysQueryResponse, response: &KeysQueryResponse,
) -> StoreResult<Vec<UserIdentities>> { ) -> StoreResult<IdentityChanges> {
let mut changed = Vec::new(); let mut changes = IdentityChanges::default();
for (user_id, master_key) in &response.master_keys { for (user_id, master_key) in &response.master_keys {
let master_key = MasterPubkey::from(master_key); let master_key = MasterPubkey::from(master_key);
@ -213,7 +218,7 @@ impl IdentityManager {
continue; continue;
}; };
let identity = if let Some(mut i) = self.store.get_user_identity(user_id).await? { let result = if let Some(mut i) = self.store.get_user_identity(user_id).await? {
match &mut i { match &mut i {
UserIdentities::Own(ref mut identity) => { UserIdentities::Own(ref mut identity) => {
let user_signing = if let Some(s) = response.user_signing_keys.get(user_id) let user_signing = if let Some(s) = response.user_signing_keys.get(user_id)
@ -230,11 +235,11 @@ impl IdentityManager {
identity identity
.update(master_key, self_signing, user_signing) .update(master_key, self_signing, user_signing)
.map(|_| i) .map(|_| (i, false))
}
UserIdentities::Other(ref mut identity) => {
identity.update(master_key, self_signing).map(|_| i)
} }
UserIdentities::Other(ref mut identity) => identity
.update(master_key, self_signing)
.map(|_| (i, false)),
} }
} else if user_id == self.user_id() { } else if user_id == self.user_id() {
if let Some(s) = response.user_signing_keys.get(user_id) { if let Some(s) = response.user_signing_keys.get(user_id) {
@ -252,7 +257,7 @@ impl IdentityManager {
} }
OwnUserIdentity::new(master_key, self_signing, user_signing) OwnUserIdentity::new(master_key, self_signing, user_signing)
.map(UserIdentities::Own) .map(|i| (UserIdentities::Own(i), true))
} else { } else {
warn!( warn!(
"User identity for our own user {} didn't contain a \ "User identity for our own user {} didn't contain a \
@ -268,17 +273,22 @@ impl IdentityManager {
); );
continue; continue;
} else { } else {
UserIdentity::new(master_key, self_signing).map(UserIdentities::Other) UserIdentity::new(master_key, self_signing)
.map(|i| (UserIdentities::Other(i), true))
}; };
match identity { match result {
Ok(i) => { Ok((i, new)) => {
trace!( trace!(
"Updated or created new user identity for {}: {:?}", "Updated or created new user identity for {}: {:?}",
user_id, user_id,
i i
); );
changed.push(i); if new {
changes.new.push(i);
} else {
changes.changed.push(i);
}
} }
Err(e) => { Err(e) => {
warn!( warn!(
@ -290,7 +300,7 @@ impl IdentityManager {
} }
} }
Ok(changed) Ok(changes)
} }
/// Get a key query request if one is needed. /// Get a key query request if one is needed.

View File

@ -41,7 +41,7 @@ use matrix_sdk_common::{
use crate::{ use crate::{
error::{OlmError, OlmResult}, error::{OlmError, OlmResult},
olm::{InboundGroupSession, OutboundGroupSession}, olm::{InboundGroupSession, OutboundGroupSession, Session},
requests::{OutgoingRequest, ToDeviceRequest}, requests::{OutgoingRequest, ToDeviceRequest},
store::{CryptoStoreError, Store}, store::{CryptoStoreError, Store},
Device, Device,
@ -235,15 +235,18 @@ impl KeyRequestMachine {
/// Handle all the incoming key requests that are queued up and empty our /// Handle all the incoming key requests that are queued up and empty our
/// key request queue. /// key request queue.
pub async fn collect_incoming_key_requests(&self) -> OlmResult<()> { pub async fn collect_incoming_key_requests(&self) -> OlmResult<Vec<Session>> {
let mut changed_sessions = Vec::new();
for item in self.incoming_key_requests.iter() { for item in self.incoming_key_requests.iter() {
let event = item.value(); let event = item.value();
self.handle_key_request(event).await?; if let Some(s) = self.handle_key_request(event).await? {
changed_sessions.push(s);
}
} }
self.incoming_key_requests.clear(); self.incoming_key_requests.clear();
Ok(()) Ok(changed_sessions)
} }
/// Store the key share request for later, once we get an Olm session with /// Store the key share request for later, once we get an Olm session with
@ -294,7 +297,7 @@ impl KeyRequestMachine {
async fn handle_key_request( async fn handle_key_request(
&self, &self,
event: &ToDeviceEvent<RoomKeyRequestEventContent>, event: &ToDeviceEvent<RoomKeyRequestEventContent>,
) -> OlmResult<()> { ) -> OlmResult<Option<Session>> {
let key_info = match event.content.action { let key_info = match event.content.action {
Action::Request => { Action::Request => {
if let Some(info) = &event.content.body { if let Some(info) = &event.content.body {
@ -305,11 +308,11 @@ impl KeyRequestMachine {
action, but no key info was found", action, but no key info was found",
event.sender, event.content.requesting_device_id event.sender, event.content.requesting_device_id
); );
return Ok(()); return Ok(None);
} }
} }
// We ignore cancellations here since there's nothing to serve. // We ignore cancellations here since there's nothing to serve.
Action::CancelRequest => return Ok(()), Action::CancelRequest => return Ok(None),
}; };
let session = self let session = self
@ -328,7 +331,7 @@ impl KeyRequestMachine {
"Received a key request from {} {} for an unknown inbound group session {}.", "Received a key request from {} {} for an unknown inbound group session {}.",
&event.sender, &event.content.requesting_device_id, &key_info.session_id &event.sender, &event.content.requesting_device_id, &key_info.session_id
); );
return Ok(()); return Ok(None);
}; };
let device = self let device = self
@ -349,6 +352,8 @@ impl KeyRequestMachine {
device.device_id(), device.device_id(),
e e
); );
Ok(None)
} else { } else {
info!( info!(
"Serving a key request for {} from {} {}.", "Serving a key request for {} from {} {}.",
@ -357,20 +362,20 @@ impl KeyRequestMachine {
device.device_id() device.device_id()
); );
if let Err(e) = self.share_session(&session, &device).await { match self.share_session(&session, &device).await {
match e { Ok(s) => Ok(Some(s)),
OlmError::MissingSession => { Err(OlmError::MissingSession) => {
info!( info!(
"Key request from {} {} is missing an Olm session, \ "Key request from {} {} is missing an Olm session, \
putting the request in the wait queue", putting the request in the wait queue",
device.user_id(), device.user_id(),
device.device_id() device.device_id()
); );
self.handle_key_share_without_session(device, event); self.handle_key_share_without_session(device, event);
return Ok(());
} Ok(None)
e => return Err(e),
} }
Err(e) => Err(e),
} }
} }
} else { } else {
@ -379,13 +384,17 @@ impl KeyRequestMachine {
&event.sender, &event.content.requesting_device_id &event.sender, &event.content.requesting_device_id
); );
self.store.update_tracked_user(&event.sender, true).await?; self.store.update_tracked_user(&event.sender, true).await?;
}
Ok(()) Ok(None)
}
} }
async fn share_session(&self, session: &InboundGroupSession, device: &Device) -> OlmResult<()> { async fn share_session(
let content = device.encrypt_session(session.clone()).await?; &self,
session: &InboundGroupSession,
device: &Device,
) -> OlmResult<Session> {
let (used_session, content) = device.encrypt_session(session.clone()).await?;
let id = Uuid::new_v4(); let id = Uuid::new_v4();
let mut messages = BTreeMap::new(); let mut messages = BTreeMap::new();
@ -412,7 +421,7 @@ impl KeyRequestMachine {
self.outgoing_to_device_requests.insert(id, request); self.outgoing_to_device_requests.insert(id, request);
Ok(()) Ok(used_session)
} }
/// Check if it's ok to share a session with the given device. /// Check if it's ok to share a session with the given device.
@ -569,23 +578,20 @@ impl KeyRequestMachine {
Ok(()) Ok(())
} }
/// Save an inbound group session we received using a key forward. /// Mark the given outgoing key info as done.
/// ///
/// At the same time delete the key info since we received the wanted key. /// This will queue up a request cancelation.
async fn save_session( async fn mark_as_done(&self, key_info: OugoingKeyInfo) -> Result<(), CryptoStoreError> {
&self,
key_info: OugoingKeyInfo,
session: InboundGroupSession,
) -> Result<(), CryptoStoreError> {
// TODO perhaps only remove the key info if the first known index is 0. // TODO perhaps only remove the key info if the first known index is 0.
trace!( trace!(
"Successfully received a forwarded room key for {:#?}", "Successfully received a forwarded room key for {:#?}",
key_info key_info
); );
self.store.save_inbound_group_sessions(&[session]).await?;
self.outgoing_to_device_requests self.outgoing_to_device_requests
.remove(&key_info.request_id); .remove(&key_info.request_id);
// TODO return the key info instead of deleting it so the sync handler
// can delete it in one transaction.
self.delete_key_info(&key_info).await?; self.delete_key_info(&key_info).await?;
let content = RoomKeyRequestEventContent { let content = RoomKeyRequestEventContent {
@ -609,7 +615,8 @@ impl KeyRequestMachine {
&self, &self,
sender_key: &str, sender_key: &str,
event: &mut ToDeviceEvent<ForwardedRoomKeyEventContent>, event: &mut ToDeviceEvent<ForwardedRoomKeyEventContent>,
) -> Result<Option<Raw<AnyToDeviceEvent>>, CryptoStoreError> { ) -> Result<(Option<Raw<AnyToDeviceEvent>>, Option<InboundGroupSession>), CryptoStoreError>
{
let key_info = self.get_key_info(&event.content).await?; let key_info = self.get_key_info(&event.content).await?;
if let Some(info) = key_info { if let Some(info) = key_info {
@ -626,27 +633,32 @@ impl KeyRequestMachine {
// If we have a previous session, check if we have a better version // If we have a previous session, check if we have a better version
// and store the new one if so. // and store the new one if so.
if let Some(old_session) = old_session { let session = if let Some(old_session) = old_session {
let first_old_index = old_session.first_known_index().await; let first_old_index = old_session.first_known_index().await;
let first_index = session.first_known_index().await; let first_index = session.first_known_index().await;
if first_old_index > first_index { if first_old_index > first_index {
self.save_session(info, session).await?; self.mark_as_done(info).await?;
Some(session)
} else {
None
} }
// If we didn't have a previous session, store it. // If we didn't have a previous session, store it.
} else { } else {
self.save_session(info, session).await?; self.mark_as_done(info).await?;
} Some(session)
};
Ok(Some(Raw::from(AnyToDeviceEvent::ForwardedRoomKey( Ok((
event.clone(), Some(Raw::from(AnyToDeviceEvent::ForwardedRoomKey(event.clone()))),
)))) session,
))
} else { } else {
info!( info!(
"Received a forwarded room key from {}, but no key info was found.", "Received a forwarded room key from {}, but no key info was found.",
event.sender, event.sender,
); );
Ok(None) Ok((None, None))
} }
} }
} }
@ -831,20 +843,20 @@ mod test {
.is_none() .is_none()
); );
machine let (_, first_session) = machine
.receive_forwarded_room_key(&session.sender_key, &mut event) .receive_forwarded_room_key(&session.sender_key, &mut event)
.await .await
.unwrap(); .unwrap();
let first_session = first_session.unwrap();
let first_session = machine
.store
.get_inbound_group_session(session.room_id(), &session.sender_key, session.session_id())
.await
.unwrap()
.unwrap();
assert_eq!(first_session.first_known_index().await, 10); assert_eq!(first_session.first_known_index().await, 10);
machine
.store
.save_inbound_group_sessions(&[first_session.clone()])
.await
.unwrap();
// Get the cancel request. // Get the cancel request.
let request = machine.outgoing_to_device_requests.iter().next().unwrap(); let request = machine.outgoing_to_device_requests.iter().next().unwrap();
let id = request.request_id; let id = request.request_id;
@ -875,19 +887,12 @@ mod test {
content, content,
}; };
machine let (_, second_session) = machine
.receive_forwarded_room_key(&session.sender_key, &mut event) .receive_forwarded_room_key(&session.sender_key, &mut event)
.await .await
.unwrap(); .unwrap();
let second_session = machine assert!(second_session.is_none());
.store
.get_inbound_group_session(session.room_id(), &session.sender_key, session.session_id())
.await
.unwrap()
.unwrap();
assert_eq!(second_session.first_known_index().await, 10);
let export = session.export_at_index(0).await.unwrap(); let export = session.export_at_index(0).await.unwrap();
@ -898,18 +903,12 @@ mod test {
content, content,
}; };
machine let (_, second_session) = machine
.receive_forwarded_room_key(&session.sender_key, &mut event) .receive_forwarded_room_key(&session.sender_key, &mut event)
.await .await
.unwrap(); .unwrap();
let second_session = machine
.store
.get_inbound_group_session(session.room_id(), &session.sender_key, session.session_id())
.await
.unwrap()
.unwrap();
assert_eq!(second_session.first_known_index().await, 0); assert_eq!(second_session.unwrap().first_known_index().await, 0);
} }
#[async_test] #[async_test]
@ -1132,14 +1131,19 @@ mod test {
.unwrap() .unwrap()
.is_none()); .is_none());
let (decrypted, sender_key, _) = let (_, decrypted, sender_key, _) =
alice_account.decrypt_to_device_event(&event).await.unwrap(); alice_account.decrypt_to_device_event(&event).await.unwrap();
if let AnyToDeviceEvent::ForwardedRoomKey(mut e) = decrypted.deserialize().unwrap() { if let AnyToDeviceEvent::ForwardedRoomKey(mut e) = decrypted.deserialize().unwrap() {
alice_machine let (_, session) = alice_machine
.receive_forwarded_room_key(&sender_key, &mut e) .receive_forwarded_room_key(&sender_key, &mut e)
.await .await
.unwrap(); .unwrap();
alice_machine
.store
.save_inbound_group_sessions(&[session.unwrap()])
.await
.unwrap();
} else { } else {
panic!("Invalid decrypted event type"); panic!("Invalid decrypted event type");
} }
@ -1315,14 +1319,19 @@ mod test {
.unwrap() .unwrap()
.is_none()); .is_none());
let (decrypted, sender_key, _) = let (_, decrypted, sender_key, _) =
alice_account.decrypt_to_device_event(&event).await.unwrap(); alice_account.decrypt_to_device_event(&event).await.unwrap();
if let AnyToDeviceEvent::ForwardedRoomKey(mut e) = decrypted.deserialize().unwrap() { if let AnyToDeviceEvent::ForwardedRoomKey(mut e) = decrypted.deserialize().unwrap() {
alice_machine let (_, session) = alice_machine
.receive_forwarded_room_key(&sender_key, &mut e) .receive_forwarded_room_key(&sender_key, &mut e)
.await .await
.unwrap(); .unwrap();
alice_machine
.store
.save_inbound_group_sessions(&[session.unwrap()])
.await
.unwrap();
} else { } else {
panic!("Invalid decrypted event type"); panic!("Invalid decrypted event type");
} }

View File

@ -47,15 +47,18 @@ use matrix_sdk_common::{
use crate::store::sqlite::SqliteStore; use crate::store::sqlite::SqliteStore;
use crate::{ use crate::{
error::{EventError, MegolmError, MegolmResult, OlmError, OlmResult}, error::{EventError, MegolmError, MegolmResult, OlmError, OlmResult},
identities::{Device, IdentityManager, ReadOnlyDevice, UserDevices, UserIdentities}, identities::{Device, IdentityManager, UserDevices},
key_request::KeyRequestMachine, key_request::KeyRequestMachine,
olm::{ olm::{
Account, EncryptionSettings, ExportedRoomKey, GroupSessionKey, IdentityKeys, Account, EncryptionSettings, ExportedRoomKey, GroupSessionKey, IdentityKeys,
InboundGroupSession, PrivateCrossSigningIdentity, ReadOnlyAccount, InboundGroupSession, PrivateCrossSigningIdentity, ReadOnlyAccount, Session,
}, },
requests::{IncomingResponse, OutgoingRequest, UploadSigningKeysRequest}, requests::{IncomingResponse, OutgoingRequest, UploadSigningKeysRequest},
session_manager::{GroupSessionManager, SessionManager}, session_manager::{GroupSessionManager, SessionManager},
store::{CryptoStore, MemoryStore, Result as StoreResult, Store}, store::{
Changes, CryptoStore, DeviceChanges, IdentityChanges, MemoryStore, Result as StoreResult,
Store,
},
verification::{Sas, VerificationMachine}, verification::{Sas, VerificationMachine},
ToDeviceRequest, ToDeviceRequest,
}; };
@ -467,7 +470,7 @@ impl OlmMachine {
async fn receive_keys_query_response( async fn receive_keys_query_response(
&self, &self,
response: &KeysQueryResponse, response: &KeysQueryResponse,
) -> OlmResult<(Vec<ReadOnlyDevice>, Vec<UserIdentities>)> { ) -> OlmResult<(DeviceChanges, IdentityChanges)> {
self.identity_manager self.identity_manager
.receive_keys_query_response(response) .receive_keys_query_response(response)
.await .await
@ -498,12 +501,12 @@ impl OlmMachine {
async fn decrypt_to_device_event( async fn decrypt_to_device_event(
&self, &self,
event: &ToDeviceEvent<EncryptedEventContent>, event: &ToDeviceEvent<EncryptedEventContent>,
) -> OlmResult<Raw<AnyToDeviceEvent>> { ) -> OlmResult<(Session, Raw<AnyToDeviceEvent>, Option<InboundGroupSession>)> {
let (decrypted_event, sender_key, signing_key) = let (session, decrypted_event, sender_key, signing_key) =
self.account.decrypt_to_device_event(event).await?; self.account.decrypt_to_device_event(event).await?;
// Handle the decrypted event, e.g. fetch out Megolm sessions out of // Handle the decrypted event, e.g. fetch out Megolm sessions out of
// the event. // the event.
if let Some(event) = self if let (Some(event), group_session) = self
.handle_decrypted_to_device_event(&sender_key, &signing_key, &decrypted_event) .handle_decrypted_to_device_event(&sender_key, &signing_key, &decrypted_event)
.await? .await?
{ {
@ -512,9 +515,9 @@ impl OlmMachine {
// don't want them to be able to do silly things with it. Handling // don't want them to be able to do silly things with it. Handling
// events modifies them and returns a modified one, so replace it // events modifies them and returns a modified one, so replace it
// here if we get one. // here if we get one.
Ok(event) Ok((session, event, group_session))
} else { } else {
Ok(decrypted_event) Ok((session, decrypted_event, None))
} }
} }
@ -524,7 +527,7 @@ impl OlmMachine {
sender_key: &str, sender_key: &str,
signing_key: &str, signing_key: &str,
event: &mut ToDeviceEvent<RoomKeyEventContent>, event: &mut ToDeviceEvent<RoomKeyEventContent>,
) -> OlmResult<Option<Raw<AnyToDeviceEvent>>> { ) -> OlmResult<(Option<Raw<AnyToDeviceEvent>>, Option<InboundGroupSession>)> {
match event.content.algorithm { match event.content.algorithm {
EventEncryptionAlgorithm::MegolmV1AesSha2 => { EventEncryptionAlgorithm::MegolmV1AesSha2 => {
let session_key = GroupSessionKey(mem::take(&mut event.content.session_key)); let session_key = GroupSessionKey(mem::take(&mut event.content.session_key));
@ -535,17 +538,15 @@ impl OlmMachine {
&event.content.room_id, &event.content.room_id,
session_key, session_key,
)?; )?;
let _ = self.store.save_inbound_group_sessions(&[session]).await?;
let event = Raw::from(AnyToDeviceEvent::RoomKey(event.clone())); let event = Raw::from(AnyToDeviceEvent::RoomKey(event.clone()));
Ok(Some(event)) Ok((Some(event), Some(session)))
} }
_ => { _ => {
warn!( warn!(
"Received room key with unsupported key algorithm {}", "Received room key with unsupported key algorithm {}",
event.content.algorithm event.content.algorithm
); );
Ok(None) Ok((None, None))
} }
} }
} }
@ -555,9 +556,14 @@ impl OlmMachine {
&self, &self,
room_id: &RoomId, room_id: &RoomId,
) -> OlmResult<()> { ) -> OlmResult<()> {
self.group_session_manager let (_, session) = self
.group_session_manager
.create_outbound_group_session(room_id, EncryptionSettings::default()) .create_outbound_group_session(room_id, EncryptionSettings::default())
.await .await?;
self.store.save_inbound_group_sessions(&[session]).await?;
Ok(())
} }
/// Encrypt a room message for the given room. /// Encrypt a room message for the given room.
@ -647,12 +653,12 @@ impl OlmMachine {
sender_key: &str, sender_key: &str,
signing_key: &str, signing_key: &str,
event: &Raw<AnyToDeviceEvent>, event: &Raw<AnyToDeviceEvent>,
) -> OlmResult<Option<Raw<AnyToDeviceEvent>>> { ) -> OlmResult<(Option<Raw<AnyToDeviceEvent>>, Option<InboundGroupSession>)> {
let event = if let Ok(e) = event.deserialize() { let event = if let Ok(e) = event.deserialize() {
e e
} else { } else {
warn!("Decrypted to-device event failed to be parsed correctly"); warn!("Decrypted to-device event failed to be parsed correctly");
return Ok(None); return Ok((None, None));
}; };
match event { match event {
@ -665,7 +671,7 @@ impl OlmMachine {
.await?), .await?),
_ => { _ => {
warn!("Received a unexpected encrypted to-device event"); warn!("Received a unexpected encrypted to-device event");
Ok(None) Ok((None, None))
} }
} }
} }
@ -699,11 +705,8 @@ impl OlmMachine {
self.verification_machine.get_sas(flow_id) self.verification_machine.get_sas(flow_id)
} }
async fn update_one_time_key_count( async fn update_one_time_key_count(&self, key_count: &BTreeMap<DeviceKeyAlgorithm, UInt>) {
&self, self.account.update_uploaded_key_count(key_count).await;
key_count: &BTreeMap<DeviceKeyAlgorithm, UInt>,
) -> StoreResult<()> {
self.account.update_uploaded_key_count(key_count).await
} }
/// Handle a sync response and update the internal state of the Olm machine. /// Handle a sync response and update the internal state of the Olm machine.
@ -719,15 +722,19 @@ impl OlmMachine {
/// ///
/// [`decrypt_room_event`]: #method.decrypt_room_event /// [`decrypt_room_event`]: #method.decrypt_room_event
#[instrument(skip(response))] #[instrument(skip(response))]
pub async fn receive_sync_response(&self, response: &mut SyncResponse) { pub async fn receive_sync_response(&self, response: &mut SyncResponse) -> OlmResult<()> {
// Remove verification objects that have expired or are done.
self.verification_machine.garbage_collect(); self.verification_machine.garbage_collect();
if let Err(e) = self // Always save the account, a new session might get created which also
.update_one_time_key_count(&response.device_one_time_keys_count) // touches the account.
.await let mut changes = Changes {
{ account: Some(self.account.inner.clone()),
error!("Error updating the one-time key count {:?}", e); ..Default::default()
} };
self.update_one_time_key_count(&response.device_one_time_keys_count)
.await;
for user_id in &response.device_lists.changed { for user_id in &response.device_lists.changed {
if let Err(e) = self.identity_manager.mark_user_as_changed(&user_id).await { if let Err(e) = self.identity_manager.mark_user_as_changed(&user_id).await {
@ -748,29 +755,36 @@ impl OlmMachine {
match &mut event { match &mut event {
AnyToDeviceEvent::RoomEncrypted(e) => { AnyToDeviceEvent::RoomEncrypted(e) => {
let decrypted_event = match self.decrypt_to_device_event(e).await { let (session, decrypted_event, group_session) =
Ok(e) => e, match self.decrypt_to_device_event(e).await {
Err(err) => { Ok(e) => e,
warn!( Err(err) => {
"Failed to decrypt to-device event from {} {}", warn!(
e.sender, err "Failed to decrypt to-device event from {} {}",
); e.sender, err
);
if let OlmError::SessionWedged(sender, curve_key) = err { if let OlmError::SessionWedged(sender, curve_key) = err {
if let Err(e) = self if let Err(e) = self
.session_manager .session_manager
.mark_device_as_wedged(&sender, &curve_key) .mark_device_as_wedged(&sender, &curve_key)
.await .await
{ {
error!( error!(
"Couldn't mark device from {} to be unwedged {:?}", "Couldn't mark device from {} to be unwedged {:?}",
sender, e sender, e
); );
}
} }
continue;
} }
continue; };
}
}; changes.sessions.push(session);
if let Some(group_session) = group_session {
changes.inbound_group_sessions.push(group_session);
}
*event_result = decrypted_event; *event_result = decrypted_event;
} }
@ -789,13 +803,14 @@ impl OlmMachine {
} }
} }
if let Err(e) = self let changed_sessions = self
.key_request_machine .key_request_machine
.collect_incoming_key_requests() .collect_incoming_key_requests()
.await .await?;
{
error!("Error collecting our key share requests {:?}", e); changes.sessions.extend(changed_sessions);
}
Ok(self.store.save_changes(changes).await?)
} }
/// Decrypt an event from a room timeline. /// Decrypt an event from a room timeline.
@ -973,7 +988,13 @@ impl OlmMachine {
let num_sessions = sessions.len(); let num_sessions = sessions.len();
self.store.save_inbound_group_sessions(&sessions).await?; let changes = Changes {
inbound_group_sessions: sessions,
..Default::default()
};
self.store.save_changes(changes).await?;
info!( info!(
"Successfully imported {} inbound group sessions", "Successfully imported {} inbound group sessions",
num_sessions num_sessions
@ -1198,15 +1219,19 @@ pub(crate) mod test {
.unwrap() .unwrap()
.unwrap(); .unwrap();
let (session, content) = bob_device
.encrypt(EventType::Dummy, json!({}))
.await
.unwrap();
alice.store.save_sessions(&[session]).await.unwrap();
let event = ToDeviceEvent { let event = ToDeviceEvent {
sender: alice.user_id().clone(), sender: alice.user_id().clone(),
content: bob_device content,
.encrypt(EventType::Dummy, json!({}))
.await
.unwrap(),
}; };
bob.decrypt_to_device_event(&event).await.unwrap(); let (session, _, _) = bob.decrypt_to_device_event(&event).await.unwrap();
bob.store.save_sessions(&[session]).await.unwrap();
(alice, bob) (alice, bob)
} }
@ -1492,13 +1517,15 @@ pub(crate) mod test {
content: bob_device content: bob_device
.encrypt(EventType::Dummy, json!({})) .encrypt(EventType::Dummy, json!({}))
.await .await
.unwrap(), .unwrap()
.1,
}; };
let event = bob let event = bob
.decrypt_to_device_event(&event) .decrypt_to_device_event(&event)
.await .await
.unwrap() .unwrap()
.1
.deserialize() .deserialize()
.unwrap(); .unwrap();
@ -1534,12 +1561,14 @@ pub(crate) mod test {
.get_outbound_group_session(&room_id) .get_outbound_group_session(&room_id)
.unwrap(); .unwrap();
let event = bob let (session, event, group_session) = bob.decrypt_to_device_event(&event).await.unwrap();
.decrypt_to_device_event(&event)
bob.store.save_sessions(&[session]).await.unwrap();
bob.store
.save_inbound_group_sessions(&[group_session.unwrap()])
.await .await
.unwrap()
.deserialize()
.unwrap(); .unwrap();
let event = event.deserialize().unwrap();
if let AnyToDeviceEvent::RoomKey(event) = event { if let AnyToDeviceEvent::RoomKey(event) = event {
assert_eq!(&event.sender, alice.user_id()); assert_eq!(&event.sender, alice.user_id());
@ -1579,7 +1608,11 @@ pub(crate) mod test {
content: to_device_requests_to_content(to_device_requests), content: to_device_requests_to_content(to_device_requests),
}; };
bob.decrypt_to_device_event(&event).await.unwrap(); let (_, _, group_session) = bob.decrypt_to_device_event(&event).await.unwrap();
bob.store
.save_inbound_group_sessions(&[group_session.unwrap()])
.await
.unwrap();
let plaintext = "It is a secret to everybody"; let plaintext = "It is a secret to everybody";

View File

@ -52,7 +52,7 @@ use olm_rs::{
use crate::{ use crate::{
error::{EventError, OlmResult, SessionCreationError}, error::{EventError, OlmResult, SessionCreationError},
identities::ReadOnlyDevice, identities::ReadOnlyDevice,
store::{Result as StoreResult, Store}, store::Store,
OlmError, OlmError,
}; };
@ -76,7 +76,7 @@ impl Account {
pub async fn decrypt_to_device_event( pub async fn decrypt_to_device_event(
&self, &self,
event: &ToDeviceEvent<EncryptedEventContent>, event: &ToDeviceEvent<EncryptedEventContent>,
) -> OlmResult<(Raw<AnyToDeviceEvent>, String, String)> { ) -> OlmResult<(Session, Raw<AnyToDeviceEvent>, String, String)> {
debug!("Decrypting to-device event"); debug!("Decrypting to-device event");
let content = if let EncryptedEventContent::OlmV1Curve25519AesSha2(c) = &event.content { let content = if let EncryptedEventContent::OlmV1Curve25519AesSha2(c) = &event.content {
@ -103,27 +103,28 @@ impl Account {
.map_err(|_| EventError::UnsupportedOlmType)?; .map_err(|_| EventError::UnsupportedOlmType)?;
// Decrypt the OlmMessage and get a Ruma event out of it. // Decrypt the OlmMessage and get a Ruma event out of it.
let (decrypted_event, signing_key) = self let (session, decrypted_event, signing_key) = self
.decrypt_olm_message(&event.sender, &content.sender_key, message) .decrypt_olm_message(&event.sender, &content.sender_key, message)
.await?; .await?;
debug!("Decrypted a to-device event {:?}", decrypted_event); debug!("Decrypted a to-device event {:?}", decrypted_event);
Ok((decrypted_event, content.sender_key.clone(), signing_key)) Ok((
session,
decrypted_event,
content.sender_key.clone(),
signing_key,
))
} else { } else {
warn!("Olm event doesn't contain a ciphertext for our key"); warn!("Olm event doesn't contain a ciphertext for our key");
Err(EventError::MissingCiphertext.into()) Err(EventError::MissingCiphertext.into())
} }
} }
pub async fn update_uploaded_key_count( pub async fn update_uploaded_key_count(&self, key_count: &BTreeMap<DeviceKeyAlgorithm, UInt>) {
&self,
key_count: &BTreeMap<DeviceKeyAlgorithm, UInt>,
) -> StoreResult<()> {
let one_time_key_count = key_count.get(&DeviceKeyAlgorithm::SignedCurve25519); let one_time_key_count = key_count.get(&DeviceKeyAlgorithm::SignedCurve25519);
let count: u64 = one_time_key_count.map_or(0, |c| (*c).into()); let count: u64 = one_time_key_count.map_or(0, |c| (*c).into());
self.inner.update_uploaded_key_count(count); self.inner.update_uploaded_key_count(count);
self.store.save_account(self.inner.clone()).await
} }
pub async fn receive_keys_upload_response( pub async fn receive_keys_upload_response(
@ -161,7 +162,7 @@ impl Account {
sender: &UserId, sender: &UserId,
sender_key: &str, sender_key: &str,
message: &OlmMessage, message: &OlmMessage,
) -> OlmResult<Option<String>> { ) -> OlmResult<Option<(Session, String)>> {
let s = self.store.get_sessions(sender_key).await?; let s = self.store.get_sessions(sender_key).await?;
// We don't have any existing sessions, return early. // We don't have any existing sessions, return early.
@ -171,8 +172,7 @@ impl Account {
return Ok(None); return Ok(None);
}; };
let mut session_to_save = None; let mut decrypted: Option<(Session, String)> = None;
let mut plaintext = None;
for session in &mut *sessions.lock().await { for session in &mut *sessions.lock().await {
let mut matches = false; let mut matches = false;
@ -191,9 +191,7 @@ impl Account {
match ret { match ret {
Ok(p) => { Ok(p) => {
plaintext = Some(p); decrypted = Some((session.clone(), p));
session_to_save = Some(session.clone());
break; break;
} }
Err(e) => { Err(e) => {
@ -214,14 +212,7 @@ impl Account {
} }
} }
if let Some(session) = session_to_save { Ok(decrypted)
// Decryption was successful, save the new ratchet state of the
// session that was used to decrypt the message.
trace!("Saved the new session state for {}", sender);
self.store.save_sessions(&[session]).await?;
}
Ok(plaintext)
} }
/// Decrypt an Olm message, creating a new Olm session if possible. /// Decrypt an Olm message, creating a new Olm session if possible.
@ -230,15 +221,15 @@ impl Account {
sender: &UserId, sender: &UserId,
sender_key: &str, sender_key: &str,
message: OlmMessage, message: OlmMessage,
) -> OlmResult<(Raw<AnyToDeviceEvent>, String)> { ) -> OlmResult<(Session, Raw<AnyToDeviceEvent>, String)> {
// First try to decrypt using an existing session. // First try to decrypt using an existing session.
let plaintext = if let Some(p) = self let (session, plaintext) = if let Some(d) = self
.try_decrypt_olm_message(sender, sender_key, &message) .try_decrypt_olm_message(sender, sender_key, &message)
.await? .await?
{ {
// Decryption succeeded, de-structure the plaintext out of the // Decryption succeeded, de-structure the session/plaintext out of
// Option. // the Option.
p d
} else { } else {
// Decryption failed with every known session, let's try to create a // Decryption failed with every known session, let's try to create a
// new session. // new session.
@ -278,9 +269,6 @@ impl Account {
} }
}; };
// Save the account since we remove the one-time key that
// was used to create this session.
self.store.save_account(self.inner.clone()).await?;
session session
} }
}; };
@ -288,15 +276,23 @@ impl Account {
// Decrypt our message, this shouldn't fail since we're using a // Decrypt our message, this shouldn't fail since we're using a
// newly created Session. // newly created Session.
let plaintext = session.decrypt(message).await?; let plaintext = session.decrypt(message).await?;
(session, plaintext)
// Save the new ratcheted state of the session.
self.store.save_sessions(&[session]).await?;
plaintext
}; };
trace!("Successfully decrypted a Olm message: {}", plaintext); trace!("Successfully decrypted a Olm message: {}", plaintext);
self.parse_decrypted_to_device_event(sender, &plaintext) let (event, signing_key) = match self.parse_decrypted_to_device_event(sender, &plaintext) {
Ok(r) => r,
Err(e) => {
// We might created a new session but decryption might still
// have failed, store it for the error case here, this is fine
// since we don't expect this to happen often or at all.
self.store.save_sessions(&[session]).await?;
return Err(e);
}
};
Ok((session, event, signing_key))
} }
/// Parse a decrypted Olm message, check that the plaintext and encrypted /// Parse a decrypted Olm message, check that the plaintext and encrypted

View File

@ -28,8 +28,8 @@ use tracing::{debug, info};
use crate::{ use crate::{
error::{EventError, MegolmResult, OlmResult}, error::{EventError, MegolmResult, OlmResult},
olm::{Account, OutboundGroupSession}, olm::{Account, InboundGroupSession, OutboundGroupSession},
store::Store, store::{Changes, Store},
Device, EncryptionSettings, OlmError, ToDeviceRequest, Device, EncryptionSettings, OlmError, ToDeviceRequest,
}; };
@ -140,19 +140,17 @@ impl GroupSessionManager {
&self, &self,
room_id: &RoomId, room_id: &RoomId,
settings: EncryptionSettings, settings: EncryptionSettings,
) -> OlmResult<()> { ) -> OlmResult<(OutboundGroupSession, InboundGroupSession)> {
let (outbound, inbound) = self let (outbound, inbound) = self
.account .account
.create_group_session_pair(room_id, settings) .create_group_session_pair(room_id, settings)
.await .await
.map_err(|_| EventError::UnsupportedAlgorithm)?; .map_err(|_| EventError::UnsupportedAlgorithm)?;
let _ = self.store.save_inbound_group_sessions(&[inbound]).await?;
let _ = self let _ = self
.outbound_group_sessions .outbound_group_sessions
.insert(room_id.to_owned(), outbound); .insert(room_id.to_owned(), outbound.clone());
Ok(()) Ok((outbound, inbound))
} }
/// Get to-device requests to share a group session with users in a room. /// Get to-device requests to share a group session with users in a room.
@ -169,13 +167,12 @@ impl GroupSessionManager {
users: impl Iterator<Item = &UserId>, users: impl Iterator<Item = &UserId>,
encryption_settings: impl Into<EncryptionSettings>, encryption_settings: impl Into<EncryptionSettings>,
) -> OlmResult<Vec<Arc<ToDeviceRequest>>> { ) -> OlmResult<Vec<Arc<ToDeviceRequest>>> {
self.create_outbound_group_session(room_id, encryption_settings.into()) let mut changes = Changes::default();
.await?;
let session = self.outbound_group_sessions.get(room_id).unwrap();
if session.shared() { let (session, inbound_session) = self
panic!("Session is already shared"); .create_outbound_group_session(room_id, encryption_settings.into())
} .await?;
changes.inbound_group_sessions.push(inbound_session);
let mut devices: Vec<Device> = Vec::new(); let mut devices: Vec<Device> = Vec::new();
@ -196,7 +193,7 @@ impl GroupSessionManager {
.encrypt(EventType::RoomKey, key_content.clone()) .encrypt(EventType::RoomKey, key_content.clone())
.await; .await;
let encrypted = match encrypted { let (used_session, encrypted) = match encrypted {
Ok(c) => c, Ok(c) => c,
Err(OlmError::MissingSession) Err(OlmError::MissingSession)
| Err(OlmError::EventError(EventError::MissingSenderKey)) => { | Err(OlmError::EventError(EventError::MissingSenderKey)) => {
@ -205,6 +202,8 @@ impl GroupSessionManager {
Err(e) => return Err(e), Err(e) => return Err(e),
}; };
changes.sessions.push(used_session);
messages messages
.entry(device.user_id().clone()) .entry(device.user_id().clone())
.or_insert_with(BTreeMap::new) .or_insert_with(BTreeMap::new)
@ -237,6 +236,8 @@ impl GroupSessionManager {
session.mark_as_shared(); session.mark_as_shared();
} }
self.store.save_changes(changes).await?;
Ok(requests) Ok(requests)
} }
} }

View File

@ -33,7 +33,7 @@ use crate::{
key_request::KeyRequestMachine, key_request::KeyRequestMachine,
olm::Account, olm::Account,
requests::{OutgoingRequest, ToDeviceRequest}, requests::{OutgoingRequest, ToDeviceRequest},
store::{Result as StoreResult, Store}, store::{Changes, Result as StoreResult, Store},
ReadOnlyDevice, ReadOnlyDevice,
}; };
@ -128,7 +128,7 @@ impl SessionManager {
.is_some() .is_some()
{ {
if let Some(device) = self.store.get_device(user_id, device_id).await? { if let Some(device) = self.store.get_device(user_id, device_id).await? {
let content = device.encrypt(EventType::Dummy, json!({})).await?; let (_, content) = device.encrypt(EventType::Dummy, json!({})).await?;
let id = Uuid::new_v4(); let id = Uuid::new_v4();
let mut messages = BTreeMap::new(); let mut messages = BTreeMap::new();
@ -258,6 +258,8 @@ impl SessionManager {
pub async fn receive_keys_claim_response(&self, response: &KeysClaimResponse) -> OlmResult<()> { pub async fn receive_keys_claim_response(&self, response: &KeysClaimResponse) -> OlmResult<()> {
// TODO log the failures here // TODO log the failures here
let mut changes = Changes::default();
for (user_id, user_devices) in &response.one_time_keys { for (user_id, user_devices) in &response.one_time_keys {
for (device_id, key_map) in user_devices { for (device_id, key_map) in user_devices {
let device = match self.store.get_readonly_device(&user_id, device_id).await { let device = match self.store.get_readonly_device(&user_id, device_id).await {
@ -284,15 +286,12 @@ impl SessionManager {
let session = match self.account.create_outbound_session(device, &key_map).await { let session = match self.account.create_outbound_session(device, &key_map).await {
Ok(s) => s, Ok(s) => s,
Err(e) => { Err(e) => {
warn!("{:?}", e); warn!("Error creating new outbound session {:?}", e);
continue; continue;
} }
}; };
if let Err(e) = self.store.save_sessions(&[session]).await { changes.sessions.push(session);
error!("Failed to store newly created Olm session {}", e);
continue;
}
self.key_request_machine.retry_keyshare(&user_id, device_id); self.key_request_machine.retry_keyshare(&user_id, device_id);
@ -304,7 +303,8 @@ impl SessionManager {
} }
} }
} }
Ok(())
Ok(self.store.save_changes(changes).await?)
} }
} }

View File

@ -26,7 +26,7 @@ use matrix_sdk_common_macros::async_trait;
use super::{ use super::{
caches::{DeviceStore, GroupSessionStore, SessionStore}, caches::{DeviceStore, GroupSessionStore, SessionStore},
CryptoStore, InboundGroupSession, ReadOnlyAccount, Result, Session, Changes, CryptoStore, InboundGroupSession, ReadOnlyAccount, Result, Session,
}; };
use crate::identities::{ReadOnlyDevice, UserIdentities}; use crate::identities::{ReadOnlyDevice, UserIdentities};
@ -61,6 +61,30 @@ impl MemoryStore {
pub fn new() -> Self { pub fn new() -> Self {
Self::default() Self::default()
} }
pub(crate) async fn save_devices(&self, mut devices: Vec<ReadOnlyDevice>) {
for device in devices.drain(..) {
let _ = self.devices.add(device);
}
}
async fn delete_devices(&self, mut devices: Vec<ReadOnlyDevice>) {
for device in devices.drain(..) {
let _ = self.devices.remove(device.user_id(), device.device_id());
}
}
async fn save_sessions(&self, mut sessions: Vec<Session>) {
for session in sessions.drain(..) {
let _ = self.sessions.add(session.clone()).await;
}
}
async fn save_inbound_group_sessions(&self, mut sessions: Vec<InboundGroupSession>) {
for session in sessions.drain(..) {
self.inbound_group_sessions.add(session);
}
}
} }
#[async_trait] #[async_trait]
@ -73,9 +97,24 @@ impl CryptoStore for MemoryStore {
Ok(()) Ok(())
} }
async fn save_sessions(&self, sessions: &[Session]) -> Result<()> { async fn save_changes(&self, mut changes: Changes) -> Result<()> {
for session in sessions { self.save_sessions(changes.sessions).await;
let _ = self.sessions.add(session.clone()).await; self.save_inbound_group_sessions(changes.inbound_group_sessions)
.await;
self.save_devices(changes.devices.new).await;
self.save_devices(changes.devices.changed).await;
self.delete_devices(changes.devices.deleted).await;
for identity in changes
.identities
.new
.drain(..)
.chain(changes.identities.changed)
{
let _ = self
.identities
.insert(identity.user_id().to_owned(), identity.clone());
} }
Ok(()) Ok(())
@ -85,14 +124,6 @@ impl CryptoStore for MemoryStore {
Ok(self.sessions.get(sender_key)) Ok(self.sessions.get(sender_key))
} }
async fn save_inbound_group_sessions(&self, sessions: &[InboundGroupSession]) -> Result<()> {
for session in sessions {
self.inbound_group_sessions.add(session.clone());
}
Ok(())
}
async fn get_inbound_group_session( async fn get_inbound_group_session(
&self, &self,
room_id: &RoomId, room_id: &RoomId,
@ -151,11 +182,6 @@ impl CryptoStore for MemoryStore {
Ok(self.devices.get(user_id, device_id)) Ok(self.devices.get(user_id, device_id))
} }
async fn delete_device(&self, device: ReadOnlyDevice) -> Result<()> {
let _ = self.devices.remove(device.user_id(), device.device_id());
Ok(())
}
async fn get_user_devices( async fn get_user_devices(
&self, &self,
user_id: &UserId, user_id: &UserId,
@ -163,28 +189,11 @@ impl CryptoStore for MemoryStore {
Ok(self.devices.user_devices(user_id)) Ok(self.devices.user_devices(user_id))
} }
async fn save_devices(&self, devices: &[ReadOnlyDevice]) -> Result<()> {
for device in devices {
let _ = self.devices.add(device.clone());
}
Ok(())
}
async fn get_user_identity(&self, user_id: &UserId) -> Result<Option<UserIdentities>> { async fn get_user_identity(&self, user_id: &UserId) -> Result<Option<UserIdentities>> {
#[allow(clippy::map_clone)] #[allow(clippy::map_clone)]
Ok(self.identities.get(user_id).map(|i| i.clone())) Ok(self.identities.get(user_id).map(|i| i.clone()))
} }
async fn save_user_identities(&self, identities: &[UserIdentities]) -> Result<()> {
for identity in identities {
let _ = self
.identities
.insert(identity.user_id().to_owned(), identity.clone());
}
Ok(())
}
async fn save_value(&self, key: String, value: String) -> Result<()> { async fn save_value(&self, key: String, value: String) -> Result<()> {
self.values.insert(key, value); self.values.insert(key, value);
Ok(()) Ok(())
@ -217,7 +226,7 @@ mod test {
assert!(store.load_account().await.unwrap().is_none()); assert!(store.load_account().await.unwrap().is_none());
store.save_account(account).await.unwrap(); store.save_account(account).await.unwrap();
store.save_sessions(&[session.clone()]).await.unwrap(); store.save_sessions(vec![session.clone()]).await;
let sessions = store let sessions = store
.get_sessions(&session.sender_key) .get_sessions(&session.sender_key)
@ -250,9 +259,8 @@ mod test {
let store = MemoryStore::new(); let store = MemoryStore::new();
let _ = store let _ = store
.save_inbound_group_sessions(&[inbound.clone()]) .save_inbound_group_sessions(vec![inbound.clone()])
.await .await;
.unwrap();
let loaded_session = store let loaded_session = store
.get_inbound_group_session(&room_id, "test_key", outbound.session_id()) .get_inbound_group_session(&room_id, "test_key", outbound.session_id())
@ -267,7 +275,7 @@ mod test {
let device = get_device(); let device = get_device();
let store = MemoryStore::new(); let store = MemoryStore::new();
store.save_devices(&[device.clone()]).await.unwrap(); store.save_devices(vec![device.clone()]).await;
let loaded_device = store let loaded_device = store
.get_device(device.user_id(), device.device_id()) .get_device(device.user_id(), device.device_id())
@ -286,7 +294,7 @@ mod test {
assert_eq!(&device, loaded_device); assert_eq!(&device, loaded_device);
store.delete_device(device.clone()).await.unwrap(); store.delete_devices(vec![device.clone()]).await;
assert!(store assert!(store
.get_device(device.user_id(), device.device_id()) .get_device(device.user_id(), device.device_id())
.await .await

View File

@ -100,6 +100,31 @@ pub(crate) struct Store {
verification_machine: VerificationMachine, verification_machine: VerificationMachine,
} }
#[derive(Debug, Default)]
#[allow(missing_docs)]
pub struct Changes {
pub account: Option<ReadOnlyAccount>,
pub sessions: Vec<Session>,
pub inbound_group_sessions: Vec<InboundGroupSession>,
pub identities: IdentityChanges,
pub devices: DeviceChanges,
}
#[derive(Debug, Clone, Default)]
#[allow(missing_docs)]
pub struct IdentityChanges {
pub new: Vec<UserIdentities>,
pub changed: Vec<UserIdentities>,
}
#[derive(Debug, Clone, Default)]
#[allow(missing_docs)]
pub struct DeviceChanges {
pub new: Vec<ReadOnlyDevice>,
pub changed: Vec<ReadOnlyDevice>,
pub deleted: Vec<ReadOnlyDevice>,
}
impl Store { impl Store {
pub fn new( pub fn new(
user_id: Arc<UserId>, user_id: Arc<UserId>,
@ -121,6 +146,41 @@ impl Store {
self.inner.get_device(user_id, device_id).await self.inner.get_device(user_id, device_id).await
} }
pub async fn save_sessions(&self, sessions: &[Session]) -> Result<()> {
let changes = Changes {
sessions: sessions.to_vec(),
..Default::default()
};
self.save_changes(changes).await
}
#[cfg(test)]
pub async fn save_devices(&self, devices: &[ReadOnlyDevice]) -> Result<()> {
let changes = Changes {
devices: DeviceChanges {
changed: devices.to_vec(),
..Default::default()
},
..Default::default()
};
self.save_changes(changes).await
}
#[cfg(test)]
pub async fn save_inbound_group_sessions(
&self,
sessions: &[InboundGroupSession],
) -> Result<()> {
let changes = Changes {
inbound_group_sessions: sessions.to_vec(),
..Default::default()
};
self.save_changes(changes).await
}
pub async fn get_readonly_devices( pub async fn get_readonly_devices(
&self, &self,
user_id: &UserId, user_id: &UserId,
@ -271,12 +331,8 @@ pub trait CryptoStore: Debug {
/// * `account` - The account that should be stored. /// * `account` - The account that should be stored.
async fn save_account(&self, account: ReadOnlyAccount) -> Result<()>; async fn save_account(&self, account: ReadOnlyAccount) -> Result<()>;
/// Save the given sessions in the store. /// TODO
/// async fn save_changes(&self, changes: Changes) -> Result<()>;
/// # Arguments
///
/// * `session` - The sessions that should be stored.
async fn save_sessions(&self, session: &[Session]) -> Result<()>;
/// Get all the sessions that belong to the given sender key. /// Get all the sessions that belong to the given sender key.
/// ///
@ -285,13 +341,6 @@ pub trait CryptoStore: Debug {
/// * `sender_key` - The sender key that was used to establish the sessions. /// * `sender_key` - The sender key that was used to establish the sessions.
async fn get_sessions(&self, sender_key: &str) -> Result<Option<Arc<Mutex<Vec<Session>>>>>; async fn get_sessions(&self, sender_key: &str) -> Result<Option<Arc<Mutex<Vec<Session>>>>>;
/// Save the given inbound group sessions in the store.
///
/// # Arguments
///
/// * `sessions` - The sessions that should be stored.
async fn save_inbound_group_sessions(&self, session: &[InboundGroupSession]) -> Result<()>;
/// Get the inbound group session from our store. /// Get the inbound group session from our store.
/// ///
/// # Arguments /// # Arguments
@ -331,20 +380,6 @@ pub trait CryptoStore: Debug {
/// * `dirty` - Should the user be also marked for a key query. /// * `dirty` - Should the user be also marked for a key query.
async fn update_tracked_user(&self, user: &UserId, dirty: bool) -> Result<bool>; async fn update_tracked_user(&self, user: &UserId, dirty: bool) -> Result<bool>;
/// Save the given devices in the store.
///
/// # Arguments
///
/// * `device` - The device that should be stored.
async fn save_devices(&self, devices: &[ReadOnlyDevice]) -> Result<()>;
/// Delete the given device from the store.
///
/// # Arguments
///
/// * `device` - The device that should be stored.
async fn delete_device(&self, device: ReadOnlyDevice) -> Result<()>;
/// Get the device for the given user with the given device id. /// Get the device for the given user with the given device id.
/// ///
/// # Arguments /// # Arguments
@ -368,13 +403,6 @@ pub trait CryptoStore: Debug {
user_id: &UserId, user_id: &UserId,
) -> Result<HashMap<DeviceIdBox, ReadOnlyDevice>>; ) -> Result<HashMap<DeviceIdBox, ReadOnlyDevice>>;
/// Save the given user identities in the store.
///
/// # Arguments
///
/// * `identities` - The identities that should be saved in the store.
async fn save_user_identities(&self, identities: &[UserIdentities]) -> Result<()>;
/// Get the user identity that is attached to the given user id. /// Get the user identity that is attached to the given user id.
/// ///
/// # Arguments /// # Arguments

View File

@ -34,7 +34,7 @@ use matrix_sdk_common::{
use sqlx::{query, query_as, sqlite::SqliteConnectOptions, Connection, Executor, SqliteConnection}; use sqlx::{query, query_as, sqlite::SqliteConnectOptions, Connection, Executor, SqliteConnection};
use zeroize::Zeroizing; use zeroize::Zeroizing;
use super::{caches::SessionStore, CryptoStore, CryptoStoreError, Result}; use super::{caches::SessionStore, Changes, CryptoStore, CryptoStoreError, Result};
use crate::{ use crate::{
identities::{LocalTrust, OwnUserIdentity, ReadOnlyDevice, UserIdentities, UserIdentity}, identities::{LocalTrust, OwnUserIdentity, ReadOnlyDevice, UserIdentities, UserIdentity},
olm::{ olm::{
@ -456,11 +456,17 @@ impl SqliteStore {
Ok(()) Ok(())
} }
async fn lazy_load_sessions(&self, sender_key: &str) -> Result<()> { async fn lazy_load_sessions(
&self,
connection: &mut SqliteConnection,
sender_key: &str,
) -> Result<()> {
let loaded_sessions = self.sessions.get(sender_key).is_some(); let loaded_sessions = self.sessions.get(sender_key).is_some();
if !loaded_sessions { if !loaded_sessions {
let sessions = self.load_sessions_for(sender_key).await?; let sessions = self
.load_sessions_for_helper(connection, sender_key)
.await?;
if !sessions.is_empty() { if !sessions.is_empty() {
self.sessions.set_for_sender(sender_key, sessions); self.sessions.set_for_sender(sender_key, sessions);
@ -470,20 +476,33 @@ impl SqliteStore {
Ok(()) Ok(())
} }
async fn get_sessions_for(&self, sender_key: &str) -> Result<Option<Arc<Mutex<Vec<Session>>>>> { async fn get_sessions_for(
self.lazy_load_sessions(sender_key).await?; &self,
connection: &mut SqliteConnection,
sender_key: &str,
) -> Result<Option<Arc<Mutex<Vec<Session>>>>> {
self.lazy_load_sessions(connection, sender_key).await?;
Ok(self.sessions.get(sender_key)) Ok(self.sessions.get(sender_key))
} }
#[cfg(test)]
async fn load_sessions_for(&self, sender_key: &str) -> Result<Vec<Session>> { async fn load_sessions_for(&self, sender_key: &str) -> Result<Vec<Session>> {
let mut connection = self.connection.lock().await;
self.load_sessions_for_helper(&mut connection, sender_key)
.await
}
async fn load_sessions_for_helper(
&self,
connection: &mut SqliteConnection,
sender_key: &str,
) -> Result<Vec<Session>> {
let account_info = self let account_info = self
.account_info .account_info
.lock() .lock()
.unwrap() .unwrap()
.clone() .clone()
.ok_or(CryptoStoreError::AccountUnset)?; .ok_or(CryptoStoreError::AccountUnset)?;
let mut connection = self.connection.lock().await;
let mut rows: Vec<(String, String, String, String)> = query_as( let mut rows: Vec<(String, String, String, String)> = query_as(
"SELECT pickle, sender_key, creation_time, last_use_time "SELECT pickle, sender_key, creation_time, last_use_time
FROM sessions WHERE account_id = ? and sender_key = ?", FROM sessions WHERE account_id = ? and sender_key = ?",
@ -1231,6 +1250,134 @@ impl SqliteStore {
Ok(()) Ok(())
} }
#[cfg(test)]
async fn save_sessions(&self, sessions: &[Session]) -> Result<()> {
let mut connection = self.connection.lock().await;
let mut transaction = connection.begin().await?;
self.save_sessions_helper(&mut transaction, sessions)
.await?;
transaction.commit().await?;
Ok(())
}
async fn save_sessions_helper(
&self,
connection: &mut SqliteConnection,
sessions: &[Session],
) -> Result<()> {
let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?;
for session in sessions {
self.lazy_load_sessions(connection, &session.sender_key)
.await?;
}
for session in sessions {
self.sessions.add(session.clone()).await;
let pickle = session.pickle(self.get_pickle_mode()).await;
let session_id = session.session_id();
let creation_time = serde_json::to_string(&pickle.creation_time)?;
let last_use_time = serde_json::to_string(&pickle.last_use_time)?;
query(
"REPLACE INTO sessions (
session_id, account_id, creation_time, last_use_time, sender_key, pickle
) VALUES (?, ?, ?, ?, ?, ?)",
)
.bind(&session_id)
.bind(&account_id)
.bind(&*creation_time)
.bind(&*last_use_time)
.bind(&pickle.sender_key)
.bind(&pickle.pickle.as_str())
.execute(&mut *connection)
.await?;
}
Ok(())
}
async fn save_devices(
&self,
mut connection: &mut SqliteConnection,
devices: &[ReadOnlyDevice],
) -> Result<()> {
for device in devices {
self.save_device_helper(&mut connection, device.clone())
.await?
}
Ok(())
}
async fn delete_devices(
&self,
connection: &mut SqliteConnection,
devices: &[ReadOnlyDevice],
) -> Result<()> {
let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?;
for device in devices {
query(
"DELETE FROM devices
WHERE account_id = ?1 and user_id = ?2 and device_id = ?3
",
)
.bind(account_id)
.bind(&device.user_id().to_string())
.bind(device.device_id().as_str())
.execute(&mut *connection)
.await?;
}
Ok(())
}
#[cfg(test)]
async fn save_inbound_group_sessions_test(
&self,
sessions: &[InboundGroupSession],
) -> Result<()> {
let mut connection = self.connection.lock().await;
let mut transaction = connection.begin().await?;
self.save_inbound_group_sessions(&mut transaction, sessions)
.await?;
transaction.commit().await?;
Ok(())
}
async fn save_inbound_group_sessions(
&self,
connection: &mut SqliteConnection,
sessions: &[InboundGroupSession],
) -> Result<()> {
let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?;
for session in sessions {
self.save_inbound_group_session_helper(account_id, connection, session)
.await?;
}
Ok(())
}
async fn save_user_identities(
&self,
mut connection: &mut SqliteConnection,
users: &[UserIdentities],
) -> Result<()> {
for user in users {
self.save_user_helper(&mut connection, user).await?;
}
Ok(())
}
async fn save_user_helper( async fn save_user_helper(
&self, &self,
mut connection: &mut SqliteConnection, mut connection: &mut SqliteConnection,
@ -1369,39 +1516,26 @@ impl CryptoStore for SqliteStore {
Ok(()) Ok(())
} }
async fn save_sessions(&self, sessions: &[Session]) -> Result<()> { async fn save_changes(&self, changes: Changes) -> Result<()> {
let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?;
for session in sessions {
self.lazy_load_sessions(&session.sender_key).await?;
}
let mut connection = self.connection.lock().await; let mut connection = self.connection.lock().await;
let mut transaction = connection.begin().await?; let mut transaction = connection.begin().await?;
for session in sessions { self.save_sessions_helper(&mut transaction, &changes.sessions)
self.sessions.add(session.clone()).await; .await?;
self.save_inbound_group_sessions(&mut transaction, &changes.inbound_group_sessions)
let pickle = session.pickle(self.get_pickle_mode()).await; .await?;
let session_id = session.session_id(); self.save_devices(&mut transaction, &changes.devices.new)
let creation_time = serde_json::to_string(&pickle.creation_time)?; .await?;
let last_use_time = serde_json::to_string(&pickle.last_use_time)?; self.save_devices(&mut transaction, &changes.devices.changed)
.await?;
query( self.delete_devices(&mut transaction, &changes.devices.deleted)
"REPLACE INTO sessions ( .await?;
session_id, account_id, creation_time, last_use_time, sender_key, pickle
) VALUES (?, ?, ?, ?, ?, ?)", self.save_user_identities(&mut transaction, &changes.identities.new)
) .await?;
.bind(&session_id) self.save_user_identities(&mut transaction, &changes.identities.changed)
.bind(&account_id)
.bind(&*creation_time)
.bind(&*last_use_time)
.bind(&pickle.sender_key)
.bind(&pickle.pickle.as_str())
.execute(&mut *transaction)
.await?; .await?;
}
transaction.commit().await?; transaction.commit().await?;
@ -1409,22 +1543,8 @@ impl CryptoStore for SqliteStore {
} }
async fn get_sessions(&self, sender_key: &str) -> Result<Option<Arc<Mutex<Vec<Session>>>>> { async fn get_sessions(&self, sender_key: &str) -> Result<Option<Arc<Mutex<Vec<Session>>>>> {
Ok(self.get_sessions_for(sender_key).await?)
}
async fn save_inbound_group_sessions(&self, sessions: &[InboundGroupSession]) -> Result<()> {
let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?;
let mut connection = self.connection.lock().await; let mut connection = self.connection.lock().await;
let mut transaction = connection.begin().await?; Ok(self.get_sessions_for(&mut connection, sender_key).await?)
for session in sessions {
self.save_inbound_group_session_helper(account_id, &mut transaction, session)
.await?;
}
transaction.commit().await?;
Ok(())
} }
async fn get_inbound_group_session( async fn get_inbound_group_session(
@ -1469,38 +1589,6 @@ impl CryptoStore for SqliteStore {
Ok(already_added) Ok(already_added)
} }
async fn save_devices(&self, devices: &[ReadOnlyDevice]) -> Result<()> {
let mut connection = self.connection.lock().await;
let mut transaction = connection.begin().await?;
for device in devices {
self.save_device_helper(&mut transaction, device.clone())
.await?
}
transaction.commit().await?;
Ok(())
}
async fn delete_device(&self, device: ReadOnlyDevice) -> Result<()> {
let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?;
let mut connection = self.connection.lock().await;
query(
"DELETE FROM devices
WHERE account_id = ?1 and user_id = ?2 and device_id = ?3
",
)
.bind(account_id)
.bind(&device.user_id().to_string())
.bind(device.device_id().as_str())
.execute(&mut *connection)
.await?;
Ok(())
}
async fn get_device( async fn get_device(
&self, &self,
user_id: &UserId, user_id: &UserId,
@ -1520,19 +1608,6 @@ impl CryptoStore for SqliteStore {
self.load_user(user_id).await self.load_user(user_id).await
} }
async fn save_user_identities(&self, users: &[UserIdentities]) -> Result<()> {
let mut connection = self.connection.lock().await;
let mut transaction = connection.begin().await?;
for user in users {
self.save_user_helper(&mut transaction, user).await?;
}
transaction.commit().await?;
Ok(())
}
async fn save_value(&self, key: String, value: String) -> Result<()> { async fn save_value(&self, key: String, value: String) -> Result<()> {
let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?;
let mut connection = self.connection.lock().await; let mut connection = self.connection.lock().await;
@ -1598,6 +1673,7 @@ mod test {
user::test::{get_other_identity, get_own_identity}, user::test::{get_other_identity, get_own_identity},
}, },
olm::{GroupSessionKey, InboundGroupSession, ReadOnlyAccount, Session}, olm::{GroupSessionKey, InboundGroupSession, ReadOnlyAccount, Session},
store::{Changes, DeviceChanges, IdentityChanges},
}; };
use matrix_sdk_common::{ use matrix_sdk_common::{
api::r0::keys::SignedKey, api::r0::keys::SignedKey,
@ -1854,7 +1930,7 @@ mod test {
.expect("Can't create session"); .expect("Can't create session");
store store
.save_inbound_group_sessions(&[session]) .save_inbound_group_sessions_test(&[session])
.await .await
.expect("Can't save group session"); .expect("Can't save group session");
} }
@ -1880,7 +1956,7 @@ mod test {
let session = InboundGroupSession::from_export(export).unwrap(); let session = InboundGroupSession::from_export(export).unwrap();
store store
.save_inbound_group_sessions(&[session.clone()]) .save_inbound_group_sessions_test(&[session.clone()])
.await .await
.expect("Can't save group session"); .expect("Can't save group session");
@ -1952,7 +2028,15 @@ mod test {
let (_account, store, dir) = get_loaded_store().await; let (_account, store, dir) = get_loaded_store().await;
let device = get_device(); let device = get_device();
store.save_devices(&[device.clone()]).await.unwrap(); let changes = Changes {
devices: DeviceChanges {
changed: vec![device.clone()],
..Default::default()
},
..Default::default()
};
store.save_changes(changes).await.unwrap();
drop(store); drop(store);
@ -1986,8 +2070,25 @@ mod test {
let (_account, store, dir) = get_loaded_store().await; let (_account, store, dir) = get_loaded_store().await;
let device = get_device(); let device = get_device();
store.save_devices(&[device.clone()]).await.unwrap(); let changes = Changes {
store.delete_device(device.clone()).await.unwrap(); devices: DeviceChanges {
changed: vec![device.clone()],
..Default::default()
},
..Default::default()
};
store.save_changes(changes).await.unwrap();
let changes = Changes {
devices: DeviceChanges {
deleted: vec![device.clone()],
..Default::default()
},
..Default::default()
};
store.save_changes(changes).await.unwrap();
let store = SqliteStore::open(&alice_id(), &alice_device_id(), dir.path()) let store = SqliteStore::open(&alice_id(), &alice_device_id(), dir.path())
.await .await
@ -2024,8 +2125,16 @@ mod test {
let own_identity = get_own_identity(); let own_identity = get_own_identity();
let changes = Changes {
identities: IdentityChanges {
changed: vec![own_identity.clone().into()],
..Default::default()
},
..Default::default()
};
store store
.save_user_identities(&[own_identity.clone().into()]) .save_changes(changes)
.await .await
.expect("Can't save identity"); .expect("Can't save identity");
@ -2052,10 +2161,15 @@ mod test {
let other_identity = get_other_identity(); let other_identity = get_other_identity();
store let changes = Changes {
.save_user_identities(&[other_identity.clone().into()]) identities: IdentityChanges {
.await changed: vec![other_identity.clone().into()],
.unwrap(); ..Default::default()
},
..Default::default()
};
store.save_changes(changes).await.unwrap();
let loaded_user = store let loaded_user = store
.load_user(other_identity.user_id()) .load_user(other_identity.user_id())
@ -2072,10 +2186,15 @@ mod test {
own_identity.mark_as_verified(); own_identity.mark_as_verified();
store let changes = Changes {
.save_user_identities(&[own_identity.into()]) identities: IdentityChanges {
.await changed: vec![own_identity.into()],
.unwrap(); ..Default::default()
},
..Default::default()
};
store.save_changes(changes).await.unwrap();
let loaded_user = store.load_user(&user_id).await.unwrap().unwrap(); let loaded_user = store.load_user(&user_id).await.unwrap().unwrap();
assert!(loaded_user.own().unwrap().is_verified()) assert!(loaded_user.own().unwrap().is_verified())
} }

View File

@ -194,18 +194,14 @@ impl VerificationMachine {
self.receive_event_helper(&s, event); self.receive_event_helper(&s, event);
if s.is_done() { if s.is_done() {
if !s.mark_device_as_verified().await? { if let Some(r) = s.mark_as_done().await? {
if let Some(r) = s.cancel() { self.outgoing_to_device_messages.insert(
self.outgoing_to_device_messages.insert( r.txn_id,
r.txn_id, OutgoingRequest {
OutgoingRequest { request_id: r.txn_id,
request_id: r.txn_id, request: Arc::new(r.into()),
request: Arc::new(r.into()), },
}, );
);
}
} else {
s.mark_identity_as_verified().await?;
} }
} }
}; };
@ -258,17 +254,15 @@ mod test {
let alice = ReadOnlyAccount::new(&alice_id(), &alice_device_id()); let alice = ReadOnlyAccount::new(&alice_id(), &alice_device_id());
let bob = ReadOnlyAccount::new(&bob_id(), &bob_device_id()); let bob = ReadOnlyAccount::new(&bob_id(), &bob_device_id());
let store = MemoryStore::new(); let store = MemoryStore::new();
let bob_store: Arc<Box<dyn CryptoStore>> = Arc::new(Box::new(MemoryStore::new())); let bob_store = MemoryStore::new();
let bob_device = ReadOnlyDevice::from_account(&bob).await; let bob_device = ReadOnlyDevice::from_account(&bob).await;
let alice_device = ReadOnlyDevice::from_account(&alice).await; let alice_device = ReadOnlyDevice::from_account(&alice).await;
store.save_devices(&[bob_device]).await.unwrap(); store.save_devices(vec![bob_device]).await;
bob_store bob_store.save_devices(vec![alice_device.clone()]).await;
.save_devices(&[alice_device.clone()])
.await
.unwrap();
let bob_store: Arc<Box<dyn CryptoStore>> = Arc::new(Box::new(bob_store));
let machine = VerificationMachine::new(alice, Arc::new(Box::new(store))); let machine = VerificationMachine::new(alice, Arc::new(Box::new(store)));
let (bob_sas, start_content) = Sas::start(bob, alice_device, bob_store, None); let (bob_sas, start_content) = Sas::start(bob, alice_device, bob_store, None);
machine machine

View File

@ -34,7 +34,7 @@ use matrix_sdk_common::{
use crate::{ use crate::{
identities::{LocalTrust, ReadOnlyDevice, UserIdentities}, identities::{LocalTrust, ReadOnlyDevice, UserIdentities},
store::{CryptoStore, CryptoStoreError}, store::{Changes, CryptoStore, CryptoStoreError, DeviceChanges},
ReadOnlyAccount, ToDeviceRequest, ReadOnlyAccount, ToDeviceRequest,
}; };
@ -189,34 +189,64 @@ impl Sas {
(content, guard.is_done()) (content, guard.is_done())
}; };
if done { let cancel = if done {
// TODO move the logic that marks and stores the device into the self.mark_as_done().await?
// else branch and only after the identity was verified as well. We } else {
// dont' want to verify one without the other. None
if !self.mark_device_as_verified().await? { };
return Ok(self.cancel());
} else {
self.mark_identity_as_verified().await?;
}
}
Ok(content.map(|c| { if cancel.is_some() {
let content = AnyToDeviceEventContent::KeyVerificationMac(c); Ok(cancel)
self.content_to_request(content) } else {
})) Ok(content.map(|c| {
let content = AnyToDeviceEventContent::KeyVerificationMac(c);
self.content_to_request(content)
}))
}
} }
pub(crate) async fn mark_identity_as_verified(&self) -> Result<bool, CryptoStoreError> { pub(crate) async fn mark_as_done(&self) -> Result<Option<ToDeviceRequest>, CryptoStoreError> {
if let Some(device) = self.mark_device_as_verified().await? {
let identity = self.mark_identity_as_verified().await?;
let mut changes = Changes {
devices: DeviceChanges {
changed: vec![device],
..Default::default()
},
..Default::default()
};
if let Some(i) = identity {
changes.identities.changed.push(i);
}
self.store.save_changes(changes).await?;
Ok(None)
} else {
Ok(self.cancel())
}
}
pub(crate) async fn mark_identity_as_verified(
&self,
) -> Result<Option<UserIdentities>, CryptoStoreError> {
// If there wasn't an identity available during the verification flow // If there wasn't an identity available during the verification flow
// return early as there's nothing to do. // return early as there's nothing to do.
if self.other_identity.is_none() { if self.other_identity.is_none() {
return Ok(false); return Ok(None);
} }
// TODO signal an error, e.g. when the identity got deleted so we don't
// verify/save the device either.
let identity = self.store.get_user_identity(self.other_user_id()).await?; let identity = self.store.get_user_identity(self.other_user_id()).await?;
if let Some(identity) = identity { if let Some(identity) = identity {
if identity.master_key() == self.other_identity.as_ref().unwrap().master_key() { if self
.other_identity
.as_ref()
.map_or(false, |i| i.master_key() == identity.master_key())
{
if self if self
.verified_identities() .verified_identities()
.map_or(false, |i| i.contains(&identity)) .map_or(false, |i| i.contains(&identity))
@ -228,13 +258,12 @@ impl Sas {
if let UserIdentities::Own(i) = &identity { if let UserIdentities::Own(i) = &identity {
i.mark_as_verified(); i.mark_as_verified();
self.store.save_user_identities(&[identity]).await?;
} }
// TODO if we have the private part of the user signing // TODO if we have the private part of the user signing
// key we should sign and upload a signature for this // key we should sign and upload a signature for this
// identity. // identity.
Ok(true) Ok(Some(identity))
} else { } else {
info!( info!(
"The interactive verification process didn't contain a \ "The interactive verification process didn't contain a \
@ -243,7 +272,7 @@ impl Sas {
self.verified_identities(), self.verified_identities(),
); );
Ok(false) Ok(None)
} }
} else { } else {
warn!( warn!(
@ -252,7 +281,7 @@ impl Sas {
identity.user_id(), identity.user_id(),
); );
Ok(false) Ok(None)
} }
} else { } else {
info!( info!(
@ -260,11 +289,13 @@ impl Sas {
verification was going on.", verification was going on.",
self.other_user_id(), self.other_user_id(),
); );
Ok(false) Ok(None)
} }
} }
pub(crate) async fn mark_device_as_verified(&self) -> Result<bool, CryptoStoreError> { pub(crate) async fn mark_device_as_verified(
&self,
) -> Result<Option<ReadOnlyDevice>, CryptoStoreError> {
let device = self let device = self
.store .store
.get_device(self.other_user_id(), self.other_device_id()) .get_device(self.other_user_id(), self.other_device_id())
@ -283,12 +314,11 @@ impl Sas {
); );
device.set_trust_state(LocalTrust::Verified); device.set_trust_state(LocalTrust::Verified);
self.store.save_devices(&[device]).await?;
// TODO if this is a device from our own user and we have // TODO if this is a device from our own user and we have
// the private part of the self signing key, we should sign // the private part of the self signing key, we should sign
// the device and upload the signature. // the device and upload the signature.
Ok(true) Ok(Some(device))
} else { } else {
info!( info!(
"The interactive verification process didn't contain a \ "The interactive verification process didn't contain a \
@ -297,7 +327,7 @@ impl Sas {
device.device_id() device.device_id()
); );
Ok(false) Ok(None)
} }
} else { } else {
warn!( warn!(
@ -306,7 +336,7 @@ impl Sas {
device.user_id(), device.user_id(),
device.device_id() device.device_id()
); );
Ok(false) Ok(None)
} }
} else { } else {
let device = self.other_device(); let device = self.other_device();
@ -317,7 +347,7 @@ impl Sas {
device.user_id(), device.user_id(),
device.device_id() device.device_id()
); );
Ok(false) Ok(None)
} }
} }
@ -777,12 +807,11 @@ mod test {
let bob_device = ReadOnlyDevice::from_account(&bob).await; let bob_device = ReadOnlyDevice::from_account(&bob).await;
let alice_store: Arc<Box<dyn CryptoStore>> = Arc::new(Box::new(MemoryStore::new())); let alice_store: Arc<Box<dyn CryptoStore>> = Arc::new(Box::new(MemoryStore::new()));
let bob_store: Arc<Box<dyn CryptoStore>> = Arc::new(Box::new(MemoryStore::new())); let bob_store = MemoryStore::new();
bob_store bob_store.save_devices(vec![alice_device.clone()]).await;
.save_devices(&[alice_device.clone()])
.await let bob_store: Arc<Box<dyn CryptoStore>> = Arc::new(Box::new(bob_store));
.unwrap();
let (alice, content) = Sas::start(alice, bob_device, alice_store, None); let (alice, content) = Sas::start(alice, bob_device, alice_store, None);
let event = wrap_to_device_event(alice.user_id(), content); let event = wrap_to_device_event(alice.user_id(), content);