crypto: Use TryFrom to check the accepted SAS protocols.
parent
d5a853f3da
commit
6c85d3e28f
|
@ -13,6 +13,7 @@
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use std::{
|
use std::{
|
||||||
|
convert::TryFrom,
|
||||||
mem,
|
mem,
|
||||||
sync::{Arc, Mutex},
|
sync::{Arc, Mutex},
|
||||||
time::{Duration, Instant},
|
time::{Duration, Instant},
|
||||||
|
@ -70,14 +71,29 @@ struct AcceptedProtocols {
|
||||||
short_auth_string: Vec<ShortAuthenticationString>,
|
short_auth_string: Vec<ShortAuthenticationString>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<AcceptV1Content> for AcceptedProtocols {
|
impl TryFrom<AcceptV1Content> for AcceptedProtocols {
|
||||||
fn from(content: AcceptV1Content) -> Self {
|
type Error = CancelCode;
|
||||||
Self {
|
|
||||||
|
fn try_from(content: AcceptV1Content) -> Result<Self, Self::Error> {
|
||||||
|
if !KEY_AGREEMENT_PROTOCOLS.contains(&content.key_agreement_protocol)
|
||||||
|
|| !HASHES.contains(&content.hash)
|
||||||
|
|| !MACS.contains(&content.message_authentication_code)
|
||||||
|
|| (!content
|
||||||
|
.short_authentication_string
|
||||||
|
.contains(&ShortAuthenticationString::Emoji)
|
||||||
|
&& !content
|
||||||
|
.short_authentication_string
|
||||||
|
.contains(&ShortAuthenticationString::Decimal))
|
||||||
|
{
|
||||||
|
Err(CancelCode::UnknownMethod)
|
||||||
|
} else {
|
||||||
|
Ok(Self {
|
||||||
method: VerificationMethod::MSasV1,
|
method: VerificationMethod::MSasV1,
|
||||||
hash: content.hash,
|
hash: content.hash,
|
||||||
key_agreement_protocol: content.key_agreement_protocol,
|
key_agreement_protocol: content.key_agreement_protocol,
|
||||||
message_auth_code: content.message_authentication_code,
|
message_auth_code: content.message_authentication_code,
|
||||||
short_auth_string: content.short_authentication_string,
|
short_auth_string: content.short_authentication_string,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -319,18 +335,9 @@ impl SasState<Created> {
|
||||||
.map_err(|c| self.clone().cancel(c))?;
|
.map_err(|c| self.clone().cancel(c))?;
|
||||||
|
|
||||||
if let AcceptMethod::MSasV1(content) = &event.content.method {
|
if let AcceptMethod::MSasV1(content) = &event.content.method {
|
||||||
if !KEY_AGREEMENT_PROTOCOLS.contains(&content.key_agreement_protocol)
|
let accepted_protocols =
|
||||||
|| !HASHES.contains(&content.hash)
|
AcceptedProtocols::try_from(content.clone()).map_err(|c| self.clone().cancel(c))?;
|
||||||
|| !MACS.contains(&content.message_authentication_code)
|
|
||||||
|| (!content
|
|
||||||
.short_authentication_string
|
|
||||||
.contains(&ShortAuthenticationString::Emoji)
|
|
||||||
&& !content
|
|
||||||
.short_authentication_string
|
|
||||||
.contains(&ShortAuthenticationString::Decimal))
|
|
||||||
{
|
|
||||||
Err(self.cancel(CancelCode::UnknownMethod))
|
|
||||||
} else {
|
|
||||||
let json_start_content = cjson::to_string(&self.as_content())
|
let json_start_content = cjson::to_string(&self.as_content())
|
||||||
.expect("Can't deserialize start event content");
|
.expect("Can't deserialize start event content");
|
||||||
|
|
||||||
|
@ -343,10 +350,9 @@ impl SasState<Created> {
|
||||||
state: Arc::new(Accepted {
|
state: Arc::new(Accepted {
|
||||||
json_start_content,
|
json_start_content,
|
||||||
commitment: content.commitment.clone(),
|
commitment: content.commitment.clone(),
|
||||||
accepted_protocols: Arc::new(content.clone().into()),
|
accepted_protocols: Arc::new(accepted_protocols),
|
||||||
}),
|
}),
|
||||||
})
|
})
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
Err(self.cancel(CancelCode::UnknownMethod))
|
Err(self.cancel(CancelCode::UnknownMethod))
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue