crypto: Correctly remember our chosen SAS methods

master
Damir Jelić 2021-07-19 09:43:35 +02:00
parent ff8089912e
commit cf30c42563
3 changed files with 20 additions and 29 deletions

View File

@ -165,9 +165,12 @@ impl InnerSas {
} }
} }
pub fn accept(self) -> Option<(InnerSas, OwnedAcceptContent)> { pub fn accept(
self,
methods: Vec<ShortAuthenticationString>,
) -> Option<(InnerSas, OwnedAcceptContent)> {
if let InnerSas::Started(s) = self { if let InnerSas::Started(s) = self {
let sas = s.into_accepted(); let sas = s.into_accepted(methods);
let content = sas.as_content(); let content = sas.as_content();
Some((InnerSas::WeAccepted(sas), content)) Some((InnerSas::WeAccepted(sas), content))
} else { } else {

View File

@ -25,11 +25,7 @@ use matrix_sdk_common::uuid::Uuid;
use ruma::{ use ruma::{
api::client::r0::keys::upload_signatures::Request as SignatureUploadRequest, api::client::r0::keys::upload_signatures::Request as SignatureUploadRequest,
events::{ events::{
key::verification::{ key::verification::{cancel::CancelCode, ShortAuthenticationString},
accept::{AcceptEventContent, AcceptMethod, AcceptToDeviceEventContent},
cancel::CancelCode,
ShortAuthenticationString,
},
AnyMessageEventContent, AnyToDeviceEventContent, AnyMessageEventContent, AnyToDeviceEventContent,
}, },
DeviceId, EventId, RoomId, UserId, DeviceId, EventId, RoomId, UserId,
@ -327,10 +323,10 @@ impl Sas {
) -> Option<OutgoingVerificationRequest> { ) -> Option<OutgoingVerificationRequest> {
let mut guard = self.inner.lock().unwrap(); let mut guard = self.inner.lock().unwrap();
let sas: InnerSas = (*guard).clone(); let sas: InnerSas = (*guard).clone();
let methods = settings.allowed_methods;
if let Some((sas, content)) = sas.accept() { if let Some((sas, content)) = sas.accept(methods) {
*guard = sas; *guard = sas;
let content = settings.apply(content);
Some(match content { Some(match content {
OwnedAcceptContent::ToDevice(c) => { OwnedAcceptContent::ToDevice(c) => {
@ -554,23 +550,6 @@ impl AcceptSettings {
pub fn with_allowed_methods(methods: Vec<ShortAuthenticationString>) -> Self { pub fn with_allowed_methods(methods: Vec<ShortAuthenticationString>) -> Self {
Self { allowed_methods: methods } Self { allowed_methods: methods }
} }
fn apply(self, mut content: OwnedAcceptContent) -> OwnedAcceptContent {
match &mut content {
OwnedAcceptContent::ToDevice(AcceptToDeviceEventContent {
method: AcceptMethod::SasV1(c),
..
})
| OwnedAcceptContent::Room(
_,
AcceptEventContent { method: AcceptMethod::SasV1(c), .. },
) => {
c.short_authentication_string.retain(|sas| self.allowed_methods.contains(sas));
content
}
_ => content,
}
}
} }
#[cfg(test)] #[cfg(test)]

View File

@ -557,7 +557,15 @@ impl SasState<Started> {
} }
} }
pub fn into_accepted(self) -> SasState<WeAccepted> { pub fn into_accepted(self, methods: Vec<ShortAuthenticationString>) -> SasState<WeAccepted> {
let mut accepted_protocols = self.state.accepted_protocols.as_ref().to_owned();
accepted_protocols.short_auth_string = methods;
// Decimal is required per spec.
if !accepted_protocols.short_auth_string.contains(&ShortAuthenticationString::Decimal) {
accepted_protocols.short_auth_string.push(ShortAuthenticationString::Decimal);
}
SasState { SasState {
inner: self.inner, inner: self.inner,
ids: self.ids, ids: self.ids,
@ -567,7 +575,7 @@ impl SasState<Started> {
started_from_request: self.started_from_request, started_from_request: self.started_from_request,
state: Arc::new(WeAccepted { state: Arc::new(WeAccepted {
we_started: false, we_started: false,
accepted_protocols: self.state.accepted_protocols.clone(), accepted_protocols: accepted_protocols.into(),
commitment: self.state.commitment.clone(), commitment: self.state.commitment.clone(),
}), }),
} }
@ -1115,6 +1123,7 @@ mod test {
events::key::verification::{ events::key::verification::{
accept::{AcceptMethod, AcceptToDeviceEventContent}, accept::{AcceptMethod, AcceptToDeviceEventContent},
start::{StartMethod, StartToDeviceEventContent}, start::{StartMethod, StartToDeviceEventContent},
ShortAuthenticationString,
}, },
DeviceId, UserId, DeviceId, UserId,
}; };
@ -1162,7 +1171,7 @@ mod test {
&start_content.as_start_content(), &start_content.as_start_content(),
false, false,
); );
let bob_sas = bob_sas.unwrap().into_accepted(); let bob_sas = bob_sas.unwrap().into_accepted(vec![ShortAuthenticationString::Emoji]);
(alice_sas, bob_sas) (alice_sas, bob_sas)
} }