diff --git a/matrix_sdk_base/src/client.rs b/matrix_sdk_base/src/client.rs index 6557a69b..e4cae5bc 100644 --- a/matrix_sdk_base/src/client.rs +++ b/matrix_sdk_base/src/client.rs @@ -971,8 +971,6 @@ impl BaseClient { return Ok(()); } - *self.sync_token.write().await = Some(response.next_batch.clone()); - #[cfg(feature = "encryption")] { let olm = self.olm.lock().await; @@ -982,10 +980,12 @@ impl BaseClient { // decryptes to-device events, but leaves room events alone. // This makes sure that we have the deryption keys for the room // events at hand. - o.receive_sync_response(response).await; + o.receive_sync_response(response).await?; } } + *self.sync_token.write().await = Some(response.next_batch.clone()); + // when events change state, updated_* signals to StateStore to update database self.iter_joined_rooms(response).await?; self.iter_invited_rooms(response).await?; diff --git a/matrix_sdk_crypto/src/identities/device.rs b/matrix_sdk_crypto/src/identities/device.rs index a0281217..0a75ad29 100644 --- a/matrix_sdk_crypto/src/identities/device.rs +++ b/matrix_sdk_crypto/src/identities/device.rs @@ -39,7 +39,10 @@ use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; use tracing::warn; -use crate::olm::{InboundGroupSession, Session}; +use crate::{ + olm::{InboundGroupSession, Session}, + store::{Changes, DeviceChanges}, +}; #[cfg(test)] use crate::{OlmMachine, ReadOnlyAccount}; @@ -118,10 +121,15 @@ impl Device { pub async fn set_local_trust(&self, trust_state: LocalTrust) -> StoreResult<()> { self.inner.set_trust_state(trust_state); - self.verification_machine - .store - .save_devices(&[self.inner.clone()]) - .await + let changes = Changes { + devices: DeviceChanges { + changed: vec![self.inner.clone()], + ..Default::default() + }, + ..Default::default() + }; + + self.verification_machine.store.save_changes(changes).await } /// Encrypt the given content for this `Device`. @@ -135,7 +143,7 @@ impl Device { &self, event_type: EventType, content: Value, - ) -> OlmResult { + ) -> OlmResult<(Session, EncryptedEventContent)> { self.inner .encrypt(&**self.verification_machine.store, event_type, content) .await @@ -146,7 +154,7 @@ impl Device { pub async fn encrypt_session( &self, session: InboundGroupSession, - ) -> OlmResult { + ) -> OlmResult<(Session, EncryptedEventContent)> { let export = session.export().await; let content: ForwardedRoomKeyEventContent = if let Ok(c) = export.try_into() { @@ -364,7 +372,7 @@ impl ReadOnlyDevice { store: &dyn CryptoStore, event_type: EventType, content: Value, - ) -> OlmResult { + ) -> OlmResult<(Session, EncryptedEventContent)> { let sender_key = if let Some(k) = self.get_key(DeviceKeyAlgorithm::Curve25519) { k } else { @@ -396,10 +404,9 @@ impl ReadOnlyDevice { return Err(OlmError::MissingSession); }; - let message = session.encrypt(&self, event_type, content).await; - store.save_sessions(&[session]).await?; + let message = session.encrypt(&self, event_type, content).await?; - message + Ok((session, message)) } /// Update a device with a new device keys struct. diff --git a/matrix_sdk_crypto/src/identities/manager.rs b/matrix_sdk_crypto/src/identities/manager.rs index ba6fb749..d91ffa2c 100644 --- a/matrix_sdk_crypto/src/identities/manager.rs +++ b/matrix_sdk_crypto/src/identities/manager.rs @@ -33,7 +33,7 @@ use crate::{ }, requests::KeysQueryRequest, session_manager::GroupSessionManager, - store::{Result as StoreResult, Store}, + store::{Changes, DeviceChanges, IdentityChanges, Result as StoreResult, Store}, }; #[derive(Debug, Clone)] @@ -79,7 +79,7 @@ impl IdentityManager { pub async fn receive_keys_query_response( &self, response: &KeysQueryResponse, - ) -> OlmResult<(Vec, Vec)> { + ) -> OlmResult<(DeviceChanges, IdentityChanges)> { // TODO create a enum that tells us how the device/identity changed, // e.g. new/deleted/display name change. // @@ -92,9 +92,15 @@ impl IdentityManager { let changed_devices = self .handle_devices_from_key_query(&response.device_keys) .await?; - self.store.save_devices(&changed_devices).await?; let changed_identities = self.handle_cross_singing_keys(response).await?; - self.store.save_user_identities(&changed_identities).await?; + + let changes = Changes { + identities: changed_identities.clone(), + devices: changed_devices.clone(), + ..Default::default() + }; + + self.store.save_changes(changes).await?; Ok((changed_devices, changed_identities)) } @@ -111,9 +117,10 @@ impl IdentityManager { async fn handle_devices_from_key_query( &self, device_keys_map: &BTreeMap>, - ) -> StoreResult> { + ) -> StoreResult { let mut users_with_new_or_deleted_devices = HashSet::new(); - let mut changed_devices = Vec::new(); + + let mut changes = DeviceChanges::default(); for (user_id, device_map) in device_keys_map { // TODO move this out into the handle keys query response method @@ -137,7 +144,7 @@ impl IdentityManager { let device = self.store.get_readonly_device(&user_id, device_id).await?; - let device = if let Some(mut device) = device { + if let Some(mut device) = device { if let Err(e) = device.update_device(device_keys) { warn!( "Failed to update the device keys for {} {}: {:?}", @@ -145,7 +152,7 @@ impl IdentityManager { ); continue; } - device + changes.changed.push(device); } else { let device = match ReadOnlyDevice::try_from(device_keys) { Ok(d) => d, @@ -159,23 +166,21 @@ impl IdentityManager { }; info!("Adding a new device to the device store {:?}", device); users_with_new_or_deleted_devices.insert(user_id); - device - }; - - changed_devices.push(device); + changes.new.push(device); + } } let current_devices: HashSet<&DeviceIdBox> = device_map.keys().collect(); let stored_devices = self.store.get_readonly_devices(&user_id).await?; let stored_devices_set: HashSet<&DeviceIdBox> = stored_devices.keys().collect(); - let deleted_devices = stored_devices_set.difference(¤t_devices); + let deleted_devices_set = stored_devices_set.difference(¤t_devices); - for device_id in deleted_devices { + for device_id in deleted_devices_set { users_with_new_or_deleted_devices.insert(user_id); if let Some(device) = stored_devices.get(*device_id) { device.mark_as_deleted(); - self.store.delete_device(device.clone()).await?; + changes.deleted.push(device.clone()); } } } @@ -183,7 +188,7 @@ impl IdentityManager { self.group_manager .invalidate_sessions_new_devices(&users_with_new_or_deleted_devices); - Ok(changed_devices) + Ok(changes) } /// Handle the device keys part of a key query response. @@ -197,8 +202,8 @@ impl IdentityManager { async fn handle_cross_singing_keys( &self, response: &KeysQueryResponse, - ) -> StoreResult> { - let mut changed = Vec::new(); + ) -> StoreResult { + let mut changes = IdentityChanges::default(); for (user_id, master_key) in &response.master_keys { let master_key = MasterPubkey::from(master_key); @@ -213,7 +218,7 @@ impl IdentityManager { continue; }; - let identity = if let Some(mut i) = self.store.get_user_identity(user_id).await? { + let result = if let Some(mut i) = self.store.get_user_identity(user_id).await? { match &mut i { UserIdentities::Own(ref mut identity) => { let user_signing = if let Some(s) = response.user_signing_keys.get(user_id) @@ -230,11 +235,11 @@ impl IdentityManager { identity .update(master_key, self_signing, user_signing) - .map(|_| i) - } - UserIdentities::Other(ref mut identity) => { - identity.update(master_key, self_signing).map(|_| i) + .map(|_| (i, false)) } + UserIdentities::Other(ref mut identity) => identity + .update(master_key, self_signing) + .map(|_| (i, false)), } } else if user_id == self.user_id() { if let Some(s) = response.user_signing_keys.get(user_id) { @@ -252,7 +257,7 @@ impl IdentityManager { } OwnUserIdentity::new(master_key, self_signing, user_signing) - .map(UserIdentities::Own) + .map(|i| (UserIdentities::Own(i), true)) } else { warn!( "User identity for our own user {} didn't contain a \ @@ -268,17 +273,22 @@ impl IdentityManager { ); continue; } else { - UserIdentity::new(master_key, self_signing).map(UserIdentities::Other) + UserIdentity::new(master_key, self_signing) + .map(|i| (UserIdentities::Other(i), true)) }; - match identity { - Ok(i) => { + match result { + Ok((i, new)) => { trace!( "Updated or created new user identity for {}: {:?}", user_id, i ); - changed.push(i); + if new { + changes.new.push(i); + } else { + changes.changed.push(i); + } } Err(e) => { warn!( @@ -290,7 +300,7 @@ impl IdentityManager { } } - Ok(changed) + Ok(changes) } /// Get a key query request if one is needed. diff --git a/matrix_sdk_crypto/src/key_request.rs b/matrix_sdk_crypto/src/key_request.rs index 060a85cb..3a4d26cf 100644 --- a/matrix_sdk_crypto/src/key_request.rs +++ b/matrix_sdk_crypto/src/key_request.rs @@ -41,7 +41,7 @@ use matrix_sdk_common::{ use crate::{ error::{OlmError, OlmResult}, - olm::{InboundGroupSession, OutboundGroupSession}, + olm::{InboundGroupSession, OutboundGroupSession, Session}, requests::{OutgoingRequest, ToDeviceRequest}, store::{CryptoStoreError, Store}, Device, @@ -235,15 +235,18 @@ impl KeyRequestMachine { /// Handle all the incoming key requests that are queued up and empty our /// key request queue. - pub async fn collect_incoming_key_requests(&self) -> OlmResult<()> { + pub async fn collect_incoming_key_requests(&self) -> OlmResult> { + let mut changed_sessions = Vec::new(); for item in self.incoming_key_requests.iter() { let event = item.value(); - self.handle_key_request(event).await?; + if let Some(s) = self.handle_key_request(event).await? { + changed_sessions.push(s); + } } self.incoming_key_requests.clear(); - Ok(()) + Ok(changed_sessions) } /// Store the key share request for later, once we get an Olm session with @@ -294,7 +297,7 @@ impl KeyRequestMachine { async fn handle_key_request( &self, event: &ToDeviceEvent, - ) -> OlmResult<()> { + ) -> OlmResult> { let key_info = match event.content.action { Action::Request => { if let Some(info) = &event.content.body { @@ -305,11 +308,11 @@ impl KeyRequestMachine { action, but no key info was found", event.sender, event.content.requesting_device_id ); - return Ok(()); + return Ok(None); } } // We ignore cancellations here since there's nothing to serve. - Action::CancelRequest => return Ok(()), + Action::CancelRequest => return Ok(None), }; let session = self @@ -328,7 +331,7 @@ impl KeyRequestMachine { "Received a key request from {} {} for an unknown inbound group session {}.", &event.sender, &event.content.requesting_device_id, &key_info.session_id ); - return Ok(()); + return Ok(None); }; let device = self @@ -349,6 +352,8 @@ impl KeyRequestMachine { device.device_id(), e ); + + Ok(None) } else { info!( "Serving a key request for {} from {} {}.", @@ -357,20 +362,20 @@ impl KeyRequestMachine { device.device_id() ); - if let Err(e) = self.share_session(&session, &device).await { - match e { - OlmError::MissingSession => { - info!( - "Key request from {} {} is missing an Olm session, \ - putting the request in the wait queue", - device.user_id(), - device.device_id() - ); - self.handle_key_share_without_session(device, event); - return Ok(()); - } - e => return Err(e), + match self.share_session(&session, &device).await { + Ok(s) => Ok(Some(s)), + Err(OlmError::MissingSession) => { + info!( + "Key request from {} {} is missing an Olm session, \ + putting the request in the wait queue", + device.user_id(), + device.device_id() + ); + self.handle_key_share_without_session(device, event); + + Ok(None) } + Err(e) => Err(e), } } } else { @@ -379,13 +384,17 @@ impl KeyRequestMachine { &event.sender, &event.content.requesting_device_id ); self.store.update_tracked_user(&event.sender, true).await?; - } - Ok(()) + Ok(None) + } } - async fn share_session(&self, session: &InboundGroupSession, device: &Device) -> OlmResult<()> { - let content = device.encrypt_session(session.clone()).await?; + async fn share_session( + &self, + session: &InboundGroupSession, + device: &Device, + ) -> OlmResult { + let (used_session, content) = device.encrypt_session(session.clone()).await?; let id = Uuid::new_v4(); let mut messages = BTreeMap::new(); @@ -412,7 +421,7 @@ impl KeyRequestMachine { self.outgoing_to_device_requests.insert(id, request); - Ok(()) + Ok(used_session) } /// Check if it's ok to share a session with the given device. @@ -569,23 +578,20 @@ impl KeyRequestMachine { Ok(()) } - /// Save an inbound group session we received using a key forward. + /// Mark the given outgoing key info as done. /// - /// At the same time delete the key info since we received the wanted key. - async fn save_session( - &self, - key_info: OugoingKeyInfo, - session: InboundGroupSession, - ) -> Result<(), CryptoStoreError> { + /// This will queue up a request cancelation. + async fn mark_as_done(&self, key_info: OugoingKeyInfo) -> Result<(), CryptoStoreError> { // TODO perhaps only remove the key info if the first known index is 0. trace!( "Successfully received a forwarded room key for {:#?}", key_info ); - self.store.save_inbound_group_sessions(&[session]).await?; self.outgoing_to_device_requests .remove(&key_info.request_id); + // TODO return the key info instead of deleting it so the sync handler + // can delete it in one transaction. self.delete_key_info(&key_info).await?; let content = RoomKeyRequestEventContent { @@ -609,7 +615,8 @@ impl KeyRequestMachine { &self, sender_key: &str, event: &mut ToDeviceEvent, - ) -> Result>, CryptoStoreError> { + ) -> Result<(Option>, Option), CryptoStoreError> + { let key_info = self.get_key_info(&event.content).await?; if let Some(info) = key_info { @@ -626,27 +633,32 @@ impl KeyRequestMachine { // If we have a previous session, check if we have a better version // and store the new one if so. - if let Some(old_session) = old_session { + let session = if let Some(old_session) = old_session { let first_old_index = old_session.first_known_index().await; let first_index = session.first_known_index().await; if first_old_index > first_index { - self.save_session(info, session).await?; + self.mark_as_done(info).await?; + Some(session) + } else { + None } // If we didn't have a previous session, store it. } else { - self.save_session(info, session).await?; - } + self.mark_as_done(info).await?; + Some(session) + }; - Ok(Some(Raw::from(AnyToDeviceEvent::ForwardedRoomKey( - event.clone(), - )))) + Ok(( + Some(Raw::from(AnyToDeviceEvent::ForwardedRoomKey(event.clone()))), + session, + )) } else { info!( "Received a forwarded room key from {}, but no key info was found.", event.sender, ); - Ok(None) + Ok((None, None)) } } } @@ -831,20 +843,20 @@ mod test { .is_none() ); - machine + let (_, first_session) = machine .receive_forwarded_room_key(&session.sender_key, &mut event) .await .unwrap(); - - let first_session = machine - .store - .get_inbound_group_session(session.room_id(), &session.sender_key, session.session_id()) - .await - .unwrap() - .unwrap(); + let first_session = first_session.unwrap(); assert_eq!(first_session.first_known_index().await, 10); + machine + .store + .save_inbound_group_sessions(&[first_session.clone()]) + .await + .unwrap(); + // Get the cancel request. let request = machine.outgoing_to_device_requests.iter().next().unwrap(); let id = request.request_id; @@ -875,19 +887,12 @@ mod test { content, }; - machine + let (_, second_session) = machine .receive_forwarded_room_key(&session.sender_key, &mut event) .await .unwrap(); - let second_session = machine - .store - .get_inbound_group_session(session.room_id(), &session.sender_key, session.session_id()) - .await - .unwrap() - .unwrap(); - - assert_eq!(second_session.first_known_index().await, 10); + assert!(second_session.is_none()); let export = session.export_at_index(0).await.unwrap(); @@ -898,18 +903,12 @@ mod test { content, }; - machine + let (_, second_session) = machine .receive_forwarded_room_key(&session.sender_key, &mut event) .await .unwrap(); - let second_session = machine - .store - .get_inbound_group_session(session.room_id(), &session.sender_key, session.session_id()) - .await - .unwrap() - .unwrap(); - assert_eq!(second_session.first_known_index().await, 0); + assert_eq!(second_session.unwrap().first_known_index().await, 0); } #[async_test] @@ -1132,14 +1131,19 @@ mod test { .unwrap() .is_none()); - let (decrypted, sender_key, _) = + let (_, decrypted, sender_key, _) = alice_account.decrypt_to_device_event(&event).await.unwrap(); if let AnyToDeviceEvent::ForwardedRoomKey(mut e) = decrypted.deserialize().unwrap() { - alice_machine + let (_, session) = alice_machine .receive_forwarded_room_key(&sender_key, &mut e) .await .unwrap(); + alice_machine + .store + .save_inbound_group_sessions(&[session.unwrap()]) + .await + .unwrap(); } else { panic!("Invalid decrypted event type"); } @@ -1315,14 +1319,19 @@ mod test { .unwrap() .is_none()); - let (decrypted, sender_key, _) = + let (_, decrypted, sender_key, _) = alice_account.decrypt_to_device_event(&event).await.unwrap(); if let AnyToDeviceEvent::ForwardedRoomKey(mut e) = decrypted.deserialize().unwrap() { - alice_machine + let (_, session) = alice_machine .receive_forwarded_room_key(&sender_key, &mut e) .await .unwrap(); + alice_machine + .store + .save_inbound_group_sessions(&[session.unwrap()]) + .await + .unwrap(); } else { panic!("Invalid decrypted event type"); } diff --git a/matrix_sdk_crypto/src/machine.rs b/matrix_sdk_crypto/src/machine.rs index 675d5be7..0db52888 100644 --- a/matrix_sdk_crypto/src/machine.rs +++ b/matrix_sdk_crypto/src/machine.rs @@ -47,15 +47,18 @@ use matrix_sdk_common::{ use crate::store::sqlite::SqliteStore; use crate::{ error::{EventError, MegolmError, MegolmResult, OlmError, OlmResult}, - identities::{Device, IdentityManager, ReadOnlyDevice, UserDevices, UserIdentities}, + identities::{Device, IdentityManager, UserDevices}, key_request::KeyRequestMachine, olm::{ Account, EncryptionSettings, ExportedRoomKey, GroupSessionKey, IdentityKeys, - InboundGroupSession, PrivateCrossSigningIdentity, ReadOnlyAccount, + InboundGroupSession, PrivateCrossSigningIdentity, ReadOnlyAccount, Session, }, requests::{IncomingResponse, OutgoingRequest, UploadSigningKeysRequest}, session_manager::{GroupSessionManager, SessionManager}, - store::{CryptoStore, MemoryStore, Result as StoreResult, Store}, + store::{ + Changes, CryptoStore, DeviceChanges, IdentityChanges, MemoryStore, Result as StoreResult, + Store, + }, verification::{Sas, VerificationMachine}, ToDeviceRequest, }; @@ -467,7 +470,7 @@ impl OlmMachine { async fn receive_keys_query_response( &self, response: &KeysQueryResponse, - ) -> OlmResult<(Vec, Vec)> { + ) -> OlmResult<(DeviceChanges, IdentityChanges)> { self.identity_manager .receive_keys_query_response(response) .await @@ -498,12 +501,12 @@ impl OlmMachine { async fn decrypt_to_device_event( &self, event: &ToDeviceEvent, - ) -> OlmResult> { - let (decrypted_event, sender_key, signing_key) = + ) -> OlmResult<(Session, Raw, Option)> { + let (session, decrypted_event, sender_key, signing_key) = self.account.decrypt_to_device_event(event).await?; // Handle the decrypted event, e.g. fetch out Megolm sessions out of // the event. - if let Some(event) = self + if let (Some(event), group_session) = self .handle_decrypted_to_device_event(&sender_key, &signing_key, &decrypted_event) .await? { @@ -512,9 +515,9 @@ impl OlmMachine { // don't want them to be able to do silly things with it. Handling // events modifies them and returns a modified one, so replace it // here if we get one. - Ok(event) + Ok((session, event, group_session)) } else { - Ok(decrypted_event) + Ok((session, decrypted_event, None)) } } @@ -524,7 +527,7 @@ impl OlmMachine { sender_key: &str, signing_key: &str, event: &mut ToDeviceEvent, - ) -> OlmResult>> { + ) -> OlmResult<(Option>, Option)> { match event.content.algorithm { EventEncryptionAlgorithm::MegolmV1AesSha2 => { let session_key = GroupSessionKey(mem::take(&mut event.content.session_key)); @@ -535,17 +538,15 @@ impl OlmMachine { &event.content.room_id, session_key, )?; - let _ = self.store.save_inbound_group_sessions(&[session]).await?; - let event = Raw::from(AnyToDeviceEvent::RoomKey(event.clone())); - Ok(Some(event)) + Ok((Some(event), Some(session))) } _ => { warn!( "Received room key with unsupported key algorithm {}", event.content.algorithm ); - Ok(None) + Ok((None, None)) } } } @@ -555,9 +556,14 @@ impl OlmMachine { &self, room_id: &RoomId, ) -> OlmResult<()> { - self.group_session_manager + let (_, session) = self + .group_session_manager .create_outbound_group_session(room_id, EncryptionSettings::default()) - .await + .await?; + + self.store.save_inbound_group_sessions(&[session]).await?; + + Ok(()) } /// Encrypt a room message for the given room. @@ -647,12 +653,12 @@ impl OlmMachine { sender_key: &str, signing_key: &str, event: &Raw, - ) -> OlmResult>> { + ) -> OlmResult<(Option>, Option)> { let event = if let Ok(e) = event.deserialize() { e } else { warn!("Decrypted to-device event failed to be parsed correctly"); - return Ok(None); + return Ok((None, None)); }; match event { @@ -665,7 +671,7 @@ impl OlmMachine { .await?), _ => { warn!("Received a unexpected encrypted to-device event"); - Ok(None) + Ok((None, None)) } } } @@ -699,11 +705,8 @@ impl OlmMachine { self.verification_machine.get_sas(flow_id) } - async fn update_one_time_key_count( - &self, - key_count: &BTreeMap, - ) -> StoreResult<()> { - self.account.update_uploaded_key_count(key_count).await + async fn update_one_time_key_count(&self, key_count: &BTreeMap) { + self.account.update_uploaded_key_count(key_count).await; } /// Handle a sync response and update the internal state of the Olm machine. @@ -719,15 +722,19 @@ impl OlmMachine { /// /// [`decrypt_room_event`]: #method.decrypt_room_event #[instrument(skip(response))] - pub async fn receive_sync_response(&self, response: &mut SyncResponse) { + pub async fn receive_sync_response(&self, response: &mut SyncResponse) -> OlmResult<()> { + // Remove verification objects that have expired or are done. self.verification_machine.garbage_collect(); - if let Err(e) = self - .update_one_time_key_count(&response.device_one_time_keys_count) - .await - { - error!("Error updating the one-time key count {:?}", e); - } + // Always save the account, a new session might get created which also + // touches the account. + let mut changes = Changes { + account: Some(self.account.inner.clone()), + ..Default::default() + }; + + self.update_one_time_key_count(&response.device_one_time_keys_count) + .await; for user_id in &response.device_lists.changed { if let Err(e) = self.identity_manager.mark_user_as_changed(&user_id).await { @@ -748,29 +755,36 @@ impl OlmMachine { match &mut event { AnyToDeviceEvent::RoomEncrypted(e) => { - let decrypted_event = match self.decrypt_to_device_event(e).await { - Ok(e) => e, - Err(err) => { - warn!( - "Failed to decrypt to-device event from {} {}", - e.sender, err - ); + let (session, decrypted_event, group_session) = + match self.decrypt_to_device_event(e).await { + Ok(e) => e, + Err(err) => { + warn!( + "Failed to decrypt to-device event from {} {}", + e.sender, err + ); - if let OlmError::SessionWedged(sender, curve_key) = err { - if let Err(e) = self - .session_manager - .mark_device_as_wedged(&sender, &curve_key) - .await - { - error!( - "Couldn't mark device from {} to be unwedged {:?}", - sender, e - ); + if let OlmError::SessionWedged(sender, curve_key) = err { + if let Err(e) = self + .session_manager + .mark_device_as_wedged(&sender, &curve_key) + .await + { + error!( + "Couldn't mark device from {} to be unwedged {:?}", + sender, e + ); + } } + continue; } - continue; - } - }; + }; + + changes.sessions.push(session); + + if let Some(group_session) = group_session { + changes.inbound_group_sessions.push(group_session); + } *event_result = decrypted_event; } @@ -789,13 +803,14 @@ impl OlmMachine { } } - if let Err(e) = self + let changed_sessions = self .key_request_machine .collect_incoming_key_requests() - .await - { - error!("Error collecting our key share requests {:?}", e); - } + .await?; + + changes.sessions.extend(changed_sessions); + + Ok(self.store.save_changes(changes).await?) } /// Decrypt an event from a room timeline. @@ -973,7 +988,13 @@ impl OlmMachine { let num_sessions = sessions.len(); - self.store.save_inbound_group_sessions(&sessions).await?; + let changes = Changes { + inbound_group_sessions: sessions, + ..Default::default() + }; + + self.store.save_changes(changes).await?; + info!( "Successfully imported {} inbound group sessions", num_sessions @@ -1198,15 +1219,19 @@ pub(crate) mod test { .unwrap() .unwrap(); + let (session, content) = bob_device + .encrypt(EventType::Dummy, json!({})) + .await + .unwrap(); + alice.store.save_sessions(&[session]).await.unwrap(); + let event = ToDeviceEvent { sender: alice.user_id().clone(), - content: bob_device - .encrypt(EventType::Dummy, json!({})) - .await - .unwrap(), + content, }; - bob.decrypt_to_device_event(&event).await.unwrap(); + let (session, _, _) = bob.decrypt_to_device_event(&event).await.unwrap(); + bob.store.save_sessions(&[session]).await.unwrap(); (alice, bob) } @@ -1492,13 +1517,15 @@ pub(crate) mod test { content: bob_device .encrypt(EventType::Dummy, json!({})) .await - .unwrap(), + .unwrap() + .1, }; let event = bob .decrypt_to_device_event(&event) .await .unwrap() + .1 .deserialize() .unwrap(); @@ -1534,12 +1561,14 @@ pub(crate) mod test { .get_outbound_group_session(&room_id) .unwrap(); - let event = bob - .decrypt_to_device_event(&event) + let (session, event, group_session) = bob.decrypt_to_device_event(&event).await.unwrap(); + + bob.store.save_sessions(&[session]).await.unwrap(); + bob.store + .save_inbound_group_sessions(&[group_session.unwrap()]) .await - .unwrap() - .deserialize() .unwrap(); + let event = event.deserialize().unwrap(); if let AnyToDeviceEvent::RoomKey(event) = event { assert_eq!(&event.sender, alice.user_id()); @@ -1579,7 +1608,11 @@ pub(crate) mod test { content: to_device_requests_to_content(to_device_requests), }; - bob.decrypt_to_device_event(&event).await.unwrap(); + let (_, _, group_session) = bob.decrypt_to_device_event(&event).await.unwrap(); + bob.store + .save_inbound_group_sessions(&[group_session.unwrap()]) + .await + .unwrap(); let plaintext = "It is a secret to everybody"; diff --git a/matrix_sdk_crypto/src/olm/account.rs b/matrix_sdk_crypto/src/olm/account.rs index 0960038b..422d127f 100644 --- a/matrix_sdk_crypto/src/olm/account.rs +++ b/matrix_sdk_crypto/src/olm/account.rs @@ -52,7 +52,7 @@ use olm_rs::{ use crate::{ error::{EventError, OlmResult, SessionCreationError}, identities::ReadOnlyDevice, - store::{Result as StoreResult, Store}, + store::Store, OlmError, }; @@ -76,7 +76,7 @@ impl Account { pub async fn decrypt_to_device_event( &self, event: &ToDeviceEvent, - ) -> OlmResult<(Raw, String, String)> { + ) -> OlmResult<(Session, Raw, String, String)> { debug!("Decrypting to-device event"); let content = if let EncryptedEventContent::OlmV1Curve25519AesSha2(c) = &event.content { @@ -103,27 +103,28 @@ impl Account { .map_err(|_| EventError::UnsupportedOlmType)?; // Decrypt the OlmMessage and get a Ruma event out of it. - let (decrypted_event, signing_key) = self + let (session, decrypted_event, signing_key) = self .decrypt_olm_message(&event.sender, &content.sender_key, message) .await?; debug!("Decrypted a to-device event {:?}", decrypted_event); - Ok((decrypted_event, content.sender_key.clone(), signing_key)) + Ok(( + session, + decrypted_event, + content.sender_key.clone(), + signing_key, + )) } else { warn!("Olm event doesn't contain a ciphertext for our key"); Err(EventError::MissingCiphertext.into()) } } - pub async fn update_uploaded_key_count( - &self, - key_count: &BTreeMap, - ) -> StoreResult<()> { + pub async fn update_uploaded_key_count(&self, key_count: &BTreeMap) { let one_time_key_count = key_count.get(&DeviceKeyAlgorithm::SignedCurve25519); let count: u64 = one_time_key_count.map_or(0, |c| (*c).into()); self.inner.update_uploaded_key_count(count); - self.store.save_account(self.inner.clone()).await } pub async fn receive_keys_upload_response( @@ -161,7 +162,7 @@ impl Account { sender: &UserId, sender_key: &str, message: &OlmMessage, - ) -> OlmResult> { + ) -> OlmResult> { let s = self.store.get_sessions(sender_key).await?; // We don't have any existing sessions, return early. @@ -171,8 +172,7 @@ impl Account { return Ok(None); }; - let mut session_to_save = None; - let mut plaintext = None; + let mut decrypted: Option<(Session, String)> = None; for session in &mut *sessions.lock().await { let mut matches = false; @@ -191,9 +191,7 @@ impl Account { match ret { Ok(p) => { - plaintext = Some(p); - session_to_save = Some(session.clone()); - + decrypted = Some((session.clone(), p)); break; } Err(e) => { @@ -214,14 +212,7 @@ impl Account { } } - if let Some(session) = session_to_save { - // Decryption was successful, save the new ratchet state of the - // session that was used to decrypt the message. - trace!("Saved the new session state for {}", sender); - self.store.save_sessions(&[session]).await?; - } - - Ok(plaintext) + Ok(decrypted) } /// Decrypt an Olm message, creating a new Olm session if possible. @@ -230,15 +221,15 @@ impl Account { sender: &UserId, sender_key: &str, message: OlmMessage, - ) -> OlmResult<(Raw, String)> { + ) -> OlmResult<(Session, Raw, String)> { // First try to decrypt using an existing session. - let plaintext = if let Some(p) = self + let (session, plaintext) = if let Some(d) = self .try_decrypt_olm_message(sender, sender_key, &message) .await? { - // Decryption succeeded, de-structure the plaintext out of the - // Option. - p + // Decryption succeeded, de-structure the session/plaintext out of + // the Option. + d } else { // Decryption failed with every known session, let's try to create a // new session. @@ -278,9 +269,6 @@ impl Account { } }; - // Save the account since we remove the one-time key that - // was used to create this session. - self.store.save_account(self.inner.clone()).await?; session } }; @@ -288,15 +276,23 @@ impl Account { // Decrypt our message, this shouldn't fail since we're using a // newly created Session. let plaintext = session.decrypt(message).await?; - - // Save the new ratcheted state of the session. - self.store.save_sessions(&[session]).await?; - plaintext + (session, plaintext) }; trace!("Successfully decrypted a Olm message: {}", plaintext); - self.parse_decrypted_to_device_event(sender, &plaintext) + let (event, signing_key) = match self.parse_decrypted_to_device_event(sender, &plaintext) { + Ok(r) => r, + Err(e) => { + // We might created a new session but decryption might still + // have failed, store it for the error case here, this is fine + // since we don't expect this to happen often or at all. + self.store.save_sessions(&[session]).await?; + return Err(e); + } + }; + + Ok((session, event, signing_key)) } /// Parse a decrypted Olm message, check that the plaintext and encrypted diff --git a/matrix_sdk_crypto/src/session_manager/group_sessions.rs b/matrix_sdk_crypto/src/session_manager/group_sessions.rs index d20f8f8e..374610fd 100644 --- a/matrix_sdk_crypto/src/session_manager/group_sessions.rs +++ b/matrix_sdk_crypto/src/session_manager/group_sessions.rs @@ -28,8 +28,8 @@ use tracing::{debug, info}; use crate::{ error::{EventError, MegolmResult, OlmResult}, - olm::{Account, OutboundGroupSession}, - store::Store, + olm::{Account, InboundGroupSession, OutboundGroupSession}, + store::{Changes, Store}, Device, EncryptionSettings, OlmError, ToDeviceRequest, }; @@ -140,19 +140,17 @@ impl GroupSessionManager { &self, room_id: &RoomId, settings: EncryptionSettings, - ) -> OlmResult<()> { + ) -> OlmResult<(OutboundGroupSession, InboundGroupSession)> { let (outbound, inbound) = self .account .create_group_session_pair(room_id, settings) .await .map_err(|_| EventError::UnsupportedAlgorithm)?; - let _ = self.store.save_inbound_group_sessions(&[inbound]).await?; - let _ = self .outbound_group_sessions - .insert(room_id.to_owned(), outbound); - Ok(()) + .insert(room_id.to_owned(), outbound.clone()); + Ok((outbound, inbound)) } /// Get to-device requests to share a group session with users in a room. @@ -169,13 +167,12 @@ impl GroupSessionManager { users: impl Iterator, encryption_settings: impl Into, ) -> OlmResult>> { - self.create_outbound_group_session(room_id, encryption_settings.into()) - .await?; - let session = self.outbound_group_sessions.get(room_id).unwrap(); + let mut changes = Changes::default(); - if session.shared() { - panic!("Session is already shared"); - } + let (session, inbound_session) = self + .create_outbound_group_session(room_id, encryption_settings.into()) + .await?; + changes.inbound_group_sessions.push(inbound_session); let mut devices: Vec = Vec::new(); @@ -196,7 +193,7 @@ impl GroupSessionManager { .encrypt(EventType::RoomKey, key_content.clone()) .await; - let encrypted = match encrypted { + let (used_session, encrypted) = match encrypted { Ok(c) => c, Err(OlmError::MissingSession) | Err(OlmError::EventError(EventError::MissingSenderKey)) => { @@ -205,6 +202,8 @@ impl GroupSessionManager { Err(e) => return Err(e), }; + changes.sessions.push(used_session); + messages .entry(device.user_id().clone()) .or_insert_with(BTreeMap::new) @@ -237,6 +236,8 @@ impl GroupSessionManager { session.mark_as_shared(); } + self.store.save_changes(changes).await?; + Ok(requests) } } diff --git a/matrix_sdk_crypto/src/session_manager/sessions.rs b/matrix_sdk_crypto/src/session_manager/sessions.rs index ab3976ab..c86c09ac 100644 --- a/matrix_sdk_crypto/src/session_manager/sessions.rs +++ b/matrix_sdk_crypto/src/session_manager/sessions.rs @@ -33,7 +33,7 @@ use crate::{ key_request::KeyRequestMachine, olm::Account, requests::{OutgoingRequest, ToDeviceRequest}, - store::{Result as StoreResult, Store}, + store::{Changes, Result as StoreResult, Store}, ReadOnlyDevice, }; @@ -128,7 +128,7 @@ impl SessionManager { .is_some() { if let Some(device) = self.store.get_device(user_id, device_id).await? { - let content = device.encrypt(EventType::Dummy, json!({})).await?; + let (_, content) = device.encrypt(EventType::Dummy, json!({})).await?; let id = Uuid::new_v4(); let mut messages = BTreeMap::new(); @@ -258,6 +258,8 @@ impl SessionManager { pub async fn receive_keys_claim_response(&self, response: &KeysClaimResponse) -> OlmResult<()> { // TODO log the failures here + let mut changes = Changes::default(); + for (user_id, user_devices) in &response.one_time_keys { for (device_id, key_map) in user_devices { let device = match self.store.get_readonly_device(&user_id, device_id).await { @@ -284,15 +286,12 @@ impl SessionManager { let session = match self.account.create_outbound_session(device, &key_map).await { Ok(s) => s, Err(e) => { - warn!("{:?}", e); + warn!("Error creating new outbound session {:?}", e); continue; } }; - if let Err(e) = self.store.save_sessions(&[session]).await { - error!("Failed to store newly created Olm session {}", e); - continue; - } + changes.sessions.push(session); self.key_request_machine.retry_keyshare(&user_id, device_id); @@ -304,7 +303,8 @@ impl SessionManager { } } } - Ok(()) + + Ok(self.store.save_changes(changes).await?) } } diff --git a/matrix_sdk_crypto/src/store/memorystore.rs b/matrix_sdk_crypto/src/store/memorystore.rs index 2e5a8762..f41d4579 100644 --- a/matrix_sdk_crypto/src/store/memorystore.rs +++ b/matrix_sdk_crypto/src/store/memorystore.rs @@ -26,7 +26,7 @@ use matrix_sdk_common_macros::async_trait; use super::{ caches::{DeviceStore, GroupSessionStore, SessionStore}, - CryptoStore, InboundGroupSession, ReadOnlyAccount, Result, Session, + Changes, CryptoStore, InboundGroupSession, ReadOnlyAccount, Result, Session, }; use crate::identities::{ReadOnlyDevice, UserIdentities}; @@ -61,6 +61,30 @@ impl MemoryStore { pub fn new() -> Self { Self::default() } + + pub(crate) async fn save_devices(&self, mut devices: Vec) { + for device in devices.drain(..) { + let _ = self.devices.add(device); + } + } + + async fn delete_devices(&self, mut devices: Vec) { + for device in devices.drain(..) { + let _ = self.devices.remove(device.user_id(), device.device_id()); + } + } + + async fn save_sessions(&self, mut sessions: Vec) { + for session in sessions.drain(..) { + let _ = self.sessions.add(session.clone()).await; + } + } + + async fn save_inbound_group_sessions(&self, mut sessions: Vec) { + for session in sessions.drain(..) { + self.inbound_group_sessions.add(session); + } + } } #[async_trait] @@ -73,9 +97,24 @@ impl CryptoStore for MemoryStore { Ok(()) } - async fn save_sessions(&self, sessions: &[Session]) -> Result<()> { - for session in sessions { - let _ = self.sessions.add(session.clone()).await; + async fn save_changes(&self, mut changes: Changes) -> Result<()> { + self.save_sessions(changes.sessions).await; + self.save_inbound_group_sessions(changes.inbound_group_sessions) + .await; + + self.save_devices(changes.devices.new).await; + self.save_devices(changes.devices.changed).await; + self.delete_devices(changes.devices.deleted).await; + + for identity in changes + .identities + .new + .drain(..) + .chain(changes.identities.changed) + { + let _ = self + .identities + .insert(identity.user_id().to_owned(), identity.clone()); } Ok(()) @@ -85,14 +124,6 @@ impl CryptoStore for MemoryStore { Ok(self.sessions.get(sender_key)) } - async fn save_inbound_group_sessions(&self, sessions: &[InboundGroupSession]) -> Result<()> { - for session in sessions { - self.inbound_group_sessions.add(session.clone()); - } - - Ok(()) - } - async fn get_inbound_group_session( &self, room_id: &RoomId, @@ -151,11 +182,6 @@ impl CryptoStore for MemoryStore { Ok(self.devices.get(user_id, device_id)) } - async fn delete_device(&self, device: ReadOnlyDevice) -> Result<()> { - let _ = self.devices.remove(device.user_id(), device.device_id()); - Ok(()) - } - async fn get_user_devices( &self, user_id: &UserId, @@ -163,28 +189,11 @@ impl CryptoStore for MemoryStore { Ok(self.devices.user_devices(user_id)) } - async fn save_devices(&self, devices: &[ReadOnlyDevice]) -> Result<()> { - for device in devices { - let _ = self.devices.add(device.clone()); - } - - Ok(()) - } - async fn get_user_identity(&self, user_id: &UserId) -> Result> { #[allow(clippy::map_clone)] Ok(self.identities.get(user_id).map(|i| i.clone())) } - async fn save_user_identities(&self, identities: &[UserIdentities]) -> Result<()> { - for identity in identities { - let _ = self - .identities - .insert(identity.user_id().to_owned(), identity.clone()); - } - Ok(()) - } - async fn save_value(&self, key: String, value: String) -> Result<()> { self.values.insert(key, value); Ok(()) @@ -217,7 +226,7 @@ mod test { assert!(store.load_account().await.unwrap().is_none()); store.save_account(account).await.unwrap(); - store.save_sessions(&[session.clone()]).await.unwrap(); + store.save_sessions(vec![session.clone()]).await; let sessions = store .get_sessions(&session.sender_key) @@ -250,9 +259,8 @@ mod test { let store = MemoryStore::new(); let _ = store - .save_inbound_group_sessions(&[inbound.clone()]) - .await - .unwrap(); + .save_inbound_group_sessions(vec![inbound.clone()]) + .await; let loaded_session = store .get_inbound_group_session(&room_id, "test_key", outbound.session_id()) @@ -267,7 +275,7 @@ mod test { let device = get_device(); let store = MemoryStore::new(); - store.save_devices(&[device.clone()]).await.unwrap(); + store.save_devices(vec![device.clone()]).await; let loaded_device = store .get_device(device.user_id(), device.device_id()) @@ -286,7 +294,7 @@ mod test { assert_eq!(&device, loaded_device); - store.delete_device(device.clone()).await.unwrap(); + store.delete_devices(vec![device.clone()]).await; assert!(store .get_device(device.user_id(), device.device_id()) .await diff --git a/matrix_sdk_crypto/src/store/mod.rs b/matrix_sdk_crypto/src/store/mod.rs index b8b661ef..478cb709 100644 --- a/matrix_sdk_crypto/src/store/mod.rs +++ b/matrix_sdk_crypto/src/store/mod.rs @@ -100,6 +100,31 @@ pub(crate) struct Store { verification_machine: VerificationMachine, } +#[derive(Debug, Default)] +#[allow(missing_docs)] +pub struct Changes { + pub account: Option, + pub sessions: Vec, + pub inbound_group_sessions: Vec, + pub identities: IdentityChanges, + pub devices: DeviceChanges, +} + +#[derive(Debug, Clone, Default)] +#[allow(missing_docs)] +pub struct IdentityChanges { + pub new: Vec, + pub changed: Vec, +} + +#[derive(Debug, Clone, Default)] +#[allow(missing_docs)] +pub struct DeviceChanges { + pub new: Vec, + pub changed: Vec, + pub deleted: Vec, +} + impl Store { pub fn new( user_id: Arc, @@ -121,6 +146,41 @@ impl Store { self.inner.get_device(user_id, device_id).await } + pub async fn save_sessions(&self, sessions: &[Session]) -> Result<()> { + let changes = Changes { + sessions: sessions.to_vec(), + ..Default::default() + }; + + self.save_changes(changes).await + } + + #[cfg(test)] + pub async fn save_devices(&self, devices: &[ReadOnlyDevice]) -> Result<()> { + let changes = Changes { + devices: DeviceChanges { + changed: devices.to_vec(), + ..Default::default() + }, + ..Default::default() + }; + + self.save_changes(changes).await + } + + #[cfg(test)] + pub async fn save_inbound_group_sessions( + &self, + sessions: &[InboundGroupSession], + ) -> Result<()> { + let changes = Changes { + inbound_group_sessions: sessions.to_vec(), + ..Default::default() + }; + + self.save_changes(changes).await + } + pub async fn get_readonly_devices( &self, user_id: &UserId, @@ -271,12 +331,8 @@ pub trait CryptoStore: Debug { /// * `account` - The account that should be stored. async fn save_account(&self, account: ReadOnlyAccount) -> Result<()>; - /// Save the given sessions in the store. - /// - /// # Arguments - /// - /// * `session` - The sessions that should be stored. - async fn save_sessions(&self, session: &[Session]) -> Result<()>; + /// TODO + async fn save_changes(&self, changes: Changes) -> Result<()>; /// Get all the sessions that belong to the given sender key. /// @@ -285,13 +341,6 @@ pub trait CryptoStore: Debug { /// * `sender_key` - The sender key that was used to establish the sessions. async fn get_sessions(&self, sender_key: &str) -> Result>>>>; - /// Save the given inbound group sessions in the store. - /// - /// # Arguments - /// - /// * `sessions` - The sessions that should be stored. - async fn save_inbound_group_sessions(&self, session: &[InboundGroupSession]) -> Result<()>; - /// Get the inbound group session from our store. /// /// # Arguments @@ -331,20 +380,6 @@ pub trait CryptoStore: Debug { /// * `dirty` - Should the user be also marked for a key query. async fn update_tracked_user(&self, user: &UserId, dirty: bool) -> Result; - /// Save the given devices in the store. - /// - /// # Arguments - /// - /// * `device` - The device that should be stored. - async fn save_devices(&self, devices: &[ReadOnlyDevice]) -> Result<()>; - - /// Delete the given device from the store. - /// - /// # Arguments - /// - /// * `device` - The device that should be stored. - async fn delete_device(&self, device: ReadOnlyDevice) -> Result<()>; - /// Get the device for the given user with the given device id. /// /// # Arguments @@ -368,13 +403,6 @@ pub trait CryptoStore: Debug { user_id: &UserId, ) -> Result>; - /// Save the given user identities in the store. - /// - /// # Arguments - /// - /// * `identities` - The identities that should be saved in the store. - async fn save_user_identities(&self, identities: &[UserIdentities]) -> Result<()>; - /// Get the user identity that is attached to the given user id. /// /// # Arguments diff --git a/matrix_sdk_crypto/src/store/sqlite.rs b/matrix_sdk_crypto/src/store/sqlite.rs index f591af79..6759abb5 100644 --- a/matrix_sdk_crypto/src/store/sqlite.rs +++ b/matrix_sdk_crypto/src/store/sqlite.rs @@ -34,7 +34,7 @@ use matrix_sdk_common::{ use sqlx::{query, query_as, sqlite::SqliteConnectOptions, Connection, Executor, SqliteConnection}; use zeroize::Zeroizing; -use super::{caches::SessionStore, CryptoStore, CryptoStoreError, Result}; +use super::{caches::SessionStore, Changes, CryptoStore, CryptoStoreError, Result}; use crate::{ identities::{LocalTrust, OwnUserIdentity, ReadOnlyDevice, UserIdentities, UserIdentity}, olm::{ @@ -456,11 +456,17 @@ impl SqliteStore { Ok(()) } - async fn lazy_load_sessions(&self, sender_key: &str) -> Result<()> { + async fn lazy_load_sessions( + &self, + connection: &mut SqliteConnection, + sender_key: &str, + ) -> Result<()> { let loaded_sessions = self.sessions.get(sender_key).is_some(); if !loaded_sessions { - let sessions = self.load_sessions_for(sender_key).await?; + let sessions = self + .load_sessions_for_helper(connection, sender_key) + .await?; if !sessions.is_empty() { self.sessions.set_for_sender(sender_key, sessions); @@ -470,20 +476,33 @@ impl SqliteStore { Ok(()) } - async fn get_sessions_for(&self, sender_key: &str) -> Result>>>> { - self.lazy_load_sessions(sender_key).await?; + async fn get_sessions_for( + &self, + connection: &mut SqliteConnection, + sender_key: &str, + ) -> Result>>>> { + self.lazy_load_sessions(connection, sender_key).await?; Ok(self.sessions.get(sender_key)) } + #[cfg(test)] async fn load_sessions_for(&self, sender_key: &str) -> Result> { + let mut connection = self.connection.lock().await; + self.load_sessions_for_helper(&mut connection, sender_key) + .await + } + + async fn load_sessions_for_helper( + &self, + connection: &mut SqliteConnection, + sender_key: &str, + ) -> Result> { let account_info = self .account_info .lock() .unwrap() .clone() .ok_or(CryptoStoreError::AccountUnset)?; - let mut connection = self.connection.lock().await; - let mut rows: Vec<(String, String, String, String)> = query_as( "SELECT pickle, sender_key, creation_time, last_use_time FROM sessions WHERE account_id = ? and sender_key = ?", @@ -1231,6 +1250,134 @@ impl SqliteStore { Ok(()) } + #[cfg(test)] + async fn save_sessions(&self, sessions: &[Session]) -> Result<()> { + let mut connection = self.connection.lock().await; + let mut transaction = connection.begin().await?; + + self.save_sessions_helper(&mut transaction, sessions) + .await?; + transaction.commit().await?; + + Ok(()) + } + + async fn save_sessions_helper( + &self, + connection: &mut SqliteConnection, + sessions: &[Session], + ) -> Result<()> { + let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; + + for session in sessions { + self.lazy_load_sessions(connection, &session.sender_key) + .await?; + } + + for session in sessions { + self.sessions.add(session.clone()).await; + + let pickle = session.pickle(self.get_pickle_mode()).await; + + let session_id = session.session_id(); + let creation_time = serde_json::to_string(&pickle.creation_time)?; + let last_use_time = serde_json::to_string(&pickle.last_use_time)?; + + query( + "REPLACE INTO sessions ( + session_id, account_id, creation_time, last_use_time, sender_key, pickle + ) VALUES (?, ?, ?, ?, ?, ?)", + ) + .bind(&session_id) + .bind(&account_id) + .bind(&*creation_time) + .bind(&*last_use_time) + .bind(&pickle.sender_key) + .bind(&pickle.pickle.as_str()) + .execute(&mut *connection) + .await?; + } + + Ok(()) + } + + async fn save_devices( + &self, + mut connection: &mut SqliteConnection, + devices: &[ReadOnlyDevice], + ) -> Result<()> { + for device in devices { + self.save_device_helper(&mut connection, device.clone()) + .await? + } + + Ok(()) + } + + async fn delete_devices( + &self, + connection: &mut SqliteConnection, + devices: &[ReadOnlyDevice], + ) -> Result<()> { + let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; + + for device in devices { + query( + "DELETE FROM devices + WHERE account_id = ?1 and user_id = ?2 and device_id = ?3 + ", + ) + .bind(account_id) + .bind(&device.user_id().to_string()) + .bind(device.device_id().as_str()) + .execute(&mut *connection) + .await?; + } + + Ok(()) + } + + #[cfg(test)] + async fn save_inbound_group_sessions_test( + &self, + sessions: &[InboundGroupSession], + ) -> Result<()> { + let mut connection = self.connection.lock().await; + let mut transaction = connection.begin().await?; + + self.save_inbound_group_sessions(&mut transaction, sessions) + .await?; + + transaction.commit().await?; + Ok(()) + } + + async fn save_inbound_group_sessions( + &self, + connection: &mut SqliteConnection, + sessions: &[InboundGroupSession], + ) -> Result<()> { + let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; + + for session in sessions { + self.save_inbound_group_session_helper(account_id, connection, session) + .await?; + } + + Ok(()) + } + + async fn save_user_identities( + &self, + mut connection: &mut SqliteConnection, + users: &[UserIdentities], + ) -> Result<()> { + for user in users { + self.save_user_helper(&mut connection, user).await?; + } + Ok(()) + } + async fn save_user_helper( &self, mut connection: &mut SqliteConnection, @@ -1369,39 +1516,26 @@ impl CryptoStore for SqliteStore { Ok(()) } - async fn save_sessions(&self, sessions: &[Session]) -> Result<()> { - let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; - - for session in sessions { - self.lazy_load_sessions(&session.sender_key).await?; - } - + async fn save_changes(&self, changes: Changes) -> Result<()> { let mut connection = self.connection.lock().await; let mut transaction = connection.begin().await?; - for session in sessions { - self.sessions.add(session.clone()).await; - - let pickle = session.pickle(self.get_pickle_mode()).await; - - let session_id = session.session_id(); - let creation_time = serde_json::to_string(&pickle.creation_time)?; - let last_use_time = serde_json::to_string(&pickle.last_use_time)?; - - query( - "REPLACE INTO sessions ( - session_id, account_id, creation_time, last_use_time, sender_key, pickle - ) VALUES (?, ?, ?, ?, ?, ?)", - ) - .bind(&session_id) - .bind(&account_id) - .bind(&*creation_time) - .bind(&*last_use_time) - .bind(&pickle.sender_key) - .bind(&pickle.pickle.as_str()) - .execute(&mut *transaction) + self.save_sessions_helper(&mut transaction, &changes.sessions) + .await?; + self.save_inbound_group_sessions(&mut transaction, &changes.inbound_group_sessions) + .await?; + + self.save_devices(&mut transaction, &changes.devices.new) + .await?; + self.save_devices(&mut transaction, &changes.devices.changed) + .await?; + self.delete_devices(&mut transaction, &changes.devices.deleted) + .await?; + + self.save_user_identities(&mut transaction, &changes.identities.new) + .await?; + self.save_user_identities(&mut transaction, &changes.identities.changed) .await?; - } transaction.commit().await?; @@ -1409,22 +1543,8 @@ impl CryptoStore for SqliteStore { } async fn get_sessions(&self, sender_key: &str) -> Result>>>> { - Ok(self.get_sessions_for(sender_key).await?) - } - - async fn save_inbound_group_sessions(&self, sessions: &[InboundGroupSession]) -> Result<()> { - let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; let mut connection = self.connection.lock().await; - let mut transaction = connection.begin().await?; - - for session in sessions { - self.save_inbound_group_session_helper(account_id, &mut transaction, session) - .await?; - } - - transaction.commit().await?; - - Ok(()) + Ok(self.get_sessions_for(&mut connection, sender_key).await?) } async fn get_inbound_group_session( @@ -1469,38 +1589,6 @@ impl CryptoStore for SqliteStore { Ok(already_added) } - async fn save_devices(&self, devices: &[ReadOnlyDevice]) -> Result<()> { - let mut connection = self.connection.lock().await; - let mut transaction = connection.begin().await?; - - for device in devices { - self.save_device_helper(&mut transaction, device.clone()) - .await? - } - - transaction.commit().await?; - - Ok(()) - } - - async fn delete_device(&self, device: ReadOnlyDevice) -> Result<()> { - let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; - let mut connection = self.connection.lock().await; - - query( - "DELETE FROM devices - WHERE account_id = ?1 and user_id = ?2 and device_id = ?3 - ", - ) - .bind(account_id) - .bind(&device.user_id().to_string()) - .bind(device.device_id().as_str()) - .execute(&mut *connection) - .await?; - - Ok(()) - } - async fn get_device( &self, user_id: &UserId, @@ -1520,19 +1608,6 @@ impl CryptoStore for SqliteStore { self.load_user(user_id).await } - async fn save_user_identities(&self, users: &[UserIdentities]) -> Result<()> { - let mut connection = self.connection.lock().await; - let mut transaction = connection.begin().await?; - - for user in users { - self.save_user_helper(&mut transaction, user).await?; - } - - transaction.commit().await?; - - Ok(()) - } - async fn save_value(&self, key: String, value: String) -> Result<()> { let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; let mut connection = self.connection.lock().await; @@ -1598,6 +1673,7 @@ mod test { user::test::{get_other_identity, get_own_identity}, }, olm::{GroupSessionKey, InboundGroupSession, ReadOnlyAccount, Session}, + store::{Changes, DeviceChanges, IdentityChanges}, }; use matrix_sdk_common::{ api::r0::keys::SignedKey, @@ -1854,7 +1930,7 @@ mod test { .expect("Can't create session"); store - .save_inbound_group_sessions(&[session]) + .save_inbound_group_sessions_test(&[session]) .await .expect("Can't save group session"); } @@ -1880,7 +1956,7 @@ mod test { let session = InboundGroupSession::from_export(export).unwrap(); store - .save_inbound_group_sessions(&[session.clone()]) + .save_inbound_group_sessions_test(&[session.clone()]) .await .expect("Can't save group session"); @@ -1952,7 +2028,15 @@ mod test { let (_account, store, dir) = get_loaded_store().await; let device = get_device(); - store.save_devices(&[device.clone()]).await.unwrap(); + let changes = Changes { + devices: DeviceChanges { + changed: vec![device.clone()], + ..Default::default() + }, + ..Default::default() + }; + + store.save_changes(changes).await.unwrap(); drop(store); @@ -1986,8 +2070,25 @@ mod test { let (_account, store, dir) = get_loaded_store().await; let device = get_device(); - store.save_devices(&[device.clone()]).await.unwrap(); - store.delete_device(device.clone()).await.unwrap(); + let changes = Changes { + devices: DeviceChanges { + changed: vec![device.clone()], + ..Default::default() + }, + ..Default::default() + }; + + store.save_changes(changes).await.unwrap(); + + let changes = Changes { + devices: DeviceChanges { + deleted: vec![device.clone()], + ..Default::default() + }, + ..Default::default() + }; + + store.save_changes(changes).await.unwrap(); let store = SqliteStore::open(&alice_id(), &alice_device_id(), dir.path()) .await @@ -2024,8 +2125,16 @@ mod test { let own_identity = get_own_identity(); + let changes = Changes { + identities: IdentityChanges { + changed: vec![own_identity.clone().into()], + ..Default::default() + }, + ..Default::default() + }; + store - .save_user_identities(&[own_identity.clone().into()]) + .save_changes(changes) .await .expect("Can't save identity"); @@ -2052,10 +2161,15 @@ mod test { let other_identity = get_other_identity(); - store - .save_user_identities(&[other_identity.clone().into()]) - .await - .unwrap(); + let changes = Changes { + identities: IdentityChanges { + changed: vec![other_identity.clone().into()], + ..Default::default() + }, + ..Default::default() + }; + + store.save_changes(changes).await.unwrap(); let loaded_user = store .load_user(other_identity.user_id()) @@ -2072,10 +2186,15 @@ mod test { own_identity.mark_as_verified(); - store - .save_user_identities(&[own_identity.into()]) - .await - .unwrap(); + let changes = Changes { + identities: IdentityChanges { + changed: vec![own_identity.into()], + ..Default::default() + }, + ..Default::default() + }; + + store.save_changes(changes).await.unwrap(); let loaded_user = store.load_user(&user_id).await.unwrap().unwrap(); assert!(loaded_user.own().unwrap().is_verified()) } diff --git a/matrix_sdk_crypto/src/verification/machine.rs b/matrix_sdk_crypto/src/verification/machine.rs index ee12d562..71e3cb19 100644 --- a/matrix_sdk_crypto/src/verification/machine.rs +++ b/matrix_sdk_crypto/src/verification/machine.rs @@ -194,18 +194,14 @@ impl VerificationMachine { self.receive_event_helper(&s, event); if s.is_done() { - if !s.mark_device_as_verified().await? { - if let Some(r) = s.cancel() { - self.outgoing_to_device_messages.insert( - r.txn_id, - OutgoingRequest { - request_id: r.txn_id, - request: Arc::new(r.into()), - }, - ); - } - } else { - s.mark_identity_as_verified().await?; + if let Some(r) = s.mark_as_done().await? { + self.outgoing_to_device_messages.insert( + r.txn_id, + OutgoingRequest { + request_id: r.txn_id, + request: Arc::new(r.into()), + }, + ); } } }; @@ -258,17 +254,15 @@ mod test { let alice = ReadOnlyAccount::new(&alice_id(), &alice_device_id()); let bob = ReadOnlyAccount::new(&bob_id(), &bob_device_id()); let store = MemoryStore::new(); - let bob_store: Arc> = Arc::new(Box::new(MemoryStore::new())); + let bob_store = MemoryStore::new(); let bob_device = ReadOnlyDevice::from_account(&bob).await; let alice_device = ReadOnlyDevice::from_account(&alice).await; - store.save_devices(&[bob_device]).await.unwrap(); - bob_store - .save_devices(&[alice_device.clone()]) - .await - .unwrap(); + store.save_devices(vec![bob_device]).await; + bob_store.save_devices(vec![alice_device.clone()]).await; + let bob_store: Arc> = Arc::new(Box::new(bob_store)); let machine = VerificationMachine::new(alice, Arc::new(Box::new(store))); let (bob_sas, start_content) = Sas::start(bob, alice_device, bob_store, None); machine diff --git a/matrix_sdk_crypto/src/verification/sas/mod.rs b/matrix_sdk_crypto/src/verification/sas/mod.rs index 95d61541..b71070a7 100644 --- a/matrix_sdk_crypto/src/verification/sas/mod.rs +++ b/matrix_sdk_crypto/src/verification/sas/mod.rs @@ -34,7 +34,7 @@ use matrix_sdk_common::{ use crate::{ identities::{LocalTrust, ReadOnlyDevice, UserIdentities}, - store::{CryptoStore, CryptoStoreError}, + store::{Changes, CryptoStore, CryptoStoreError, DeviceChanges}, ReadOnlyAccount, ToDeviceRequest, }; @@ -189,34 +189,64 @@ impl Sas { (content, guard.is_done()) }; - if done { - // TODO move the logic that marks and stores the device into the - // else branch and only after the identity was verified as well. We - // dont' want to verify one without the other. - if !self.mark_device_as_verified().await? { - return Ok(self.cancel()); - } else { - self.mark_identity_as_verified().await?; - } - } + let cancel = if done { + self.mark_as_done().await? + } else { + None + }; - Ok(content.map(|c| { - let content = AnyToDeviceEventContent::KeyVerificationMac(c); - self.content_to_request(content) - })) + if cancel.is_some() { + Ok(cancel) + } else { + Ok(content.map(|c| { + let content = AnyToDeviceEventContent::KeyVerificationMac(c); + self.content_to_request(content) + })) + } } - pub(crate) async fn mark_identity_as_verified(&self) -> Result { + pub(crate) async fn mark_as_done(&self) -> Result, CryptoStoreError> { + if let Some(device) = self.mark_device_as_verified().await? { + let identity = self.mark_identity_as_verified().await?; + + let mut changes = Changes { + devices: DeviceChanges { + changed: vec![device], + ..Default::default() + }, + ..Default::default() + }; + + if let Some(i) = identity { + changes.identities.changed.push(i); + } + + self.store.save_changes(changes).await?; + Ok(None) + } else { + Ok(self.cancel()) + } + } + + pub(crate) async fn mark_identity_as_verified( + &self, + ) -> Result, CryptoStoreError> { // If there wasn't an identity available during the verification flow // return early as there's nothing to do. if self.other_identity.is_none() { - return Ok(false); + return Ok(None); } + // TODO signal an error, e.g. when the identity got deleted so we don't + // verify/save the device either. let identity = self.store.get_user_identity(self.other_user_id()).await?; if let Some(identity) = identity { - if identity.master_key() == self.other_identity.as_ref().unwrap().master_key() { + if self + .other_identity + .as_ref() + .map_or(false, |i| i.master_key() == identity.master_key()) + { if self .verified_identities() .map_or(false, |i| i.contains(&identity)) @@ -228,13 +258,12 @@ impl Sas { if let UserIdentities::Own(i) = &identity { i.mark_as_verified(); - self.store.save_user_identities(&[identity]).await?; } // TODO if we have the private part of the user signing // key we should sign and upload a signature for this // identity. - Ok(true) + Ok(Some(identity)) } else { info!( "The interactive verification process didn't contain a \ @@ -243,7 +272,7 @@ impl Sas { self.verified_identities(), ); - Ok(false) + Ok(None) } } else { warn!( @@ -252,7 +281,7 @@ impl Sas { identity.user_id(), ); - Ok(false) + Ok(None) } } else { info!( @@ -260,11 +289,13 @@ impl Sas { verification was going on.", self.other_user_id(), ); - Ok(false) + Ok(None) } } - pub(crate) async fn mark_device_as_verified(&self) -> Result { + pub(crate) async fn mark_device_as_verified( + &self, + ) -> Result, CryptoStoreError> { let device = self .store .get_device(self.other_user_id(), self.other_device_id()) @@ -283,12 +314,11 @@ impl Sas { ); device.set_trust_state(LocalTrust::Verified); - self.store.save_devices(&[device]).await?; // TODO if this is a device from our own user and we have // the private part of the self signing key, we should sign // the device and upload the signature. - Ok(true) + Ok(Some(device)) } else { info!( "The interactive verification process didn't contain a \ @@ -297,7 +327,7 @@ impl Sas { device.device_id() ); - Ok(false) + Ok(None) } } else { warn!( @@ -306,7 +336,7 @@ impl Sas { device.user_id(), device.device_id() ); - Ok(false) + Ok(None) } } else { let device = self.other_device(); @@ -317,7 +347,7 @@ impl Sas { device.user_id(), device.device_id() ); - Ok(false) + Ok(None) } } @@ -777,12 +807,11 @@ mod test { let bob_device = ReadOnlyDevice::from_account(&bob).await; let alice_store: Arc> = Arc::new(Box::new(MemoryStore::new())); - let bob_store: Arc> = Arc::new(Box::new(MemoryStore::new())); + let bob_store = MemoryStore::new(); - bob_store - .save_devices(&[alice_device.clone()]) - .await - .unwrap(); + bob_store.save_devices(vec![alice_device.clone()]).await; + + let bob_store: Arc> = Arc::new(Box::new(bob_store)); let (alice, content) = Sas::start(alice, bob_device, alice_store, None); let event = wrap_to_device_event(alice.user_id(), content);