crypto: Make the Sas struct thread safe.

master
Damir Jelić 2020-07-24 11:26:45 +02:00
parent 8ff8ea1342
commit 2f28976694
1 changed files with 79 additions and 57 deletions

View File

@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
use std::mem; use std::mem;
use std::sync::{Arc, Mutex};
use olm_rs::sas::OlmSas; use olm_rs::sas::OlmSas;
@ -35,6 +36,7 @@ use crate::{Account, Device};
/// Struct containing the protocols that were agreed to be used for the SAS /// Struct containing the protocols that were agreed to be used for the SAS
/// flow. /// flow.
#[derive(Clone, Debug)]
struct AcceptedProtocols { struct AcceptedProtocols {
method: VerificationMethod, method: VerificationMethod,
key_agreement_protocol: KeyAgreementProtocol, key_agreement_protocol: KeyAgreementProtocol,
@ -65,16 +67,16 @@ impl From<AcceptEventContent> for AcceptedProtocols {
/// and the specific state. /// and the specific state.
struct Sas<S> { struct Sas<S> {
/// The Olm SAS struct. /// The Olm SAS struct.
inner: OlmSas, inner: Arc<Mutex<OlmSas>>,
/// Struct holding the identities that are doing the SAS dance. /// Struct holding the identities that are doing the SAS dance.
ids: SasIds, ids: SasIds,
/// The unique identifier of this SAS flow. /// The unique identifier of this SAS flow.
/// ///
/// This will be the transaction id for to-device events and the relates_to /// This will be the transaction id for to-device events and the relates_to
/// field for in-room events. /// field for in-room events.
verification_flow_id: String, verification_flow_id: Arc<String>,
/// The SAS state we're in. /// The SAS state we're in.
state: S, state: Arc<S>,
} }
/// The initial SAS state. /// The initial SAS state.
@ -90,7 +92,7 @@ struct Started {
/// The SAS state we're going to be in after the other side accepted our /// The SAS state we're going to be in after the other side accepted our
/// verification start event. /// verification start event.
struct Accepted { struct Accepted {
accepted_protocols: AcceptedProtocols, accepted_protocols: Arc<AcceptedProtocols>,
commitment: String, commitment: String,
} }
@ -100,14 +102,14 @@ struct Accepted {
/// From now on we can show the short auth string to the user. /// From now on we can show the short auth string to the user.
struct KeyReceived { struct KeyReceived {
we_started: bool, we_started: bool,
accepted_protocols: AcceptedProtocols, accepted_protocols: Arc<AcceptedProtocols>,
} }
/// The SAS state we're going to be in after the user has confirmed that the /// The SAS state we're going to be in after the user has confirmed that the
/// short auth string matches. We still need to receive a MAC event from the /// short auth string matches. We still need to receive a MAC event from the
/// other side. /// other side.
struct Confirmed { struct Confirmed {
accepted_protocols: AcceptedProtocols, accepted_protocols: Arc<AcceptedProtocols>,
} }
/// The SAS state we're going to be in after we receive a MAC event from the /// The SAS state we're going to be in after we receive a MAC event from the
@ -115,8 +117,8 @@ struct Confirmed {
/// matches. /// matches.
struct MacReceived { struct MacReceived {
we_started: bool, we_started: bool,
verified_devices: Vec<Box<DeviceId>>, verified_devices: Arc<Vec<Box<DeviceId>>>,
verified_master_keys: Vec<String>, verified_master_keys: Arc<Vec<String>>,
} }
/// The SAS state indicating that the verification finished successfully. /// The SAS state indicating that the verification finished successfully.
@ -124,8 +126,8 @@ struct MacReceived {
/// We can now mark the device in our verified devices lits as verified and sign /// We can now mark the device in our verified devices lits as verified and sign
/// the master keys in the verified devices list. /// the master keys in the verified devices list.
struct Done { struct Done {
verified_devices: Vec<Box<DeviceId>>, verified_devices: Arc<Vec<Box<DeviceId>>>,
verified_master_keys: Vec<String>, verified_master_keys: Arc<Vec<String>>,
} }
impl<S> Sas<S> { impl<S> Sas<S> {
@ -153,14 +155,14 @@ impl Sas<Created> {
let from_device: Box<DeviceId> = account.device_id().into(); let from_device: Box<DeviceId> = account.device_id().into();
Sas { Sas {
inner: OlmSas::new(), inner: Arc::new(Mutex::new(OlmSas::new())),
ids: SasIds { ids: SasIds {
account, account,
other_device, other_device,
}, },
verification_flow_id: verification_flow_id.clone(), verification_flow_id: Arc::new(verification_flow_id.clone()),
state: Created { state: Arc::new(Created {
protocol_definitions: MSasV1ContentOptions { protocol_definitions: MSasV1ContentOptions {
transaction_id: verification_flow_id, transaction_id: verification_flow_id,
from_device, from_device,
@ -172,7 +174,7 @@ impl Sas<Created> {
message_authentication_codes: vec![MessageAuthenticationCode::HkdfHmacSha256], message_authentication_codes: vec![MessageAuthenticationCode::HkdfHmacSha256],
hashes: vec![HashAlgorithm::Sha256], hashes: vec![HashAlgorithm::Sha256],
}, },
}, }),
} }
} }
@ -202,10 +204,10 @@ impl Sas<Created> {
inner: self.inner, inner: self.inner,
ids: self.ids, ids: self.ids,
verification_flow_id: self.verification_flow_id, verification_flow_id: self.verification_flow_id,
state: Accepted { state: Arc::new(Accepted {
commitment: content.commitment.clone(), commitment: content.commitment.clone(),
accepted_protocols: content.clone().into(), accepted_protocols: Arc::new(content.clone().into()),
}, }),
} }
} }
} }
@ -238,18 +240,18 @@ impl Sas<Started> {
}; };
Sas { Sas {
inner: OlmSas::new(), inner: Arc::new(Mutex::new(OlmSas::new())),
ids: SasIds { ids: SasIds {
account, account,
other_device, other_device,
}, },
verification_flow_id: content.transaction_id.clone(), verification_flow_id: Arc::new(content.transaction_id.clone()),
state: Started { state: Arc::new(Started {
protocol_definitions: content.clone(), protocol_definitions: content.clone(),
}, }),
} }
} }
@ -285,9 +287,11 @@ impl Sas<Started> {
/// * `event` - The m.key.verification.key event that was sent to us by /// * `event` - The m.key.verification.key event that was sent to us by
/// the other side. The event will be modified so it doesn't contain any key /// the other side. The event will be modified so it doesn't contain any key
/// anymore. /// anymore.
fn into_key_received(mut self, event: &mut ToDeviceEvent<KeyEventContent>) -> Sas<KeyReceived> { fn into_key_received(self, event: &mut ToDeviceEvent<KeyEventContent>) -> Sas<KeyReceived> {
let accepted_protocols: AcceptedProtocols = self.get_accept_content().into(); let accepted_protocols: AcceptedProtocols = self.get_accept_content().into();
self.inner self.inner
.lock()
.unwrap()
.set_their_public_key(&mem::take(&mut event.content.key)) .set_their_public_key(&mem::take(&mut event.content.key))
.expect("Can't set public key"); .expect("Can't set public key");
@ -295,10 +299,10 @@ impl Sas<Started> {
inner: self.inner, inner: self.inner,
ids: self.ids, ids: self.ids,
verification_flow_id: self.verification_flow_id, verification_flow_id: self.verification_flow_id,
state: KeyReceived { state: Arc::new(KeyReceived {
we_started: false, we_started: false,
accepted_protocols, accepted_protocols: Arc::new(accepted_protocols),
}, }),
} }
} }
} }
@ -312,9 +316,11 @@ impl Sas<Accepted> {
/// * `event` - The m.key.verification.key event that was sent to us by /// * `event` - The m.key.verification.key event that was sent to us by
/// the other side. The event will be modified so it doesn't contain any key /// the other side. The event will be modified so it doesn't contain any key
/// anymore. /// anymore.
fn into_key_received(mut self, event: &mut ToDeviceEvent<KeyEventContent>) -> Sas<KeyReceived> { fn into_key_received(self, event: &mut ToDeviceEvent<KeyEventContent>) -> Sas<KeyReceived> {
// TODO check the commitment here since we started the SAS dance. // TODO check the commitment here since we started the SAS dance.
self.inner self.inner
.lock()
.unwrap()
.set_their_public_key(&mem::take(&mut event.content.key)) .set_their_public_key(&mem::take(&mut event.content.key))
.expect("Can't set public key"); .expect("Can't set public key");
@ -322,10 +328,10 @@ impl Sas<Accepted> {
inner: self.inner, inner: self.inner,
ids: self.ids, ids: self.ids,
verification_flow_id: self.verification_flow_id, verification_flow_id: self.verification_flow_id,
state: KeyReceived { state: Arc::new(KeyReceived {
we_started: true, we_started: true,
accepted_protocols: self.state.accepted_protocols, accepted_protocols: self.state.accepted_protocols.clone(),
}, }),
} }
} }
@ -335,7 +341,7 @@ impl Sas<Accepted> {
fn get_key_content(&self) -> KeyEventContent { fn get_key_content(&self) -> KeyEventContent {
KeyEventContent { KeyEventContent {
transaction_id: self.verification_flow_id.to_string(), transaction_id: self.verification_flow_id.to_string(),
key: self.inner.public_key(), key: self.inner.lock().unwrap().public_key(),
} }
} }
} }
@ -348,7 +354,7 @@ impl Sas<KeyReceived> {
fn get_key_content(&self) -> KeyEventContent { fn get_key_content(&self) -> KeyEventContent {
KeyEventContent { KeyEventContent {
transaction_id: self.verification_flow_id.to_string(), transaction_id: self.verification_flow_id.to_string(),
key: self.inner.public_key(), key: self.inner.lock().unwrap().public_key(),
} }
} }
@ -358,7 +364,7 @@ impl Sas<KeyReceived> {
/// second element the English description of the emoji. /// second element the English description of the emoji.
fn get_emoji(&self) -> Vec<(&'static str, &'static str)> { fn get_emoji(&self) -> Vec<(&'static str, &'static str)> {
get_emoji( get_emoji(
&self.inner, &self.inner.lock().unwrap(),
&self.ids, &self.ids,
&self.verification_flow_id, &self.verification_flow_id,
self.state.we_started, self.state.we_started,
@ -371,7 +377,7 @@ impl Sas<KeyReceived> {
/// the short auth string. /// the short auth string.
fn get_decimal(&self) -> (u32, u32, u32) { fn get_decimal(&self) -> (u32, u32, u32) {
get_decimal( get_decimal(
&self.inner, &self.inner.lock().unwrap(),
&self.ids, &self.ids,
&self.verification_flow_id, &self.verification_flow_id,
self.state.we_started, self.state.we_started,
@ -386,17 +392,21 @@ impl Sas<KeyReceived> {
/// * `event` - The m.key.verification.mac event that was sent to us by /// * `event` - The m.key.verification.mac event that was sent to us by
/// the other side. /// the other side.
fn into_mac_received(self, event: &ToDeviceEvent<MacEventContent>) -> Sas<MacReceived> { fn into_mac_received(self, event: &ToDeviceEvent<MacEventContent>) -> Sas<MacReceived> {
let (devices, master_keys) = let (devices, master_keys) = receive_mac_event(
receive_mac_event(&self.inner, &self.ids, &self.verification_flow_id, event); &self.inner.lock().unwrap(),
&self.ids,
&self.verification_flow_id,
event,
);
Sas { Sas {
inner: self.inner, inner: self.inner,
verification_flow_id: self.verification_flow_id, verification_flow_id: self.verification_flow_id,
ids: self.ids, ids: self.ids,
state: MacReceived { state: Arc::new(MacReceived {
we_started: self.state.we_started, we_started: self.state.we_started,
verified_devices: devices, verified_devices: Arc::new(devices),
verified_master_keys: master_keys, verified_master_keys: Arc::new(master_keys),
}, }),
} }
} }
@ -409,9 +419,9 @@ impl Sas<KeyReceived> {
inner: self.inner, inner: self.inner,
verification_flow_id: self.verification_flow_id, verification_flow_id: self.verification_flow_id,
ids: self.ids, ids: self.ids,
state: Confirmed { state: Arc::new(Confirmed {
accepted_protocols: self.state.accepted_protocols, accepted_protocols: self.state.accepted_protocols.clone(),
}, }),
} }
} }
} }
@ -425,18 +435,22 @@ impl Sas<Confirmed> {
/// * `event` - The m.key.verification.mac event that was sent to us by /// * `event` - The m.key.verification.mac event that was sent to us by
/// the other side. /// the other side.
fn into_done(self, event: &ToDeviceEvent<MacEventContent>) -> Sas<Done> { fn into_done(self, event: &ToDeviceEvent<MacEventContent>) -> Sas<Done> {
let (devices, master_keys) = let (devices, master_keys) = receive_mac_event(
receive_mac_event(&self.inner, &self.ids, &self.verification_flow_id, event); &self.inner.lock().unwrap(),
&self.ids,
&self.verification_flow_id,
event,
);
Sas { Sas {
inner: self.inner, inner: self.inner,
verification_flow_id: self.verification_flow_id, verification_flow_id: self.verification_flow_id,
ids: self.ids, ids: self.ids,
state: Done { state: Arc::new(Done {
verified_devices: devices, verified_devices: Arc::new(devices),
verified_master_keys: master_keys, verified_master_keys: Arc::new(master_keys),
}, }),
} }
} }
@ -444,7 +458,11 @@ impl Sas<Confirmed> {
/// ///
/// The content needs to be automatically sent to the other side. /// The content needs to be automatically sent to the other side.
fn get_mac_event_content(&self) -> MacEventContent { fn get_mac_event_content(&self) -> MacEventContent {
get_mac_content(&self.inner, &self.ids, &self.verification_flow_id) get_mac_content(
&self.inner.lock().unwrap(),
&self.ids,
&self.verification_flow_id,
)
} }
} }
@ -458,10 +476,10 @@ impl Sas<MacReceived> {
inner: self.inner, inner: self.inner,
verification_flow_id: self.verification_flow_id, verification_flow_id: self.verification_flow_id,
ids: self.ids, ids: self.ids,
state: Done { state: Arc::new(Done {
verified_devices: self.state.verified_devices, verified_devices: self.state.verified_devices.clone(),
verified_master_keys: self.state.verified_master_keys, verified_master_keys: self.state.verified_master_keys.clone(),
}, }),
} }
} }
@ -471,7 +489,7 @@ impl Sas<MacReceived> {
/// second element the English description of the emoji. /// second element the English description of the emoji.
fn get_emoji(&self) -> Vec<(&'static str, &'static str)> { fn get_emoji(&self) -> Vec<(&'static str, &'static str)> {
get_emoji( get_emoji(
&self.inner, &self.inner.lock().unwrap(),
&self.ids, &self.ids,
&self.verification_flow_id, &self.verification_flow_id,
self.state.we_started, self.state.we_started,
@ -484,7 +502,7 @@ impl Sas<MacReceived> {
/// the short auth string. /// the short auth string.
fn get_decimal(&self) -> (u32, u32, u32) { fn get_decimal(&self) -> (u32, u32, u32) {
get_decimal( get_decimal(
&self.inner, &self.inner.lock().unwrap(),
&self.ids, &self.ids,
&self.verification_flow_id, &self.verification_flow_id,
self.state.we_started, self.state.we_started,
@ -498,11 +516,15 @@ impl Sas<Done> {
/// The content needs to be automatically sent to the other side if it /// The content needs to be automatically sent to the other side if it
/// wasn't already sent. /// wasn't already sent.
fn get_mac_event_content(&self) -> MacEventContent { fn get_mac_event_content(&self) -> MacEventContent {
get_mac_content(&self.inner, &self.ids, &self.verification_flow_id) get_mac_content(
&self.inner.lock().unwrap(),
&self.ids,
&self.verification_flow_id,
)
} }
/// Get the list of verified devices. /// Get the list of verified devices.
fn verified_devices(&self) -> &Vec<Box<DeviceId>> { fn verified_devices(&self) -> &[Box<DeviceId>] {
&self.state.verified_devices &self.state.verified_devices
} }