diff --git a/matrix_sdk_crypto/src/verification/sas/sas_state.rs b/matrix_sdk_crypto/src/verification/sas/sas_state.rs index 28506898..b3d7e099 100644 --- a/matrix_sdk_crypto/src/verification/sas/sas_state.rs +++ b/matrix_sdk_crypto/src/verification/sas/sas_state.rs @@ -15,6 +15,7 @@ use std::{ mem, sync::{Arc, Mutex}, + time::{Duration, Instant}, }; use olm_rs::{sas::OlmSas, utility::OlmUtility}; @@ -52,6 +53,12 @@ const STRINGS: &[ShortAuthenticationString] = &[ ShortAuthenticationString::Emoji, ]; +// The max time a SAS flow can take from start to done. +const MAX_AGE: Duration = Duration::from_secs(60 * 5); + +// The max time a SAS object will wait for a new event to arrive. +const MAX_EVENT_TIMEOUT: Duration = Duration::from_secs(60 * 1); + /// Struct containing the protocols that were agreed to be used for the SAS /// flow. #[derive(Clone, Debug)] @@ -101,13 +108,24 @@ impl Default for AcceptedProtocols { pub struct SasState { /// The Olm SAS struct. inner: Arc>, + /// Struct holding the identities that are doing the SAS dance. ids: SasIds, + + /// The instant when the SAS object was created. If this more than + /// MAX_AGE seconds are elapsed, the event will be canceled with a + /// `CancelCode::Timeout` + creation_time: Arc, + + /// The instant the SAS object last received an event. + last_event_time: Arc, + /// The unique identifier of this SAS flow. /// /// This will be the transaction id for to-device events and the relates_to /// field for in-room events. pub verification_flow_id: Arc, + /// The SAS state we're in. state: Arc, } @@ -209,12 +227,19 @@ impl SasState { SasState { inner: self.inner, ids: self.ids, + creation_time: self.creation_time, + last_event_time: self.last_event_time, verification_flow_id: self.verification_flow_id, state: Arc::new(Canceled::new(cancel_code)), } } - fn check_sender_and_txid(&self, sender: &UserId, flow_id: &str) -> Result<(), CancelCode> { + /// Did our SAS verification time out. + fn timed_out(&self) -> bool { + self.creation_time.elapsed() > MAX_AGE || self.last_event_time.elapsed() > MAX_EVENT_TIMEOUT + } + + fn check_event(&self, sender: &UserId, flow_id: &str) -> Result<(), CancelCode> { if flow_id != *self.verification_flow_id { Err(CancelCode::UnknownTransaction) } else if sender != self.ids.other_device.user_id() { @@ -244,6 +269,9 @@ impl SasState { }, verification_flow_id: Arc::new(verification_flow_id), + creation_time: Arc::new(Instant::now()), + last_event_time: Arc::new(Instant::now()), + state: Arc::new(Created { protocol_definitions: MSasV1ContentInit { short_authentication_string: STRINGS.to_vec(), @@ -280,7 +308,7 @@ impl SasState { self, event: &ToDeviceEvent, ) -> Result, SasState> { - self.check_sender_and_txid(&event.sender, &event.content.transaction_id) + self.check_event(&event.sender, &event.content.transaction_id) .map_err(|c| self.clone().cancel(c))?; if let AcceptMethod::MSasV1(content) = &event.content.method { @@ -303,6 +331,8 @@ impl SasState { inner: self.inner, ids: self.ids, verification_flow_id: self.verification_flow_id, + creation_time: self.creation_time, + last_event_time: self.last_event_time, state: Arc::new(Accepted { json_start_content, commitment: content.commitment.clone(), @@ -351,6 +381,9 @@ impl SasState { other_device, }, + creation_time: Arc::new(Instant::now()), + last_event_time: Arc::new(Instant::now()), + verification_flow_id: Arc::new(event.content.transaction_id.clone()), state: Arc::new(Started { @@ -381,6 +414,9 @@ impl SasState { Err(SasState { inner: Arc::new(Mutex::new(OlmSas::new())), + creation_time: Arc::new(Instant::now()), + last_event_time: Arc::new(Instant::now()), + ids: SasIds { account, other_device, @@ -433,7 +469,7 @@ impl SasState { self, event: &mut ToDeviceEvent, ) -> Result, SasState> { - self.check_sender_and_txid(&event.sender, &event.content.transaction_id) + self.check_event(&event.sender, &event.content.transaction_id) .map_err(|c| self.clone().cancel(c))?; let accepted_protocols = AcceptedProtocols::default(); @@ -450,6 +486,8 @@ impl SasState { inner: self.inner, ids: self.ids, verification_flow_id: self.verification_flow_id, + creation_time: self.creation_time, + last_event_time: self.last_event_time, state: Arc::new(KeyReceived { we_started: false, their_pubkey, @@ -472,7 +510,7 @@ impl SasState { self, event: &mut ToDeviceEvent, ) -> Result, SasState> { - self.check_sender_and_txid(&event.sender, &event.content.transaction_id) + self.check_event(&event.sender, &event.content.transaction_id) .map_err(|c| self.clone().cancel(c))?; let utility = OlmUtility::new(); @@ -496,6 +534,8 @@ impl SasState { inner: self.inner, ids: self.ids, verification_flow_id: self.verification_flow_id, + creation_time: self.creation_time, + last_event_time: self.last_event_time, state: Arc::new(KeyReceived { their_pubkey, we_started: true, @@ -567,7 +607,7 @@ impl SasState { self, event: &ToDeviceEvent, ) -> Result, SasState> { - self.check_sender_and_txid(&event.sender, &event.content.transaction_id) + self.check_event(&event.sender, &event.content.transaction_id) .map_err(|c| self.clone().cancel(c))?; let (devices, master_keys) = receive_mac_event( @@ -581,6 +621,8 @@ impl SasState { Ok(SasState { inner: self.inner, verification_flow_id: self.verification_flow_id, + creation_time: self.creation_time, + last_event_time: self.last_event_time, ids: self.ids, state: Arc::new(MacReceived { we_started: self.state.we_started, @@ -599,6 +641,8 @@ impl SasState { SasState { inner: self.inner, verification_flow_id: self.verification_flow_id, + creation_time: self.creation_time, + last_event_time: self.last_event_time, ids: self.ids, state: Arc::new(Confirmed { accepted_protocols: self.state.accepted_protocols.clone(), @@ -619,7 +663,7 @@ impl SasState { self, event: &ToDeviceEvent, ) -> Result, SasState> { - self.check_sender_and_txid(&event.sender, &event.content.transaction_id) + self.check_event(&event.sender, &event.content.transaction_id) .map_err(|c| self.clone().cancel(c))?; let (devices, master_keys) = receive_mac_event( @@ -632,6 +676,8 @@ impl SasState { Ok(SasState { inner: self.inner, + creation_time: self.creation_time, + last_event_time: self.last_event_time, verification_flow_id: self.verification_flow_id, ids: self.ids, @@ -663,6 +709,8 @@ impl SasState { SasState { inner: self.inner, verification_flow_id: self.verification_flow_id, + creation_time: self.creation_time, + last_event_time: self.last_event_time, ids: self.ids, state: Arc::new(Done { verified_devices: self.state.verified_devices.clone(),