crypto: Correctly remember our chosen SAS methods

master
Damir Jelić 2021-07-19 09:43:35 +02:00
parent ff8089912e
commit cf30c42563
3 changed files with 20 additions and 29 deletions

View File

@ -165,9 +165,12 @@ impl InnerSas {
}
}
pub fn accept(self) -> Option<(InnerSas, OwnedAcceptContent)> {
pub fn accept(
self,
methods: Vec<ShortAuthenticationString>,
) -> Option<(InnerSas, OwnedAcceptContent)> {
if let InnerSas::Started(s) = self {
let sas = s.into_accepted();
let sas = s.into_accepted(methods);
let content = sas.as_content();
Some((InnerSas::WeAccepted(sas), content))
} else {

View File

@ -25,11 +25,7 @@ use matrix_sdk_common::uuid::Uuid;
use ruma::{
api::client::r0::keys::upload_signatures::Request as SignatureUploadRequest,
events::{
key::verification::{
accept::{AcceptEventContent, AcceptMethod, AcceptToDeviceEventContent},
cancel::CancelCode,
ShortAuthenticationString,
},
key::verification::{cancel::CancelCode, ShortAuthenticationString},
AnyMessageEventContent, AnyToDeviceEventContent,
},
DeviceId, EventId, RoomId, UserId,
@ -327,10 +323,10 @@ impl Sas {
) -> Option<OutgoingVerificationRequest> {
let mut guard = self.inner.lock().unwrap();
let sas: InnerSas = (*guard).clone();
let methods = settings.allowed_methods;
if let Some((sas, content)) = sas.accept() {
if let Some((sas, content)) = sas.accept(methods) {
*guard = sas;
let content = settings.apply(content);
Some(match content {
OwnedAcceptContent::ToDevice(c) => {
@ -554,23 +550,6 @@ impl AcceptSettings {
pub fn with_allowed_methods(methods: Vec<ShortAuthenticationString>) -> Self {
Self { allowed_methods: methods }
}
fn apply(self, mut content: OwnedAcceptContent) -> OwnedAcceptContent {
match &mut content {
OwnedAcceptContent::ToDevice(AcceptToDeviceEventContent {
method: AcceptMethod::SasV1(c),
..
})
| OwnedAcceptContent::Room(
_,
AcceptEventContent { method: AcceptMethod::SasV1(c), .. },
) => {
c.short_authentication_string.retain(|sas| self.allowed_methods.contains(sas));
content
}
_ => content,
}
}
}
#[cfg(test)]

View File

@ -557,7 +557,15 @@ impl SasState<Started> {
}
}
pub fn into_accepted(self) -> SasState<WeAccepted> {
pub fn into_accepted(self, methods: Vec<ShortAuthenticationString>) -> SasState<WeAccepted> {
let mut accepted_protocols = self.state.accepted_protocols.as_ref().to_owned();
accepted_protocols.short_auth_string = methods;
// Decimal is required per spec.
if !accepted_protocols.short_auth_string.contains(&ShortAuthenticationString::Decimal) {
accepted_protocols.short_auth_string.push(ShortAuthenticationString::Decimal);
}
SasState {
inner: self.inner,
ids: self.ids,
@ -567,7 +575,7 @@ impl SasState<Started> {
started_from_request: self.started_from_request,
state: Arc::new(WeAccepted {
we_started: false,
accepted_protocols: self.state.accepted_protocols.clone(),
accepted_protocols: accepted_protocols.into(),
commitment: self.state.commitment.clone(),
}),
}
@ -1115,6 +1123,7 @@ mod test {
events::key::verification::{
accept::{AcceptMethod, AcceptToDeviceEventContent},
start::{StartMethod, StartToDeviceEventContent},
ShortAuthenticationString,
},
DeviceId, UserId,
};
@ -1162,7 +1171,7 @@ mod test {
&start_content.as_start_content(),
false,
);
let bob_sas = bob_sas.unwrap().into_accepted();
let bob_sas = bob_sas.unwrap().into_accepted(vec![ShortAuthenticationString::Emoji]);
(alice_sas, bob_sas)
}