diff --git a/matrix_sdk_crypto/src/verification/sas/helpers.rs b/matrix_sdk_crypto/src/verification/sas/helpers.rs index 34032aad..9c6e36fc 100644 --- a/matrix_sdk_crypto/src/verification/sas/helpers.rs +++ b/matrix_sdk_crypto/src/verification/sas/helpers.rs @@ -1,5 +1,7 @@ use std::{collections::BTreeMap, convert::TryInto}; +use tracing::trace; + use olm_rs::sas::OlmSas; use matrix_sdk_common::{ @@ -264,32 +266,42 @@ pub fn get_mac_content(sas: &OlmSas, ids: &SasIds, flow_id: &str) -> MacEventCon /// * `flow_id` - The unique id that identifies this SAS verification process. /// /// * `we_started` - Flag signaling if the SAS process was started on our side. -fn extra_info_sas(ids: &SasIds, flow_id: &str, we_started: bool) -> String { - let (first_user, first_device, second_user, second_device) = if we_started { - ( - ids.account.user_id(), - ids.account.device_id(), - ids.other_device.user_id(), - ids.other_device.device_id(), - ) +fn extra_info_sas( + ids: &SasIds, + own_pubkey: &str, + their_pubkey: &str, + flow_id: &str, + we_started: bool, +) -> String { + let our_info = format!( + "{}|{}|{}", + ids.account.user_id(), + ids.account.device_id(), + own_pubkey + ); + let their_info = format!( + "{}|{}|{}", + ids.other_device.user_id(), + ids.other_device.device_id(), + their_pubkey + ); + + let (first_info, second_info) = if we_started { + (our_info, their_info) } else { - ( - ids.other_device.user_id(), - ids.other_device.device_id(), - ids.account.user_id(), - ids.account.device_id(), - ) + (their_info, our_info) }; - format!( - "MATRIX_KEY_VERIFICATION_SAS{first_user}{first_device}\ - {second_user}{second_device}{transaction_id}", - first_user = first_user, - first_device = first_device, - second_user = second_user, - second_device = second_device, - transaction_id = flow_id, - ) + let info = format!( + "MATRIX_KEY_VERIFICATION_SAS|{first_info}|{second_info}|{flow_id}", + first_info = first_info, + second_info = second_info, + flow_id = flow_id, + ); + + trace!("Generated a SAS extra info: {}", info); + + info } /// Get the emoji version of the short authentication string. @@ -314,11 +326,15 @@ fn extra_info_sas(ids: &SasIds, flow_id: &str, we_started: bool) -> String { pub fn get_emoji( sas: &OlmSas, ids: &SasIds, + their_pubkey: &str, flow_id: &str, we_started: bool, ) -> Vec<(&'static str, &'static str)> { let bytes = sas - .generate_bytes(&extra_info_sas(&ids, &flow_id, we_started), 6) + .generate_bytes( + &extra_info_sas(&ids, &sas.public_key(), their_pubkey, &flow_id, we_started), + 6, + ) .expect("Can't generate bytes"); bytes_to_emoji(bytes) @@ -374,9 +390,18 @@ fn bytes_to_emoji(bytes: Vec) -> Vec<(&'static str, &'static str)> { /// # Panics /// /// This will panic if the public key of the other side wasn't set. -pub fn get_decimal(sas: &OlmSas, ids: &SasIds, flow_id: &str, we_started: bool) -> (u16, u16, u16) { +pub fn get_decimal( + sas: &OlmSas, + ids: &SasIds, + their_pubkey: &str, + flow_id: &str, + we_started: bool, +) -> (u16, u16, u16) { let bytes = sas - .generate_bytes(&extra_info_sas(&ids, &flow_id, we_started), 5) + .generate_bytes( + &extra_info_sas(&ids, &sas.public_key(), their_pubkey, &flow_id, we_started), + 5, + ) .expect("Can't generate bytes"); bytes_to_decimal(bytes) diff --git a/matrix_sdk_crypto/src/verification/sas/sas_state.rs b/matrix_sdk_crypto/src/verification/sas/sas_state.rs index e1765091..d2a70be8 100644 --- a/matrix_sdk_crypto/src/verification/sas/sas_state.rs +++ b/matrix_sdk_crypto/src/verification/sas/sas_state.rs @@ -150,6 +150,7 @@ pub struct Accepted { /// From now on we can show the short auth string to the user. #[derive(Clone, Debug)] pub struct KeyReceived { + their_pubkey: String, we_started: bool, accepted_protocols: Arc, } @@ -168,6 +169,7 @@ pub struct Confirmed { #[derive(Clone, Debug)] pub struct MacReceived { we_started: bool, + their_pubkey: String, verified_devices: Arc>, verified_master_keys: Arc>, } @@ -436,10 +438,15 @@ impl SasState { let accepted_protocols = AcceptedProtocols::default(); + let their_pubkey = mem::take(&mut event.content.key); + + // The SAS object clears the public key, so we make a copy. + let pubkey_copy = their_pubkey.clone(); + self.inner .lock() .unwrap() - .set_their_public_key(&mem::take(&mut event.content.key)) + .set_their_public_key(&pubkey_copy) .expect("Can't set public key"); Ok(SasState { @@ -448,6 +455,7 @@ impl SasState { verification_flow_id: self.verification_flow_id, state: Arc::new(KeyReceived { we_started: false, + their_pubkey, accepted_protocols: Arc::new(accepted_protocols), }), }) @@ -479,10 +487,15 @@ impl SasState { if self.state.commitment != commitment { Err(self.cancel(CancelCode::InvalidMessage)) } else { + let their_pubkey = mem::take(&mut event.content.key); + + // The SAS object clears the public key, so we make a copy. + let pubkey_copy = their_pubkey.clone(); + self.inner .lock() .unwrap() - .set_their_public_key(&mem::take(&mut event.content.key)) + .set_their_public_key(&pubkey_copy) .expect("Can't set public key"); Ok(SasState { @@ -490,6 +503,7 @@ impl SasState { ids: self.ids, verification_flow_id: self.verification_flow_id, state: Arc::new(KeyReceived { + their_pubkey, we_started: true, accepted_protocols: self.state.accepted_protocols.clone(), }), @@ -528,6 +542,7 @@ impl SasState { get_emoji( &self.inner.lock().unwrap(), &self.ids, + &self.state.their_pubkey, &self.verification_flow_id, self.state.we_started, ) @@ -541,6 +556,7 @@ impl SasState { get_decimal( &self.inner.lock().unwrap(), &self.ids, + &self.state.their_pubkey, &self.verification_flow_id, self.state.we_started, ) @@ -574,6 +590,7 @@ impl SasState { ids: self.ids, state: Arc::new(MacReceived { we_started: self.state.we_started, + their_pubkey: self.state.their_pubkey.clone(), verified_devices: Arc::new(devices), verified_master_keys: Arc::new(master_keys), }), @@ -668,6 +685,7 @@ impl SasState { get_emoji( &self.inner.lock().unwrap(), &self.ids, + &self.state.their_pubkey, &self.verification_flow_id, self.state.we_started, ) @@ -681,6 +699,7 @@ impl SasState { get_decimal( &self.inner.lock().unwrap(), &self.ids, + &self.state.their_pubkey, &self.verification_flow_id, self.state.we_started, )