crypto: Make the Sas struct thread safe.

master
Damir Jelić 2020-07-24 11:26:45 +02:00
parent 8ff8ea1342
commit 2f28976694
1 changed files with 79 additions and 57 deletions

View File

@ -13,6 +13,7 @@
// limitations under the License.
use std::mem;
use std::sync::{Arc, Mutex};
use olm_rs::sas::OlmSas;
@ -35,6 +36,7 @@ use crate::{Account, Device};
/// Struct containing the protocols that were agreed to be used for the SAS
/// flow.
#[derive(Clone, Debug)]
struct AcceptedProtocols {
method: VerificationMethod,
key_agreement_protocol: KeyAgreementProtocol,
@ -65,16 +67,16 @@ impl From<AcceptEventContent> for AcceptedProtocols {
/// and the specific state.
struct Sas<S> {
/// The Olm SAS struct.
inner: OlmSas,
inner: Arc<Mutex<OlmSas>>,
/// Struct holding the identities that are doing the SAS dance.
ids: SasIds,
/// The unique identifier of this SAS flow.
///
/// This will be the transaction id for to-device events and the relates_to
/// field for in-room events.
verification_flow_id: String,
verification_flow_id: Arc<String>,
/// The SAS state we're in.
state: S,
state: Arc<S>,
}
/// The initial SAS state.
@ -90,7 +92,7 @@ struct Started {
/// The SAS state we're going to be in after the other side accepted our
/// verification start event.
struct Accepted {
accepted_protocols: AcceptedProtocols,
accepted_protocols: Arc<AcceptedProtocols>,
commitment: String,
}
@ -100,14 +102,14 @@ struct Accepted {
/// From now on we can show the short auth string to the user.
struct KeyReceived {
we_started: bool,
accepted_protocols: AcceptedProtocols,
accepted_protocols: Arc<AcceptedProtocols>,
}
/// The SAS state we're going to be in after the user has confirmed that the
/// short auth string matches. We still need to receive a MAC event from the
/// other side.
struct Confirmed {
accepted_protocols: AcceptedProtocols,
accepted_protocols: Arc<AcceptedProtocols>,
}
/// The SAS state we're going to be in after we receive a MAC event from the
@ -115,8 +117,8 @@ struct Confirmed {
/// matches.
struct MacReceived {
we_started: bool,
verified_devices: Vec<Box<DeviceId>>,
verified_master_keys: Vec<String>,
verified_devices: Arc<Vec<Box<DeviceId>>>,
verified_master_keys: Arc<Vec<String>>,
}
/// The SAS state indicating that the verification finished successfully.
@ -124,8 +126,8 @@ struct MacReceived {
/// We can now mark the device in our verified devices lits as verified and sign
/// the master keys in the verified devices list.
struct Done {
verified_devices: Vec<Box<DeviceId>>,
verified_master_keys: Vec<String>,
verified_devices: Arc<Vec<Box<DeviceId>>>,
verified_master_keys: Arc<Vec<String>>,
}
impl<S> Sas<S> {
@ -153,14 +155,14 @@ impl Sas<Created> {
let from_device: Box<DeviceId> = account.device_id().into();
Sas {
inner: OlmSas::new(),
inner: Arc::new(Mutex::new(OlmSas::new())),
ids: SasIds {
account,
other_device,
},
verification_flow_id: verification_flow_id.clone(),
verification_flow_id: Arc::new(verification_flow_id.clone()),
state: Created {
state: Arc::new(Created {
protocol_definitions: MSasV1ContentOptions {
transaction_id: verification_flow_id,
from_device,
@ -172,7 +174,7 @@ impl Sas<Created> {
message_authentication_codes: vec![MessageAuthenticationCode::HkdfHmacSha256],
hashes: vec![HashAlgorithm::Sha256],
},
},
}),
}
}
@ -202,10 +204,10 @@ impl Sas<Created> {
inner: self.inner,
ids: self.ids,
verification_flow_id: self.verification_flow_id,
state: Accepted {
state: Arc::new(Accepted {
commitment: content.commitment.clone(),
accepted_protocols: content.clone().into(),
},
accepted_protocols: Arc::new(content.clone().into()),
}),
}
}
}
@ -238,18 +240,18 @@ impl Sas<Started> {
};
Sas {
inner: OlmSas::new(),
inner: Arc::new(Mutex::new(OlmSas::new())),
ids: SasIds {
account,
other_device,
},
verification_flow_id: content.transaction_id.clone(),
verification_flow_id: Arc::new(content.transaction_id.clone()),
state: Started {
state: Arc::new(Started {
protocol_definitions: content.clone(),
},
}),
}
}
@ -285,9 +287,11 @@ impl Sas<Started> {
/// * `event` - The m.key.verification.key event that was sent to us by
/// the other side. The event will be modified so it doesn't contain any key
/// anymore.
fn into_key_received(mut self, event: &mut ToDeviceEvent<KeyEventContent>) -> Sas<KeyReceived> {
fn into_key_received(self, event: &mut ToDeviceEvent<KeyEventContent>) -> Sas<KeyReceived> {
let accepted_protocols: AcceptedProtocols = self.get_accept_content().into();
self.inner
.lock()
.unwrap()
.set_their_public_key(&mem::take(&mut event.content.key))
.expect("Can't set public key");
@ -295,10 +299,10 @@ impl Sas<Started> {
inner: self.inner,
ids: self.ids,
verification_flow_id: self.verification_flow_id,
state: KeyReceived {
state: Arc::new(KeyReceived {
we_started: false,
accepted_protocols,
},
accepted_protocols: Arc::new(accepted_protocols),
}),
}
}
}
@ -312,9 +316,11 @@ impl Sas<Accepted> {
/// * `event` - The m.key.verification.key event that was sent to us by
/// the other side. The event will be modified so it doesn't contain any key
/// anymore.
fn into_key_received(mut self, event: &mut ToDeviceEvent<KeyEventContent>) -> Sas<KeyReceived> {
fn into_key_received(self, event: &mut ToDeviceEvent<KeyEventContent>) -> Sas<KeyReceived> {
// TODO check the commitment here since we started the SAS dance.
self.inner
.lock()
.unwrap()
.set_their_public_key(&mem::take(&mut event.content.key))
.expect("Can't set public key");
@ -322,10 +328,10 @@ impl Sas<Accepted> {
inner: self.inner,
ids: self.ids,
verification_flow_id: self.verification_flow_id,
state: KeyReceived {
state: Arc::new(KeyReceived {
we_started: true,
accepted_protocols: self.state.accepted_protocols,
},
accepted_protocols: self.state.accepted_protocols.clone(),
}),
}
}
@ -335,7 +341,7 @@ impl Sas<Accepted> {
fn get_key_content(&self) -> KeyEventContent {
KeyEventContent {
transaction_id: self.verification_flow_id.to_string(),
key: self.inner.public_key(),
key: self.inner.lock().unwrap().public_key(),
}
}
}
@ -348,7 +354,7 @@ impl Sas<KeyReceived> {
fn get_key_content(&self) -> KeyEventContent {
KeyEventContent {
transaction_id: self.verification_flow_id.to_string(),
key: self.inner.public_key(),
key: self.inner.lock().unwrap().public_key(),
}
}
@ -358,7 +364,7 @@ impl Sas<KeyReceived> {
/// second element the English description of the emoji.
fn get_emoji(&self) -> Vec<(&'static str, &'static str)> {
get_emoji(
&self.inner,
&self.inner.lock().unwrap(),
&self.ids,
&self.verification_flow_id,
self.state.we_started,
@ -371,7 +377,7 @@ impl Sas<KeyReceived> {
/// the short auth string.
fn get_decimal(&self) -> (u32, u32, u32) {
get_decimal(
&self.inner,
&self.inner.lock().unwrap(),
&self.ids,
&self.verification_flow_id,
self.state.we_started,
@ -386,17 +392,21 @@ impl Sas<KeyReceived> {
/// * `event` - The m.key.verification.mac event that was sent to us by
/// the other side.
fn into_mac_received(self, event: &ToDeviceEvent<MacEventContent>) -> Sas<MacReceived> {
let (devices, master_keys) =
receive_mac_event(&self.inner, &self.ids, &self.verification_flow_id, event);
let (devices, master_keys) = receive_mac_event(
&self.inner.lock().unwrap(),
&self.ids,
&self.verification_flow_id,
event,
);
Sas {
inner: self.inner,
verification_flow_id: self.verification_flow_id,
ids: self.ids,
state: MacReceived {
state: Arc::new(MacReceived {
we_started: self.state.we_started,
verified_devices: devices,
verified_master_keys: master_keys,
},
verified_devices: Arc::new(devices),
verified_master_keys: Arc::new(master_keys),
}),
}
}
@ -409,9 +419,9 @@ impl Sas<KeyReceived> {
inner: self.inner,
verification_flow_id: self.verification_flow_id,
ids: self.ids,
state: Confirmed {
accepted_protocols: self.state.accepted_protocols,
},
state: Arc::new(Confirmed {
accepted_protocols: self.state.accepted_protocols.clone(),
}),
}
}
}
@ -425,18 +435,22 @@ impl Sas<Confirmed> {
/// * `event` - The m.key.verification.mac event that was sent to us by
/// the other side.
fn into_done(self, event: &ToDeviceEvent<MacEventContent>) -> Sas<Done> {
let (devices, master_keys) =
receive_mac_event(&self.inner, &self.ids, &self.verification_flow_id, event);
let (devices, master_keys) = receive_mac_event(
&self.inner.lock().unwrap(),
&self.ids,
&self.verification_flow_id,
event,
);
Sas {
inner: self.inner,
verification_flow_id: self.verification_flow_id,
ids: self.ids,
state: Done {
verified_devices: devices,
verified_master_keys: master_keys,
},
state: Arc::new(Done {
verified_devices: Arc::new(devices),
verified_master_keys: Arc::new(master_keys),
}),
}
}
@ -444,7 +458,11 @@ impl Sas<Confirmed> {
///
/// The content needs to be automatically sent to the other side.
fn get_mac_event_content(&self) -> MacEventContent {
get_mac_content(&self.inner, &self.ids, &self.verification_flow_id)
get_mac_content(
&self.inner.lock().unwrap(),
&self.ids,
&self.verification_flow_id,
)
}
}
@ -458,10 +476,10 @@ impl Sas<MacReceived> {
inner: self.inner,
verification_flow_id: self.verification_flow_id,
ids: self.ids,
state: Done {
verified_devices: self.state.verified_devices,
verified_master_keys: self.state.verified_master_keys,
},
state: Arc::new(Done {
verified_devices: self.state.verified_devices.clone(),
verified_master_keys: self.state.verified_master_keys.clone(),
}),
}
}
@ -471,7 +489,7 @@ impl Sas<MacReceived> {
/// second element the English description of the emoji.
fn get_emoji(&self) -> Vec<(&'static str, &'static str)> {
get_emoji(
&self.inner,
&self.inner.lock().unwrap(),
&self.ids,
&self.verification_flow_id,
self.state.we_started,
@ -484,7 +502,7 @@ impl Sas<MacReceived> {
/// the short auth string.
fn get_decimal(&self) -> (u32, u32, u32) {
get_decimal(
&self.inner,
&self.inner.lock().unwrap(),
&self.ids,
&self.verification_flow_id,
self.state.we_started,
@ -498,11 +516,15 @@ impl Sas<Done> {
/// The content needs to be automatically sent to the other side if it
/// wasn't already sent.
fn get_mac_event_content(&self) -> MacEventContent {
get_mac_content(&self.inner, &self.ids, &self.verification_flow_id)
get_mac_content(
&self.inner.lock().unwrap(),
&self.ids,
&self.verification_flow_id,
)
}
/// Get the list of verified devices.
fn verified_devices(&self) -> &Vec<Box<DeviceId>> {
fn verified_devices(&self) -> &[Box<DeviceId>] {
&self.state.verified_devices
}