crypto: Use TryFrom to check the accepted SAS protocols.

master
Damir Jelić 2020-08-11 11:24:29 +02:00
parent d5a853f3da
commit 6c85d3e28f
1 changed files with 41 additions and 35 deletions

View File

@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
use std::{ use std::{
convert::TryFrom,
mem, mem,
sync::{Arc, Mutex}, sync::{Arc, Mutex},
time::{Duration, Instant}, time::{Duration, Instant},
@ -70,14 +71,29 @@ struct AcceptedProtocols {
short_auth_string: Vec<ShortAuthenticationString>, short_auth_string: Vec<ShortAuthenticationString>,
} }
impl From<AcceptV1Content> for AcceptedProtocols { impl TryFrom<AcceptV1Content> for AcceptedProtocols {
fn from(content: AcceptV1Content) -> Self { type Error = CancelCode;
Self {
fn try_from(content: AcceptV1Content) -> Result<Self, Self::Error> {
if !KEY_AGREEMENT_PROTOCOLS.contains(&content.key_agreement_protocol)
|| !HASHES.contains(&content.hash)
|| !MACS.contains(&content.message_authentication_code)
|| (!content
.short_authentication_string
.contains(&ShortAuthenticationString::Emoji)
&& !content
.short_authentication_string
.contains(&ShortAuthenticationString::Decimal))
{
Err(CancelCode::UnknownMethod)
} else {
Ok(Self {
method: VerificationMethod::MSasV1, method: VerificationMethod::MSasV1,
hash: content.hash, hash: content.hash,
key_agreement_protocol: content.key_agreement_protocol, key_agreement_protocol: content.key_agreement_protocol,
message_auth_code: content.message_authentication_code, message_auth_code: content.message_authentication_code,
short_auth_string: content.short_authentication_string, short_auth_string: content.short_authentication_string,
})
} }
} }
} }
@ -319,18 +335,9 @@ impl SasState<Created> {
.map_err(|c| self.clone().cancel(c))?; .map_err(|c| self.clone().cancel(c))?;
if let AcceptMethod::MSasV1(content) = &event.content.method { if let AcceptMethod::MSasV1(content) = &event.content.method {
if !KEY_AGREEMENT_PROTOCOLS.contains(&content.key_agreement_protocol) let accepted_protocols =
|| !HASHES.contains(&content.hash) AcceptedProtocols::try_from(content.clone()).map_err(|c| self.clone().cancel(c))?;
|| !MACS.contains(&content.message_authentication_code)
|| (!content
.short_authentication_string
.contains(&ShortAuthenticationString::Emoji)
&& !content
.short_authentication_string
.contains(&ShortAuthenticationString::Decimal))
{
Err(self.cancel(CancelCode::UnknownMethod))
} else {
let json_start_content = cjson::to_string(&self.as_content()) let json_start_content = cjson::to_string(&self.as_content())
.expect("Can't deserialize start event content"); .expect("Can't deserialize start event content");
@ -343,10 +350,9 @@ impl SasState<Created> {
state: Arc::new(Accepted { state: Arc::new(Accepted {
json_start_content, json_start_content,
commitment: content.commitment.clone(), commitment: content.commitment.clone(),
accepted_protocols: Arc::new(content.clone().into()), accepted_protocols: Arc::new(accepted_protocols),
}), }),
}) })
}
} else { } else {
Err(self.cancel(CancelCode::UnknownMethod)) Err(self.cancel(CancelCode::UnknownMethod))
} }