diff --git a/matrix_sdk/src/sas.rs b/matrix_sdk/src/sas.rs index 17fa8eb7..1adf963e 100644 --- a/matrix_sdk/src/sas.rs +++ b/matrix_sdk/src/sas.rs @@ -121,6 +121,12 @@ impl Sas { self.inner.decimals() } + /// Does this verification flow support emoji for the short authentication + /// string. + pub fn supports_emoji(&self) -> bool { + self.inner.supports_emoji() + } + /// Is the verification process done. pub fn is_done(&self) -> bool { self.inner.is_done() diff --git a/matrix_sdk_crypto/src/verification/sas/inner_sas.rs b/matrix_sdk_crypto/src/verification/sas/inner_sas.rs index 89080b5b..b43e57cb 100644 --- a/matrix_sdk_crypto/src/verification/sas/inner_sas.rs +++ b/matrix_sdk_crypto/src/verification/sas/inner_sas.rs @@ -18,7 +18,10 @@ use std::time::Instant; use std::sync::Arc; use matrix_sdk_common::{ - events::{key::verification::cancel::CancelCode, AnyMessageEvent, AnyToDeviceEvent}, + events::{ + key::verification::{cancel::CancelCode, ShortAuthenticationString}, + AnyMessageEvent, AnyToDeviceEvent, + }, identifiers::{EventId, RoomId}, }; @@ -61,6 +64,37 @@ impl InnerSas { (InnerSas::Created(sas), content) } + pub fn supports_emoji(&self) -> bool { + match self { + InnerSas::Created(_) => false, + InnerSas::Started(s) => s + .state + .accepted_protocols + .short_auth_string + .contains(&ShortAuthenticationString::Emoji), + InnerSas::Accepted(s) => s + .state + .accepted_protocols + .short_auth_string + .contains(&ShortAuthenticationString::Emoji), + InnerSas::KeyRecieved(s) => s + .state + .accepted_protocols + .short_auth_string + .contains(&ShortAuthenticationString::Emoji), + InnerSas::Confirmed(_) => false, + InnerSas::MacReceived(s) => s + .state + .accepted_protocols + .short_auth_string + .contains(&ShortAuthenticationString::Emoji), + InnerSas::WaitingForDone(_) => false, + InnerSas::WaitingForDoneUnconfirmed(_) => false, + InnerSas::Done(_) => false, + InnerSas::Canceled(_) => false, + } + } + pub fn start_in_room( event_id: EventId, room_id: RoomId, diff --git a/matrix_sdk_crypto/src/verification/sas/mod.rs b/matrix_sdk_crypto/src/verification/sas/mod.rs index 6a7c9332..8b9da496 100644 --- a/matrix_sdk_crypto/src/verification/sas/mod.rs +++ b/matrix_sdk_crypto/src/verification/sas/mod.rs @@ -110,6 +110,12 @@ impl Sas { &self.flow_id } + /// Does this verification flow support displaying emoji for the short + /// authentication string. + pub fn supports_emoji(&self) -> bool { + self.inner.lock().unwrap().supports_emoji() + } + #[cfg(test)] #[allow(dead_code)] pub(crate) fn set_creation_time(&self, time: Instant) { diff --git a/matrix_sdk_crypto/src/verification/sas/sas_state.rs b/matrix_sdk_crypto/src/verification/sas/sas_state.rs index 96bf25f9..5603aef1 100644 --- a/matrix_sdk_crypto/src/verification/sas/sas_state.rs +++ b/matrix_sdk_crypto/src/verification/sas/sas_state.rs @@ -98,12 +98,12 @@ impl FlowId { /// Struct containing the protocols that were agreed to be used for the SAS /// flow. #[derive(Clone, Debug)] -struct AcceptedProtocols { - method: VerificationMethod, - key_agreement_protocol: KeyAgreementProtocol, - hash: HashAlgorithm, - message_auth_code: MessageAuthenticationCode, - short_auth_string: Vec, +pub struct AcceptedProtocols { + pub method: VerificationMethod, + pub key_agreement_protocol: KeyAgreementProtocol, + pub hash: HashAlgorithm, + pub message_auth_code: MessageAuthenticationCode, + pub short_auth_string: Vec, } impl TryFrom for AcceptedProtocols { @@ -133,6 +133,53 @@ impl TryFrom for AcceptedProtocols { } } +impl TryFrom<&MSasV1Content> for AcceptedProtocols { + type Error = CancelCode; + + fn try_from(method_content: &MSasV1Content) -> Result { + if !method_content + .key_agreement_protocols + .contains(&KeyAgreementProtocol::Curve25519HkdfSha256) + || !method_content + .message_authentication_codes + .contains(&MessageAuthenticationCode::HkdfHmacSha256) + || !method_content.hashes.contains(&HashAlgorithm::Sha256) + || (!method_content + .short_authentication_string + .contains(&ShortAuthenticationString::Decimal) + && !method_content + .short_authentication_string + .contains(&ShortAuthenticationString::Emoji)) + { + Err(CancelCode::UnknownMethod) + } else { + let mut short_auth_string = vec![]; + + if method_content + .short_authentication_string + .contains(&ShortAuthenticationString::Decimal) + { + short_auth_string.push(ShortAuthenticationString::Decimal) + } + + if method_content + .short_authentication_string + .contains(&ShortAuthenticationString::Emoji) + { + short_auth_string.push(ShortAuthenticationString::Emoji); + } + + Ok(Self { + method: VerificationMethod::MSasV1, + hash: HashAlgorithm::Sha256, + key_agreement_protocol: KeyAgreementProtocol::Curve25519HkdfSha256, + message_auth_code: MessageAuthenticationCode::HkdfHmacSha256, + short_auth_string, + }) + } + } +} + #[cfg(not(tarpaulin_include))] impl Default for AcceptedProtocols { fn default() -> Self { @@ -176,7 +223,7 @@ pub struct SasState { pub verification_flow_id: Arc, /// The SAS state we're in. - state: Arc, + pub state: Arc, } #[cfg(not(tarpaulin_include))] @@ -200,14 +247,14 @@ pub struct Created { #[derive(Clone, Debug)] pub struct Started { commitment: String, - protocol_definitions: MSasV1Content, + pub accepted_protocols: Arc, } /// The SAS state we're going to be in after the other side accepted our /// verification start event. #[derive(Clone, Debug)] pub struct Accepted { - accepted_protocols: Arc, + pub accepted_protocols: Arc, start_content: Arc, commitment: String, } @@ -220,7 +267,7 @@ pub struct Accepted { pub struct KeyReceived { their_pubkey: String, we_started: bool, - accepted_protocols: Arc, + pub accepted_protocols: Arc, } /// The SAS state we're going to be in after the user has confirmed that the @@ -228,7 +275,7 @@ pub struct KeyReceived { /// other side. #[derive(Clone, Debug)] pub struct Confirmed { - accepted_protocols: Arc, + pub accepted_protocols: Arc, } /// The SAS state we're going to be in after we receive a MAC event from the @@ -240,6 +287,7 @@ pub struct MacReceived { their_pubkey: String, verified_devices: Arc<[ReadOnlyDevice]>, verified_master_keys: Arc<[UserIdentities]>, + pub accepted_protocols: Arc, } /// The SAS state we're going to be in after we receive a MAC event in a DM. DMs @@ -490,6 +538,22 @@ impl SasState { other_identity: Option, content: &StartContent, ) -> Result, SasState> { + let canceled = || SasState { + inner: Arc::new(Mutex::new(OlmSas::new())), + + creation_time: Arc::new(Instant::now()), + last_event_time: Arc::new(Instant::now()), + + ids: SasIds { + account: account.clone(), + other_device: other_device.clone(), + other_identity: other_identity.clone(), + }, + + verification_flow_id: content.flow_id().into(), + state: Arc::new(Canceled::new(CancelCode::UnknownMethod)), + }; + if let StartMethod::MSasV1(method_content) = content.method() { let sas = OlmSas::new(); @@ -501,60 +565,31 @@ impl SasState { pubkey, content, commitment ); - let sas = SasState { - inner: Arc::new(Mutex::new(sas)), + if let Ok(accepted_protocols) = AcceptedProtocols::try_from(method_content) { + Ok(SasState { + inner: Arc::new(Mutex::new(sas)), - ids: SasIds { - account, - other_device, - other_identity, - }, + ids: SasIds { + account, + other_device, + other_identity, + }, - creation_time: Arc::new(Instant::now()), - last_event_time: Arc::new(Instant::now()), + creation_time: Arc::new(Instant::now()), + last_event_time: Arc::new(Instant::now()), - verification_flow_id: content.flow_id().into(), + verification_flow_id: content.flow_id().into(), - state: Arc::new(Started { - protocol_definitions: method_content.clone(), - commitment, - }), - }; - - if !method_content - .key_agreement_protocols - .contains(&KeyAgreementProtocol::Curve25519HkdfSha256) - || !method_content - .message_authentication_codes - .contains(&MessageAuthenticationCode::HkdfHmacSha256) - || !method_content.hashes.contains(&HashAlgorithm::Sha256) - || (!method_content - .short_authentication_string - .contains(&ShortAuthenticationString::Decimal) - && !method_content - .short_authentication_string - .contains(&ShortAuthenticationString::Emoji)) - { - Err(sas.cancel(CancelCode::UnknownMethod)) + state: Arc::new(Started { + accepted_protocols: accepted_protocols.into(), + commitment, + }), + }) } else { - Ok(sas) + Err(canceled()) } } else { - Err(SasState { - inner: Arc::new(Mutex::new(OlmSas::new())), - - creation_time: Arc::new(Instant::now()), - last_event_time: Arc::new(Instant::now()), - - ids: SasIds { - account, - other_device, - other_identity, - }, - - verification_flow_id: content.flow_id().into(), - state: Arc::new(Canceled::new(CancelCode::UnknownMethod)), - }) + Err(canceled()) } } @@ -566,18 +601,24 @@ impl SasState { /// been started because of a /// m.key.verification.request -> m.key.verification.ready flow. pub fn as_content(&self) -> AcceptContent { - let accepted_protocols = AcceptedProtocols::default(); - let method = AcceptMethod::MSasV1( AcceptV1ContentInit { commitment: self.state.commitment.clone(), - hash: accepted_protocols.hash, - key_agreement_protocol: accepted_protocols.key_agreement_protocol, - message_authentication_code: accepted_protocols.message_auth_code, + hash: self.state.accepted_protocols.hash.clone(), + key_agreement_protocol: self + .state + .accepted_protocols + .key_agreement_protocol + .clone(), + message_authentication_code: self + .state + .accepted_protocols + .message_auth_code + .clone(), short_authentication_string: self .state - .protocol_definitions - .short_authentication_string + .accepted_protocols + .short_auth_string .clone(), } .into(), @@ -620,8 +661,6 @@ impl SasState { self.check_event(&sender, &content.flow_id().as_str()) .map_err(|c| self.clone().cancel(c))?; - let accepted_protocols = AcceptedProtocols::default(); - let their_pubkey = content.public_key().to_owned(); self.inner @@ -639,7 +678,7 @@ impl SasState { state: Arc::new(KeyReceived { we_started: false, their_pubkey, - accepted_protocols: Arc::new(accepted_protocols), + accepted_protocols: self.state.accepted_protocols.clone(), }), }) } @@ -823,6 +862,7 @@ impl SasState { their_pubkey: self.state.their_pubkey.clone(), verified_devices: devices.into(), verified_master_keys: master_keys.into(), + accepted_protocols: self.state.accepted_protocols.clone(), }), }) }