diff --git a/matrix_sdk_crypto/src/machine.rs b/matrix_sdk_crypto/src/machine.rs index b290bf54..f4755f4d 100644 --- a/matrix_sdk_crypto/src/machine.rs +++ b/matrix_sdk_crypto/src/machine.rs @@ -1825,7 +1825,7 @@ pub(crate) mod test { let mut event = request_to_event(alice.user_id(), &request); bob.handle_verification_event(&mut event).await; - let bob_sas = bob.get_verification(alice_sas.flow_id()).unwrap(); + let bob_sas = bob.get_verification(alice_sas.flow_id().as_str()).unwrap(); assert!(alice_sas.emoji().is_none()); assert!(bob_sas.emoji().is_none()); diff --git a/matrix_sdk_crypto/src/verification/machine.rs b/matrix_sdk_crypto/src/verification/machine.rs index 5efa5b5d..4cfb04f9 100644 --- a/matrix_sdk_crypto/src/verification/machine.rs +++ b/matrix_sdk_crypto/src/verification/machine.rs @@ -88,7 +88,7 @@ impl VerificationMachine { ); self.verifications - .insert(sas.flow_id().to_owned(), sas.clone()); + .insert(sas.flow_id().to_string(), sas.clone()); Ok((sas, request)) } @@ -367,7 +367,7 @@ mod test { async fn full_flow() { let (alice_machine, bob) = setup_verification_machine().await; - let alice = alice_machine.get_sas(bob.flow_id()).unwrap(); + let alice = alice_machine.get_sas(bob.flow_id().as_str()).unwrap(); let mut event = alice .accept() @@ -428,7 +428,7 @@ mod test { #[tokio::test] async fn timing_out() { let (alice_machine, bob) = setup_verification_machine().await; - let alice = alice_machine.get_sas(bob.flow_id()).unwrap(); + let alice = alice_machine.get_sas(bob.flow_id().as_str()).unwrap(); assert!(!alice.timed_out()); assert!(alice_machine.outgoing_to_device_messages.is_empty()); diff --git a/matrix_sdk_crypto/src/verification/sas/helpers.rs b/matrix_sdk_crypto/src/verification/sas/helpers.rs index 37974d14..99d81280 100644 --- a/matrix_sdk_crypto/src/verification/sas/helpers.rs +++ b/matrix_sdk_crypto/src/verification/sas/helpers.rs @@ -38,6 +38,8 @@ use crate::{ ReadOnlyAccount, ToDeviceRequest, }; +use super::sas_state::FlowId; + #[derive(Clone, Debug)] pub struct SasIds { pub account: ReadOnlyAccount, @@ -298,12 +300,12 @@ fn extra_mac_info_send(ids: &SasIds, flow_id: &str) -> String { /// # Panics /// /// This will panic if the public key of the other side wasn't set. -pub fn get_mac_content(sas: &OlmSas, ids: &SasIds, flow_id: &str) -> MacToDeviceEventContent { +pub fn get_mac_content(sas: &OlmSas, ids: &SasIds, flow_id: &FlowId) -> MacToDeviceEventContent { let mut mac: BTreeMap = BTreeMap::new(); let key_id = DeviceKeyId::from_parts(DeviceKeyAlgorithm::Ed25519, ids.account.device_id()); let key = ids.account.identity_keys().ed25519(); - let info = extra_mac_info_send(ids, flow_id); + let info = extra_mac_info_send(ids, flow_id.as_str()); mac.insert( key_id.to_string(), @@ -319,10 +321,13 @@ pub fn get_mac_content(sas: &OlmSas, ids: &SasIds, flow_id: &str) -> MacToDevice .calculate_mac(&keys.join(","), &format!("{}KEY_IDS", &info)) .expect("Can't calculate SAS MAC"); - MacToDeviceEventContent { - transaction_id: flow_id.to_owned(), - keys, - mac, + match flow_id { + FlowId::ToDevice(s) => MacToDeviceEventContent { + transaction_id: s.to_string(), + keys, + mac, + }, + _ => todo!(), } } diff --git a/matrix_sdk_crypto/src/verification/sas/inner_sas.rs b/matrix_sdk_crypto/src/verification/sas/inner_sas.rs index 6a566990..034e4734 100644 --- a/matrix_sdk_crypto/src/verification/sas/inner_sas.rs +++ b/matrix_sdk_crypto/src/verification/sas/inner_sas.rs @@ -31,7 +31,8 @@ use crate::{ }; use super::sas_state::{ - Accepted, Canceled, Confirmed, Created, Done, KeyReceived, MacReceived, SasState, Started, + Accepted, Canceled, Confirmed, Created, Done, FlowId, KeyReceived, MacReceived, SasState, + Started, }; #[derive(Clone, Debug)] @@ -220,7 +221,7 @@ impl InnerSas { } } - pub fn verification_flow_id(&self) -> Arc { + pub fn verification_flow_id(&self) -> Arc { match self { InnerSas::Created(s) => s.verification_flow_id.clone(), InnerSas::Started(s) => s.verification_flow_id.clone(), diff --git a/matrix_sdk_crypto/src/verification/sas/mod.rs b/matrix_sdk_crypto/src/verification/sas/mod.rs index 025f0fa7..6d9f3ffe 100644 --- a/matrix_sdk_crypto/src/verification/sas/mod.rs +++ b/matrix_sdk_crypto/src/verification/sas/mod.rs @@ -31,7 +31,7 @@ use matrix_sdk_common::{ }, AnyToDeviceEvent, AnyToDeviceEventContent, ToDeviceEvent, }, - identifiers::{DeviceId, UserId}, + identifiers::{DeviceId, EventId, RoomId, UserId}, }; use crate::{ @@ -44,6 +44,7 @@ use crate::{ pub use helpers::content_to_request; use inner_sas::InnerSas; +pub use sas_state::FlowId; #[derive(Debug)] /// A result of a verification flow. @@ -65,7 +66,7 @@ pub struct Sas { private_identity: PrivateCrossSigningIdentity, other_device: ReadOnlyDevice, other_identity: Option, - flow_id: Arc, + flow_id: Arc, } impl Sas { @@ -95,7 +96,7 @@ impl Sas { } /// Get the unique ID that identifies this SAS verification flow. - pub fn flow_id(&self) -> &str { + pub fn flow_id(&self) -> &FlowId { &self.flow_id } diff --git a/matrix_sdk_crypto/src/verification/sas/sas_state.rs b/matrix_sdk_crypto/src/verification/sas/sas_state.rs index 8ddd0e9c..817bb247 100644 --- a/matrix_sdk_crypto/src/verification/sas/sas_state.rs +++ b/matrix_sdk_crypto/src/verification/sas/sas_state.rs @@ -36,7 +36,7 @@ use matrix_sdk_common::{ }, AnyToDeviceEventContent, ToDeviceEvent, }, - identifiers::{DeviceId, UserId}, + identifiers::{DeviceId, RoomId, UserId}, uuid::Uuid, }; use tracing::error; @@ -65,6 +65,28 @@ const MAX_AGE: Duration = Duration::from_secs(60 * 5); // The max time a SAS object will wait for a new event to arrive. const MAX_EVENT_TIMEOUT: Duration = Duration::from_secs(60); +#[derive(Clone, Debug)] +pub enum FlowId { + ToDevice(String), + InRoom(RoomId), +} + +impl FlowId { + pub fn to_string(&self) -> String { + match self { + FlowId::InRoom(r) => r.to_string(), + FlowId::ToDevice(t) => t.to_string(), + } + } + + pub fn as_str(&self) -> &str { + match self { + FlowId::InRoom(r) => r.as_str(), + FlowId::ToDevice(t) => t.as_str(), + } + } +} + /// Struct containing the protocols that were agreed to be used for the SAS /// flow. #[derive(Clone, Debug)] @@ -143,7 +165,7 @@ pub struct SasState { /// /// This will be the transaction id for to-device events and the relates_to /// field for in-room events. - pub verification_flow_id: Arc, + pub verification_flow_id: Arc, /// The SAS state we're in. state: Arc, @@ -268,7 +290,7 @@ impl SasState { } fn check_event(&self, sender: &UserId, flow_id: &str) -> Result<(), CancelCode> { - if *flow_id != *self.verification_flow_id { + if *flow_id != *self.verification_flow_id.as_str() { Err(CancelCode::UnknownTransaction) } else if sender != self.ids.other_device.user_id() { Err(CancelCode::UserMismatch) @@ -302,7 +324,7 @@ impl SasState { other_device, other_identity, }, - verification_flow_id: verification_flow_id.into(), + verification_flow_id: FlowId::ToDevice(verification_flow_id).into(), creation_time: Arc::new(Instant::now()), last_event_time: Arc::new(Instant::now()), @@ -413,7 +435,7 @@ impl SasState { creation_time: Arc::new(Instant::now()), last_event_time: Arc::new(Instant::now()), - verification_flow_id: event.content.transaction_id.as_str().into(), + verification_flow_id: FlowId::ToDevice(event.content.transaction_id.clone()).into(), state: Arc::new(Started { protocol_definitions: content.clone(), @@ -452,7 +474,7 @@ impl SasState { other_identity, }, - verification_flow_id: event.content.transaction_id.as_str().into(), + verification_flow_id: FlowId::ToDevice(event.content.transaction_id.clone()).into(), state: Arc::new(Canceled::new(CancelCode::UnknownMethod)), }) } @@ -575,9 +597,14 @@ impl SasState { /// /// The content needs to be automatically sent to the other side. pub fn as_content(&self) -> KeyToDeviceEventContent { - KeyToDeviceEventContent { - transaction_id: self.verification_flow_id.to_string(), - key: self.inner.lock().unwrap().public_key(), + match &*self.verification_flow_id { + FlowId::ToDevice(s) => KeyToDeviceEventContent { + transaction_id: s.to_string(), + key: self.inner.lock().unwrap().public_key(), + }, + FlowId::InRoom(r) => { + todo!("In-room verifications aren't implemented") + } } } } @@ -588,9 +615,12 @@ impl SasState { /// The content needs to be automatically sent to the other side if and only /// if we_started is false. pub fn as_content(&self) -> KeyToDeviceEventContent { - KeyToDeviceEventContent { - transaction_id: self.verification_flow_id.to_string(), - key: self.inner.lock().unwrap().public_key(), + match self.verification_flow_id.as_ref() { + FlowId::ToDevice(s) => KeyToDeviceEventContent { + transaction_id: s.to_string(), + key: self.inner.lock().unwrap().public_key(), + }, + _ => todo!(), } } @@ -603,7 +633,7 @@ impl SasState { &self.inner.lock().unwrap(), &self.ids, &self.state.their_pubkey, - &self.verification_flow_id, + self.verification_flow_id.as_str(), self.state.we_started, ) } @@ -617,7 +647,7 @@ impl SasState { &self.inner.lock().unwrap(), &self.ids, &self.state.their_pubkey, - &self.verification_flow_id, + self.verification_flow_id.as_str(), self.state.we_started, ) } @@ -639,7 +669,7 @@ impl SasState { let (devices, master_keys) = receive_mac_event( &self.inner.lock().unwrap(), &self.ids, - &self.verification_flow_id, + self.verification_flow_id.as_str(), event, ) .map_err(|c| self.clone().cancel(c))?; @@ -695,7 +725,7 @@ impl SasState { let (devices, master_keys) = receive_mac_event( &self.inner.lock().unwrap(), &self.ids, - &self.verification_flow_id, + &self.verification_flow_id.as_str(), event, ) .map_err(|c| self.clone().cancel(c))?; @@ -754,7 +784,7 @@ impl SasState { &self.inner.lock().unwrap(), &self.ids, &self.state.their_pubkey, - &self.verification_flow_id, + &self.verification_flow_id.as_str(), self.state.we_started, ) } @@ -768,7 +798,7 @@ impl SasState { &self.inner.lock().unwrap(), &self.ids, &self.state.their_pubkey, - &self.verification_flow_id, + &self.verification_flow_id.as_str(), self.state.we_started, ) } @@ -828,11 +858,16 @@ impl Canceled { impl SasState { pub fn as_content(&self) -> AnyToDeviceEventContent { - AnyToDeviceEventContent::KeyVerificationCancel(CancelToDeviceEventContent { - transaction_id: self.verification_flow_id.to_string(), - reason: self.state.reason.to_string(), - code: self.state.cancel_code.clone(), - }) + match self.verification_flow_id.as_ref() { + FlowId::ToDevice(s) => { + AnyToDeviceEventContent::KeyVerificationCancel(CancelToDeviceEventContent { + transaction_id: self.verification_flow_id.to_string(), + reason: self.state.reason.to_string(), + code: self.state.cancel_code.clone(), + }) + } + _ => todo!(), + } } }