diff --git a/matrix_sdk_crypto/src/device.rs b/matrix_sdk_crypto/src/device.rs index 3f96736d..4ac12dd4 100644 --- a/matrix_sdk_crypto/src/device.rs +++ b/matrix_sdk_crypto/src/device.rs @@ -28,15 +28,17 @@ use matrix_sdk_common::{ keys::SignedKey, to_device::send_event_to_device::IncomingRequest as OwnedToDeviceRequest, }, encryption::DeviceKeys, + events::{room::encrypted::EncryptedEventContent, EventType}, identifiers::{DeviceId, DeviceKeyAlgorithm, DeviceKeyId, EventEncryptionAlgorithm, UserId}, }; use serde_json::{json, Value}; +use tracing::warn; #[cfg(test)] use super::{Account, OlmMachine}; use crate::{ - error::SignatureError, + error::{EventError, OlmError, OlmResult, SignatureError}, store::Result as StoreResult, user_identity::{OwnUserIdentity, UserIdentities}, verification::VerificationMachine, @@ -136,6 +138,57 @@ impl Device { .save_devices(&[self.inner.clone()]) .await } + + /// Encrypt the given content for this `Device`. + /// + /// # Arguments + /// + /// * `event_type` - The type of the event. + /// + /// * `content` - The content of the event that should be encrypted. + pub(crate) async fn encrypt( + &self, + event_type: EventType, + content: Value, + ) -> OlmResult { + let sender_key = if let Some(k) = self.inner.get_key(DeviceKeyAlgorithm::Curve25519) { + k + } else { + warn!( + "Trying to encrypt a Megolm session for user {} on device {}, \ + but the device doesn't have a curve25519 key", + self.user_id(), + self.device_id() + ); + return Err(EventError::MissingSenderKey.into()); + }; + + let mut session = if let Some(s) = self + .verification_machine + .store + .get_sessions(sender_key) + .await? + { + let session = &s.lock().await[0]; + session.clone() + } else { + warn!( + "Trying to encrypt a Megolm session for user {} on device {}, \ + but no Olm session is found", + self.user_id(), + self.device_id() + ); + return Err(OlmError::MissingSession); + }; + + let message = session.encrypt(&self.inner, event_type, content).await; + self.verification_machine + .store + .save_sessions(&[session]) + .await?; + + message + } } /// A read only view over all devices belonging to a user. diff --git a/matrix_sdk_crypto/src/machine.rs b/matrix_sdk_crypto/src/machine.rs index d4b5755b..ada96b3d 100644 --- a/matrix_sdk_crypto/src/machine.rs +++ b/matrix_sdk_crypto/src/machine.rs @@ -993,53 +993,6 @@ impl OlmMachine { Ok(session.encrypt(content).await) } - /// Encrypt the given event for the given Device - /// - /// # Arguments - /// - /// * `reciepient_device` - The device that the event should be encrypted - /// for. - /// - /// * `event_type` - The type of the event. - /// - /// * `content` - The content of the event that should be encrypted. - async fn olm_encrypt( - &self, - recipient_device: &ReadOnlyDevice, - event_type: EventType, - content: Value, - ) -> OlmResult { - let sender_key = if let Some(k) = recipient_device.get_key(DeviceKeyAlgorithm::Curve25519) { - k - } else { - warn!( - "Trying to encrypt a Megolm session for user {} on device {}, \ - but the device doesn't have a curve25519 key", - recipient_device.user_id(), - recipient_device.device_id() - ); - return Err(EventError::MissingSenderKey.into()); - }; - - let mut session = if let Some(s) = self.store.get_sessions(sender_key).await? { - let session = &s.lock().await[0]; - session.clone() - } else { - warn!( - "Trying to encrypt a Megolm session for user {} on device {}, \ - but no Olm session is found", - recipient_device.user_id(), - recipient_device.device_id() - ); - return Err(OlmError::MissingSession); - }; - - let message = session.encrypt(recipient_device, event_type, content).await; - self.store.save_sessions(&[session]).await?; - - message - } - /// Should the client share a group session for the given room. /// /// Returns true if a session needs to be shared before room messages can be @@ -1100,7 +1053,7 @@ impl OlmMachine { let mut devices = Vec::new(); for user_id in users { - for device in self.store.get_user_devices(user_id).await?.devices() { + for device in self.get_user_devices(user_id).await?.devices() { if !device.is_blacklisted() { devices.push(device.clone()); } @@ -1114,8 +1067,8 @@ impl OlmMachine { let mut messages = BTreeMap::new(); for device in device_map_chunk { - let encrypted = self - .olm_encrypt(&device, EventType::RoomKey, key_content.clone()) + let encrypted = device + .encrypt(EventType::RoomKey, key_content.clone()) .await; let encrypted = match encrypted { @@ -1643,16 +1596,14 @@ pub(crate) mod test { let (alice, bob) = get_machine_pair_with_session().await; let bob_device = alice - .store .get_device(&bob.user_id, &bob.device_id) .await - .unwrap() .unwrap(); let event = ToDeviceEvent { sender: alice.user_id.clone(), - content: alice - .olm_encrypt(&bob_device, EventType::Dummy, json!({})) + content: bob_device + .encrypt(EventType::Dummy, json!({})) .await .unwrap(), }; @@ -1924,16 +1875,14 @@ pub(crate) mod test { let (alice, bob) = get_machine_pair_with_session().await; let bob_device = alice - .store .get_device(&bob.user_id, &bob.device_id) .await - .unwrap() .unwrap(); let event = ToDeviceEvent { sender: alice.user_id.clone(), - content: alice - .olm_encrypt(&bob_device, EventType::Dummy, json!({})) + content: bob_device + .encrypt(EventType::Dummy, json!({})) .await .unwrap(), };