From 6c85d3e28f5225dc311f5fdb392a9dc8a8c799f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Damir=20Jeli=C4=87?= Date: Tue, 11 Aug 2020 11:24:29 +0200 Subject: [PATCH] crypto: Use TryFrom to check the accepted SAS protocols. --- .../src/verification/sas/sas_state.rs | 76 ++++++++++--------- 1 file changed, 41 insertions(+), 35 deletions(-) diff --git a/matrix_sdk_crypto/src/verification/sas/sas_state.rs b/matrix_sdk_crypto/src/verification/sas/sas_state.rs index 7ee292c5..f53155d0 100644 --- a/matrix_sdk_crypto/src/verification/sas/sas_state.rs +++ b/matrix_sdk_crypto/src/verification/sas/sas_state.rs @@ -13,6 +13,7 @@ // limitations under the License. use std::{ + convert::TryFrom, mem, sync::{Arc, Mutex}, time::{Duration, Instant}, @@ -70,14 +71,29 @@ struct AcceptedProtocols { short_auth_string: Vec, } -impl From for AcceptedProtocols { - fn from(content: AcceptV1Content) -> Self { - Self { - method: VerificationMethod::MSasV1, - hash: content.hash, - key_agreement_protocol: content.key_agreement_protocol, - message_auth_code: content.message_authentication_code, - short_auth_string: content.short_authentication_string, +impl TryFrom for AcceptedProtocols { + type Error = CancelCode; + + fn try_from(content: AcceptV1Content) -> Result { + if !KEY_AGREEMENT_PROTOCOLS.contains(&content.key_agreement_protocol) + || !HASHES.contains(&content.hash) + || !MACS.contains(&content.message_authentication_code) + || (!content + .short_authentication_string + .contains(&ShortAuthenticationString::Emoji) + && !content + .short_authentication_string + .contains(&ShortAuthenticationString::Decimal)) + { + Err(CancelCode::UnknownMethod) + } else { + Ok(Self { + method: VerificationMethod::MSasV1, + hash: content.hash, + key_agreement_protocol: content.key_agreement_protocol, + message_auth_code: content.message_authentication_code, + short_auth_string: content.short_authentication_string, + }) } } } @@ -319,34 +335,24 @@ impl SasState { .map_err(|c| self.clone().cancel(c))?; if let AcceptMethod::MSasV1(content) = &event.content.method { - if !KEY_AGREEMENT_PROTOCOLS.contains(&content.key_agreement_protocol) - || !HASHES.contains(&content.hash) - || !MACS.contains(&content.message_authentication_code) - || (!content - .short_authentication_string - .contains(&ShortAuthenticationString::Emoji) - && !content - .short_authentication_string - .contains(&ShortAuthenticationString::Decimal)) - { - Err(self.cancel(CancelCode::UnknownMethod)) - } else { - let json_start_content = cjson::to_string(&self.as_content()) - .expect("Can't deserialize start event content"); + let accepted_protocols = + AcceptedProtocols::try_from(content.clone()).map_err(|c| self.clone().cancel(c))?; - Ok(SasState { - inner: self.inner, - ids: self.ids, - verification_flow_id: self.verification_flow_id, - creation_time: self.creation_time, - last_event_time: self.last_event_time, - state: Arc::new(Accepted { - json_start_content, - commitment: content.commitment.clone(), - accepted_protocols: Arc::new(content.clone().into()), - }), - }) - } + let json_start_content = cjson::to_string(&self.as_content()) + .expect("Can't deserialize start event content"); + + Ok(SasState { + inner: self.inner, + ids: self.ids, + verification_flow_id: self.verification_flow_id, + creation_time: self.creation_time, + last_event_time: self.last_event_time, + state: Arc::new(Accepted { + json_start_content, + commitment: content.commitment.clone(), + accepted_protocols: Arc::new(accepted_protocols), + }), + }) } else { Err(self.cancel(CancelCode::UnknownMethod)) }