diff --git a/matrix_sdk_crypto/src/verification/mod.rs b/matrix_sdk_crypto/src/verification/mod.rs index 8ad1561a..d5c3a206 100644 --- a/matrix_sdk_crypto/src/verification/mod.rs +++ b/matrix_sdk_crypto/src/verification/mod.rs @@ -26,6 +26,7 @@ use crate::{Account, Device}; #[allow(dead_code)] mod sas; +#[derive(Clone)] struct SasIds { account: Account, other_device: Device, diff --git a/matrix_sdk_crypto/src/verification/sas.rs b/matrix_sdk_crypto/src/verification/sas.rs index 26d18419..5c42453e 100644 --- a/matrix_sdk_crypto/src/verification/sas.rs +++ b/matrix_sdk_crypto/src/verification/sas.rs @@ -26,7 +26,7 @@ use matrix_sdk_common::events::{ HashAlgorithm, KeyAgreementProtocol, MessageAuthenticationCode, ShortAuthenticationString, VerificationMethod, }, - ToDeviceEvent, + AnyToDeviceEvent, AnyToDeviceEventContent, ToDeviceEvent, }; use matrix_sdk_common::identifiers::{DeviceId, UserId}; use matrix_sdk_common::uuid::Uuid; @@ -34,6 +34,230 @@ use matrix_sdk_common::uuid::Uuid; use super::{get_decimal, get_emoji, get_mac_content, receive_mac_event, SasIds}; use crate::{Account, Device}; +#[derive(Clone)] +struct Sas { + inner: Arc>, + account: Account, + other_device: Device, +} + +impl Sas { + fn user_id(&self) -> &UserId { + self.account.user_id() + } + + fn device_id(&self) -> &DeviceId { + self.account.device_id() + } + + fn start(account: Account, other_device: Device) -> (Sas, StartEventContent) { + let (inner, content) = InnerSas::start(account.clone(), other_device.clone()); + + let sas = Sas { + inner: Arc::new(Mutex::new(inner)), + account, + other_device, + }; + + (sas, content) + } + + fn from_start_event( + account: Account, + other_device: Device, + event: &ToDeviceEvent, + ) -> Sas { + let inner = InnerSas::from_start_event(account.clone(), other_device.clone(), event); + Sas { + inner: Arc::new(Mutex::new(inner)), + account, + other_device, + } + } + + fn accept(&self) -> Option { + self.inner.lock().unwrap().accept() + } + + fn confirm(&self) -> Option { + let mut guard = self.inner.lock().unwrap(); + let sas: InnerSas = (*guard).clone(); + let (sas, content) = sas.confirm(); + *guard = sas; + content + } + + fn can_be_presented(&self) -> bool { + self.inner.lock().unwrap().can_be_presented() + } + + fn is_done(&self) -> bool { + self.inner.lock().unwrap().is_done() + } + + fn emoji(&self) -> Option> { + self.inner.lock().unwrap().emoji() + } + + fn decimals(&self) -> Option<(u32, u32, u32)> { + self.inner.lock().unwrap().decimals() + } + + fn receive_event(&self, event: &mut AnyToDeviceEvent) -> Option { + let mut guard = self.inner.lock().unwrap(); + let sas: InnerSas = (*guard).clone(); + let (sas, content) = sas.receive_event(event); + *guard = sas; + + content + } + + fn verified_devices(&self) -> Option>>> { + self.inner.lock().unwrap().verified_devices() + } +} + +#[derive(Clone)] +enum InnerSas { + Created(SasState), + Started(SasState), + Accepted(SasState), + KeyRecieved(SasState), + Confirmed(SasState), + MacReceived(SasState), + Done(SasState), +} + +impl InnerSas { + fn start(account: Account, other_device: Device) -> (InnerSas, StartEventContent) { + let sas = SasState::::new(account, other_device); + let content = sas.get_start_event(); + (InnerSas::Created(sas), content) + } + + fn from_start_event( + account: Account, + other_device: Device, + event: &ToDeviceEvent, + ) -> InnerSas { + InnerSas::Started(SasState::::from_start_event( + account, + other_device, + event, + )) + } + + fn accept(&self) -> Option { + if let InnerSas::Started(s) = self { + Some(s.get_accept_content()) + } else { + None + } + } + + fn confirm(self) -> (InnerSas, Option) { + match self { + InnerSas::KeyRecieved(s) => { + let sas = s.confirm(); + let content = sas.get_mac_event_content(); + (InnerSas::Confirmed(sas), Some(content)) + } + InnerSas::MacReceived(s) => { + let sas = s.confirm(); + let content = sas.get_mac_event_content(); + (InnerSas::Done(sas), Some(content)) + } + _ => (self, None), + } + } + + fn receive_event( + self, + event: &mut AnyToDeviceEvent, + ) -> (InnerSas, Option) { + match event { + AnyToDeviceEvent::KeyVerificationAccept(e) => { + if let InnerSas::Created(s) = self { + let sas = s.into_accepted(e); + let content = sas.get_key_content(); + ( + InnerSas::Accepted(sas), + Some(AnyToDeviceEventContent::KeyVerificationKey(content)), + ) + } else { + (self, None) + } + } + AnyToDeviceEvent::KeyVerificationKey(e) => match self { + InnerSas::Accepted(s) => (InnerSas::KeyRecieved(s.into_key_received(e)), None), + InnerSas::Started(s) => { + let sas = s.into_key_received(e); + let content = sas.get_key_content(); + ( + InnerSas::KeyRecieved(sas), + Some(AnyToDeviceEventContent::KeyVerificationKey(content)), + ) + } + _ => (self, None), + }, + AnyToDeviceEvent::KeyVerificationMac(e) => match self { + InnerSas::KeyRecieved(s) => (InnerSas::MacReceived(s.into_mac_received(e)), None), + InnerSas::Confirmed(s) => (InnerSas::Done(s.into_done(e)), None), + _ => (self, None), + }, + _ => (self, None), + } + } + + fn can_be_presented(&self) -> bool { + match self { + InnerSas::KeyRecieved(_) => true, + InnerSas::MacReceived(_) => true, + _ => false, + } + } + + fn is_done(&self) -> bool { + if let InnerSas::Done(_) = self { + true + } else { + false + } + } + + fn emoji(&self) -> Option> { + match self { + InnerSas::KeyRecieved(s) => Some(s.get_emoji()), + InnerSas::MacReceived(s) => Some(s.get_emoji()), + _ => None, + } + } + + fn decimals(&self) -> Option<(u32, u32, u32)> { + match self { + InnerSas::KeyRecieved(s) => Some(s.get_decimal()), + InnerSas::MacReceived(s) => Some(s.get_decimal()), + _ => None, + } + } + + fn verified_devices(&self) -> Option>>> { + if let InnerSas::Done(s) = self { + Some(s.verified_devices()) + } else { + None + } + } + + fn verified_master_keys(&self) -> Option>> { + if let InnerSas::Done(s) = self { + Some(s.verified_master_keys()) + } else { + None + } + } +} + /// Struct containing the protocols that were agreed to be used for the SAS /// flow. #[derive(Clone, Debug)] @@ -65,7 +289,8 @@ impl From for AcceptedProtocols { /// /// This is the generic struc holding common data between the different states /// and the specific state. -struct SasState { +#[derive(Clone)] +struct SasState { /// The Olm SAS struct. inner: Arc>, /// Struct holding the identities that are doing the SAS dance. @@ -80,17 +305,20 @@ struct SasState { } /// The initial SAS state. +#[derive(Clone)] struct Created { protocol_definitions: MSasV1ContentOptions, } /// The initial SAS state if the other side started the SAS verification. +#[derive(Clone)] struct Started { protocol_definitions: MSasV1Content, } /// The SAS state we're going to be in after the other side accepted our /// verification start event. +#[derive(Clone)] struct Accepted { accepted_protocols: Arc, commitment: String, @@ -100,6 +328,7 @@ struct Accepted { /// other participant. /// /// From now on we can show the short auth string to the user. +#[derive(Clone)] struct KeyReceived { we_started: bool, accepted_protocols: Arc, @@ -108,6 +337,7 @@ struct KeyReceived { /// The SAS state we're going to be in after the user has confirmed that the /// short auth string matches. We still need to receive a MAC event from the /// other side. +#[derive(Clone)] struct Confirmed { accepted_protocols: Arc, } @@ -115,6 +345,7 @@ struct Confirmed { /// The SAS state we're going to be in after we receive a MAC event from the /// other side. Our own user still needs to confirm that the short auth string /// matches. +#[derive(Clone)] struct MacReceived { we_started: bool, verified_devices: Arc>>, @@ -125,12 +356,13 @@ struct MacReceived { /// /// We can now mark the device in our verified devices lits as verified and sign /// the master keys in the verified devices list. +#[derive(Clone)] struct Done { verified_devices: Arc>>, verified_master_keys: Arc>, } -impl SasState { +impl SasState { /// Get our own user id. pub fn user_id(&self) -> &UserId { &self.ids.account.user_id() @@ -530,13 +762,13 @@ impl SasState { } /// Get the list of verified devices. - fn verified_devices(&self) -> &[Box] { - &self.state.verified_devices + fn verified_devices(&self) -> Arc>> { + self.state.verified_devices.clone() } /// Get the list of verified master keys. - fn verified_master_keys(&self) -> &[String] { - &self.state.verified_master_keys + fn verified_master_keys(&self) -> Arc> { + self.state.verified_master_keys.clone() } } @@ -545,10 +777,12 @@ mod test { use std::convert::TryFrom; use crate::{Account, Device}; - use matrix_sdk_common::events::{EventContent, ToDeviceEvent}; + use matrix_sdk_common::events::{ + AnyToDeviceEvent, AnyToDeviceEventContent, EventContent, ToDeviceEvent, + }; use matrix_sdk_common::identifiers::{DeviceId, UserId}; - use super::{Accepted, Created, SasState, Started}; + use super::{Accepted, Created, Sas, SasState, Started}; fn alice_id() -> UserId { UserId::try_from("@alice:example.org").unwrap() @@ -573,6 +807,21 @@ mod test { } } + fn wrap_any_to_device_content( + sender: &UserId, + content: AnyToDeviceEventContent, + ) -> AnyToDeviceEvent { + match content { + AnyToDeviceEventContent::KeyVerificationKey(c) => { + AnyToDeviceEvent::KeyVerificationKey(ToDeviceEvent { + sender: sender.clone(), + content: c, + }) + } + _ => unreachable!(), + } + } + async fn get_sas_pair() -> (SasState, SasState) { let alice = Account::new(&alice_id(), &alice_device_id()); let alice_device = Device::from_account(&alice).await; @@ -655,4 +904,51 @@ mod test { assert!(bob.verified_devices().contains(&alice.device_id().into())); assert!(alice.verified_devices().contains(&bob.device_id().into())); } + + #[tokio::test] + async fn sas_wrapper_full() { + let alice = Account::new(&alice_id(), &alice_device_id()); + let alice_device = Device::from_account(&alice).await; + + let bob = Account::new(&bob_id(), &bob_device_id()); + let bob_device = Device::from_account(&bob).await; + + let (alice, content) = Sas::start(alice, bob_device); + let mut event = wrap_to_device_event(alice.user_id(), content); + + let bob = Sas::from_start_event(bob, alice_device, &mut event); + let event = wrap_to_device_event(bob.user_id(), bob.accept().unwrap()); + + let content = alice.receive_event(&mut AnyToDeviceEvent::KeyVerificationAccept(event)); + + assert!(!alice.can_be_presented()); + assert!(!bob.can_be_presented()); + + let mut event = wrap_any_to_device_content(alice.user_id(), content.unwrap()); + let mut event = + wrap_any_to_device_content(bob.user_id(), bob.receive_event(&mut event).unwrap()); + + assert!(bob.can_be_presented()); + + alice.receive_event(&mut event); + assert!(alice.can_be_presented()); + + assert_eq!(alice.emoji().unwrap(), bob.emoji().unwrap()); + assert_eq!(alice.decimals().unwrap(), bob.decimals().unwrap()); + + let event = wrap_to_device_event(alice.user_id(), alice.confirm().unwrap()); + bob.receive_event(&mut AnyToDeviceEvent::KeyVerificationMac(event)); + + let event = wrap_to_device_event(bob.user_id(), bob.confirm().unwrap()); + alice.receive_event(&mut AnyToDeviceEvent::KeyVerificationMac(event)); + + assert!(alice + .verified_devices() + .unwrap() + .contains(&bob.device_id().into())); + assert!(bob + .verified_devices() + .unwrap() + .contains(&alice.device_id().into())); + } }