crypto: Scope the verifications per sender

master
Damir Jelić 2021-06-08 16:13:14 +02:00
parent 533a5b92b0
commit ada71586ac
8 changed files with 114 additions and 45 deletions

View File

@ -81,7 +81,7 @@ async fn login(
match event { match event {
AnyToDeviceEvent::KeyVerificationStart(e) => { AnyToDeviceEvent::KeyVerificationStart(e) => {
let sas = client let sas = client
.get_verification(&e.content.transaction_id) .get_verification(&e.sender, &e.content.transaction_id)
.await .await
.expect("Sas object wasn't created"); .expect("Sas object wasn't created");
println!( println!(
@ -95,7 +95,7 @@ async fn login(
AnyToDeviceEvent::KeyVerificationKey(e) => { AnyToDeviceEvent::KeyVerificationKey(e) => {
let sas = client let sas = client
.get_verification(&e.content.transaction_id) .get_verification(&e.sender, &e.content.transaction_id)
.await .await
.expect("Sas object wasn't created"); .expect("Sas object wasn't created");
@ -104,7 +104,7 @@ async fn login(
AnyToDeviceEvent::KeyVerificationMac(e) => { AnyToDeviceEvent::KeyVerificationMac(e) => {
let sas = client let sas = client
.get_verification(&e.content.transaction_id) .get_verification(&e.sender, &e.content.transaction_id)
.await .await
.expect("Sas object wasn't created"); .expect("Sas object wasn't created");
@ -141,7 +141,10 @@ async fn login(
} }
AnySyncMessageEvent::KeyVerificationKey(e) => { AnySyncMessageEvent::KeyVerificationKey(e) => {
let sas = client let sas = client
.get_verification(e.content.relation.event_id.as_str()) .get_verification(
&e.sender,
e.content.relation.event_id.as_str(),
)
.await .await
.expect("Sas object wasn't created"); .expect("Sas object wasn't created");
@ -149,7 +152,10 @@ async fn login(
} }
AnySyncMessageEvent::KeyVerificationMac(e) => { AnySyncMessageEvent::KeyVerificationMac(e) => {
let sas = client let sas = client
.get_verification(e.content.relation.event_id.as_str()) .get_verification(
&e.sender,
e.content.relation.event_id.as_str(),
)
.await .await
.expect("Sas object wasn't created"); .expect("Sas object wasn't created");

View File

@ -2185,9 +2185,9 @@ impl Client {
/// Get a `Sas` verification object with the given flow id. /// Get a `Sas` verification object with the given flow id.
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
#[cfg_attr(feature = "docs", doc(cfg(encryption)))] #[cfg_attr(feature = "docs", doc(cfg(encryption)))]
pub async fn get_verification(&self, flow_id: &str) -> Option<Sas> { pub async fn get_verification(&self, user_id: &UserId, flow_id: &str) -> Option<Sas> {
self.base_client self.base_client
.get_verification(flow_id) .get_verification(user_id, flow_id)
.await .await
.map(|sas| Sas { inner: sas, client: self.clone() }) .map(|sas| Sas { inner: sas, client: self.clone() })
} }

View File

@ -1213,8 +1213,13 @@ impl BaseClient {
/// *m.key.verification.start* event. /// *m.key.verification.start* event.
#[cfg(feature = "encryption")] #[cfg(feature = "encryption")]
#[cfg_attr(feature = "docs", doc(cfg(encryption)))] #[cfg_attr(feature = "docs", doc(cfg(encryption)))]
pub async fn get_verification(&self, flow_id: &str) -> Option<Sas> { pub async fn get_verification(&self, user_id: &UserId, flow_id: &str) -> Option<Sas> {
self.olm.lock().await.as_ref().and_then(|o| o.get_verification(flow_id)) self.olm
.lock()
.await
.as_ref()
.and_then(|o| o.get_verification(user_id, flow_id).map(|v| v.sas_v1()))
.flatten()
} }
/// Get a specific device of a user. /// Get a specific device of a user.

View File

@ -59,7 +59,7 @@ use crate::{
Changes, CryptoStore, DeviceChanges, IdentityChanges, MemoryStore, Result as StoreResult, Changes, CryptoStore, DeviceChanges, IdentityChanges, MemoryStore, Result as StoreResult,
Store, Store,
}, },
verification::{Sas, VerificationMachine, VerificationRequest}, verification::{Verification, VerificationMachine, VerificationRequest},
ToDeviceRequest, ToDeviceRequest,
}; };
@ -717,9 +717,9 @@ impl OlmMachine {
Ok(()) Ok(())
} }
/// Get a `Sas` verification object with the given flow id. /// Get a verification object for the given user id with the given flow id.
pub fn get_verification(&self, flow_id: &str) -> Option<Sas> { pub fn get_verification(&self, user_id: &UserId, flow_id: &str) -> Option<Verification> {
self.verification_machine.get_sas(flow_id) self.verification_machine.get_verification(user_id, flow_id)
} }
/// Get a verification request object with the given flow id. /// Get a verification request object with the given flow id.
@ -1761,7 +1761,11 @@ pub(crate) mod test {
let event = request_to_event(alice.user_id(), &request.into()); let event = request_to_event(alice.user_id(), &request.into());
bob.handle_verification_event(&event).await; bob.handle_verification_event(&event).await;
let bob_sas = bob.get_verification(alice_sas.flow_id().as_str()).unwrap(); let bob_sas = bob
.get_verification(alice.user_id(), alice_sas.flow_id().as_str())
.unwrap()
.sas_v1()
.unwrap();
assert!(alice_sas.emoji().is_none()); assert!(alice_sas.emoji().is_none());
assert!(bob_sas.emoji().is_none()); assert!(bob_sas.emoji().is_none());

View File

@ -23,7 +23,7 @@ use crate::{OutgoingRequest, RoomMessageRequest};
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct VerificationCache { pub struct VerificationCache {
verification: Arc<DashMap<String, Verification>>, verification: Arc<DashMap<UserId, DashMap<String, Verification>>>,
outgoing_requests: Arc<DashMap<Uuid, OutgoingRequest>>, outgoing_requests: Arc<DashMap<Uuid, OutgoingRequest>>,
} }
@ -35,11 +35,24 @@ impl VerificationCache {
#[cfg(test)] #[cfg(test)]
#[allow(dead_code)] #[allow(dead_code)]
pub fn is_empty(&self) -> bool { pub fn is_empty(&self) -> bool {
self.verification.is_empty() self.verification.iter().all(|m| m.is_empty())
}
pub fn insert(&self, verification: impl Into<Verification>) {
let verification = verification.into();
self.verification
.entry(verification.other_user().to_owned())
.or_insert_with(DashMap::new)
.insert(verification.flow_id().to_owned(), verification);
} }
pub fn insert_sas(&self, sas: Sas) { pub fn insert_sas(&self, sas: Sas) {
self.verification.insert(sas.flow_id().as_str().to_string(), sas.into()); self.insert(sas);
}
pub fn get(&self, sender: &UserId, flow_id: &str) -> Option<Verification> {
self.verification.get(sender).and_then(|m| m.get(flow_id).map(|v| v.clone()))
} }
pub fn outgoing_requests(&self) -> Vec<OutgoingRequest> { pub fn outgoing_requests(&self) -> Vec<OutgoingRequest> {
@ -47,28 +60,38 @@ impl VerificationCache {
} }
pub fn garbage_collect(&self) -> Vec<OutgoingRequest> { pub fn garbage_collect(&self) -> Vec<OutgoingRequest> {
self.verification.retain(|_, s| !(s.is_done() || s.is_cancelled())); for user_verification in self.verification.iter() {
user_verification.retain(|_, s| !(s.is_done() || s.is_cancelled()));
}
self.verification.retain(|_, m| !m.is_empty());
self.verification self.verification
.iter() .iter()
.filter_map(|s| { .flat_map(|v| {
#[allow(irrefutable_let_patterns)] let requests: Vec<OutgoingRequest> = v
if let Verification::SasV1(s) = s.value() { .value()
s.cancel_if_timed_out().map(|r| OutgoingRequest { .iter()
request_id: r.request_id(), .filter_map(|s| {
request: Arc::new(r.into()), if let Verification::SasV1(s) = s.value() {
s.cancel_if_timed_out().map(|r| OutgoingRequest {
request_id: r.request_id(),
request: Arc::new(r.into()),
})
} else {
None
}
}) })
} else { .collect();
None
} requests
}) })
.collect() .collect()
} }
pub fn get_sas(&self, transaction_id: &str) -> Option<Sas> { pub fn get_sas(&self, user_id: &UserId, flow_id: &str) -> Option<Sas> {
self.verification.get(transaction_id).and_then(|v| { self.get(user_id, flow_id).and_then(|v| {
#[allow(irrefutable_let_patterns)] if let Verification::SasV1(sas) = v {
if let Verification::SasV1(sas) = v.value() {
Some(sas.clone()) Some(sas.clone())
} else { } else {
None None

View File

@ -24,7 +24,7 @@ use super::{
event_enums::{AnyEvent, AnyVerificationContent, OutgoingContent}, event_enums::{AnyEvent, AnyVerificationContent, OutgoingContent},
requests::VerificationRequest, requests::VerificationRequest,
sas::{content_to_request, Sas}, sas::{content_to_request, Sas},
FlowId, VerificationResult, FlowId, Verification, VerificationResult,
}; };
use crate::{ use crate::{
olm::PrivateCrossSigningIdentity, olm::PrivateCrossSigningIdentity,
@ -94,8 +94,12 @@ impl VerificationMachine {
self.requests.get(flow_id.as_ref()).map(|s| s.clone()) self.requests.get(flow_id.as_ref()).map(|s| s.clone())
} }
pub fn get_sas(&self, transaction_id: &str) -> Option<Sas> { pub fn get_verification(&self, user_id: &UserId, flow_id: &str) -> Option<Verification> {
self.verifications.get_sas(transaction_id) self.verifications.get(user_id, flow_id)
}
pub fn get_sas(&self, user_id: &UserId, flow_id: &str) -> Option<Sas> {
self.verifications.get_sas(user_id, flow_id)
} }
#[cfg(not(target_arch = "wasm32"))] #[cfg(not(target_arch = "wasm32"))]
@ -242,7 +246,7 @@ impl VerificationMachine {
verification.receive_cancel(event.sender(), c); verification.receive_cancel(event.sender(), c);
} }
if let Some(sas) = self.verifications.get_sas(flow_id.as_str()) { if let Some(sas) = self.get_sas(event.sender(), flow_id.as_str()) {
// This won't produce an outgoing content // This won't produce an outgoing content
let _ = sas.receive_any_event(event.sender(), &content); let _ = sas.receive_any_event(event.sender(), &content);
} }
@ -296,7 +300,7 @@ impl VerificationMachine {
} }
} }
AnyVerificationContent::Accept(_) | AnyVerificationContent::Key(_) => { AnyVerificationContent::Accept(_) | AnyVerificationContent::Key(_) => {
if let Some(sas) = self.verifications.get_sas(flow_id.as_str()) { if let Some(sas) = self.get_sas(event.sender(), flow_id.as_str()) {
if sas.flow_id() == &flow_id { if sas.flow_id() == &flow_id {
if let Some(content) = sas.receive_any_event(event.sender(), &content) { if let Some(content) = sas.receive_any_event(event.sender(), &content) {
self.queue_up_content( self.queue_up_content(
@ -311,7 +315,7 @@ impl VerificationMachine {
} }
} }
AnyVerificationContent::Mac(_) => { AnyVerificationContent::Mac(_) => {
if let Some(s) = self.verifications.get_sas(flow_id.as_str()) { if let Some(s) = self.get_sas(event.sender(), flow_id.as_str()) {
if s.flow_id() == &flow_id { if s.flow_id() == &flow_id {
let content = s.receive_any_event(event.sender(), &content); let content = s.receive_any_event(event.sender(), &content);
@ -328,12 +332,15 @@ impl VerificationMachine {
verification.receive_done(event.sender(), c); verification.receive_done(event.sender(), c);
} }
if let Some(s) = self.verifications.get_sas(flow_id.as_str()) { match self.get_verification(event.sender(), flow_id.as_str()) {
let content = s.receive_any_event(event.sender(), &content); Some(Verification::SasV1(sas)) => {
let content = sas.receive_any_event(event.sender(), &content);
if s.is_done() { if sas.is_done() {
self.mark_sas_as_done(s, content).await?; self.mark_sas_as_done(sas, content).await?;
}
} }
None => (),
} }
} }
} }
@ -426,7 +433,7 @@ mod test {
async fn full_flow() { async fn full_flow() {
let (alice_machine, bob) = setup_verification_machine().await; let (alice_machine, bob) = setup_verification_machine().await;
let alice = alice_machine.get_sas(bob.flow_id().as_str()).unwrap(); let alice = alice_machine.get_sas(bob.user_id(), bob.flow_id().as_str()).unwrap();
let request = alice.accept().unwrap(); let request = alice.accept().unwrap();
@ -472,7 +479,7 @@ mod test {
#[tokio::test] #[tokio::test]
async fn timing_out() { async fn timing_out() {
let (alice_machine, bob) = setup_verification_machine().await; let (alice_machine, bob) = setup_verification_machine().await;
let alice = alice_machine.get_sas(bob.flow_id().as_str()).unwrap(); let alice = alice_machine.get_sas(bob.user_id(), bob.flow_id().as_str()).unwrap();
assert!(!alice.timed_out()); assert!(!alice.timed_out());
assert!(alice_machine.verifications.outgoing_requests().is_empty()); assert!(alice_machine.verifications.outgoing_requests().is_empty());

View File

@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#![allow(missing_docs)]
mod cache; mod cache;
mod event_enums; mod event_enums;
mod machine; mod machine;
@ -57,11 +59,31 @@ impl Verification {
} }
} }
pub fn sas_v1(self) -> Option<Sas> {
if let Verification::SasV1(sas) = self {
Some(sas)
} else {
None
}
}
pub fn flow_id(&self) -> &str {
match self {
Verification::SasV1(s) => s.flow_id().as_str(),
}
}
pub fn is_cancelled(&self) -> bool { pub fn is_cancelled(&self) -> bool {
match self { match self {
Verification::SasV1(s) => s.is_cancelled(), Verification::SasV1(s) => s.is_cancelled(),
} }
} }
pub fn other_user(&self) -> &UserId {
match self {
Verification::SasV1(s) => s.other_user_id(),
}
}
} }
impl From<Sas> for Verification { impl From<Sas> for Verification {

View File

@ -810,7 +810,8 @@ mod test {
let content = StartContent::try_from(&start_content).unwrap(); let content = StartContent::try_from(&start_content).unwrap();
let flow_id = content.flow_id().to_owned(); let flow_id = content.flow_id().to_owned();
alice_request.receive_start(bob_device.user_id(), &content).await.unwrap(); alice_request.receive_start(bob_device.user_id(), &content).await.unwrap();
let alice_sas = alice_request.verification_cache.get_sas(&flow_id).unwrap(); let alice_sas =
alice_request.verification_cache.get_sas(bob_device.user_id(), &flow_id).unwrap();
assert!(!bob_sas.is_cancelled()); assert!(!bob_sas.is_cancelled());
assert!(!alice_sas.is_cancelled()); assert!(!alice_sas.is_cancelled());
@ -867,7 +868,8 @@ mod test {
let content = StartContent::try_from(&start_content).unwrap(); let content = StartContent::try_from(&start_content).unwrap();
let flow_id = content.flow_id().to_owned(); let flow_id = content.flow_id().to_owned();
alice_request.receive_start(bob_device.user_id(), &content).await.unwrap(); alice_request.receive_start(bob_device.user_id(), &content).await.unwrap();
let alice_sas = alice_request.verification_cache.get_sas(&flow_id).unwrap(); let alice_sas =
alice_request.verification_cache.get_sas(bob_device.user_id(), &flow_id).unwrap();
assert!(!bob_sas.is_cancelled()); assert!(!bob_sas.is_cancelled());
assert!(!alice_sas.is_cancelled()); assert!(!alice_sas.is_cancelled());