diff --git a/matrix_sdk_crypto/src/verification/machine.rs b/matrix_sdk_crypto/src/verification/machine.rs index 374e4410..3a89b070 100644 --- a/matrix_sdk_crypto/src/verification/machine.rs +++ b/matrix_sdk_crypto/src/verification/machine.rs @@ -419,7 +419,7 @@ impl VerificationMachine { private_identity, device, identity, - false, + None, false, ) { Ok(sas) => { diff --git a/matrix_sdk_crypto/src/verification/qrcode.rs b/matrix_sdk_crypto/src/verification/qrcode.rs index d460f54b..b472db4a 100644 --- a/matrix_sdk_crypto/src/verification/qrcode.rs +++ b/matrix_sdk_crypto/src/verification/qrcode.rs @@ -40,6 +40,7 @@ use tracing::trace; use super::{ event_enums::{CancelContent, DoneContent, OutgoingContent, OwnedStartContent, StartContent}, + requests::RequestHandle, Cancelled, Done, FlowId, IdentitiesBeingVerified, VerificationResult, }; use crate::{ @@ -84,6 +85,7 @@ pub struct QrVerification { inner: Arc, state: Arc>, identities: IdentitiesBeingVerified, + request_handle: Option, we_started: bool, } @@ -222,11 +224,15 @@ impl QrVerification { /// /// [`cancel()`]: #method.cancel pub fn cancel_with_code(&self, code: CancelCode) -> Option { + let mut state = self.state.lock().unwrap(); + + if let Some(request) = &self.request_handle { + request.cancel_with_code(&code); + } + let new_state = QrState::::new(true, code); let content = new_state.as_content(self.flow_id()); - let mut state = self.state.lock().unwrap(); - match &*state { InnerState::Confirmed(_) | InnerState::Created(_) @@ -438,6 +444,7 @@ impl QrVerification { other_device_key: String, identities: IdentitiesBeingVerified, we_started: bool, + request_handle: Option, ) -> Self { let secret = Self::generate_secret(); @@ -449,7 +456,7 @@ impl QrVerification { ) .into(); - Self::new_helper(store, flow_id, inner, identities, we_started) + Self::new_helper(store, flow_id, inner, identities, we_started, request_handle) } pub(crate) fn new_self_no_master( @@ -459,6 +466,7 @@ impl QrVerification { own_master_key: String, identities: IdentitiesBeingVerified, we_started: bool, + request_handle: Option, ) -> QrVerification { let secret = Self::generate_secret(); @@ -470,7 +478,7 @@ impl QrVerification { ) .into(); - Self::new_helper(store, flow_id, inner, identities, we_started) + Self::new_helper(store, flow_id, inner, identities, we_started, request_handle) } pub(crate) fn new_cross( @@ -480,6 +488,7 @@ impl QrVerification { other_master_key: String, identities: IdentitiesBeingVerified, we_started: bool, + request_handle: Option, ) -> Self { let secret = Self::generate_secret(); @@ -492,7 +501,7 @@ impl QrVerification { let inner: QrVerificationData = VerificationData::new(event_id, own_master_key, other_master_key, secret).into(); - Self::new_helper(store, flow_id, inner, identities, we_started) + Self::new_helper(store, flow_id, inner, identities, we_started, request_handle) } #[allow(clippy::too_many_arguments)] @@ -505,6 +514,7 @@ impl QrVerification { flow_id: FlowId, qr_code: QrVerificationData, we_started: bool, + request_handle: Option, ) -> Result { if flow_id.as_str() != qr_code.flow_id() { return Err(ScanError::FlowIdMismatch { @@ -602,6 +612,7 @@ impl QrVerification { .into(), identities, we_started, + request_handle, }) } @@ -611,6 +622,7 @@ impl QrVerification { inner: QrVerificationData, identities: IdentitiesBeingVerified, we_started: bool, + request_handle: Option, ) -> Self { let secret = inner.secret().to_owned(); @@ -621,6 +633,7 @@ impl QrVerification { state: Mutex::new(InnerState::Created(QrState { state: Created { secret } })).into(), identities, we_started, + request_handle, } } } @@ -848,6 +861,7 @@ mod test { master_key.clone(), identities.clone(), false, + None, ); assert_eq!(verification.inner.first_key(), &device_key); @@ -860,6 +874,7 @@ mod test { device_key.clone(), identities.clone(), false, + None, ); assert_eq!(verification.inner.first_key(), &master_key); @@ -878,6 +893,7 @@ mod test { bob_master_key.clone(), identities, false, + None, ); assert_eq!(verification.inner.first_key(), &master_key); @@ -922,6 +938,7 @@ mod test { master_key.clone(), identities, false, + None, ); let bob_store = memory_store(); @@ -943,6 +960,7 @@ mod test { flow_id, qr_code, false, + None, ) .await .unwrap(); diff --git a/matrix_sdk_crypto/src/verification/requests.rs b/matrix_sdk_crypto/src/verification/requests.rs index 1c661e82..435e7e03 100644 --- a/matrix_sdk_crypto/src/verification/requests.rs +++ b/matrix_sdk_crypto/src/verification/requests.rs @@ -71,6 +71,31 @@ pub struct VerificationRequest { we_started: bool, } +/// A handle to a request so child verification flows can cancel the request. +/// +/// A verification flow can branch off into different types of verification +/// flows after the initial request handshake is done. +/// +/// Cancelling a QR code verification should also cancel the request. This +/// `RequestHandle` allows the QR code verification object to cancel the parent +/// `VerificationRequest` object. +#[derive(Clone, Debug)] +pub(crate) struct RequestHandle { + inner: Arc>, +} + +impl RequestHandle { + pub fn cancel_with_code(&self, cancel_code: &CancelCode) { + self.inner.lock().unwrap().cancel(true, cancel_code) + } +} + +impl From>> for RequestHandle { + fn from(inner: Arc>) -> Self { + Self { inner } + } +} + impl VerificationRequest { pub(crate) fn new( cache: VerificationCache, @@ -254,7 +279,11 @@ impl VerificationRequest { /// Generate a QR code that can be used by another client to start a QR code /// based verification. pub async fn generate_qr_code(&self) -> Result, CryptoStoreError> { - self.inner.lock().unwrap().generate_qr_code(self.we_started).await + self.inner + .lock() + .unwrap() + .generate_qr_code(self.we_started, self.inner.clone().into()) + .await } /// Start a QR code verification by providing a scanned QR code for this @@ -280,6 +309,7 @@ impl VerificationRequest { r.flow_id.as_ref().to_owned(), data, self.we_started, + Some(self.inner.clone().into()), ) .await?; @@ -411,7 +441,7 @@ impl VerificationRequest { let inner = self.inner.lock().unwrap().clone(); if let InnerRequest::Ready(s) = inner { - s.receive_start(sender, content, self.we_started).await?; + s.receive_start(sender, content, self.we_started, self.inner.clone().into()).await?; } else { warn!( sender = sender.as_str(), @@ -457,6 +487,7 @@ impl VerificationRequest { s.account.clone(), s.private_cross_signing_identity.clone(), self.we_started, + self.inner.clone().into(), ) .await? { @@ -544,11 +575,12 @@ impl InnerRequest { async fn generate_qr_code( &self, we_started: bool, + request_handle: RequestHandle, ) -> Result, CryptoStoreError> { match self { InnerRequest::Created(_) => Ok(None), InnerRequest::Requested(_) => Ok(None), - InnerRequest::Ready(s) => s.generate_qr_code(we_started).await, + InnerRequest::Ready(s) => s.generate_qr_code(we_started, request_handle).await, InnerRequest::Passive(_) => Ok(None), InnerRequest::Done(_) => Ok(None), InnerRequest::Cancelled(_) => Ok(None), @@ -752,6 +784,7 @@ impl RequestState { other_device: ReadOnlyDevice, other_identity: Option, we_started: bool, + request_handle: RequestHandle, ) -> Result { Sas::from_start_event( (&*self.flow_id).to_owned(), @@ -761,7 +794,7 @@ impl RequestState { self.private_cross_signing_identity.clone(), other_device, other_identity, - true, + Some(request_handle), we_started, ) } @@ -769,6 +802,7 @@ impl RequestState { async fn generate_qr_code( &self, we_started: bool, + request_handle: RequestHandle, ) -> Result, CryptoStoreError> { // If we didn't state that we support showing QR codes or if the other // side doesn't support scanning QR codes bail early. @@ -814,6 +848,7 @@ impl RequestState { device_key.to_owned(), identites, we_started, + Some(request_handle), )) } else { warn!( @@ -832,6 +867,7 @@ impl RequestState { master_key.to_owned(), identites, we_started, + Some(request_handle), )) } } else { @@ -862,6 +898,7 @@ impl RequestState { other_master.to_owned(), identites, we_started, + Some(request_handle), )) } else { warn!( @@ -906,6 +943,7 @@ impl RequestState { sender: &UserId, content: &StartContent<'_>, we_started: bool, + request_handle: RequestHandle, ) -> Result<(), CryptoStoreError> { info!( sender = sender.as_str(), @@ -929,7 +967,13 @@ impl RequestState { match content.method() { StartMethod::SasV1(_) => { - match self.to_started_sas(content, device.clone(), identity, we_started) { + match self.to_started_sas( + content, + device.clone(), + identity, + we_started, + request_handle, + ) { // TODO check if there is already a SAS verification, i.e. we // already started one before the other side tried to do the // same; ignore it if we did and we're the lexicographically @@ -982,6 +1026,7 @@ impl RequestState { account: ReadOnlyAccount, private_identity: PrivateCrossSigningIdentity, we_started: bool, + request_handle: RequestHandle, ) -> Result, CryptoStoreError> { if !self.state.their_methods.contains(&VerificationMethod::SasV1) { return Ok(None); @@ -1027,6 +1072,7 @@ impl RequestState { store, other_identity, we_started, + request_handle, ); (sas, content) } diff --git a/matrix_sdk_crypto/src/verification/sas/mod.rs b/matrix_sdk_crypto/src/verification/sas/mod.rs index bdf1f107..c0cb4fe4 100644 --- a/matrix_sdk_crypto/src/verification/sas/mod.rs +++ b/matrix_sdk_crypto/src/verification/sas/mod.rs @@ -38,6 +38,7 @@ use tracing::trace; use super::{ event_enums::{AnyVerificationContent, OutgoingContent, OwnedAcceptContent, StartContent}, + requests::RequestHandle, FlowId, IdentitiesBeingVerified, VerificationResult, }; use crate::{ @@ -56,6 +57,7 @@ pub struct Sas { identities_being_verified: IdentitiesBeingVerified, flow_id: Arc, we_started: bool, + request_handle: Option, } impl Sas { @@ -145,6 +147,7 @@ impl Sas { self.inner.lock().unwrap().set_creation_time(time) } + #[allow(clippy::too_many_arguments)] fn start_helper( inner_sas: InnerSas, account: ReadOnlyAccount, @@ -153,6 +156,7 @@ impl Sas { store: Arc, other_identity: Option, we_started: bool, + request_handle: Option, ) -> Sas { let flow_id = inner_sas.verification_flow_id(); @@ -169,6 +173,7 @@ impl Sas { identities_being_verified: identities, flow_id, we_started, + request_handle, } } @@ -207,6 +212,7 @@ impl Sas { store, other_identity, we_started, + None, ), content, ) @@ -232,6 +238,7 @@ impl Sas { store: Arc, other_identity: Option, we_started: bool, + request_handle: RequestHandle, ) -> (Sas, OutgoingContent) { let (inner, content) = InnerSas::start_in_room( flow_id, @@ -250,6 +257,7 @@ impl Sas { store, other_identity, we_started, + Some(request_handle), ), content, ) @@ -274,7 +282,7 @@ impl Sas { private_identity: PrivateCrossSigningIdentity, other_device: ReadOnlyDevice, other_identity: Option, - started_from_request: bool, + request_handle: Option, we_started: bool, ) -> Result { let inner = InnerSas::from_start_event( @@ -283,7 +291,7 @@ impl Sas { flow_id, content, other_identity.clone(), - started_from_request, + request_handle.is_some(), )?; Ok(Self::start_helper( @@ -294,6 +302,7 @@ impl Sas { store, other_identity, we_started, + request_handle, )) } @@ -418,6 +427,11 @@ impl Sas { /// [`cancel()`]: #method.cancel pub fn cancel_with_code(&self, code: CancelCode) -> Option { let mut guard = self.inner.lock().unwrap(); + + if let Some(request) = &self.request_handle { + request.cancel_with_code(&code) + } + let sas: InnerSas = (*guard).clone(); let (sas, content) = sas.cancel(true, code); *guard = sas; @@ -626,7 +640,7 @@ mod test { PrivateCrossSigningIdentity::empty(bob_id()), alice_device, None, - false, + None, false, ) .unwrap();