diff --git a/matrix_sdk_crypto/src/machine.rs b/matrix_sdk_crypto/src/machine.rs index 585962a2..fcb00c86 100644 --- a/matrix_sdk_crypto/src/machine.rs +++ b/matrix_sdk_crypto/src/machine.rs @@ -1826,7 +1826,7 @@ pub(crate) mod test { let (alice_sas, request) = bob_device.start_verification().await.unwrap(); - let mut event = request_to_event(alice.user_id(), &request); + let mut event = request_to_event(alice.user_id(), &request.into()); bob.handle_verification_event(&mut event).await; let bob_sas = bob.get_verification(alice_sas.flow_id().as_str()).unwrap(); diff --git a/matrix_sdk_crypto/src/verification/machine.rs b/matrix_sdk_crypto/src/verification/machine.rs index eaf55850..64b52ab7 100644 --- a/matrix_sdk_crypto/src/verification/machine.rs +++ b/matrix_sdk_crypto/src/verification/machine.rs @@ -85,7 +85,7 @@ impl VerificationMachine { identity, ); - let request = match content { + let request = match content.into() { OutgoingContent::Room(r, c) => RoomMessageRequest { room_id: r, txn_id: Uuid::new_v4(), @@ -459,7 +459,6 @@ mod test { }; use matrix_sdk_common::{ - events::AnyToDeviceEventContent, identifiers::{DeviceId, UserId}, locks::Mutex, }; @@ -511,10 +510,11 @@ mod test { bob_store, None, ); + machine .receive_event(&mut wrap_any_to_device_content( bob_sas.user_id(), - AnyToDeviceEventContent::KeyVerificationStart(start_content), + start_content.into(), )) .await .unwrap(); @@ -559,12 +559,13 @@ mod test { let txn_id = *request.request_id(); let r = if let OutgoingRequests::ToDeviceRequest(r) = request.request() { - r + r.clone() } else { panic!("Invalid request type"); }; - let mut event = wrap_any_to_device_content(alice.user_id(), get_content_from_request(r)); + let mut event = + wrap_any_to_device_content(alice.user_id(), get_content_from_request(&r.into())); drop(request); alice_machine.mark_request_as_sent(&txn_id); diff --git a/matrix_sdk_crypto/src/verification/mod.rs b/matrix_sdk_crypto/src/verification/mod.rs index 0e170a35..eeed8de5 100644 --- a/matrix_sdk_crypto/src/verification/mod.rs +++ b/matrix_sdk_crypto/src/verification/mod.rs @@ -22,7 +22,10 @@ pub use sas::{Sas, VerificationResult}; #[cfg(test)] pub(crate) mod test { - use crate::requests::{OutgoingRequest, OutgoingRequests, ToDeviceRequest}; + use crate::{ + requests::{OutgoingRequest, OutgoingRequests}, + OutgoingVerificationRequest, + }; use serde_json::Value; use matrix_sdk_common::{ @@ -30,7 +33,12 @@ pub(crate) mod test { identifiers::UserId, }; - pub(crate) fn request_to_event(sender: &UserId, request: &ToDeviceRequest) -> AnyToDeviceEvent { + use super::sas::OutgoingContent; + + pub(crate) fn request_to_event( + sender: &UserId, + request: &OutgoingVerificationRequest, + ) -> AnyToDeviceEvent { let content = get_content_from_request(request); wrap_any_to_device_content(sender, content) } @@ -40,15 +48,21 @@ pub(crate) mod test { request: &OutgoingRequest, ) -> AnyToDeviceEvent { match request.request() { - OutgoingRequests::ToDeviceRequest(r) => request_to_event(sender, r), + OutgoingRequests::ToDeviceRequest(r) => request_to_event(sender, &r.clone().into()), _ => panic!("Unsupported outgoing request"), } } pub(crate) fn wrap_any_to_device_content( sender: &UserId, - content: AnyToDeviceEventContent, + content: OutgoingContent, ) -> AnyToDeviceEvent { + let content = if let OutgoingContent::ToDevice(c) = content { + c + } else { + unreachable!() + }; + match content { AnyToDeviceEventContent::KeyVerificationKey(c) => { AnyToDeviceEvent::KeyVerificationKey(ToDeviceEvent { @@ -79,7 +93,15 @@ pub(crate) mod test { } } - pub(crate) fn get_content_from_request(request: &ToDeviceRequest) -> AnyToDeviceEventContent { + pub(crate) fn get_content_from_request( + request: &OutgoingVerificationRequest, + ) -> OutgoingContent { + let request = if let OutgoingVerificationRequest::ToDevice(r) = request { + r + } else { + unreachable!() + }; + let json: Value = serde_json::from_str( request .messages @@ -111,5 +133,6 @@ pub(crate) mod test { ), _ => unreachable!(), } + .into() } } diff --git a/matrix_sdk_crypto/src/verification/sas/event_enums.rs b/matrix_sdk_crypto/src/verification/sas/event_enums.rs index 7b6b2629..c21b0e8b 100644 --- a/matrix_sdk_crypto/src/verification/sas/event_enums.rs +++ b/matrix_sdk_crypto/src/verification/sas/event_enums.rs @@ -19,7 +19,7 @@ use std::{collections::BTreeMap, convert::TryInto}; use matrix_sdk_common::{ events::{ key::verification::{ - accept::{AcceptEventContent, AcceptToDeviceEventContent}, + accept::{AcceptEventContent, AcceptMethod, AcceptToDeviceEventContent}, cancel::{CancelEventContent, CancelToDeviceEventContent}, done::DoneEventContent, key::{KeyEventContent, KeyToDeviceEventContent}, @@ -86,6 +86,22 @@ pub enum AcceptContent { Room(RoomId, AcceptEventContent), } +impl AcceptContent { + pub fn flow_id(&self) -> FlowId { + match self { + AcceptContent::ToDevice(c) => FlowId::ToDevice(c.transaction_id.clone()), + AcceptContent::Room(r, c) => FlowId::InRoom(r.clone(), c.relation.event_id.clone()), + } + } + + pub fn method(&self) -> &AcceptMethod { + match self { + AcceptContent::ToDevice(c) => &c.method, + AcceptContent::Room(_, c) => &c.method, + } + } +} + impl From for AcceptContent { fn from(content: AcceptToDeviceEventContent) -> Self { AcceptContent::ToDevice(content) diff --git a/matrix_sdk_crypto/src/verification/sas/helpers.rs b/matrix_sdk_crypto/src/verification/sas/helpers.rs index db6eec84..587982f6 100644 --- a/matrix_sdk_crypto/src/verification/sas/helpers.rs +++ b/matrix_sdk_crypto/src/verification/sas/helpers.rs @@ -570,7 +570,7 @@ mod test { }); let content: StartToDeviceEventContent = serde_json::from_value(content).unwrap(); - let calculated_commitment = calculate_commitment(public_key, &content); + let calculated_commitment = calculate_commitment(public_key, content); assert_eq!(commitment, &calculated_commitment); } diff --git a/matrix_sdk_crypto/src/verification/sas/inner_sas.rs b/matrix_sdk_crypto/src/verification/sas/inner_sas.rs index 6a8872b1..3f8e08ad 100644 --- a/matrix_sdk_crypto/src/verification/sas/inner_sas.rs +++ b/matrix_sdk_crypto/src/verification/sas/inner_sas.rs @@ -55,7 +55,7 @@ impl InnerSas { account: ReadOnlyAccount, other_device: ReadOnlyDevice, other_identity: Option, - ) -> (InnerSas, OutgoingContent) { + ) -> (InnerSas, StartContent) { let sas = SasState::::new(account, other_device, other_identity); let content = sas.as_content(); (InnerSas::Created(sas), content) @@ -67,7 +67,7 @@ impl InnerSas { account: ReadOnlyAccount, other_device: ReadOnlyDevice, other_identity: Option, - ) -> (InnerSas, OutgoingContent) { + ) -> (InnerSas, StartContent) { let sas = SasState::::new_in_room( room_id, event_id, @@ -245,7 +245,7 @@ impl InnerSas { match event { AnyToDeviceEvent::KeyVerificationAccept(e) => { if let InnerSas::Created(s) = self { - match s.into_accepted(e) { + match s.into_accepted(&e.sender, e.content.clone()) { Ok(s) => { let content = s.as_content(); (InnerSas::Accepted(s), Some(content.into())) diff --git a/matrix_sdk_crypto/src/verification/sas/mod.rs b/matrix_sdk_crypto/src/verification/sas/mod.rs index 9e5a7de2..de23dc64 100644 --- a/matrix_sdk_crypto/src/verification/sas/mod.rs +++ b/matrix_sdk_crypto/src/verification/sas/mod.rs @@ -113,16 +113,15 @@ impl Sas { fn start_helper( inner_sas: InnerSas, - content: OutgoingContent, account: ReadOnlyAccount, private_identity: PrivateCrossSigningIdentity, other_device: ReadOnlyDevice, store: Arc>, other_identity: Option, - ) -> (Sas, OutgoingContent) { + ) -> Sas { let flow_id = inner_sas.verification_flow_id(); - let sas = Sas { + Sas { inner: Arc::new(Mutex::new(inner_sas)), account, private_identity, @@ -130,9 +129,7 @@ impl Sas { other_device, flow_id, other_identity, - }; - - (sas, content) + } } /// Start a new SAS auth flow with the given device. @@ -151,21 +148,23 @@ impl Sas { other_device: ReadOnlyDevice, store: Arc>, other_identity: Option, - ) -> (Sas, OutgoingContent) { + ) -> (Sas, StartContent) { let (inner, content) = InnerSas::start( account.clone(), other_device.clone(), other_identity.clone(), ); - Self::start_helper( - inner, + ( + Self::start_helper( + inner, + account, + private_identity, + other_device, + store, + other_identity, + ), content, - account, - private_identity, - other_device, - store, - other_identity, ) } @@ -188,7 +187,7 @@ impl Sas { other_device: ReadOnlyDevice, store: Arc>, other_identity: Option, - ) -> (Sas, OutgoingContent) { + ) -> (Sas, StartContent) { let (inner, content) = InnerSas::start_in_room( flow_id, room_id, @@ -197,14 +196,16 @@ impl Sas { other_identity.clone(), ); - Self::start_helper( - inner, + ( + Self::start_helper( + inner, + account, + private_identity, + other_device, + store, + other_identity, + ), content, - account, - private_identity, - other_device, - store, - other_identity, ) } @@ -656,10 +657,7 @@ impl Sas { mod test { use std::{convert::TryFrom, sync::Arc}; - use matrix_sdk_common::{ - events::{EventContent, ToDeviceEvent}, - identifiers::{DeviceId, UserId}, - }; + use matrix_sdk_common::identifiers::{DeviceId, UserId}; use crate::{ olm::PrivateCrossSigningIdentity, @@ -668,10 +666,7 @@ mod test { ReadOnlyAccount, ReadOnlyDevice, }; - use super::{ - sas_state::{Accepted, Created, SasState, Started}, - Sas, - }; + use super::Sas; fn alice_id() -> UserId { UserId::try_from("@alice:example.org").unwrap() @@ -689,97 +684,6 @@ mod test { "BOBDEVCIE".into() } - fn wrap_to_device_event(sender: &UserId, content: C) -> ToDeviceEvent { - ToDeviceEvent { - sender: sender.clone(), - content, - } - } - - async fn get_sas_pair() -> (SasState, SasState) { - let alice = ReadOnlyAccount::new(&alice_id(), &alice_device_id()); - let alice_device = ReadOnlyDevice::from_account(&alice).await; - - let bob = ReadOnlyAccount::new(&bob_id(), &bob_device_id()); - let bob_device = ReadOnlyDevice::from_account(&bob).await; - - let alice_sas = SasState::::new(alice.clone(), bob_device, None); - - let start_content = alice_sas.as_content(); - let event = wrap_to_device_event(alice_sas.user_id(), start_content); - - let bob_sas = - SasState::::from_start_event(bob.clone(), alice_device, &event, None); - - (alice_sas, bob_sas.unwrap()) - } - - #[tokio::test] - async fn create_sas() { - let (_, _) = get_sas_pair().await; - } - - #[tokio::test] - async fn sas_accept() { - let (alice, bob) = get_sas_pair().await; - - let event = wrap_to_device_event(bob.user_id(), bob.as_content()); - - alice.into_accepted(&event).unwrap(); - } - - #[tokio::test] - async fn sas_key_share() { - let (alice, bob) = get_sas_pair().await; - - let event = wrap_to_device_event(bob.user_id(), bob.as_content()); - - let alice: SasState = alice.into_accepted(&event).unwrap(); - let mut event = wrap_to_device_event(alice.user_id(), alice.as_content()); - - let bob = bob.into_key_received(&mut event).unwrap(); - - let mut event = wrap_to_device_event(bob.user_id(), bob.as_content()); - - let alice = alice.into_key_received(&mut event).unwrap(); - - assert_eq!(alice.get_decimal(), bob.get_decimal()); - assert_eq!(alice.get_emoji(), bob.get_emoji()); - } - - #[tokio::test] - async fn sas_full() { - let (alice, bob) = get_sas_pair().await; - - let event = wrap_to_device_event(bob.user_id(), bob.as_content()); - - let alice: SasState = alice.into_accepted(&event).unwrap(); - let mut event = wrap_to_device_event(alice.user_id(), alice.as_content()); - - let bob = bob.into_key_received(&mut event).unwrap(); - - let mut event = wrap_to_device_event(bob.user_id(), bob.as_content()); - - let alice = alice.into_key_received(&mut event).unwrap(); - - assert_eq!(alice.get_decimal(), bob.get_decimal()); - assert_eq!(alice.get_emoji(), bob.get_emoji()); - - let bob = bob.confirm(); - - let event = wrap_to_device_event(bob.user_id(), bob.as_content()); - - let alice = alice.into_mac_received(&event).unwrap(); - assert!(!alice.get_emoji().is_empty()); - let alice = alice.confirm(); - - let event = wrap_to_device_event(alice.user_id(), alice.as_content()); - let bob = bob.into_done(&event).unwrap(); - - assert!(bob.verified_devices().contains(&bob.other_device())); - assert!(alice.verified_devices().contains(&alice.other_device())); - } - #[tokio::test] async fn sas_wrapper_full() { let alice = ReadOnlyAccount::new(&alice_id(), &alice_device_id()); @@ -802,14 +706,13 @@ mod test { alice_store, None, ); - let event = wrap_to_device_event(alice.user_id(), content); let bob = Sas::from_start_event( bob, PrivateCrossSigningIdentity::empty(bob_id()), alice_device, bob_store, - &event, + content, None, ) .unwrap(); diff --git a/matrix_sdk_crypto/src/verification/sas/sas_state.rs b/matrix_sdk_crypto/src/verification/sas/sas_state.rs index d4cb9c8a..d3eec749 100644 --- a/matrix_sdk_crypto/src/verification/sas/sas_state.rs +++ b/matrix_sdk_crypto/src/verification/sas/sas_state.rs @@ -22,23 +22,20 @@ use std::{ use olm_rs::sas::OlmSas; use matrix_sdk_common::{ - events::{ - key::verification::{ - accept::{ - AcceptEventContent, AcceptMethod, AcceptToDeviceEventContent, - MSasV1Content as AcceptV1Content, MSasV1ContentInit as AcceptV1ContentInit, - }, - cancel::{CancelCode, CancelEventContent, CancelToDeviceEventContent}, - done::DoneEventContent, - key::{KeyEventContent, KeyToDeviceEventContent}, - start::{ - MSasV1Content, MSasV1ContentInit, StartEventContent, StartMethod, - StartToDeviceEventContent, - }, - HashAlgorithm, KeyAgreementProtocol, MessageAuthenticationCode, Relation, - ShortAuthenticationString, VerificationMethod, + events::key::verification::{ + accept::{ + AcceptEventContent, AcceptMethod, AcceptToDeviceEventContent, + MSasV1Content as AcceptV1Content, MSasV1ContentInit as AcceptV1ContentInit, }, - ToDeviceEvent, + cancel::{CancelCode, CancelEventContent, CancelToDeviceEventContent}, + done::DoneEventContent, + key::{KeyEventContent, KeyToDeviceEventContent}, + start::{ + MSasV1Content, MSasV1ContentInit, StartEventContent, StartMethod, + StartToDeviceEventContent, + }, + HashAlgorithm, KeyAgreementProtocol, MessageAuthenticationCode, Relation, + ShortAuthenticationString, VerificationMethod, }, identifiers::{DeviceId, EventId, RoomId, UserId}, uuid::Uuid, @@ -47,8 +44,7 @@ use tracing::error; use super::{ event_enums::{ - AcceptContent, CancelContent, DoneContent, KeyContent, MacContent, OutgoingContent, - StartContent, + AcceptContent, CancelContent, DoneContent, KeyContent, MacContent, StartContent, }, helpers::{ calculate_commitment, get_decimal, get_emoji, get_mac_content, receive_mac_event, SasIds, @@ -405,7 +401,7 @@ impl SasState { } } - pub fn as_start_content(&self) -> StartContent { + pub fn as_content(&self) -> StartContent { match self.verification_flow_id.as_ref() { FlowId::ToDevice(_) => StartContent::ToDevice(StartToDeviceEventContent { transaction_id: self.verification_flow_id.to_string(), @@ -431,13 +427,6 @@ impl SasState { } } - /// Get the content for the start event. - /// - /// The content needs to be sent to the other device. - pub fn as_content(&self) -> OutgoingContent { - self.as_start_content().into() - } - /// Receive a m.key.verification.accept event, changing the state into /// an Accepted one. /// @@ -447,16 +436,18 @@ impl SasState { /// the other side. pub fn into_accepted( self, - event: &ToDeviceEvent, + sender: &UserId, + content: impl Into, ) -> Result, SasState> { - self.check_event(&event.sender, &event.content.transaction_id) + let content = content.into(); + self.check_event(&sender, content.flow_id().as_str()) .map_err(|c| self.clone().cancel(c))?; - if let AcceptMethod::MSasV1(content) = &event.content.method { + if let AcceptMethod::MSasV1(content) = content.method() { let accepted_protocols = AcceptedProtocols::try_from(content.clone()).map_err(|c| self.clone().cancel(c))?; - let start_content = self.as_start_content().into(); + let start_content = self.as_content().into(); Ok(SasState { inner: self.inner, @@ -713,14 +704,14 @@ impl SasState { /// Get the content for the key event. /// /// The content needs to be automatically sent to the other side. - pub fn as_content(&self) -> OutgoingContent { + pub fn as_content(&self) -> KeyContent { match &*self.verification_flow_id { - FlowId::ToDevice(s) => KeyContent::ToDevice(KeyToDeviceEventContent { + FlowId::ToDevice(s) => KeyToDeviceEventContent { transaction_id: s.to_string(), key: self.inner.lock().unwrap().public_key(), - }) + } .into(), - FlowId::InRoom(r, e) => KeyContent::Room( + FlowId::InRoom(r, e) => ( r.clone(), KeyEventContent { key: self.inner.lock().unwrap().public_key(), @@ -729,7 +720,7 @@ impl SasState { }, }, ) - .into(), + .into(), } } } @@ -1169,14 +1160,14 @@ impl SasState { mod test { use std::convert::TryFrom; - use crate::{ReadOnlyAccount, ReadOnlyDevice}; + use crate::{ + verification::sas::{event_enums::AcceptContent, StartContent}, + ReadOnlyAccount, ReadOnlyDevice, + }; use matrix_sdk_common::{ - events::{ - key::verification::{ - accept::{AcceptMethod, CustomContent}, - start::{CustomContent as CustomStartContent, StartMethod}, - }, - EventContent, ToDeviceEvent, + events::key::verification::{ + accept::{AcceptMethod, CustomContent}, + start::{CustomContent as CustomStartContent, StartMethod}, }, identifiers::{DeviceId, UserId}, }; @@ -1199,13 +1190,6 @@ mod test { "BOBDEVCIE".into() } - fn wrap_to_device_event(sender: &UserId, content: C) -> ToDeviceEvent { - ToDeviceEvent { - sender: sender.clone(), - content, - } - } - async fn get_sas_pair() -> (SasState, SasState) { let alice = ReadOnlyAccount::new(&alice_id(), &alice_device_id()); let alice_device = ReadOnlyDevice::from_account(&alice).await; @@ -1216,10 +1200,9 @@ mod test { let alice_sas = SasState::::new(alice.clone(), bob_device, None); let start_content = alice_sas.as_content(); - let event = wrap_to_device_event(alice_sas.user_id(), start_content); let bob_sas = - SasState::::from_start_event(bob.clone(), alice_device, &event, None); + SasState::::from_start_event(bob.clone(), alice_device, None, start_content); (alice_sas, bob_sas.unwrap()) } @@ -1233,25 +1216,25 @@ mod test { async fn sas_accept() { let (alice, bob) = get_sas_pair().await; - let event = wrap_to_device_event(bob.user_id(), bob.as_content()); + let event = bob.as_content(); - alice.into_accepted(&event).unwrap(); + alice.into_accepted(bob.user_id(), event).unwrap(); } #[tokio::test] async fn sas_key_share() { let (alice, bob) = get_sas_pair().await; - let event = wrap_to_device_event(bob.user_id(), bob.as_content()); + let content = bob.as_content(); - let alice: SasState = alice.into_accepted(&event).unwrap(); - let mut event = wrap_to_device_event(alice.user_id(), alice.as_content()); + let alice: SasState = alice.into_accepted(bob.user_id(), content).unwrap(); + let content = alice.as_content(); - let bob = bob.into_key_received(&mut event).unwrap(); + let bob = bob.into_key_received(alice.user_id(), content).unwrap(); - let mut event = wrap_to_device_event(bob.user_id(), bob.as_content()); + let content = bob.as_content(); - let alice = alice.into_key_received(&mut event).unwrap(); + let alice = alice.into_key_received(bob.user_id(), content).unwrap(); assert_eq!(alice.get_decimal(), bob.get_decimal()); assert_eq!(alice.get_emoji(), bob.get_emoji()); @@ -1261,16 +1244,16 @@ mod test { async fn sas_full() { let (alice, bob) = get_sas_pair().await; - let event = wrap_to_device_event(bob.user_id(), bob.as_content()); + let content = bob.as_content(); - let alice: SasState = alice.into_accepted(&event).unwrap(); - let mut event = wrap_to_device_event(alice.user_id(), alice.as_content()); + let alice: SasState = alice.into_accepted(bob.user_id(), content).unwrap(); + let content = alice.as_content(); - let bob = bob.into_key_received(&mut event).unwrap(); + let bob = bob.into_key_received(alice.user_id(), content).unwrap(); - let mut event = wrap_to_device_event(bob.user_id(), bob.as_content()); + let content = bob.as_content(); - let alice = alice.into_key_received(&mut event).unwrap(); + let alice = alice.into_key_received(bob.user_id(), content).unwrap(); assert_eq!(alice.get_decimal(), bob.get_decimal()); assert_eq!(alice.get_emoji(), bob.get_emoji()); @@ -1279,15 +1262,15 @@ mod test { let bob = bob.confirm(); - let event = wrap_to_device_event(bob.user_id(), bob.as_content()); + let content = bob.as_content(); - let alice = alice.into_mac_received(&event).unwrap(); + let alice = alice.into_mac_received(bob.user_id(), content).unwrap(); assert!(!alice.get_emoji().is_empty()); assert_eq!(alice.get_decimal(), bob_decimals); let alice = alice.confirm(); - let event = wrap_to_device_event(alice.user_id(), alice.as_content()); - let bob = bob.into_done(&event).unwrap(); + let content = alice.as_content(); + let bob = bob.into_done(alice.user_id(), content).unwrap(); assert!(bob.verified_devices().contains(&bob.other_device())); assert!(alice.verified_devices().contains(&alice.other_device())); @@ -1297,23 +1280,28 @@ mod test { async fn sas_invalid_commitment() { let (alice, bob) = get_sas_pair().await; - let mut event = wrap_to_device_event(bob.user_id(), bob.as_content()); + let mut content = bob.as_content(); - match &mut event.content.method { + let mut method = match &mut content { + AcceptContent::ToDevice(c) => &mut c.method, + AcceptContent::Room(_, c) => &mut c.method, + }; + + match &mut method { AcceptMethod::MSasV1(ref mut c) => { c.commitment = "".to_string(); } _ => panic!("Unknown accept event content"), } - let alice: SasState = alice.into_accepted(&event).unwrap(); + let alice: SasState = alice.into_accepted(bob.user_id(), content).unwrap(); - let mut event = wrap_to_device_event(alice.user_id(), alice.as_content()); - let bob = bob.into_key_received(&mut event).unwrap(); - let mut event = wrap_to_device_event(bob.user_id(), bob.as_content()); + let content = alice.as_content(); + let bob = bob.into_key_received(alice.user_id(), content).unwrap(); + let content = bob.as_content(); alice - .into_key_received(&mut event) + .into_key_received(bob.user_id(), content) .expect_err("Didn't cancel on invalid commitment"); } @@ -1321,10 +1309,10 @@ mod test { async fn sas_invalid_sender() { let (alice, bob) = get_sas_pair().await; - let mut event = wrap_to_device_event(bob.user_id(), bob.as_content()); - event.sender = UserId::try_from("@malory:example.org").unwrap(); + let content = bob.as_content(); + let sender = UserId::try_from("@malory:example.org").unwrap(); alice - .into_accepted(&event) + .into_accepted(&sender, content) .expect_err("Didn't cancel on a invalid sender"); } @@ -1332,9 +1320,14 @@ mod test { async fn sas_unknown_sas_method() { let (alice, bob) = get_sas_pair().await; - let mut event = wrap_to_device_event(bob.user_id(), bob.as_content()); + let mut content = bob.as_content(); - match &mut event.content.method { + let mut method = match &mut content { + AcceptContent::ToDevice(c) => &mut c.method, + AcceptContent::Room(_, c) => &mut c.method, + }; + + match &mut method { AcceptMethod::MSasV1(ref mut c) => { c.short_authentication_string = vec![]; } @@ -1342,7 +1335,7 @@ mod test { } alice - .into_accepted(&event) + .into_accepted(bob.user_id(), content) .expect_err("Didn't cancel on an invalid SAS method"); } @@ -1350,15 +1343,20 @@ mod test { async fn sas_unknown_method() { let (alice, bob) = get_sas_pair().await; - let mut event = wrap_to_device_event(bob.user_id(), bob.as_content()); + let mut content = bob.as_content(); - event.content.method = AcceptMethod::Custom(CustomContent { + let method = match &mut content { + AcceptContent::ToDevice(c) => &mut c.method, + AcceptContent::Room(_, c) => &mut c.method, + }; + + *method = AcceptMethod::Custom(CustomContent { method: "m.sas.custom".to_string(), fields: vec![].into_iter().collect(), }); alice - .into_accepted(&event) + .into_accepted(bob.user_id(), content) .expect_err("Didn't cancel on an unknown SAS method"); } @@ -1374,26 +1372,39 @@ mod test { let mut start_content = alice_sas.as_content(); - match start_content.method { + let method = match &mut start_content { + StartContent::ToDevice(c) => &mut c.method, + StartContent::Room(_, c) => &mut c.method, + }; + + match method { StartMethod::MSasV1(ref mut c) => { c.message_authentication_codes = vec![]; } _ => panic!("Unknown SAS start method"), } - let event = wrap_to_device_event(alice_sas.user_id(), start_content); - SasState::::from_start_event(bob.clone(), alice_device.clone(), &event, None) - .expect_err("Didn't cancel on invalid MAC method"); + SasState::::from_start_event( + bob.clone(), + alice_device.clone(), + None, + start_content, + ) + .expect_err("Didn't cancel on invalid MAC method"); let mut start_content = alice_sas.as_content(); - start_content.method = StartMethod::Custom(CustomStartContent { + let method = match &mut start_content { + StartContent::ToDevice(c) => &mut c.method, + StartContent::Room(_, c) => &mut c.method, + }; + + *method = StartMethod::Custom(CustomStartContent { method: "m.sas.custom".to_string(), fields: vec![].into_iter().collect(), }); - let event = wrap_to_device_event(alice_sas.user_id(), start_content); - SasState::::from_start_event(bob.clone(), alice_device, &event, None) + SasState::::from_start_event(bob.clone(), alice_device, None, start_content) .expect_err("Didn't cancel on unknown sas method"); } }