diff --git a/matrix_sdk_crypto/src/machine.rs b/matrix_sdk_crypto/src/machine.rs index e3298909..0a2fc577 100644 --- a/matrix_sdk_crypto/src/machine.rs +++ b/matrix_sdk_crypto/src/machine.rs @@ -39,12 +39,13 @@ use matrix_sdk_common::{ assign, encryption::DeviceKeys, events::{ - forwarded_room_key::ForwardedRoomKeyEventContent, room::encrypted::EncryptedEventContent, - room_key::RoomKeyEventContent, room_key_request::RoomKeyRequestEventContent, - AnyMessageEventContent, AnySyncRoomEvent, AnyToDeviceEvent, EventType, SyncMessageEvent, - ToDeviceEvent, + room::encrypted::EncryptedEventContent, room_key::RoomKeyEventContent, + room_key_request::RoomKeyRequestEventContent, AnyMessageEventContent, AnySyncRoomEvent, + AnyToDeviceEvent, EventType, SyncMessageEvent, ToDeviceEvent, + }, + identifiers::{ + DeviceId, DeviceIdBox, DeviceKeyAlgorithm, EventEncryptionAlgorithm, RoomId, UserId, }, - identifiers::{DeviceId, DeviceKeyAlgorithm, EventEncryptionAlgorithm, RoomId, UserId}, uuid::Uuid, Raw, }; @@ -57,6 +58,7 @@ use super::{ Device, MasterPubkey, OwnUserIdentity, ReadOnlyDevice, SelfSigningPubkey, UserDevices, UserIdentities, UserIdentity, UserSigningPubkey, }, + key_request::KeyRequestMachine, olm::{ Account, EncryptionSettings, ExportedRoomKey, GroupSessionKey, IdentityKeys, InboundGroupSession, OlmMessage, OutboundGroupSession, @@ -71,9 +73,9 @@ use super::{ #[derive(Clone)] pub struct OlmMachine { /// The unique user id that owns this account. - user_id: UserId, + user_id: Arc, /// The unique device id of the device that holds this account. - device_id: Box, + device_id: Arc>, /// Our underlying Olm Account holding our identity keys. account: Account, /// Store for the encryption keys. @@ -85,6 +87,9 @@ pub struct OlmMachine { /// A state machine that is responsible to handle and keep track of SAS /// verification flows. verification_machine: VerificationMachine, + /// The state machine that is responsible to handle outgoing and incoming + /// key requests. + key_request_machine: KeyRequestMachine, } #[cfg(not(tarpaulin_include))] @@ -115,14 +120,17 @@ impl OlmMachine { let store: Box = Box::new(MemoryStore::new()); let store = Store::new(store); let account = Account::new(user_id, device_id); + let user_id = Arc::new(user_id.clone()); + let device_id: Arc = Arc::new(device_id.into()); OlmMachine { user_id: user_id.clone(), - device_id: device_id.into(), + device_id: device_id.clone(), account: account.clone(), store: store.clone(), outbound_group_sessions: Arc::new(DashMap::new()), - verification_machine: VerificationMachine::new(account, store), + verification_machine: VerificationMachine::new(account, store.clone()), + key_request_machine: KeyRequestMachine::new(user_id, device_id, store), } } @@ -163,6 +171,10 @@ impl OlmMachine { let store = Store::new(store); let verification_machine = VerificationMachine::new(account.clone(), store.clone()); + let user_id = Arc::new(user_id.clone()); + let device_id: Arc = Arc::new(device_id.into()); + let key_request_machine = + KeyRequestMachine::new(user_id.clone(), device_id.clone(), store.clone()); Ok(OlmMachine { user_id, @@ -171,6 +183,7 @@ impl OlmMachine { store, outbound_group_sessions: Arc::new(DashMap::new()), verification_machine, + key_request_machine, }) } @@ -238,6 +251,7 @@ impl OlmMachine { } requests.append(&mut self.outgoing_to_device_requests()); + requests.append(&mut self.key_request_machine.outgoing_to_device_requests()); requests } @@ -267,7 +281,7 @@ impl OlmMachine { self.receive_keys_claim_response(response).await?; } IncomingResponse::ToDevice(_) => { - self.mark_to_device_request_as_sent(&request_id); + self.mark_to_device_request_as_sent(&request_id).await?; } }; @@ -500,7 +514,7 @@ impl OlmMachine { for (device_id, device_keys) in device_map.iter() { // We don't need our own device in the device store. - if user_id == &self.user_id && device_id == &self.device_id { + if user_id == self.user_id() && &**device_id == self.device_id() { continue; } @@ -881,7 +895,7 @@ impl OlmMachine { .ok_or_else(|| EventError::MissingField("keys".to_string()))?, )?; - if recipient != self.user_id || sender != &encrytped_sender { + if &recipient != self.user_id() || sender != &encrytped_sender { return Err(EventError::MissmatchedSender.into()); } @@ -1192,16 +1206,6 @@ impl OlmMachine { Ok(requests) } - fn add_forwarded_room_key( - &self, - _sender_key: &str, - _signing_key: &str, - _event: &ToDeviceEvent, - ) -> OlmResult<()> { - Ok(()) - // TODO - } - /// Receive and properly handle a decrypted to-device event. /// /// # Arguments @@ -1228,8 +1232,11 @@ impl OlmMachine { AnyToDeviceEvent::RoomKey(mut e) => { Ok(self.add_room_key(sender_key, signing_key, &mut e).await?) } - AnyToDeviceEvent::ForwardedRoomKey(e) => { - self.add_forwarded_room_key(sender_key, signing_key, &e)?; + AnyToDeviceEvent::ForwardedRoomKey(mut e) => { + // TODO do the mem take dance to remove the key. + self.key_request_machine + .receive_forwarded_room_key(sender_key, &mut e) + .await?; Ok(None) } _ => { @@ -1255,8 +1262,13 @@ impl OlmMachine { } /// Mark an outgoing to-device requests as sent. - fn mark_to_device_request_as_sent(&self, request_id: &Uuid) { + async fn mark_to_device_request_as_sent(&self, request_id: &Uuid) -> StoreResult<()> { self.verification_machine.mark_request_as_sent(request_id); + self.key_request_machine + .mark_outgoing_request_as_sent(request_id) + .await?; + + Ok(()) } /// Get a `Sas` verification object with the given flow id. @@ -1357,7 +1369,14 @@ impl OlmMachine { .get_inbound_group_session(room_id, &content.sender_key, &content.session_id) .await?; // TODO check if the Olm session is wedged and re-request the key. - let session = session.ok_or(MegolmError::MissingSession)?; + let session = if let Some(s) = session { + s + } else { + self.key_request_machine + .create_outgoing_key_request(room_id, &content.sender_key, &content.session_id) + .await?; + return Err(MegolmError::MissingSession); + }; // TODO check the message index. // TODO check if this is from a verified device. @@ -1774,10 +1793,10 @@ pub(crate) mod test { let one_time_key = one_time_keys.iter().next().unwrap(); let mut keys = BTreeMap::new(); keys.insert(one_time_key.0.clone(), one_time_key.1.clone()); - bob_keys.insert(bob.device_id.clone(), keys); + bob_keys.insert(bob.device_id().into(), keys); let mut one_time_keys = BTreeMap::new(); - one_time_keys.insert(bob.user_id.clone(), bob_keys); + one_time_keys.insert(bob.user_id().clone(), bob_keys); let response = claim_keys::Response::new(one_time_keys); @@ -1795,7 +1814,7 @@ pub(crate) mod test { .unwrap(); let event = ToDeviceEvent { - sender: alice.user_id.clone(), + sender: alice.user_id().clone(), content: bob_device .encrypt(EventType::Dummy, json!({})) .await @@ -2045,10 +2064,10 @@ pub(crate) mod test { let one_time_key = one_time_keys.iter().next().unwrap(); let mut keys = BTreeMap::new(); keys.insert(one_time_key.0.clone(), one_time_key.1.clone()); - bob_keys.insert(bob_machine.device_id.clone(), keys); + bob_keys.insert(bob_machine.device_id().into(), keys); let mut one_time_keys = BTreeMap::new(); - one_time_keys.insert(bob_machine.user_id.clone(), bob_keys); + one_time_keys.insert(bob_machine.user_id().clone(), bob_keys); let response = claim_keys::Response::new(one_time_keys); @@ -2077,7 +2096,7 @@ pub(crate) mod test { .unwrap(); let event = ToDeviceEvent { - sender: alice.user_id.clone(), + sender: alice.user_id().clone(), content: bob_device .encrypt(EventType::Dummy, json!({})) .await @@ -2092,7 +2111,7 @@ pub(crate) mod test { .unwrap(); if let AnyToDeviceEvent::Dummy(e) = event { - assert_eq!(e.sender, alice.user_id); + assert_eq!(&e.sender, alice.user_id()); } else { panic!("Wrong event type found {:?}", event); } @@ -2107,14 +2126,14 @@ pub(crate) mod test { let to_device_requests = alice .share_group_session( &room_id, - [bob.user_id.clone()].iter(), + [bob.user_id().clone()].iter(), EncryptionSettings::default(), ) .await .unwrap(); let event = ToDeviceEvent { - sender: alice.user_id.clone(), + sender: alice.user_id().clone(), content: to_device_requests_to_content(to_device_requests), }; @@ -2128,7 +2147,7 @@ pub(crate) mod test { .unwrap(); if let AnyToDeviceEvent::RoomKey(event) = event { - assert_eq!(event.sender, alice.user_id); + assert_eq!(&event.sender, alice.user_id()); assert!(event.content.session_key.is_empty()); } else { panic!("expected RoomKeyEvent found {:?}", event); @@ -2161,7 +2180,7 @@ pub(crate) mod test { .unwrap(); let event = ToDeviceEvent { - sender: alice.user_id.clone(), + sender: alice.user_id().clone(), content: to_device_requests_to_content(to_device_requests), };