diff --git a/matrix_sdk_crypto/src/verification/sas.rs b/matrix_sdk_crypto/src/verification/sas.rs index 134ab1bc..2c97d8f2 100644 --- a/matrix_sdk_crypto/src/verification/sas.rs +++ b/matrix_sdk_crypto/src/verification/sas.rs @@ -1,13 +1,15 @@ +use std::mem; + use crate::Device; use olm_rs::sas::OlmSas; use matrix_sdk_common::events::{ key::verification::{ - accept::AcceptEvent, - key::KeyEvent, + accept::AcceptEventContent, + key::KeyEventContent, mac::MacEvent, - start::{MSasV1Content, MSasV1ContentOptions, StartEvent, StartEventContent}, + start::{MSasV1Content, MSasV1ContentOptions, StartEventContent}, HashAlgorithm, KeyAgreementProtocol, MessageAuthenticationCode, ShortAuthenticationString, VerificationMethod, }, @@ -30,10 +32,22 @@ struct AcceptedProtocols { short_auth_string: Vec, } +impl From for AcceptedProtocols { + fn from(content: AcceptEventContent) -> Self { + Self { + method: content.method, + hash: content.hash, + key_agreement_protocol: content.key_agreement_protocol, + message_auth_code: content.message_authentication_code, + short_auth_string: content.short_authentication_string.clone(), + } + } +} + struct Sas { inner: OlmSas, ids: SasIds, - verification_flow_id: Uuid, + verification_flow_id: String, state: S, } @@ -45,7 +59,7 @@ impl Sas { impl Sas { fn new(own_user_id: UserId, own_device_id: &DeviceId, other_device: Device) -> Sas { - let verification_flow_id = Uuid::new_v4(); + let verification_flow_id = Uuid::new_v4().to_string(); Sas { inner: OlmSas::new(), @@ -54,11 +68,11 @@ impl Sas { own_device_id: own_device_id.into(), other_device, }, - verification_flow_id, + verification_flow_id: verification_flow_id.clone(), state: Created { protocol_definitions: MSasV1ContentOptions { - transaction_id: verification_flow_id.to_string(), + transaction_id: verification_flow_id, from_device: own_device_id.into(), short_authentication_string: vec![ ShortAuthenticationString::Decimal, @@ -79,7 +93,7 @@ impl Sas { ) } - fn into_accepted(self, event: &AcceptEvent) -> Sas { + fn into_accepted(self, event: &ToDeviceEvent) -> Sas { let content = &event.content; Sas { @@ -88,13 +102,7 @@ impl Sas { verification_flow_id: self.verification_flow_id, state: Accepted { commitment: content.commitment.clone(), - accepted_protocols: AcceptedProtocols { - method: content.method, - hash: content.hash, - key_agreement_protocol: content.key_agreement_protocol, - message_auth_code: content.message_authentication_code, - short_auth_string: content.short_authentication_string.clone(), - }, + accepted_protocols: content.clone().into(), }, } } @@ -130,7 +138,7 @@ impl Sas { other_device, }, - verification_flow_id: Uuid::new_v4(), + verification_flow_id: content.transaction_id.clone(), state: Started { protocol_definitions: content.clone(), @@ -138,8 +146,37 @@ impl Sas { } } - fn into_key_received(self, event: &KeyEvent) -> Sas { - todo!() + fn get_accept_content(&self) -> AcceptEventContent { + AcceptEventContent { + method: VerificationMethod::MSasV1, + transaction_id: self.verification_flow_id.to_string(), + commitment: "".to_owned(), + hash: HashAlgorithm::Sha256, + key_agreement_protocol: KeyAgreementProtocol::Curve25519HkdfSha256, + message_authentication_code: MessageAuthenticationCode::HkdfHmacSha256, + short_authentication_string: self + .state + .protocol_definitions + .short_authentication_string + .clone(), + } + } + + fn into_key_received(mut self, event: &mut ToDeviceEvent) -> Sas { + let accepted_protocols: AcceptedProtocols = self.get_accept_content().into(); + self.inner + .set_their_public_key(&mem::take(&mut event.content.key)) + .expect("Can't set public key"); + + Sas { + inner: self.inner, + ids: self.ids, + verification_flow_id: self.verification_flow_id, + state: KeyReceived { + we_started: false, + accepted_protocols, + }, + } } } @@ -149,16 +186,87 @@ struct Accepted { } impl Sas { - fn into_key_received(self, event: &KeyEvent) -> Sas { - todo!() + fn into_key_received(mut self, event: &mut ToDeviceEvent) -> Sas { + self.inner + .set_their_public_key(&mem::take(&mut event.content.key)) + .expect("Can't set public key"); + + Sas { + inner: self.inner, + ids: self.ids, + verification_flow_id: self.verification_flow_id, + state: KeyReceived { + we_started: true, + accepted_protocols: self.state.accepted_protocols, + }, + } + } + + fn get_key_content(&self) -> KeyEventContent { + KeyEventContent { + transaction_id: self.verification_flow_id.to_string(), + key: self.inner.public_key(), + } } } struct KeyReceived { + we_started: bool, accepted_protocols: AcceptedProtocols, } impl Sas { + fn get_key_content(&self) -> KeyEventContent { + KeyEventContent { + transaction_id: self.verification_flow_id.to_string(), + key: self.inner.public_key(), + } + } + + fn extra_info(&self) -> String { + if self.state.we_started { + format!( + "MATRIX_KEY_VERIFICATION_SAS{first_user}{first_device}\ + {second_user}{second_device}{transaction_id}", + first_user = self.ids.own_user_id, + first_device = self.ids.own_device_id, + second_user = self.ids.other_device.user_id(), + second_device = self.ids.other_device.device_id(), + transaction_id = self.verification_flow_id, + ) + } else { + format!( + "MATRIX_KEY_VERIFICATION_SAS{first_user}{first_device}\ + {second_user}{second_device}{transaction_id}", + first_user = self.ids.other_device.user_id(), + first_device = self.ids.other_device.device_id(), + second_user = self.ids.own_user_id, + second_device = self.ids.own_device_id, + transaction_id = self.verification_flow_id, + ) + } + } + + fn get_emoji(&self) -> Vec<(String, String)> { + todo!() + } + + fn get_decimal(&self) -> (i32, i32, i32) { + let bytes: Vec = self + .inner + .generate_bytes(&self.extra_info(), 5) + .expect("Can't generate bytes") + .into_iter() + .map(|b| b as i32) + .collect(); + + let first = (bytes[0] << 5 | bytes[1] >> 3) + 1000; + let second = ((bytes[1] & 0x7) << 10 | bytes[2] << 2 | bytes[3] >> 6) + 1000; + let third = ((bytes[3] & 0x3F) << 7 | bytes[4] >> 1) + 1000; + + (first, second, third) + } + fn into_mac_received(self, event: &MacEvent) -> Sas { todo!() } @@ -199,16 +307,10 @@ mod test { use std::convert::TryFrom; use crate::{Account, Device}; - use matrix_sdk_common::events::key::verification::{ - accept::AcceptEvent, - key::KeyEvent, - mac::MacEvent, - start::{MSasV1Content, MSasV1ContentOptions, StartEvent, StartEventContent}, - }; - use matrix_sdk_common::events::{AnyToDeviceEvent, ToDeviceEvent}; + use matrix_sdk_common::events::{EventContent, ToDeviceEvent}; use matrix_sdk_common::identifiers::{DeviceId, UserId}; - use super::{Created, Sas, Started}; + use super::{Accepted, Created, Sas, Started}; fn alice_id() -> UserId { UserId::try_from("@alice:example.org").unwrap() @@ -226,18 +328,14 @@ mod test { "BOBDEVCIE".into() } - fn wrap_start_event( - sender: &UserId, - content: StartEventContent, - ) -> ToDeviceEvent { + fn wrap_to_device_event(sender: &UserId, content: C) -> ToDeviceEvent { ToDeviceEvent { sender: sender.clone(), content, } } - #[tokio::test] - async fn create_sas() { + async fn get_sas_pair() -> (Sas, Sas) { let alice = Account::new(&alice_id(), &alice_device_id()); let alice_device = Device::from_account(&alice).await; @@ -247,9 +345,43 @@ mod test { let alice_sas = Sas::::new(alice_id(), &alice_device_id(), bob_device); let start_content = alice_sas.get_start_event(); - let event = wrap_start_event(alice_sas.user_id(), start_content); + let event = wrap_to_device_event(alice_sas.user_id(), start_content); let bob_sas = Sas::::from_start_event(bob.user_id(), bob.device_id(), alice_device, &event); + + (alice_sas, bob_sas) + } + + #[tokio::test] + async fn create_sas() { + let (_, _) = get_sas_pair().await; + } + + #[tokio::test] + async fn sas_accept() { + let (alice, bob) = get_sas_pair().await; + + let event = wrap_to_device_event(bob.user_id(), bob.get_accept_content()); + + alice.into_accepted(&event); + } + + #[tokio::test] + async fn sas_key_share() { + let (alice, bob) = get_sas_pair().await; + + let event = wrap_to_device_event(bob.user_id(), bob.get_accept_content()); + + let alice: Sas = alice.into_accepted(&event); + let mut event = wrap_to_device_event(alice.user_id(), alice.get_key_content()); + + let bob = bob.into_key_received(&mut event); + + let mut event = wrap_to_device_event(bob.user_id(), bob.get_key_content()); + + let alice = alice.into_key_received(&mut event); + + assert_eq!(alice.get_decimal(), bob.get_decimal()); } }