use std::mem; use crate::Device; use olm_rs::sas::OlmSas; use matrix_sdk_common::events::{ key::verification::{ accept::AcceptEventContent, key::KeyEventContent, mac::MacEvent, start::{MSasV1Content, MSasV1ContentOptions, StartEventContent}, HashAlgorithm, KeyAgreementProtocol, MessageAuthenticationCode, ShortAuthenticationString, VerificationMethod, }, ToDeviceEvent, }; use matrix_sdk_common::identifiers::{DeviceId, UserId}; use matrix_sdk_common::uuid::Uuid; struct SasIds { own_user_id: UserId, own_device_id: Box, other_device: Device, } struct AcceptedProtocols { method: VerificationMethod, key_agreement_protocol: KeyAgreementProtocol, hash: HashAlgorithm, message_auth_code: MessageAuthenticationCode, short_auth_string: Vec, } impl From for AcceptedProtocols { fn from(content: AcceptEventContent) -> Self { Self { method: content.method, hash: content.hash, key_agreement_protocol: content.key_agreement_protocol, message_auth_code: content.message_authentication_code, short_auth_string: content.short_authentication_string.clone(), } } } struct Sas { inner: OlmSas, ids: SasIds, verification_flow_id: String, state: S, } impl Sas { pub fn user_id(&self) -> &UserId { &self.ids.own_user_id } } fn get_emoji(index: u8) -> (&'static str, &'static str) { match index { 0 => ("🐶", "Dog"), 1 => ("🐱", "Cat"), 2 => ("🦁", "Lion"), 3 => ("🐎", "Horse"), 4 => ("🦄", "Unicorn"), 5 => ("🐷", "Pig"), 6 => ("🐘", "Elephant"), 7 => ("🐰", "Rabbit"), 8 => ("🐼", "Panda"), 9 => ("🐓", "Rooster"), 10 => ("🐧", "Penguin"), 11 => ("🐢", "Turtle"), 12 => ("🐟", "Fish"), 13 => ("🐙", "Octopus"), 14 => ("🦋", "Butterfly"), 15 => ("🌷", "Flower"), 16 => ("🌳", "Tree"), 17 => ("🌵", "Cactus"), 18 => ("🍄", "Mushroom"), 19 => ("🌏", "Globe"), 20 => ("🌙", "Moon"), 21 => ("☁️", "Cloud"), 22 => ("🔥", "Fire"), 23 => ("🍌", "Banana"), 24 => ("🍎", "Apple"), 25 => ("🍓", "Strawberry"), 26 => ("🌽", "Corn"), 27 => ("🍕", "Pizza"), 28 => ("🎂", "Cake"), 29 => ("❤️", "Heart"), 30 => ("😀", "Smiley"), 31 => ("🤖", "Robot"), 32 => ("🎩", "Hat"), 33 => ("👓", "Glasses"), 34 => ("🔧", "Spanner"), 35 => ("🎅", "Santa"), 36 => ("👍", "Thumbs up"), 37 => ("☂️", "Umbrella"), 38 => ("⌛", "Hourglass"), 39 => ("⏰", "Clock"), 40 => ("🎁", "Gift"), 41 => ("💡", "Light Bulb"), 42 => ("📕", "Book"), 43 => ("✏️", "Pencil"), 44 => ("📎", "Paperclip"), 45 => ("✂️", "Scissors"), 46 => ("🔒", "Lock"), 47 => ("🔑", "Key"), 48 => ("🔨", "Hammer"), 49 => ("☎️", "Telephone"), 50 => ("🏁", "Flag"), 51 => ("🚂", "Train"), 52 => ("🚲", "Bicycle"), 53 => ("✈️", "Airplane"), 54 => ("🚀", "Rocket"), 55 => ("🏆", "Trophy"), 56 => ("⚽", "Ball"), 57 => ("🎸", "Guitar"), 58 => ("🎺", "Trumpet"), 59 => ("🔔", "Bell"), 60 => ("⚓", "Anchor"), 61 => ("🎧", "Headphones"), 62 => ("📁", "Folder"), 63 => ("📌", "Pin"), _ => panic!("Trying to fetch an SAS emoji outside the allowed range"), } } impl Sas { fn new(own_user_id: UserId, own_device_id: &DeviceId, other_device: Device) -> Sas { let verification_flow_id = Uuid::new_v4().to_string(); Sas { inner: OlmSas::new(), ids: SasIds { own_user_id, own_device_id: own_device_id.into(), other_device, }, verification_flow_id: verification_flow_id.clone(), state: Created { protocol_definitions: MSasV1ContentOptions { transaction_id: verification_flow_id, from_device: own_device_id.into(), short_authentication_string: vec![ ShortAuthenticationString::Decimal, ShortAuthenticationString::Emoji, ], key_agreement_protocols: vec![KeyAgreementProtocol::Curve25519HkdfSha256], message_authentication_codes: vec![MessageAuthenticationCode::HkdfHmacSha256], hashes: vec![HashAlgorithm::Sha256], }, }, } } fn get_start_event(&self) -> StartEventContent { StartEventContent::MSasV1( MSasV1Content::new(self.state.protocol_definitions.clone()) .expect("Invalid initial protocol definitions."), ) } fn into_accepted(self, event: &ToDeviceEvent) -> Sas { let content = &event.content; Sas { inner: self.inner, ids: self.ids, verification_flow_id: self.verification_flow_id, state: Accepted { commitment: content.commitment.clone(), accepted_protocols: content.clone().into(), }, } } } struct Created { protocol_definitions: MSasV1ContentOptions, } struct Started { protocol_definitions: MSasV1Content, } impl Sas { fn from_start_event( own_user_id: &UserId, own_device_id: &DeviceId, other_device: Device, event: &ToDeviceEvent, ) -> Sas { let content = if let StartEventContent::MSasV1(content) = &event.content { content } else { panic!("Invalid sas version") }; Sas { inner: OlmSas::new(), ids: SasIds { own_user_id: own_user_id.clone(), own_device_id: own_device_id.into(), other_device, }, verification_flow_id: content.transaction_id.clone(), state: Started { protocol_definitions: content.clone(), }, } } fn get_accept_content(&self) -> AcceptEventContent { AcceptEventContent { method: VerificationMethod::MSasV1, transaction_id: self.verification_flow_id.to_string(), commitment: "".to_owned(), hash: HashAlgorithm::Sha256, key_agreement_protocol: KeyAgreementProtocol::Curve25519HkdfSha256, message_authentication_code: MessageAuthenticationCode::HkdfHmacSha256, short_authentication_string: self .state .protocol_definitions .short_authentication_string .clone(), } } fn into_key_received(mut self, event: &mut ToDeviceEvent) -> Sas { let accepted_protocols: AcceptedProtocols = self.get_accept_content().into(); self.inner .set_their_public_key(&mem::take(&mut event.content.key)) .expect("Can't set public key"); Sas { inner: self.inner, ids: self.ids, verification_flow_id: self.verification_flow_id, state: KeyReceived { we_started: false, accepted_protocols, }, } } } struct Accepted { accepted_protocols: AcceptedProtocols, commitment: String, } impl Sas { fn into_key_received(mut self, event: &mut ToDeviceEvent) -> Sas { self.inner .set_their_public_key(&mem::take(&mut event.content.key)) .expect("Can't set public key"); Sas { inner: self.inner, ids: self.ids, verification_flow_id: self.verification_flow_id, state: KeyReceived { we_started: true, accepted_protocols: self.state.accepted_protocols, }, } } fn get_key_content(&self) -> KeyEventContent { KeyEventContent { transaction_id: self.verification_flow_id.to_string(), key: self.inner.public_key(), } } } struct KeyReceived { we_started: bool, accepted_protocols: AcceptedProtocols, } impl Sas { fn get_key_content(&self) -> KeyEventContent { KeyEventContent { transaction_id: self.verification_flow_id.to_string(), key: self.inner.public_key(), } } fn extra_info(&self) -> String { if self.state.we_started { format!( "MATRIX_KEY_VERIFICATION_SAS{first_user}{first_device}\ {second_user}{second_device}{transaction_id}", first_user = self.ids.own_user_id, first_device = self.ids.own_device_id, second_user = self.ids.other_device.user_id(), second_device = self.ids.other_device.device_id(), transaction_id = self.verification_flow_id, ) } else { format!( "MATRIX_KEY_VERIFICATION_SAS{first_user}{first_device}\ {second_user}{second_device}{transaction_id}", first_user = self.ids.other_device.user_id(), first_device = self.ids.other_device.device_id(), second_user = self.ids.own_user_id, second_device = self.ids.own_device_id, transaction_id = self.verification_flow_id, ) } } fn get_emoji(&self) -> Vec<(&'static str, &'static str)> { let bytes: Vec = self .inner .generate_bytes(&self.extra_info(), 6) .expect("Can't generate bytes") .into_iter() .map(|b| b as u64) .collect(); let mut num: u64 = bytes[0] << 40; num += bytes[1] << 32; num += bytes[2] << 24; num += bytes[3] << 16; num += bytes[4] << 8; num += bytes[5]; let numbers = vec![ ((num >> 42) & 63) as u8, ((num >> 36) & 63) as u8, ((num >> 30) & 63) as u8, ((num >> 24) & 63) as u8, ((num >> 18) & 63) as u8, ((num >> 12) & 63) as u8, ((num >> 6) & 63) as u8, ]; numbers.into_iter().map(get_emoji).collect() } fn get_decimal(&self) -> (u32, u32, u32) { let bytes: Vec = self .inner .generate_bytes(&self.extra_info(), 5) .expect("Can't generate bytes") .into_iter() .map(|b| b as u32) .collect(); let first = (bytes[0] << 5 | bytes[1] >> 3) + 1000; let second = ((bytes[1] & 0x7) << 10 | bytes[2] << 2 | bytes[3] >> 6) + 1000; let third = ((bytes[3] & 0x3F) << 7 | bytes[4] >> 1) + 1000; (first, second, third) } fn into_mac_received(self, event: &MacEvent) -> Sas { todo!() } fn confirm(self) -> Sas { todo!() } } struct Confirmed { accepted_protocols: AcceptedProtocols, } impl Sas { fn confirm(self) -> Sas { todo!() } } struct MacReceived { verified_devices: Vec, verified_master_keys: Vec, } impl Sas { fn into_done(self, event: &MacEvent) -> Sas { todo!() } } struct Done { verified_devices: Vec, verified_master_keys: Vec, } #[cfg(test)] mod test { use std::convert::TryFrom; use crate::{Account, Device}; use matrix_sdk_common::events::{EventContent, ToDeviceEvent}; use matrix_sdk_common::identifiers::{DeviceId, UserId}; use super::{Accepted, Created, Sas, Started}; fn alice_id() -> UserId { UserId::try_from("@alice:example.org").unwrap() } fn alice_device_id() -> Box { "JLAFKJWSCS".into() } fn bob_id() -> UserId { UserId::try_from("@bob:example.org").unwrap() } fn bob_device_id() -> Box { "BOBDEVCIE".into() } fn wrap_to_device_event(sender: &UserId, content: C) -> ToDeviceEvent { ToDeviceEvent { sender: sender.clone(), content, } } async fn get_sas_pair() -> (Sas, Sas) { let alice = Account::new(&alice_id(), &alice_device_id()); let alice_device = Device::from_account(&alice).await; let bob = Account::new(&bob_id(), &bob_device_id()); let bob_device = Device::from_account(&bob).await; let alice_sas = Sas::::new(alice_id(), &alice_device_id(), bob_device); let start_content = alice_sas.get_start_event(); let event = wrap_to_device_event(alice_sas.user_id(), start_content); let bob_sas = Sas::::from_start_event(bob.user_id(), bob.device_id(), alice_device, &event); (alice_sas, bob_sas) } #[tokio::test] async fn create_sas() { let (_, _) = get_sas_pair().await; } #[tokio::test] async fn sas_accept() { let (alice, bob) = get_sas_pair().await; let event = wrap_to_device_event(bob.user_id(), bob.get_accept_content()); alice.into_accepted(&event); } #[tokio::test] async fn sas_key_share() { let (alice, bob) = get_sas_pair().await; let event = wrap_to_device_event(bob.user_id(), bob.get_accept_content()); let alice: Sas = alice.into_accepted(&event); let mut event = wrap_to_device_event(alice.user_id(), alice.get_key_content()); let bob = bob.into_key_received(&mut event); let mut event = wrap_to_device_event(bob.user_id(), bob.get_key_content()); let alice = alice.into_key_received(&mut event); assert_eq!(alice.get_decimal(), bob.get_decimal()); assert_eq!(alice.get_emoji(), bob.get_emoji()); } }