diff --git a/matrix_sdk_crypto/src/verification/machine.rs b/matrix_sdk_crypto/src/verification/machine.rs index d57c38f3..7e24b6de 100644 --- a/matrix_sdk_crypto/src/verification/machine.rs +++ b/matrix_sdk_crypto/src/verification/machine.rs @@ -44,7 +44,7 @@ pub struct VerificationCache { } impl VerificationCache { - fn new() -> Self { + pub fn new() -> Self { Self { sas_verification: DashMap::new().into(), room_sas_verifications: DashMap::new().into(), @@ -257,6 +257,7 @@ impl VerificationMachine { ); let request = VerificationRequest::from_room_request( + self.verifications.clone(), self.account.clone(), self.private_identity.lock().await.clone(), self.store.clone(), @@ -291,7 +292,7 @@ impl VerificationMachine { self.store.get_device(&e.sender, &e.content.from_device).await? { match request.into_started_sas( - e, + &e.content, d, self.store.get_user_identity(&e.sender).await?, ) { @@ -388,6 +389,7 @@ impl VerificationMachine { match event { AnyToDeviceEvent::KeyVerificationRequest(e) => { let request = VerificationRequest::from_request( + self.verifications.clone(), self.account.clone(), self.private_identity.lock().await.clone(), self.store.clone(), diff --git a/matrix_sdk_crypto/src/verification/mod.rs b/matrix_sdk_crypto/src/verification/mod.rs index c9868d3c..4e014bd0 100644 --- a/matrix_sdk_crypto/src/verification/mod.rs +++ b/matrix_sdk_crypto/src/verification/mod.rs @@ -16,7 +16,7 @@ mod machine; mod requests; mod sas; -pub use machine::VerificationMachine; +pub use machine::{VerificationCache, VerificationMachine}; use matrix_sdk_common::identifiers::{EventId, RoomId}; pub use requests::VerificationRequest; pub use sas::{AcceptSettings, Sas, VerificationResult}; diff --git a/matrix_sdk_crypto/src/verification/requests.rs b/matrix_sdk_crypto/src/verification/requests.rs index 5bb1a444..3e8ebb6e 100644 --- a/matrix_sdk_crypto/src/verification/requests.rs +++ b/matrix_sdk_crypto/src/verification/requests.rs @@ -25,19 +25,20 @@ use matrix_sdk_common::{ key::verification::{ ready::{ReadyEventContent, ReadyToDeviceEventContent}, request::RequestToDeviceEventContent, - start::StartEventContent, + start::{StartEventContent, StartMethod, StartToDeviceEventContent}, Relation, VerificationMethod, }, room::message::KeyVerificationRequestEventContent, - AnyMessageEventContent, AnyToDeviceEventContent, MessageEvent, SyncMessageEvent, + AnyMessageEventContent, AnyToDeviceEventContent, }, identifiers::{DeviceId, DeviceIdBox, EventId, RoomId, UserId}, uuid::Uuid, + MilliSecondsSinceUnixEpoch, }; use super::{ - sas::{content_to_request, OutgoingContent, StartContent}, - FlowId, + sas::{content_to_request, OutgoingContent, StartContent as OwnedStartContent}, + FlowId, VerificationCache, }; use crate::{ olm::{PrivateCrossSigningIdentity, ReadOnlyAccount}, @@ -137,9 +138,74 @@ impl<'a> TryFrom<&'a OutgoingContent> for ReadyContent<'a> { } } +pub enum StartContent<'a> { + ToDevice(&'a StartToDeviceEventContent), + Room(&'a StartEventContent), +} + +impl<'a> StartContent<'a> { + pub fn from_device(&self) -> &DeviceId { + match self { + StartContent::ToDevice(c) => &c.from_device, + StartContent::Room(c) => &c.from_device, + } + } + + pub fn flow_id(&self) -> &str { + match self { + StartContent::ToDevice(c) => &c.transaction_id, + StartContent::Room(c) => &c.relation.event_id.as_str(), + } + } + + pub fn methods(&self) -> &StartMethod { + match self { + StartContent::ToDevice(c) => &c.method, + StartContent::Room(c) => &c.method, + } + } +} + +impl<'a> From<&'a StartEventContent> for StartContent<'a> { + fn from(c: &'a StartEventContent) -> Self { + Self::Room(c) + } +} + +impl<'a> From<&'a StartToDeviceEventContent> for StartContent<'a> { + fn from(c: &'a StartToDeviceEventContent) -> Self { + Self::ToDevice(c) + } +} + +impl<'a> TryFrom<&'a OutgoingContent> for StartContent<'a> { + type Error = (); + + fn try_from(value: &'a OutgoingContent) -> Result { + match value { + OutgoingContent::Room(_, c) => { + if let AnyMessageEventContent::KeyVerificationStart(c) = c { + Ok(StartContent::Room(c)) + } else { + Err(()) + } + } + OutgoingContent::ToDevice(c) => { + if let AnyToDeviceEventContent::KeyVerificationStart(c) = c { + Ok(StartContent::ToDevice(c)) + } else { + Err(()) + } + } + } + } +} + #[derive(Clone, Debug)] /// TODO pub struct VerificationRequest { + verification_cache: VerificationCache, + account: ReadOnlyAccount, flow_id: Arc, other_user_id: Arc, inner: Arc>, @@ -147,7 +213,8 @@ pub struct VerificationRequest { impl VerificationRequest { /// TODO - pub fn new( + pub(crate) fn new( + cache: VerificationCache, account: ReadOnlyAccount, private_cross_signing_identity: PrivateCrossSigningIdentity, store: Arc>, @@ -158,7 +225,7 @@ impl VerificationRequest { let flow_id = (room_id.to_owned(), event_id.to_owned()).into(); let inner = Mutex::new(InnerRequest::Created(RequestState::new( - account, + account.clone(), private_cross_signing_identity, store, other_user, @@ -166,7 +233,23 @@ impl VerificationRequest { ))) .into(); - Self { flow_id: flow_id.into(), inner, other_user_id: other_user.to_owned().into() } + Self { + account, + verification_cache: cache, + flow_id: flow_id.into(), + inner, + other_user_id: other_user.to_owned().into(), + } + } + + /// TODO + pub fn request_to_device(&self) -> RequestToDeviceEventContent { + RequestToDeviceEventContent::new( + self.account.device_id().into(), + self.flow_id().as_str().to_string(), + SUPPORTED_METHODS.to_vec(), + MilliSecondsSinceUnixEpoch::now(), + ) } /// TODO @@ -200,6 +283,7 @@ impl VerificationRequest { } pub(crate) fn from_room_request( + cache: VerificationCache, account: ReadOnlyAccount, private_cross_signing_identity: PrivateCrossSigningIdentity, store: Arc>, @@ -210,6 +294,7 @@ impl VerificationRequest { ) -> Self { let flow_id = FlowId::from((room_id.to_owned(), event_id.to_owned())); Self::from_helper( + cache, account, private_cross_signing_identity, store, @@ -220,6 +305,7 @@ impl VerificationRequest { } pub(crate) fn from_request( + cache: VerificationCache, account: ReadOnlyAccount, private_cross_signing_identity: PrivateCrossSigningIdentity, store: Arc>, @@ -228,6 +314,7 @@ impl VerificationRequest { ) -> Self { let flow_id = FlowId::from(content.transaction_id.to_owned()); Self::from_helper( + cache, account, private_cross_signing_identity, store, @@ -238,6 +325,7 @@ impl VerificationRequest { } fn from_helper( + cache: VerificationCache, account: ReadOnlyAccount, private_cross_signing_identity: PrivateCrossSigningIdentity, store: Arc>, @@ -246,14 +334,16 @@ impl VerificationRequest { content: RequestContent, ) -> Self { Self { + verification_cache: cache, inner: Arc::new(Mutex::new(InnerRequest::Requested(RequestState::from_request_event( - account, + account.clone(), private_cross_signing_identity, store, sender, &flow_id, content, )))), + account, other_user_id: sender.to_owned().into(), flow_id: flow_id.into(), } @@ -294,24 +384,21 @@ impl VerificationRequest { matches!(&*self.inner.lock().unwrap(), InnerRequest::Ready(_)) } - pub(crate) fn into_started_sas( + pub(crate) fn into_started_sas<'a>( self, - event: &SyncMessageEvent, + content: impl Into>, device: ReadOnlyDevice, user_identity: Option, ) -> Result { match &*self.inner.lock().unwrap() { - InnerRequest::Ready(s) => match &s.state.flow_id { - FlowId::ToDevice(_) => todo!(), - FlowId::InRoom(r, _) => s.clone().into_started_sas( - &event.clone().into_full_event(r.to_owned()), - s.store.clone(), - s.account.clone(), - s.private_cross_signing_identity.clone(), - device, - user_identity, - ), - }, + InnerRequest::Ready(s) => s.clone().into_started_sas( + content, + s.store.clone(), + s.account.clone(), + s.private_cross_signing_identity.clone(), + device, + user_identity, + ), // TODO cancel here since we got a missmatched message or do // nothing? _ => todo!(), @@ -322,18 +409,15 @@ impl VerificationRequest { &self, device: ReadOnlyDevice, user_identity: Option, - ) -> Option<(Sas, StartContent)> { + ) -> Option<(Sas, OutgoingContent)> { match &*self.inner.lock().unwrap() { - InnerRequest::Ready(s) => match &s.state.flow_id { - FlowId::ToDevice(_) => todo!(), - FlowId::InRoom(_, _) => Some(s.clone().start_sas( - s.store.clone(), - s.account.clone(), - s.private_cross_signing_identity.clone(), - device, - user_identity, - )), - }, + InnerRequest::Ready(s) => Some(s.clone().start_sas( + s.store.clone(), + s.account.clone(), + s.private_cross_signing_identity.clone(), + device, + user_identity, + )), _ => None, } } @@ -387,9 +471,9 @@ impl InnerRequest { } } - fn into_started_sas( + fn into_started_sas<'a>( self, - event: &MessageEvent, + content: impl Into>, store: Arc>, account: ReadOnlyAccount, private_identity: PrivateCrossSigningIdentity, @@ -398,7 +482,7 @@ impl InnerRequest { ) -> Result, OutgoingContent> { if let InnerRequest::Ready(s) = self { Ok(Some(s.into_started_sas( - event, + content, store, account, private_identity, @@ -559,21 +643,33 @@ struct Ready { } impl RequestState { - fn into_started_sas( + fn into_started_sas<'a>( self, - event: &MessageEvent, + content: impl Into>, store: Arc>, account: ReadOnlyAccount, private_identity: PrivateCrossSigningIdentity, other_device: ReadOnlyDevice, other_identity: Option, ) -> Result { + let content: OwnedStartContent = match content.into() { + StartContent::Room(c) => { + if let FlowId::InRoom(r, _) = &*self.flow_id { + (r.to_owned(), c.to_owned()).into() + } else { + // TODO cancel here + panic!("Missmatch between content and flow id"); + } + } + StartContent::ToDevice(c) => c.clone().into(), + }; + Sas::from_start_event( account, private_identity, other_device, store, - (event.room_id.clone(), event.content.clone()), + content, other_identity, ) } @@ -585,20 +681,31 @@ impl RequestState { private_identity: PrivateCrossSigningIdentity, other_device: ReadOnlyDevice, other_identity: Option, - ) -> (Sas, StartContent) { + ) -> (Sas, OutgoingContent) { match self.state.flow_id { FlowId::ToDevice(t) => { - Sas::start(account, private_identity, other_device, store, other_identity, Some(t)) + let (sas, content) = Sas::start( + account, + private_identity, + other_device, + store, + other_identity, + Some(t), + ); + (sas, content.into()) + } + FlowId::InRoom(r, e) => { + let (sas, content) = Sas::start_in_room( + e, + r, + account, + private_identity, + other_device, + store, + other_identity, + ); + (sas, content.into()) } - FlowId::InRoom(r, e) => Sas::start_in_room( - e, - r, - account, - private_identity, - other_device, - store, - other_identity, - ), } } } @@ -617,21 +724,14 @@ struct Passive { mod test { use std::convert::TryFrom; - use matrix_sdk_common::{ - events::{SyncMessageEvent, Unsigned}, - identifiers::{event_id, room_id, DeviceIdBox, UserId}, - MilliSecondsSinceUnixEpoch, - }; + use matrix_sdk_common::identifiers::{event_id, room_id, DeviceIdBox, UserId}; use matrix_sdk_test::async_test; - use super::VerificationRequest; + use super::{StartContent, VerificationRequest}; use crate::{ olm::{PrivateCrossSigningIdentity, ReadOnlyAccount}, store::{CryptoStore, MemoryStore}, - verification::{ - requests::ReadyContent, - sas::{OutgoingContent, StartContent}, - }, + verification::{requests::ReadyContent, sas::OutgoingContent, VerificationCache}, ReadOnlyDevice, }; @@ -667,6 +767,7 @@ mod test { let content = VerificationRequest::request(bob.user_id(), bob.device_id(), &alice_id()); let bob_request = VerificationRequest::new( + VerificationCache::new(), bob, bob_identity, bob_store.into(), @@ -676,6 +777,7 @@ mod test { ); let alice_request = VerificationRequest::from_room_request( + VerificationCache::new(), alice, alice_identity, alice_store.into(), @@ -713,6 +815,7 @@ mod test { let content = VerificationRequest::request(bob.user_id(), bob.device_id(), &alice_id()); let bob_request = VerificationRequest::new( + VerificationCache::new(), bob, bob_identity, bob_store.into(), @@ -722,6 +825,7 @@ mod test { ); let alice_request = VerificationRequest::from_room_request( + VerificationCache::new(), alice, alice_identity, alice_store.into(), @@ -741,19 +845,8 @@ mod test { let (bob_sas, start_content) = bob_request.start(alice_device, None).unwrap(); - let event = if let StartContent::Room(_, c) = start_content { - SyncMessageEvent { - content: c, - event_id: event_id.clone(), - sender: bob_id(), - origin_server_ts: MilliSecondsSinceUnixEpoch::now(), - unsigned: Unsigned::default(), - } - } else { - panic!("Invalid start event content type"); - }; - - let alice_sas = alice_request.into_started_sas(&event, bob_device, None).unwrap(); + let content = StartContent::try_from(&start_content).unwrap(); + let alice_sas = alice_request.into_started_sas(content, bob_device, None).unwrap(); assert!(!bob_sas.is_canceled()); assert!(!alice_sas.is_canceled());