crypto: Add initial SAS canceling.

master
Damir Jelić 2020-07-27 13:16:56 +02:00
parent 7128505768
commit da3734ffc7
1 changed files with 69 additions and 38 deletions

View File

@ -12,25 +12,29 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::mem; use std::{
use std::sync::{Arc, Mutex}; mem,
sync::{Arc, Mutex},
};
use olm_rs::sas::OlmSas; use olm_rs::{sas::OlmSas, utility::OlmUtility};
use matrix_sdk_common::events::{ use matrix_sdk_common::{
events::{
key::verification::{ key::verification::{
accept::AcceptEventContent, accept::AcceptEventContent,
cancel::CancelCode, cancel::{CancelCode, CancelEventContent},
key::KeyEventContent, key::KeyEventContent,
mac::MacEventContent, mac::MacEventContent,
start::{MSasV1Content, MSasV1ContentOptions, StartEventContent}, start::{MSasV1Content, MSasV1ContentOptions, StartEventContent},
HashAlgorithm, KeyAgreementProtocol, MessageAuthenticationCode, ShortAuthenticationString, HashAlgorithm, KeyAgreementProtocol, MessageAuthenticationCode,
VerificationMethod, ShortAuthenticationString, VerificationMethod,
}, },
AnyToDeviceEvent, AnyToDeviceEventContent, ToDeviceEvent, AnyToDeviceEvent, AnyToDeviceEventContent, ToDeviceEvent,
},
identifiers::{DeviceId, UserId},
uuid::Uuid,
}; };
use matrix_sdk_common::identifiers::{DeviceId, UserId};
use matrix_sdk_common::uuid::Uuid;
use super::{get_decimal, get_emoji, get_mac_content, receive_mac_event, SasIds}; use super::{get_decimal, get_emoji, get_mac_content, receive_mac_event, SasIds};
use crate::{Account, Device}; use crate::{Account, Device};
@ -44,6 +48,15 @@ struct Sas {
} }
impl Sas { impl Sas {
const KEY_AGREEMENT_PROTOCOLS: &'static [KeyAgreementProtocol] =
&[KeyAgreementProtocol::Curve25519HkdfSha256];
const HASHES: &'static [HashAlgorithm] = &[HashAlgorithm::Sha256];
const MACS: &'static [MessageAuthenticationCode] = &[MessageAuthenticationCode::HkdfHmacSha256];
const STRINGS: &'static [ShortAuthenticationString] = &[
ShortAuthenticationString::Decimal,
ShortAuthenticationString::Emoji,
];
/// Get our own user id. /// Get our own user id.
fn user_id(&self) -> &UserId { fn user_id(&self) -> &UserId {
self.account.user_id() self.account.user_id()
@ -90,13 +103,13 @@ impl Sas {
account: Account, account: Account,
other_device: Device, other_device: Device,
event: &ToDeviceEvent<StartEventContent>, event: &ToDeviceEvent<StartEventContent>,
) -> Sas { ) -> Result<Sas, CancelEventContent> {
let inner = InnerSas::from_start_event(account.clone(), other_device.clone(), event); let inner = InnerSas::from_start_event(account.clone(), other_device.clone(), event)?;
Sas { Ok(Sas {
inner: Arc::new(Mutex::new(inner)), inner: Arc::new(Mutex::new(inner)),
account, account,
other_device, other_device,
} })
} }
fn accept(&self) -> Option<AcceptEventContent> { fn accept(&self) -> Option<AcceptEventContent> {
@ -156,7 +169,7 @@ enum InnerSas {
impl InnerSas { impl InnerSas {
fn start(account: Account, other_device: Device) -> (InnerSas, StartEventContent) { fn start(account: Account, other_device: Device) -> (InnerSas, StartEventContent) {
let sas = SasState::<Created>::new(account, other_device); let sas = SasState::<Created>::new(account, other_device);
let content = sas.get_start_event(); let content = sas.as_content();
(InnerSas::Created(sas), content) (InnerSas::Created(sas), content)
} }
@ -164,16 +177,16 @@ impl InnerSas {
account: Account, account: Account,
other_device: Device, other_device: Device,
event: &ToDeviceEvent<StartEventContent>, event: &ToDeviceEvent<StartEventContent>,
) -> InnerSas { ) -> Result<InnerSas, CancelEventContent> {
match SasState::<Started>::from_start_event(account, other_device, event) { match SasState::<Started>::from_start_event(account, other_device, event) {
Ok(s) => InnerSas::Started(s), Ok(s) => Ok(InnerSas::Started(s)),
Err(s) => InnerSas::Canceled(s), Err(s) => Err(s.as_content()),
} }
} }
fn accept(&self) -> Option<AcceptEventContent> { fn accept(&self) -> Option<AcceptEventContent> {
if let InnerSas::Started(s) = self { if let InnerSas::Started(s) = self {
Some(s.get_accept_content()) Some(s.as_content())
} else { } else {
None None
} }
@ -347,6 +360,7 @@ struct Created {
/// The initial SAS state if the other side started the SAS verification. /// The initial SAS state if the other side started the SAS verification.
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
struct Started { struct Started {
commitment: String,
protocol_definitions: MSasV1Content, protocol_definitions: MSasV1Content,
} }
@ -462,7 +476,7 @@ impl SasState<Created> {
/// Get the content for the start event. /// Get the content for the start event.
/// ///
/// The content needs to be sent to the other device. /// The content needs to be sent to the other device.
fn get_start_event(&self) -> StartEventContent { fn as_content(&self) -> StartEventContent {
StartEventContent::MSasV1( StartEventContent::MSasV1(
MSasV1Content::new(self.state.protocol_definitions.clone()) MSasV1Content::new(self.state.protocol_definitions.clone())
.expect("Invalid initial protocol definitions."), .expect("Invalid initial protocol definitions."),
@ -513,8 +527,15 @@ impl SasState<Started> {
event: &ToDeviceEvent<StartEventContent>, event: &ToDeviceEvent<StartEventContent>,
) -> Result<SasState<Started>, SasState<Canceled>> { ) -> Result<SasState<Started>, SasState<Canceled>> {
if let StartEventContent::MSasV1(content) = &event.content { if let StartEventContent::MSasV1(content) = &event.content {
let sas = OlmSas::new();
let utility = OlmUtility::new();
let json_content = cjson::to_string(&event.content).expect("Can't serialize content");
let pubkey = sas.public_key();
let commitment = utility.sha256_utf8_msg(&format!("{}{}", pubkey, json_content));
let sas = SasState { let sas = SasState {
inner: Arc::new(Mutex::new(OlmSas::new())), inner: Arc::new(Mutex::new(sas)),
ids: SasIds { ids: SasIds {
account, account,
@ -525,6 +546,7 @@ impl SasState<Started> {
state: Arc::new(Started { state: Arc::new(Started {
protocol_definitions: content.clone(), protocol_definitions: content.clone(),
commitment,
}), }),
}; };
@ -571,12 +593,11 @@ impl SasState<Started> {
/// This should be sent out automatically if the SAS verification flow has /// This should be sent out automatically if the SAS verification flow has
/// been started because of a /// been started because of a
/// m.key.verification.request -> m.key.verification.ready flow. /// m.key.verification.request -> m.key.verification.ready flow.
fn get_accept_content(&self) -> AcceptEventContent { fn as_content(&self) -> AcceptEventContent {
AcceptEventContent { AcceptEventContent {
method: VerificationMethod::MSasV1, method: VerificationMethod::MSasV1,
transaction_id: self.verification_flow_id.to_string(), transaction_id: self.verification_flow_id.to_string(),
// TODO calculate the commitment. commitment: self.state.commitment.clone(),
commitment: "".to_owned(),
hash: HashAlgorithm::Sha256, hash: HashAlgorithm::Sha256,
key_agreement_protocol: KeyAgreementProtocol::Curve25519HkdfSha256, key_agreement_protocol: KeyAgreementProtocol::Curve25519HkdfSha256,
message_authentication_code: MessageAuthenticationCode::HkdfHmacSha256, message_authentication_code: MessageAuthenticationCode::HkdfHmacSha256,
@ -600,7 +621,7 @@ impl SasState<Started> {
self, self,
event: &mut ToDeviceEvent<KeyEventContent>, event: &mut ToDeviceEvent<KeyEventContent>,
) -> SasState<KeyReceived> { ) -> SasState<KeyReceived> {
let accepted_protocols: AcceptedProtocols = self.get_accept_content().into(); let accepted_protocols: AcceptedProtocols = self.as_content().into();
self.inner self.inner
.lock() .lock()
.unwrap() .unwrap()
@ -877,6 +898,16 @@ impl Canceled {
} }
} }
impl SasState<Canceled> {
fn as_content(&self) -> CancelEventContent {
CancelEventContent {
transaction_id: self.verification_flow_id.to_string(),
reason: self.state.reason.to_string(),
code: self.state.cancel_code.clone(),
}
}
}
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use std::convert::TryFrom; use std::convert::TryFrom;
@ -936,7 +967,7 @@ mod test {
let alice_sas = SasState::<Created>::new(alice.clone(), bob_device); let alice_sas = SasState::<Created>::new(alice.clone(), bob_device);
let start_content = alice_sas.get_start_event(); let start_content = alice_sas.as_content();
let event = wrap_to_device_event(alice_sas.user_id(), start_content); let event = wrap_to_device_event(alice_sas.user_id(), start_content);
let bob_sas = SasState::<Started>::from_start_event(bob.clone(), alice_device, &event); let bob_sas = SasState::<Started>::from_start_event(bob.clone(), alice_device, &event);
@ -953,7 +984,7 @@ mod test {
async fn sas_accept() { async fn sas_accept() {
let (alice, bob) = get_sas_pair().await; let (alice, bob) = get_sas_pair().await;
let event = wrap_to_device_event(bob.user_id(), bob.get_accept_content()); let event = wrap_to_device_event(bob.user_id(), bob.as_content());
alice.into_accepted(&event); alice.into_accepted(&event);
} }
@ -962,7 +993,7 @@ mod test {
async fn sas_key_share() { async fn sas_key_share() {
let (alice, bob) = get_sas_pair().await; let (alice, bob) = get_sas_pair().await;
let event = wrap_to_device_event(bob.user_id(), bob.get_accept_content()); let event = wrap_to_device_event(bob.user_id(), bob.as_content());
let alice: SasState<Accepted> = alice.into_accepted(&event); let alice: SasState<Accepted> = alice.into_accepted(&event);
let mut event = wrap_to_device_event(alice.user_id(), alice.get_key_content()); let mut event = wrap_to_device_event(alice.user_id(), alice.get_key_content());
@ -981,7 +1012,7 @@ mod test {
async fn sas_full() { async fn sas_full() {
let (alice, bob) = get_sas_pair().await; let (alice, bob) = get_sas_pair().await;
let event = wrap_to_device_event(bob.user_id(), bob.get_accept_content()); let event = wrap_to_device_event(bob.user_id(), bob.as_content());
let alice: SasState<Accepted> = alice.into_accepted(&event); let alice: SasState<Accepted> = alice.into_accepted(&event);
let mut event = wrap_to_device_event(alice.user_id(), alice.get_key_content()); let mut event = wrap_to_device_event(alice.user_id(), alice.get_key_content());
@ -1021,7 +1052,7 @@ mod test {
let (alice, content) = Sas::start(alice, bob_device); let (alice, content) = Sas::start(alice, bob_device);
let event = wrap_to_device_event(alice.user_id(), content); let event = wrap_to_device_event(alice.user_id(), content);
let bob = Sas::from_start_event(bob, alice_device, &event); let bob = Sas::from_start_event(bob, alice_device, &event).unwrap();
let event = wrap_to_device_event(bob.user_id(), bob.accept().unwrap()); let event = wrap_to_device_event(bob.user_id(), bob.accept().unwrap());
let content = alice.receive_event(&mut AnyToDeviceEvent::KeyVerificationAccept(event)); let content = alice.receive_event(&mut AnyToDeviceEvent::KeyVerificationAccept(event));