crypto: Use TryFrom to check the accepted SAS protocols.

master
Damir Jelić 2020-08-11 11:24:29 +02:00
parent d5a853f3da
commit 6c85d3e28f
1 changed files with 41 additions and 35 deletions

View File

@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
use std::{ use std::{
convert::TryFrom,
mem, mem,
sync::{Arc, Mutex}, sync::{Arc, Mutex},
time::{Duration, Instant}, time::{Duration, Instant},
@ -70,14 +71,29 @@ struct AcceptedProtocols {
short_auth_string: Vec<ShortAuthenticationString>, short_auth_string: Vec<ShortAuthenticationString>,
} }
impl From<AcceptV1Content> for AcceptedProtocols { impl TryFrom<AcceptV1Content> for AcceptedProtocols {
fn from(content: AcceptV1Content) -> Self { type Error = CancelCode;
Self {
method: VerificationMethod::MSasV1, fn try_from(content: AcceptV1Content) -> Result<Self, Self::Error> {
hash: content.hash, if !KEY_AGREEMENT_PROTOCOLS.contains(&content.key_agreement_protocol)
key_agreement_protocol: content.key_agreement_protocol, || !HASHES.contains(&content.hash)
message_auth_code: content.message_authentication_code, || !MACS.contains(&content.message_authentication_code)
short_auth_string: content.short_authentication_string, || (!content
.short_authentication_string
.contains(&ShortAuthenticationString::Emoji)
&& !content
.short_authentication_string
.contains(&ShortAuthenticationString::Decimal))
{
Err(CancelCode::UnknownMethod)
} else {
Ok(Self {
method: VerificationMethod::MSasV1,
hash: content.hash,
key_agreement_protocol: content.key_agreement_protocol,
message_auth_code: content.message_authentication_code,
short_auth_string: content.short_authentication_string,
})
} }
} }
} }
@ -319,34 +335,24 @@ impl SasState<Created> {
.map_err(|c| self.clone().cancel(c))?; .map_err(|c| self.clone().cancel(c))?;
if let AcceptMethod::MSasV1(content) = &event.content.method { if let AcceptMethod::MSasV1(content) = &event.content.method {
if !KEY_AGREEMENT_PROTOCOLS.contains(&content.key_agreement_protocol) let accepted_protocols =
|| !HASHES.contains(&content.hash) AcceptedProtocols::try_from(content.clone()).map_err(|c| self.clone().cancel(c))?;
|| !MACS.contains(&content.message_authentication_code)
|| (!content
.short_authentication_string
.contains(&ShortAuthenticationString::Emoji)
&& !content
.short_authentication_string
.contains(&ShortAuthenticationString::Decimal))
{
Err(self.cancel(CancelCode::UnknownMethod))
} else {
let json_start_content = cjson::to_string(&self.as_content())
.expect("Can't deserialize start event content");
Ok(SasState { let json_start_content = cjson::to_string(&self.as_content())
inner: self.inner, .expect("Can't deserialize start event content");
ids: self.ids,
verification_flow_id: self.verification_flow_id, Ok(SasState {
creation_time: self.creation_time, inner: self.inner,
last_event_time: self.last_event_time, ids: self.ids,
state: Arc::new(Accepted { verification_flow_id: self.verification_flow_id,
json_start_content, creation_time: self.creation_time,
commitment: content.commitment.clone(), last_event_time: self.last_event_time,
accepted_protocols: Arc::new(content.clone().into()), state: Arc::new(Accepted {
}), json_start_content,
}) commitment: content.commitment.clone(),
} accepted_protocols: Arc::new(accepted_protocols),
}),
})
} else { } else {
Err(self.cancel(CancelCode::UnknownMethod)) Err(self.cancel(CancelCode::UnknownMethod))
} }