crypto: Use TryFrom to check the accepted SAS protocols.

This commit is contained in:
Damir Jelić 2020-08-11 11:24:29 +02:00
parent d5a853f3da
commit 6c85d3e28f

View file

@ -13,6 +13,7 @@
// limitations under the License.
use std::{
convert::TryFrom,
mem,
sync::{Arc, Mutex},
time::{Duration, Instant},
@ -70,14 +71,29 @@ struct AcceptedProtocols {
short_auth_string: Vec<ShortAuthenticationString>,
}
impl From<AcceptV1Content> for AcceptedProtocols {
fn from(content: AcceptV1Content) -> Self {
Self {
method: VerificationMethod::MSasV1,
hash: content.hash,
key_agreement_protocol: content.key_agreement_protocol,
message_auth_code: content.message_authentication_code,
short_auth_string: content.short_authentication_string,
impl TryFrom<AcceptV1Content> for AcceptedProtocols {
type Error = CancelCode;
fn try_from(content: AcceptV1Content) -> Result<Self, Self::Error> {
if !KEY_AGREEMENT_PROTOCOLS.contains(&content.key_agreement_protocol)
|| !HASHES.contains(&content.hash)
|| !MACS.contains(&content.message_authentication_code)
|| (!content
.short_authentication_string
.contains(&ShortAuthenticationString::Emoji)
&& !content
.short_authentication_string
.contains(&ShortAuthenticationString::Decimal))
{
Err(CancelCode::UnknownMethod)
} else {
Ok(Self {
method: VerificationMethod::MSasV1,
hash: content.hash,
key_agreement_protocol: content.key_agreement_protocol,
message_auth_code: content.message_authentication_code,
short_auth_string: content.short_authentication_string,
})
}
}
}
@ -319,34 +335,24 @@ impl SasState<Created> {
.map_err(|c| self.clone().cancel(c))?;
if let AcceptMethod::MSasV1(content) = &event.content.method {
if !KEY_AGREEMENT_PROTOCOLS.contains(&content.key_agreement_protocol)
|| !HASHES.contains(&content.hash)
|| !MACS.contains(&content.message_authentication_code)
|| (!content
.short_authentication_string
.contains(&ShortAuthenticationString::Emoji)
&& !content
.short_authentication_string
.contains(&ShortAuthenticationString::Decimal))
{
Err(self.cancel(CancelCode::UnknownMethod))
} else {
let json_start_content = cjson::to_string(&self.as_content())
.expect("Can't deserialize start event content");
let accepted_protocols =
AcceptedProtocols::try_from(content.clone()).map_err(|c| self.clone().cancel(c))?;
Ok(SasState {
inner: self.inner,
ids: self.ids,
verification_flow_id: self.verification_flow_id,
creation_time: self.creation_time,
last_event_time: self.last_event_time,
state: Arc::new(Accepted {
json_start_content,
commitment: content.commitment.clone(),
accepted_protocols: Arc::new(content.clone().into()),
}),
})
}
let json_start_content = cjson::to_string(&self.as_content())
.expect("Can't deserialize start event content");
Ok(SasState {
inner: self.inner,
ids: self.ids,
verification_flow_id: self.verification_flow_id,
creation_time: self.creation_time,
last_event_time: self.last_event_time,
state: Arc::new(Accepted {
json_start_content,
commitment: content.commitment.clone(),
accepted_protocols: Arc::new(accepted_protocols),
}),
})
} else {
Err(self.cancel(CancelCode::UnknownMethod))
}