diff --git a/matrix_sdk/src/sas.rs b/matrix_sdk/src/sas.rs index 2f995ef9..3cff3faa 100644 --- a/matrix_sdk/src/sas.rs +++ b/matrix_sdk/src/sas.rs @@ -63,7 +63,7 @@ impl Sas { } /// Get the decimal version of the short auth string. - pub fn decimals(&self) -> Option<(u32, u32, u32)> { + pub fn decimals(&self) -> Option<(u16, u16, u16)> { self.inner.decimals() } diff --git a/matrix_sdk_crypto/Cargo.toml b/matrix_sdk_crypto/Cargo.toml index ea1fc330..509a3ee3 100644 --- a/matrix_sdk_crypto/Cargo.toml +++ b/matrix_sdk_crypto/Cargo.toml @@ -45,8 +45,9 @@ default-features = false features = ["runtime-tokio", "sqlite"] [dev-dependencies] -tokio = { version = "0.2.21", features = ["rt-threaded", "macros"] } -serde_json = "1.0.56" +tokio = { version = "0.2.22", features = ["rt-threaded", "macros"] } +proptest = "0.10.0" +serde_json = "1.0.57" tempfile = "3.1.0" http = "0.2.1" matrix-sdk-test = { version = "0.1.0", path = "../matrix_sdk_test" } diff --git a/matrix_sdk_crypto/src/verification/sas/helpers.rs b/matrix_sdk_crypto/src/verification/sas/helpers.rs index 066d51f3..7934901b 100644 --- a/matrix_sdk_crypto/src/verification/sas/helpers.rs +++ b/matrix_sdk_crypto/src/verification/sas/helpers.rs @@ -318,13 +318,15 @@ pub fn get_emoji( flow_id: &str, we_started: bool, ) -> Vec<(&'static str, &'static str)> { - let bytes: Vec = sas + let bytes = sas .generate_bytes(&extra_info_sas(&ids, &flow_id, we_started), 6) - .expect("Can't generate bytes") - .into_iter() - .map(|b| b as u64) - .collect(); + .expect("Can't generate bytes"); + bytes_to_emoji(bytes) +} + +fn bytes_to_emoji_index(bytes: Vec) -> Vec { + let bytes: Vec = bytes.iter().map(|b| *b as u64).collect(); // Join the 6 bytes into one 64 bit unsigned int. This u64 will contain 48 // bits from our 6 bytes. let mut num: u64 = bytes[0] << 40; @@ -336,7 +338,7 @@ pub fn get_emoji( // Take the top 42 bits of our 48 bits from the u64 and convert each 6 bits // into a 6 bit number. - let numbers = vec![ + vec![ ((num >> 42) & 63) as u8, ((num >> 36) & 63) as u8, ((num >> 30) & 63) as u8, @@ -344,7 +346,11 @@ pub fn get_emoji( ((num >> 18) & 63) as u8, ((num >> 12) & 63) as u8, ((num >> 6) & 63) as u8, - ]; + ] +} + +fn bytes_to_emoji(bytes: Vec) -> Vec<(&'static str, &'static str)> { + let numbers = bytes_to_emoji_index(bytes); // Convert the 6 bit number into a emoji/description tuple. numbers.into_iter().map(emoji_from_index).collect() @@ -369,13 +375,16 @@ pub fn get_emoji( /// # Panics /// /// This will panic if the public key of the other side wasn't set. -pub fn get_decimal(sas: &OlmSas, ids: &SasIds, flow_id: &str, we_started: bool) -> (u32, u32, u32) { - let bytes: Vec = sas +pub fn get_decimal(sas: &OlmSas, ids: &SasIds, flow_id: &str, we_started: bool) -> (u16, u16, u16) { + let bytes = sas .generate_bytes(&extra_info_sas(&ids, &flow_id, we_started), 5) - .expect("Can't generate bytes") - .into_iter() - .map(|b| b as u32) - .collect(); + .expect("Can't generate bytes"); + + bytes_to_decimal(bytes) +} + +fn bytes_to_decimal(bytes: Vec) -> (u16, u16, u16) { + let bytes: Vec = bytes.into_iter().map(|b| b as u16).collect(); // This bitwise operation is taken from the [spec] // [spec]: https://matrix.org/docs/spec/client_server/latest#sas-method-decimal @@ -415,3 +424,62 @@ pub fn content_to_request( messages, } } + +#[cfg(test)] +mod test { + use proptest::prelude::*; + + use super::{bytes_to_decimal, bytes_to_emoji, bytes_to_emoji_index, emoji_from_index}; + + #[test] + fn test_emoji_generation() { + let bytes = vec![0, 0, 0, 0, 0, 0]; + let index: Vec<(&'static str, &'static str)> = vec![0, 0, 0, 0, 0, 0, 0] + .into_iter() + .map(emoji_from_index) + .collect(); + assert_eq!(bytes_to_emoji(bytes), index); + + let bytes = vec![0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF]; + + let index: Vec<(&'static str, &'static str)> = vec![63, 63, 63, 63, 63, 63, 63] + .into_iter() + .map(emoji_from_index) + .collect(); + assert_eq!(bytes_to_emoji(bytes), index); + } + + #[test] + fn test_decimal_generation() { + let bytes = vec![0, 0, 0, 0, 0]; + let result = bytes_to_decimal(bytes); + + assert_eq!(result, (1000, 1000, 1000)); + + let bytes = vec![0xFF, 0xFF, 0xFF, 0xFF, 0xFF]; + let result = bytes_to_decimal(bytes); + assert_eq!(result, (9191, 9191, 9191)); + } + + proptest! { + #[test] + fn proptest_emoji(bytes in prop::array::uniform6(0u8..)) { + let numbers = bytes_to_emoji_index(bytes.to_vec()); + + for number in numbers { + prop_assert!(number < 64); + } + } + } + + proptest! { + #[test] + fn proptest_decimals(bytes in prop::array::uniform5(0u8..)) { + let (first, second, third) = bytes_to_decimal(bytes.to_vec()); + + prop_assert!(first <= 9191 && first >= 1000); + prop_assert!(second <= 9191 && second >= 1000); + prop_assert!(third <= 9191 && third >= 1000); + } + } +} diff --git a/matrix_sdk_crypto/src/verification/sas/mod.rs b/matrix_sdk_crypto/src/verification/sas/mod.rs index 6ab639b6..ae6442bf 100644 --- a/matrix_sdk_crypto/src/verification/sas/mod.rs +++ b/matrix_sdk_crypto/src/verification/sas/mod.rs @@ -257,7 +257,7 @@ impl Sas { /// Returns None if we can't yet present the short auth string, otherwise a /// tuple containing three 4-digit integers that represent the short auth /// string. - pub fn decimals(&self) -> Option<(u32, u32, u32)> { + pub fn decimals(&self) -> Option<(u16, u16, u16)> { self.inner.lock().unwrap().decimals() } @@ -464,7 +464,7 @@ impl InnerSas { } } - fn decimals(&self) -> Option<(u32, u32, u32)> { + fn decimals(&self) -> Option<(u16, u16, u16)> { match self { InnerSas::KeyRecieved(s) => Some(s.get_decimal()), InnerSas::MacReceived(s) => Some(s.get_decimal()), diff --git a/matrix_sdk_crypto/src/verification/sas/sas_state.rs b/matrix_sdk_crypto/src/verification/sas/sas_state.rs index 9b638bf0..4a382c6c 100644 --- a/matrix_sdk_crypto/src/verification/sas/sas_state.rs +++ b/matrix_sdk_crypto/src/verification/sas/sas_state.rs @@ -530,7 +530,7 @@ impl SasState { /// /// Returns a tuple containing three 4 digit integer numbers that represent /// the short auth string. - pub fn get_decimal(&self) -> (u32, u32, u32) { + pub fn get_decimal(&self) -> (u16, u16, u16) { get_decimal( &self.inner.lock().unwrap(), &self.ids, @@ -667,7 +667,7 @@ impl SasState { /// /// Returns a tuple containing three 4 digit integer numbers that represent /// the short auth string. - pub fn get_decimal(&self) -> (u32, u32, u32) { + pub fn get_decimal(&self) -> (u16, u16, u16) { get_decimal( &self.inner.lock().unwrap(), &self.ids,