crypto: Fix our tests now that we support in-room verifications.

master
Damir Jelić 2020-12-18 12:55:06 +01:00
parent f735107caf
commit 897c6abe92
8 changed files with 185 additions and 231 deletions

View File

@ -1826,7 +1826,7 @@ pub(crate) mod test {
let (alice_sas, request) = bob_device.start_verification().await.unwrap(); let (alice_sas, request) = bob_device.start_verification().await.unwrap();
let mut event = request_to_event(alice.user_id(), &request); let mut event = request_to_event(alice.user_id(), &request.into());
bob.handle_verification_event(&mut event).await; bob.handle_verification_event(&mut event).await;
let bob_sas = bob.get_verification(alice_sas.flow_id().as_str()).unwrap(); let bob_sas = bob.get_verification(alice_sas.flow_id().as_str()).unwrap();

View File

@ -85,7 +85,7 @@ impl VerificationMachine {
identity, identity,
); );
let request = match content { let request = match content.into() {
OutgoingContent::Room(r, c) => RoomMessageRequest { OutgoingContent::Room(r, c) => RoomMessageRequest {
room_id: r, room_id: r,
txn_id: Uuid::new_v4(), txn_id: Uuid::new_v4(),
@ -459,7 +459,6 @@ mod test {
}; };
use matrix_sdk_common::{ use matrix_sdk_common::{
events::AnyToDeviceEventContent,
identifiers::{DeviceId, UserId}, identifiers::{DeviceId, UserId},
locks::Mutex, locks::Mutex,
}; };
@ -511,10 +510,11 @@ mod test {
bob_store, bob_store,
None, None,
); );
machine machine
.receive_event(&mut wrap_any_to_device_content( .receive_event(&mut wrap_any_to_device_content(
bob_sas.user_id(), bob_sas.user_id(),
AnyToDeviceEventContent::KeyVerificationStart(start_content), start_content.into(),
)) ))
.await .await
.unwrap(); .unwrap();
@ -559,12 +559,13 @@ mod test {
let txn_id = *request.request_id(); let txn_id = *request.request_id();
let r = if let OutgoingRequests::ToDeviceRequest(r) = request.request() { let r = if let OutgoingRequests::ToDeviceRequest(r) = request.request() {
r r.clone()
} else { } else {
panic!("Invalid request type"); panic!("Invalid request type");
}; };
let mut event = wrap_any_to_device_content(alice.user_id(), get_content_from_request(r)); let mut event =
wrap_any_to_device_content(alice.user_id(), get_content_from_request(&r.into()));
drop(request); drop(request);
alice_machine.mark_request_as_sent(&txn_id); alice_machine.mark_request_as_sent(&txn_id);

View File

@ -22,7 +22,10 @@ pub use sas::{Sas, VerificationResult};
#[cfg(test)] #[cfg(test)]
pub(crate) mod test { pub(crate) mod test {
use crate::requests::{OutgoingRequest, OutgoingRequests, ToDeviceRequest}; use crate::{
requests::{OutgoingRequest, OutgoingRequests},
OutgoingVerificationRequest,
};
use serde_json::Value; use serde_json::Value;
use matrix_sdk_common::{ use matrix_sdk_common::{
@ -30,7 +33,12 @@ pub(crate) mod test {
identifiers::UserId, identifiers::UserId,
}; };
pub(crate) fn request_to_event(sender: &UserId, request: &ToDeviceRequest) -> AnyToDeviceEvent { use super::sas::OutgoingContent;
pub(crate) fn request_to_event(
sender: &UserId,
request: &OutgoingVerificationRequest,
) -> AnyToDeviceEvent {
let content = get_content_from_request(request); let content = get_content_from_request(request);
wrap_any_to_device_content(sender, content) wrap_any_to_device_content(sender, content)
} }
@ -40,15 +48,21 @@ pub(crate) mod test {
request: &OutgoingRequest, request: &OutgoingRequest,
) -> AnyToDeviceEvent { ) -> AnyToDeviceEvent {
match request.request() { match request.request() {
OutgoingRequests::ToDeviceRequest(r) => request_to_event(sender, r), OutgoingRequests::ToDeviceRequest(r) => request_to_event(sender, &r.clone().into()),
_ => panic!("Unsupported outgoing request"), _ => panic!("Unsupported outgoing request"),
} }
} }
pub(crate) fn wrap_any_to_device_content( pub(crate) fn wrap_any_to_device_content(
sender: &UserId, sender: &UserId,
content: AnyToDeviceEventContent, content: OutgoingContent,
) -> AnyToDeviceEvent { ) -> AnyToDeviceEvent {
let content = if let OutgoingContent::ToDevice(c) = content {
c
} else {
unreachable!()
};
match content { match content {
AnyToDeviceEventContent::KeyVerificationKey(c) => { AnyToDeviceEventContent::KeyVerificationKey(c) => {
AnyToDeviceEvent::KeyVerificationKey(ToDeviceEvent { AnyToDeviceEvent::KeyVerificationKey(ToDeviceEvent {
@ -79,7 +93,15 @@ pub(crate) mod test {
} }
} }
pub(crate) fn get_content_from_request(request: &ToDeviceRequest) -> AnyToDeviceEventContent { pub(crate) fn get_content_from_request(
request: &OutgoingVerificationRequest,
) -> OutgoingContent {
let request = if let OutgoingVerificationRequest::ToDevice(r) = request {
r
} else {
unreachable!()
};
let json: Value = serde_json::from_str( let json: Value = serde_json::from_str(
request request
.messages .messages
@ -111,5 +133,6 @@ pub(crate) mod test {
), ),
_ => unreachable!(), _ => unreachable!(),
} }
.into()
} }
} }

View File

@ -19,7 +19,7 @@ use std::{collections::BTreeMap, convert::TryInto};
use matrix_sdk_common::{ use matrix_sdk_common::{
events::{ events::{
key::verification::{ key::verification::{
accept::{AcceptEventContent, AcceptToDeviceEventContent}, accept::{AcceptEventContent, AcceptMethod, AcceptToDeviceEventContent},
cancel::{CancelEventContent, CancelToDeviceEventContent}, cancel::{CancelEventContent, CancelToDeviceEventContent},
done::DoneEventContent, done::DoneEventContent,
key::{KeyEventContent, KeyToDeviceEventContent}, key::{KeyEventContent, KeyToDeviceEventContent},
@ -86,6 +86,22 @@ pub enum AcceptContent {
Room(RoomId, AcceptEventContent), Room(RoomId, AcceptEventContent),
} }
impl AcceptContent {
pub fn flow_id(&self) -> FlowId {
match self {
AcceptContent::ToDevice(c) => FlowId::ToDevice(c.transaction_id.clone()),
AcceptContent::Room(r, c) => FlowId::InRoom(r.clone(), c.relation.event_id.clone()),
}
}
pub fn method(&self) -> &AcceptMethod {
match self {
AcceptContent::ToDevice(c) => &c.method,
AcceptContent::Room(_, c) => &c.method,
}
}
}
impl From<AcceptToDeviceEventContent> for AcceptContent { impl From<AcceptToDeviceEventContent> for AcceptContent {
fn from(content: AcceptToDeviceEventContent) -> Self { fn from(content: AcceptToDeviceEventContent) -> Self {
AcceptContent::ToDevice(content) AcceptContent::ToDevice(content)

View File

@ -570,7 +570,7 @@ mod test {
}); });
let content: StartToDeviceEventContent = serde_json::from_value(content).unwrap(); let content: StartToDeviceEventContent = serde_json::from_value(content).unwrap();
let calculated_commitment = calculate_commitment(public_key, &content); let calculated_commitment = calculate_commitment(public_key, content);
assert_eq!(commitment, &calculated_commitment); assert_eq!(commitment, &calculated_commitment);
} }

View File

@ -55,7 +55,7 @@ impl InnerSas {
account: ReadOnlyAccount, account: ReadOnlyAccount,
other_device: ReadOnlyDevice, other_device: ReadOnlyDevice,
other_identity: Option<UserIdentities>, other_identity: Option<UserIdentities>,
) -> (InnerSas, OutgoingContent) { ) -> (InnerSas, StartContent) {
let sas = SasState::<Created>::new(account, other_device, other_identity); let sas = SasState::<Created>::new(account, other_device, other_identity);
let content = sas.as_content(); let content = sas.as_content();
(InnerSas::Created(sas), content) (InnerSas::Created(sas), content)
@ -67,7 +67,7 @@ impl InnerSas {
account: ReadOnlyAccount, account: ReadOnlyAccount,
other_device: ReadOnlyDevice, other_device: ReadOnlyDevice,
other_identity: Option<UserIdentities>, other_identity: Option<UserIdentities>,
) -> (InnerSas, OutgoingContent) { ) -> (InnerSas, StartContent) {
let sas = SasState::<Created>::new_in_room( let sas = SasState::<Created>::new_in_room(
room_id, room_id,
event_id, event_id,
@ -245,7 +245,7 @@ impl InnerSas {
match event { match event {
AnyToDeviceEvent::KeyVerificationAccept(e) => { AnyToDeviceEvent::KeyVerificationAccept(e) => {
if let InnerSas::Created(s) = self { if let InnerSas::Created(s) = self {
match s.into_accepted(e) { match s.into_accepted(&e.sender, e.content.clone()) {
Ok(s) => { Ok(s) => {
let content = s.as_content(); let content = s.as_content();
(InnerSas::Accepted(s), Some(content.into())) (InnerSas::Accepted(s), Some(content.into()))

View File

@ -113,16 +113,15 @@ impl Sas {
fn start_helper( fn start_helper(
inner_sas: InnerSas, inner_sas: InnerSas,
content: OutgoingContent,
account: ReadOnlyAccount, account: ReadOnlyAccount,
private_identity: PrivateCrossSigningIdentity, private_identity: PrivateCrossSigningIdentity,
other_device: ReadOnlyDevice, other_device: ReadOnlyDevice,
store: Arc<Box<dyn CryptoStore>>, store: Arc<Box<dyn CryptoStore>>,
other_identity: Option<UserIdentities>, other_identity: Option<UserIdentities>,
) -> (Sas, OutgoingContent) { ) -> Sas {
let flow_id = inner_sas.verification_flow_id(); let flow_id = inner_sas.verification_flow_id();
let sas = Sas { Sas {
inner: Arc::new(Mutex::new(inner_sas)), inner: Arc::new(Mutex::new(inner_sas)),
account, account,
private_identity, private_identity,
@ -130,9 +129,7 @@ impl Sas {
other_device, other_device,
flow_id, flow_id,
other_identity, other_identity,
}; }
(sas, content)
} }
/// Start a new SAS auth flow with the given device. /// Start a new SAS auth flow with the given device.
@ -151,21 +148,23 @@ impl Sas {
other_device: ReadOnlyDevice, other_device: ReadOnlyDevice,
store: Arc<Box<dyn CryptoStore>>, store: Arc<Box<dyn CryptoStore>>,
other_identity: Option<UserIdentities>, other_identity: Option<UserIdentities>,
) -> (Sas, OutgoingContent) { ) -> (Sas, StartContent) {
let (inner, content) = InnerSas::start( let (inner, content) = InnerSas::start(
account.clone(), account.clone(),
other_device.clone(), other_device.clone(),
other_identity.clone(), other_identity.clone(),
); );
(
Self::start_helper( Self::start_helper(
inner, inner,
content,
account, account,
private_identity, private_identity,
other_device, other_device,
store, store,
other_identity, other_identity,
),
content,
) )
} }
@ -188,7 +187,7 @@ impl Sas {
other_device: ReadOnlyDevice, other_device: ReadOnlyDevice,
store: Arc<Box<dyn CryptoStore>>, store: Arc<Box<dyn CryptoStore>>,
other_identity: Option<UserIdentities>, other_identity: Option<UserIdentities>,
) -> (Sas, OutgoingContent) { ) -> (Sas, StartContent) {
let (inner, content) = InnerSas::start_in_room( let (inner, content) = InnerSas::start_in_room(
flow_id, flow_id,
room_id, room_id,
@ -197,14 +196,16 @@ impl Sas {
other_identity.clone(), other_identity.clone(),
); );
(
Self::start_helper( Self::start_helper(
inner, inner,
content,
account, account,
private_identity, private_identity,
other_device, other_device,
store, store,
other_identity, other_identity,
),
content,
) )
} }
@ -656,10 +657,7 @@ impl Sas {
mod test { mod test {
use std::{convert::TryFrom, sync::Arc}; use std::{convert::TryFrom, sync::Arc};
use matrix_sdk_common::{ use matrix_sdk_common::identifiers::{DeviceId, UserId};
events::{EventContent, ToDeviceEvent},
identifiers::{DeviceId, UserId},
};
use crate::{ use crate::{
olm::PrivateCrossSigningIdentity, olm::PrivateCrossSigningIdentity,
@ -668,10 +666,7 @@ mod test {
ReadOnlyAccount, ReadOnlyDevice, ReadOnlyAccount, ReadOnlyDevice,
}; };
use super::{ use super::Sas;
sas_state::{Accepted, Created, SasState, Started},
Sas,
};
fn alice_id() -> UserId { fn alice_id() -> UserId {
UserId::try_from("@alice:example.org").unwrap() UserId::try_from("@alice:example.org").unwrap()
@ -689,97 +684,6 @@ mod test {
"BOBDEVCIE".into() "BOBDEVCIE".into()
} }
fn wrap_to_device_event<C: EventContent>(sender: &UserId, content: C) -> ToDeviceEvent<C> {
ToDeviceEvent {
sender: sender.clone(),
content,
}
}
async fn get_sas_pair() -> (SasState<Created>, SasState<Started>) {
let alice = ReadOnlyAccount::new(&alice_id(), &alice_device_id());
let alice_device = ReadOnlyDevice::from_account(&alice).await;
let bob = ReadOnlyAccount::new(&bob_id(), &bob_device_id());
let bob_device = ReadOnlyDevice::from_account(&bob).await;
let alice_sas = SasState::<Created>::new(alice.clone(), bob_device, None);
let start_content = alice_sas.as_content();
let event = wrap_to_device_event(alice_sas.user_id(), start_content);
let bob_sas =
SasState::<Started>::from_start_event(bob.clone(), alice_device, &event, None);
(alice_sas, bob_sas.unwrap())
}
#[tokio::test]
async fn create_sas() {
let (_, _) = get_sas_pair().await;
}
#[tokio::test]
async fn sas_accept() {
let (alice, bob) = get_sas_pair().await;
let event = wrap_to_device_event(bob.user_id(), bob.as_content());
alice.into_accepted(&event).unwrap();
}
#[tokio::test]
async fn sas_key_share() {
let (alice, bob) = get_sas_pair().await;
let event = wrap_to_device_event(bob.user_id(), bob.as_content());
let alice: SasState<Accepted> = alice.into_accepted(&event).unwrap();
let mut event = wrap_to_device_event(alice.user_id(), alice.as_content());
let bob = bob.into_key_received(&mut event).unwrap();
let mut event = wrap_to_device_event(bob.user_id(), bob.as_content());
let alice = alice.into_key_received(&mut event).unwrap();
assert_eq!(alice.get_decimal(), bob.get_decimal());
assert_eq!(alice.get_emoji(), bob.get_emoji());
}
#[tokio::test]
async fn sas_full() {
let (alice, bob) = get_sas_pair().await;
let event = wrap_to_device_event(bob.user_id(), bob.as_content());
let alice: SasState<Accepted> = alice.into_accepted(&event).unwrap();
let mut event = wrap_to_device_event(alice.user_id(), alice.as_content());
let bob = bob.into_key_received(&mut event).unwrap();
let mut event = wrap_to_device_event(bob.user_id(), bob.as_content());
let alice = alice.into_key_received(&mut event).unwrap();
assert_eq!(alice.get_decimal(), bob.get_decimal());
assert_eq!(alice.get_emoji(), bob.get_emoji());
let bob = bob.confirm();
let event = wrap_to_device_event(bob.user_id(), bob.as_content());
let alice = alice.into_mac_received(&event).unwrap();
assert!(!alice.get_emoji().is_empty());
let alice = alice.confirm();
let event = wrap_to_device_event(alice.user_id(), alice.as_content());
let bob = bob.into_done(&event).unwrap();
assert!(bob.verified_devices().contains(&bob.other_device()));
assert!(alice.verified_devices().contains(&alice.other_device()));
}
#[tokio::test] #[tokio::test]
async fn sas_wrapper_full() { async fn sas_wrapper_full() {
let alice = ReadOnlyAccount::new(&alice_id(), &alice_device_id()); let alice = ReadOnlyAccount::new(&alice_id(), &alice_device_id());
@ -802,14 +706,13 @@ mod test {
alice_store, alice_store,
None, None,
); );
let event = wrap_to_device_event(alice.user_id(), content);
let bob = Sas::from_start_event( let bob = Sas::from_start_event(
bob, bob,
PrivateCrossSigningIdentity::empty(bob_id()), PrivateCrossSigningIdentity::empty(bob_id()),
alice_device, alice_device,
bob_store, bob_store,
&event, content,
None, None,
) )
.unwrap(); .unwrap();

View File

@ -22,8 +22,7 @@ use std::{
use olm_rs::sas::OlmSas; use olm_rs::sas::OlmSas;
use matrix_sdk_common::{ use matrix_sdk_common::{
events::{ events::key::verification::{
key::verification::{
accept::{ accept::{
AcceptEventContent, AcceptMethod, AcceptToDeviceEventContent, AcceptEventContent, AcceptMethod, AcceptToDeviceEventContent,
MSasV1Content as AcceptV1Content, MSasV1ContentInit as AcceptV1ContentInit, MSasV1Content as AcceptV1Content, MSasV1ContentInit as AcceptV1ContentInit,
@ -38,8 +37,6 @@ use matrix_sdk_common::{
HashAlgorithm, KeyAgreementProtocol, MessageAuthenticationCode, Relation, HashAlgorithm, KeyAgreementProtocol, MessageAuthenticationCode, Relation,
ShortAuthenticationString, VerificationMethod, ShortAuthenticationString, VerificationMethod,
}, },
ToDeviceEvent,
},
identifiers::{DeviceId, EventId, RoomId, UserId}, identifiers::{DeviceId, EventId, RoomId, UserId},
uuid::Uuid, uuid::Uuid,
}; };
@ -47,8 +44,7 @@ use tracing::error;
use super::{ use super::{
event_enums::{ event_enums::{
AcceptContent, CancelContent, DoneContent, KeyContent, MacContent, OutgoingContent, AcceptContent, CancelContent, DoneContent, KeyContent, MacContent, StartContent,
StartContent,
}, },
helpers::{ helpers::{
calculate_commitment, get_decimal, get_emoji, get_mac_content, receive_mac_event, SasIds, calculate_commitment, get_decimal, get_emoji, get_mac_content, receive_mac_event, SasIds,
@ -405,7 +401,7 @@ impl SasState<Created> {
} }
} }
pub fn as_start_content(&self) -> StartContent { pub fn as_content(&self) -> StartContent {
match self.verification_flow_id.as_ref() { match self.verification_flow_id.as_ref() {
FlowId::ToDevice(_) => StartContent::ToDevice(StartToDeviceEventContent { FlowId::ToDevice(_) => StartContent::ToDevice(StartToDeviceEventContent {
transaction_id: self.verification_flow_id.to_string(), transaction_id: self.verification_flow_id.to_string(),
@ -431,13 +427,6 @@ impl SasState<Created> {
} }
} }
/// Get the content for the start event.
///
/// The content needs to be sent to the other device.
pub fn as_content(&self) -> OutgoingContent {
self.as_start_content().into()
}
/// Receive a m.key.verification.accept event, changing the state into /// Receive a m.key.verification.accept event, changing the state into
/// an Accepted one. /// an Accepted one.
/// ///
@ -447,16 +436,18 @@ impl SasState<Created> {
/// the other side. /// the other side.
pub fn into_accepted( pub fn into_accepted(
self, self,
event: &ToDeviceEvent<AcceptToDeviceEventContent>, sender: &UserId,
content: impl Into<AcceptContent>,
) -> Result<SasState<Accepted>, SasState<Canceled>> { ) -> Result<SasState<Accepted>, SasState<Canceled>> {
self.check_event(&event.sender, &event.content.transaction_id) let content = content.into();
self.check_event(&sender, content.flow_id().as_str())
.map_err(|c| self.clone().cancel(c))?; .map_err(|c| self.clone().cancel(c))?;
if let AcceptMethod::MSasV1(content) = &event.content.method { if let AcceptMethod::MSasV1(content) = content.method() {
let accepted_protocols = let accepted_protocols =
AcceptedProtocols::try_from(content.clone()).map_err(|c| self.clone().cancel(c))?; AcceptedProtocols::try_from(content.clone()).map_err(|c| self.clone().cancel(c))?;
let start_content = self.as_start_content().into(); let start_content = self.as_content().into();
Ok(SasState { Ok(SasState {
inner: self.inner, inner: self.inner,
@ -713,14 +704,14 @@ impl SasState<Accepted> {
/// Get the content for the key event. /// Get the content for the key event.
/// ///
/// The content needs to be automatically sent to the other side. /// The content needs to be automatically sent to the other side.
pub fn as_content(&self) -> OutgoingContent { pub fn as_content(&self) -> KeyContent {
match &*self.verification_flow_id { match &*self.verification_flow_id {
FlowId::ToDevice(s) => KeyContent::ToDevice(KeyToDeviceEventContent { FlowId::ToDevice(s) => KeyToDeviceEventContent {
transaction_id: s.to_string(), transaction_id: s.to_string(),
key: self.inner.lock().unwrap().public_key(), key: self.inner.lock().unwrap().public_key(),
}) }
.into(), .into(),
FlowId::InRoom(r, e) => KeyContent::Room( FlowId::InRoom(r, e) => (
r.clone(), r.clone(),
KeyEventContent { KeyEventContent {
key: self.inner.lock().unwrap().public_key(), key: self.inner.lock().unwrap().public_key(),
@ -1169,15 +1160,15 @@ impl SasState<Canceled> {
mod test { mod test {
use std::convert::TryFrom; use std::convert::TryFrom;
use crate::{ReadOnlyAccount, ReadOnlyDevice}; use crate::{
verification::sas::{event_enums::AcceptContent, StartContent},
ReadOnlyAccount, ReadOnlyDevice,
};
use matrix_sdk_common::{ use matrix_sdk_common::{
events::{ events::key::verification::{
key::verification::{
accept::{AcceptMethod, CustomContent}, accept::{AcceptMethod, CustomContent},
start::{CustomContent as CustomStartContent, StartMethod}, start::{CustomContent as CustomStartContent, StartMethod},
}, },
EventContent, ToDeviceEvent,
},
identifiers::{DeviceId, UserId}, identifiers::{DeviceId, UserId},
}; };
@ -1199,13 +1190,6 @@ mod test {
"BOBDEVCIE".into() "BOBDEVCIE".into()
} }
fn wrap_to_device_event<C: EventContent>(sender: &UserId, content: C) -> ToDeviceEvent<C> {
ToDeviceEvent {
sender: sender.clone(),
content,
}
}
async fn get_sas_pair() -> (SasState<Created>, SasState<Started>) { async fn get_sas_pair() -> (SasState<Created>, SasState<Started>) {
let alice = ReadOnlyAccount::new(&alice_id(), &alice_device_id()); let alice = ReadOnlyAccount::new(&alice_id(), &alice_device_id());
let alice_device = ReadOnlyDevice::from_account(&alice).await; let alice_device = ReadOnlyDevice::from_account(&alice).await;
@ -1216,10 +1200,9 @@ mod test {
let alice_sas = SasState::<Created>::new(alice.clone(), bob_device, None); let alice_sas = SasState::<Created>::new(alice.clone(), bob_device, None);
let start_content = alice_sas.as_content(); let start_content = alice_sas.as_content();
let event = wrap_to_device_event(alice_sas.user_id(), start_content);
let bob_sas = let bob_sas =
SasState::<Started>::from_start_event(bob.clone(), alice_device, &event, None); SasState::<Started>::from_start_event(bob.clone(), alice_device, None, start_content);
(alice_sas, bob_sas.unwrap()) (alice_sas, bob_sas.unwrap())
} }
@ -1233,25 +1216,25 @@ mod test {
async fn sas_accept() { async fn sas_accept() {
let (alice, bob) = get_sas_pair().await; let (alice, bob) = get_sas_pair().await;
let event = wrap_to_device_event(bob.user_id(), bob.as_content()); let event = bob.as_content();
alice.into_accepted(&event).unwrap(); alice.into_accepted(bob.user_id(), event).unwrap();
} }
#[tokio::test] #[tokio::test]
async fn sas_key_share() { async fn sas_key_share() {
let (alice, bob) = get_sas_pair().await; let (alice, bob) = get_sas_pair().await;
let event = wrap_to_device_event(bob.user_id(), bob.as_content()); let content = bob.as_content();
let alice: SasState<Accepted> = alice.into_accepted(&event).unwrap(); let alice: SasState<Accepted> = alice.into_accepted(bob.user_id(), content).unwrap();
let mut event = wrap_to_device_event(alice.user_id(), alice.as_content()); let content = alice.as_content();
let bob = bob.into_key_received(&mut event).unwrap(); let bob = bob.into_key_received(alice.user_id(), content).unwrap();
let mut event = wrap_to_device_event(bob.user_id(), bob.as_content()); let content = bob.as_content();
let alice = alice.into_key_received(&mut event).unwrap(); let alice = alice.into_key_received(bob.user_id(), content).unwrap();
assert_eq!(alice.get_decimal(), bob.get_decimal()); assert_eq!(alice.get_decimal(), bob.get_decimal());
assert_eq!(alice.get_emoji(), bob.get_emoji()); assert_eq!(alice.get_emoji(), bob.get_emoji());
@ -1261,16 +1244,16 @@ mod test {
async fn sas_full() { async fn sas_full() {
let (alice, bob) = get_sas_pair().await; let (alice, bob) = get_sas_pair().await;
let event = wrap_to_device_event(bob.user_id(), bob.as_content()); let content = bob.as_content();
let alice: SasState<Accepted> = alice.into_accepted(&event).unwrap(); let alice: SasState<Accepted> = alice.into_accepted(bob.user_id(), content).unwrap();
let mut event = wrap_to_device_event(alice.user_id(), alice.as_content()); let content = alice.as_content();
let bob = bob.into_key_received(&mut event).unwrap(); let bob = bob.into_key_received(alice.user_id(), content).unwrap();
let mut event = wrap_to_device_event(bob.user_id(), bob.as_content()); let content = bob.as_content();
let alice = alice.into_key_received(&mut event).unwrap(); let alice = alice.into_key_received(bob.user_id(), content).unwrap();
assert_eq!(alice.get_decimal(), bob.get_decimal()); assert_eq!(alice.get_decimal(), bob.get_decimal());
assert_eq!(alice.get_emoji(), bob.get_emoji()); assert_eq!(alice.get_emoji(), bob.get_emoji());
@ -1279,15 +1262,15 @@ mod test {
let bob = bob.confirm(); let bob = bob.confirm();
let event = wrap_to_device_event(bob.user_id(), bob.as_content()); let content = bob.as_content();
let alice = alice.into_mac_received(&event).unwrap(); let alice = alice.into_mac_received(bob.user_id(), content).unwrap();
assert!(!alice.get_emoji().is_empty()); assert!(!alice.get_emoji().is_empty());
assert_eq!(alice.get_decimal(), bob_decimals); assert_eq!(alice.get_decimal(), bob_decimals);
let alice = alice.confirm(); let alice = alice.confirm();
let event = wrap_to_device_event(alice.user_id(), alice.as_content()); let content = alice.as_content();
let bob = bob.into_done(&event).unwrap(); let bob = bob.into_done(alice.user_id(), content).unwrap();
assert!(bob.verified_devices().contains(&bob.other_device())); assert!(bob.verified_devices().contains(&bob.other_device()));
assert!(alice.verified_devices().contains(&alice.other_device())); assert!(alice.verified_devices().contains(&alice.other_device()));
@ -1297,23 +1280,28 @@ mod test {
async fn sas_invalid_commitment() { async fn sas_invalid_commitment() {
let (alice, bob) = get_sas_pair().await; let (alice, bob) = get_sas_pair().await;
let mut event = wrap_to_device_event(bob.user_id(), bob.as_content()); let mut content = bob.as_content();
match &mut event.content.method { let mut method = match &mut content {
AcceptContent::ToDevice(c) => &mut c.method,
AcceptContent::Room(_, c) => &mut c.method,
};
match &mut method {
AcceptMethod::MSasV1(ref mut c) => { AcceptMethod::MSasV1(ref mut c) => {
c.commitment = "".to_string(); c.commitment = "".to_string();
} }
_ => panic!("Unknown accept event content"), _ => panic!("Unknown accept event content"),
} }
let alice: SasState<Accepted> = alice.into_accepted(&event).unwrap(); let alice: SasState<Accepted> = alice.into_accepted(bob.user_id(), content).unwrap();
let mut event = wrap_to_device_event(alice.user_id(), alice.as_content()); let content = alice.as_content();
let bob = bob.into_key_received(&mut event).unwrap(); let bob = bob.into_key_received(alice.user_id(), content).unwrap();
let mut event = wrap_to_device_event(bob.user_id(), bob.as_content()); let content = bob.as_content();
alice alice
.into_key_received(&mut event) .into_key_received(bob.user_id(), content)
.expect_err("Didn't cancel on invalid commitment"); .expect_err("Didn't cancel on invalid commitment");
} }
@ -1321,10 +1309,10 @@ mod test {
async fn sas_invalid_sender() { async fn sas_invalid_sender() {
let (alice, bob) = get_sas_pair().await; let (alice, bob) = get_sas_pair().await;
let mut event = wrap_to_device_event(bob.user_id(), bob.as_content()); let content = bob.as_content();
event.sender = UserId::try_from("@malory:example.org").unwrap(); let sender = UserId::try_from("@malory:example.org").unwrap();
alice alice
.into_accepted(&event) .into_accepted(&sender, content)
.expect_err("Didn't cancel on a invalid sender"); .expect_err("Didn't cancel on a invalid sender");
} }
@ -1332,9 +1320,14 @@ mod test {
async fn sas_unknown_sas_method() { async fn sas_unknown_sas_method() {
let (alice, bob) = get_sas_pair().await; let (alice, bob) = get_sas_pair().await;
let mut event = wrap_to_device_event(bob.user_id(), bob.as_content()); let mut content = bob.as_content();
match &mut event.content.method { let mut method = match &mut content {
AcceptContent::ToDevice(c) => &mut c.method,
AcceptContent::Room(_, c) => &mut c.method,
};
match &mut method {
AcceptMethod::MSasV1(ref mut c) => { AcceptMethod::MSasV1(ref mut c) => {
c.short_authentication_string = vec![]; c.short_authentication_string = vec![];
} }
@ -1342,7 +1335,7 @@ mod test {
} }
alice alice
.into_accepted(&event) .into_accepted(bob.user_id(), content)
.expect_err("Didn't cancel on an invalid SAS method"); .expect_err("Didn't cancel on an invalid SAS method");
} }
@ -1350,15 +1343,20 @@ mod test {
async fn sas_unknown_method() { async fn sas_unknown_method() {
let (alice, bob) = get_sas_pair().await; let (alice, bob) = get_sas_pair().await;
let mut event = wrap_to_device_event(bob.user_id(), bob.as_content()); let mut content = bob.as_content();
event.content.method = AcceptMethod::Custom(CustomContent { let method = match &mut content {
AcceptContent::ToDevice(c) => &mut c.method,
AcceptContent::Room(_, c) => &mut c.method,
};
*method = AcceptMethod::Custom(CustomContent {
method: "m.sas.custom".to_string(), method: "m.sas.custom".to_string(),
fields: vec![].into_iter().collect(), fields: vec![].into_iter().collect(),
}); });
alice alice
.into_accepted(&event) .into_accepted(bob.user_id(), content)
.expect_err("Didn't cancel on an unknown SAS method"); .expect_err("Didn't cancel on an unknown SAS method");
} }
@ -1374,26 +1372,39 @@ mod test {
let mut start_content = alice_sas.as_content(); let mut start_content = alice_sas.as_content();
match start_content.method { let method = match &mut start_content {
StartContent::ToDevice(c) => &mut c.method,
StartContent::Room(_, c) => &mut c.method,
};
match method {
StartMethod::MSasV1(ref mut c) => { StartMethod::MSasV1(ref mut c) => {
c.message_authentication_codes = vec![]; c.message_authentication_codes = vec![];
} }
_ => panic!("Unknown SAS start method"), _ => panic!("Unknown SAS start method"),
} }
let event = wrap_to_device_event(alice_sas.user_id(), start_content); SasState::<Started>::from_start_event(
SasState::<Started>::from_start_event(bob.clone(), alice_device.clone(), &event, None) bob.clone(),
alice_device.clone(),
None,
start_content,
)
.expect_err("Didn't cancel on invalid MAC method"); .expect_err("Didn't cancel on invalid MAC method");
let mut start_content = alice_sas.as_content(); let mut start_content = alice_sas.as_content();
start_content.method = StartMethod::Custom(CustomStartContent { let method = match &mut start_content {
StartContent::ToDevice(c) => &mut c.method,
StartContent::Room(_, c) => &mut c.method,
};
*method = StartMethod::Custom(CustomStartContent {
method: "m.sas.custom".to_string(), method: "m.sas.custom".to_string(),
fields: vec![].into_iter().collect(), fields: vec![].into_iter().collect(),
}); });
let event = wrap_to_device_event(alice_sas.user_id(), start_content); SasState::<Started>::from_start_event(bob.clone(), alice_device, None, start_content)
SasState::<Started>::from_start_event(bob.clone(), alice_device, &event, None)
.expect_err("Didn't cancel on unknown sas method"); .expect_err("Didn't cancel on unknown sas method");
} }
} }