diff --git a/matrix_sdk_crypto/src/verification/sas.rs b/matrix_sdk_crypto/src/verification/sas.rs index 97a6fd20..1e23281a 100644 --- a/matrix_sdk_crypto/src/verification/sas.rs +++ b/matrix_sdk_crypto/src/verification/sas.rs @@ -196,12 +196,12 @@ impl InnerSas { match self { InnerSas::KeyRecieved(s) => { let sas = s.confirm(); - let content = sas.get_mac_event_content(); + let content = sas.as_content(); (InnerSas::Confirmed(sas), Some(content)) } InnerSas::MacReceived(s) => { let sas = s.confirm(); - let content = sas.get_mac_event_content(); + let content = sas.as_content(); (InnerSas::Done(sas), Some(content)) } _ => (self, None), @@ -215,21 +215,32 @@ impl InnerSas { match event { AnyToDeviceEvent::KeyVerificationAccept(e) => { if let InnerSas::Created(s) = self { - let sas = s.into_accepted(e); - let content = sas.get_key_content(); - ( - InnerSas::Accepted(sas), - Some(AnyToDeviceEventContent::KeyVerificationKey(content)), - ) + match s.into_accepted(e) { + Ok(s) => { + let content = s.as_content(); + ( + InnerSas::Accepted(s), + Some(AnyToDeviceEventContent::KeyVerificationKey(content)), + ) + } + Err(s) => { + let content = + AnyToDeviceEventContent::KeyVerificationCancel(s.as_content()); + (InnerSas::Canceled(s), Some(content)) + } + } } else { (self, None) } } AnyToDeviceEvent::KeyVerificationKey(e) => match self { - InnerSas::Accepted(s) => (InnerSas::KeyRecieved(s.into_key_received(e)), None), + InnerSas::Accepted(s) => match s.into_key_received(e) { + Ok(s) => (InnerSas::KeyRecieved(s), None), + Err(s) => (InnerSas::Canceled(s), None), + }, InnerSas::Started(s) => { let sas = s.into_key_received(e); - let content = sas.get_key_content(); + let content = sas.as_content(); ( InnerSas::KeyRecieved(sas), Some(AnyToDeviceEventContent::KeyVerificationKey(content)), @@ -369,6 +380,7 @@ struct Started { #[derive(Clone, Debug)] struct Accepted { accepted_protocols: Arc, + json_start_content: String, commitment: String, } @@ -490,19 +502,39 @@ impl SasState { /// /// * `event` - The m.key.verification.accept event that was sent to us by /// the other side. - fn into_accepted(self, event: &ToDeviceEvent) -> SasState { + fn into_accepted( + self, + event: &ToDeviceEvent, + ) -> Result, SasState> { let content = &event.content; - // TODO check that we support the agreed upon protocols, cancel if not. + if !Sas::KEY_AGREEMENT_PROTOCOLS.contains(&event.content.key_agreement_protocol) + || !Sas::HASHES.contains(&event.content.hash) + || !Sas::MACS.contains(&event.content.message_authentication_code) + || (!event + .content + .short_authentication_string + .contains(&ShortAuthenticationString::Emoji) + && !event + .content + .short_authentication_string + .contains(&ShortAuthenticationString::Decimal)) + { + Err(self.cancel(CancelCode::UnknownMethod)) + } else { + let json_start_content = cjson::to_string(&self.as_content()) + .expect("Can't deserialize start event content"); - SasState { - inner: self.inner, - ids: self.ids, - verification_flow_id: self.verification_flow_id, - state: Arc::new(Accepted { - commitment: content.commitment.clone(), - accepted_protocols: Arc::new(content.clone().into()), - }), + Ok(SasState { + inner: self.inner, + ids: self.ids, + verification_flow_id: self.verification_flow_id, + state: Arc::new(Accepted { + json_start_content, + commitment: content.commitment.clone(), + accepted_protocols: Arc::new(content.clone().into()), + }), + }) } } } @@ -652,29 +684,38 @@ impl SasState { fn into_key_received( self, event: &mut ToDeviceEvent, - ) -> SasState { - // TODO check the commitment here since we started the SAS dance. - self.inner - .lock() - .unwrap() - .set_their_public_key(&mem::take(&mut event.content.key)) - .expect("Can't set public key"); + ) -> Result, SasState> { + let utility = OlmUtility::new(); + let commitment = utility.sha256_utf8_msg(&format!( + "{}{}", + event.content.key, self.state.json_start_content + )); - SasState { - inner: self.inner, - ids: self.ids, - verification_flow_id: self.verification_flow_id, - state: Arc::new(KeyReceived { - we_started: true, - accepted_protocols: self.state.accepted_protocols.clone(), - }), + if self.state.commitment != commitment { + Err(self.cancel(CancelCode::InvalidMessage)) + } else { + self.inner + .lock() + .unwrap() + .set_their_public_key(&mem::take(&mut event.content.key)) + .expect("Can't set public key"); + + Ok(SasState { + inner: self.inner, + ids: self.ids, + verification_flow_id: self.verification_flow_id, + state: Arc::new(KeyReceived { + we_started: true, + accepted_protocols: self.state.accepted_protocols.clone(), + }), + }) } } /// Get the content for the key event. /// /// The content needs to be automatically sent to the other side. - fn get_key_content(&self) -> KeyEventContent { + fn as_content(&self) -> KeyEventContent { KeyEventContent { transaction_id: self.verification_flow_id.to_string(), key: self.inner.lock().unwrap().public_key(), @@ -687,7 +728,7 @@ impl SasState { /// /// The content needs to be automatically sent to the other side if and only /// if we_started is false. - fn get_key_content(&self) -> KeyEventContent { + fn as_content(&self) -> KeyEventContent { KeyEventContent { transaction_id: self.verification_flow_id.to_string(), key: self.inner.lock().unwrap().public_key(), @@ -793,7 +834,7 @@ impl SasState { /// Get the content for the mac event. /// /// The content needs to be automatically sent to the other side. - fn get_mac_event_content(&self) -> MacEventContent { + fn as_content(&self) -> MacEventContent { get_mac_content( &self.inner.lock().unwrap(), &self.ids, @@ -851,7 +892,7 @@ impl SasState { /// /// The content needs to be automatically sent to the other side if it /// wasn't already sent. - fn get_mac_event_content(&self) -> MacEventContent { + fn as_content(&self) -> MacEventContent { get_mac_content( &self.inner.lock().unwrap(), &self.ids, @@ -986,7 +1027,7 @@ mod test { let event = wrap_to_device_event(bob.user_id(), bob.as_content()); - alice.into_accepted(&event); + alice.into_accepted(&event).unwrap(); } #[tokio::test] @@ -995,14 +1036,14 @@ mod test { let event = wrap_to_device_event(bob.user_id(), bob.as_content()); - let alice: SasState = alice.into_accepted(&event); - let mut event = wrap_to_device_event(alice.user_id(), alice.get_key_content()); + let alice: SasState = alice.into_accepted(&event).unwrap(); + let mut event = wrap_to_device_event(alice.user_id(), alice.as_content()); let bob = bob.into_key_received(&mut event); - let mut event = wrap_to_device_event(bob.user_id(), bob.get_key_content()); + let mut event = wrap_to_device_event(bob.user_id(), bob.as_content()); - let alice = alice.into_key_received(&mut event); + let alice = alice.into_key_received(&mut event).unwrap(); assert_eq!(alice.get_decimal(), bob.get_decimal()); assert_eq!(alice.get_emoji(), bob.get_emoji()); @@ -1014,27 +1055,27 @@ mod test { let event = wrap_to_device_event(bob.user_id(), bob.as_content()); - let alice: SasState = alice.into_accepted(&event); - let mut event = wrap_to_device_event(alice.user_id(), alice.get_key_content()); + let alice: SasState = alice.into_accepted(&event).unwrap(); + let mut event = wrap_to_device_event(alice.user_id(), alice.as_content()); let bob = bob.into_key_received(&mut event); - let mut event = wrap_to_device_event(bob.user_id(), bob.get_key_content()); + let mut event = wrap_to_device_event(bob.user_id(), bob.as_content()); - let alice = alice.into_key_received(&mut event); + let alice = alice.into_key_received(&mut event).unwrap(); assert_eq!(alice.get_decimal(), bob.get_decimal()); assert_eq!(alice.get_emoji(), bob.get_emoji()); let bob = bob.confirm(); - let event = wrap_to_device_event(bob.user_id(), bob.get_mac_event_content()); + let event = wrap_to_device_event(bob.user_id(), bob.as_content()); let alice = alice.into_mac_received(&event); assert!(!alice.get_emoji().is_empty()); let alice = alice.confirm(); - let event = wrap_to_device_event(alice.user_id(), alice.get_mac_event_content()); + let event = wrap_to_device_event(alice.user_id(), alice.as_content()); let bob = bob.into_done(&event); assert!(bob.verified_devices().contains(&alice.device_id().into()));