crypto: Scope the verifications per sender

This commit is contained in:
Damir Jelić 2021-06-08 16:13:14 +02:00
parent 533a5b92b0
commit ada71586ac
8 changed files with 114 additions and 45 deletions

View file

@ -81,7 +81,7 @@ async fn login(
match event {
AnyToDeviceEvent::KeyVerificationStart(e) => {
let sas = client
.get_verification(&e.content.transaction_id)
.get_verification(&e.sender, &e.content.transaction_id)
.await
.expect("Sas object wasn't created");
println!(
@ -95,7 +95,7 @@ async fn login(
AnyToDeviceEvent::KeyVerificationKey(e) => {
let sas = client
.get_verification(&e.content.transaction_id)
.get_verification(&e.sender, &e.content.transaction_id)
.await
.expect("Sas object wasn't created");
@ -104,7 +104,7 @@ async fn login(
AnyToDeviceEvent::KeyVerificationMac(e) => {
let sas = client
.get_verification(&e.content.transaction_id)
.get_verification(&e.sender, &e.content.transaction_id)
.await
.expect("Sas object wasn't created");
@ -141,7 +141,10 @@ async fn login(
}
AnySyncMessageEvent::KeyVerificationKey(e) => {
let sas = client
.get_verification(e.content.relation.event_id.as_str())
.get_verification(
&e.sender,
e.content.relation.event_id.as_str(),
)
.await
.expect("Sas object wasn't created");
@ -149,7 +152,10 @@ async fn login(
}
AnySyncMessageEvent::KeyVerificationMac(e) => {
let sas = client
.get_verification(e.content.relation.event_id.as_str())
.get_verification(
&e.sender,
e.content.relation.event_id.as_str(),
)
.await
.expect("Sas object wasn't created");

View file

@ -2185,9 +2185,9 @@ impl Client {
/// Get a `Sas` verification object with the given flow id.
#[cfg(feature = "encryption")]
#[cfg_attr(feature = "docs", doc(cfg(encryption)))]
pub async fn get_verification(&self, flow_id: &str) -> Option<Sas> {
pub async fn get_verification(&self, user_id: &UserId, flow_id: &str) -> Option<Sas> {
self.base_client
.get_verification(flow_id)
.get_verification(user_id, flow_id)
.await
.map(|sas| Sas { inner: sas, client: self.clone() })
}

View file

@ -1213,8 +1213,13 @@ impl BaseClient {
/// *m.key.verification.start* event.
#[cfg(feature = "encryption")]
#[cfg_attr(feature = "docs", doc(cfg(encryption)))]
pub async fn get_verification(&self, flow_id: &str) -> Option<Sas> {
self.olm.lock().await.as_ref().and_then(|o| o.get_verification(flow_id))
pub async fn get_verification(&self, user_id: &UserId, flow_id: &str) -> Option<Sas> {
self.olm
.lock()
.await
.as_ref()
.and_then(|o| o.get_verification(user_id, flow_id).map(|v| v.sas_v1()))
.flatten()
}
/// Get a specific device of a user.

View file

@ -59,7 +59,7 @@ use crate::{
Changes, CryptoStore, DeviceChanges, IdentityChanges, MemoryStore, Result as StoreResult,
Store,
},
verification::{Sas, VerificationMachine, VerificationRequest},
verification::{Verification, VerificationMachine, VerificationRequest},
ToDeviceRequest,
};
@ -717,9 +717,9 @@ impl OlmMachine {
Ok(())
}
/// Get a `Sas` verification object with the given flow id.
pub fn get_verification(&self, flow_id: &str) -> Option<Sas> {
self.verification_machine.get_sas(flow_id)
/// Get a verification object for the given user id with the given flow id.
pub fn get_verification(&self, user_id: &UserId, flow_id: &str) -> Option<Verification> {
self.verification_machine.get_verification(user_id, flow_id)
}
/// Get a verification request object with the given flow id.
@ -1761,7 +1761,11 @@ pub(crate) mod test {
let event = request_to_event(alice.user_id(), &request.into());
bob.handle_verification_event(&event).await;
let bob_sas = bob.get_verification(alice_sas.flow_id().as_str()).unwrap();
let bob_sas = bob
.get_verification(alice.user_id(), alice_sas.flow_id().as_str())
.unwrap()
.sas_v1()
.unwrap();
assert!(alice_sas.emoji().is_none());
assert!(bob_sas.emoji().is_none());

View file

@ -23,7 +23,7 @@ use crate::{OutgoingRequest, RoomMessageRequest};
#[derive(Clone, Debug)]
pub struct VerificationCache {
verification: Arc<DashMap<String, Verification>>,
verification: Arc<DashMap<UserId, DashMap<String, Verification>>>,
outgoing_requests: Arc<DashMap<Uuid, OutgoingRequest>>,
}
@ -35,11 +35,24 @@ impl VerificationCache {
#[cfg(test)]
#[allow(dead_code)]
pub fn is_empty(&self) -> bool {
self.verification.is_empty()
self.verification.iter().all(|m| m.is_empty())
}
pub fn insert(&self, verification: impl Into<Verification>) {
let verification = verification.into();
self.verification
.entry(verification.other_user().to_owned())
.or_insert_with(DashMap::new)
.insert(verification.flow_id().to_owned(), verification);
}
pub fn insert_sas(&self, sas: Sas) {
self.verification.insert(sas.flow_id().as_str().to_string(), sas.into());
self.insert(sas);
}
pub fn get(&self, sender: &UserId, flow_id: &str) -> Option<Verification> {
self.verification.get(sender).and_then(|m| m.get(flow_id).map(|v| v.clone()))
}
pub fn outgoing_requests(&self) -> Vec<OutgoingRequest> {
@ -47,28 +60,38 @@ impl VerificationCache {
}
pub fn garbage_collect(&self) -> Vec<OutgoingRequest> {
self.verification.retain(|_, s| !(s.is_done() || s.is_cancelled()));
for user_verification in self.verification.iter() {
user_verification.retain(|_, s| !(s.is_done() || s.is_cancelled()));
}
self.verification.retain(|_, m| !m.is_empty());
self.verification
.iter()
.filter_map(|s| {
#[allow(irrefutable_let_patterns)]
if let Verification::SasV1(s) = s.value() {
s.cancel_if_timed_out().map(|r| OutgoingRequest {
request_id: r.request_id(),
request: Arc::new(r.into()),
.flat_map(|v| {
let requests: Vec<OutgoingRequest> = v
.value()
.iter()
.filter_map(|s| {
if let Verification::SasV1(s) = s.value() {
s.cancel_if_timed_out().map(|r| OutgoingRequest {
request_id: r.request_id(),
request: Arc::new(r.into()),
})
} else {
None
}
})
} else {
None
}
.collect();
requests
})
.collect()
}
pub fn get_sas(&self, transaction_id: &str) -> Option<Sas> {
self.verification.get(transaction_id).and_then(|v| {
#[allow(irrefutable_let_patterns)]
if let Verification::SasV1(sas) = v.value() {
pub fn get_sas(&self, user_id: &UserId, flow_id: &str) -> Option<Sas> {
self.get(user_id, flow_id).and_then(|v| {
if let Verification::SasV1(sas) = v {
Some(sas.clone())
} else {
None

View file

@ -24,7 +24,7 @@ use super::{
event_enums::{AnyEvent, AnyVerificationContent, OutgoingContent},
requests::VerificationRequest,
sas::{content_to_request, Sas},
FlowId, VerificationResult,
FlowId, Verification, VerificationResult,
};
use crate::{
olm::PrivateCrossSigningIdentity,
@ -94,8 +94,12 @@ impl VerificationMachine {
self.requests.get(flow_id.as_ref()).map(|s| s.clone())
}
pub fn get_sas(&self, transaction_id: &str) -> Option<Sas> {
self.verifications.get_sas(transaction_id)
pub fn get_verification(&self, user_id: &UserId, flow_id: &str) -> Option<Verification> {
self.verifications.get(user_id, flow_id)
}
pub fn get_sas(&self, user_id: &UserId, flow_id: &str) -> Option<Sas> {
self.verifications.get_sas(user_id, flow_id)
}
#[cfg(not(target_arch = "wasm32"))]
@ -242,7 +246,7 @@ impl VerificationMachine {
verification.receive_cancel(event.sender(), c);
}
if let Some(sas) = self.verifications.get_sas(flow_id.as_str()) {
if let Some(sas) = self.get_sas(event.sender(), flow_id.as_str()) {
// This won't produce an outgoing content
let _ = sas.receive_any_event(event.sender(), &content);
}
@ -296,7 +300,7 @@ impl VerificationMachine {
}
}
AnyVerificationContent::Accept(_) | AnyVerificationContent::Key(_) => {
if let Some(sas) = self.verifications.get_sas(flow_id.as_str()) {
if let Some(sas) = self.get_sas(event.sender(), flow_id.as_str()) {
if sas.flow_id() == &flow_id {
if let Some(content) = sas.receive_any_event(event.sender(), &content) {
self.queue_up_content(
@ -311,7 +315,7 @@ impl VerificationMachine {
}
}
AnyVerificationContent::Mac(_) => {
if let Some(s) = self.verifications.get_sas(flow_id.as_str()) {
if let Some(s) = self.get_sas(event.sender(), flow_id.as_str()) {
if s.flow_id() == &flow_id {
let content = s.receive_any_event(event.sender(), &content);
@ -328,12 +332,15 @@ impl VerificationMachine {
verification.receive_done(event.sender(), c);
}
if let Some(s) = self.verifications.get_sas(flow_id.as_str()) {
let content = s.receive_any_event(event.sender(), &content);
match self.get_verification(event.sender(), flow_id.as_str()) {
Some(Verification::SasV1(sas)) => {
let content = sas.receive_any_event(event.sender(), &content);
if s.is_done() {
self.mark_sas_as_done(s, content).await?;
if sas.is_done() {
self.mark_sas_as_done(sas, content).await?;
}
}
None => (),
}
}
}
@ -426,7 +433,7 @@ mod test {
async fn full_flow() {
let (alice_machine, bob) = setup_verification_machine().await;
let alice = alice_machine.get_sas(bob.flow_id().as_str()).unwrap();
let alice = alice_machine.get_sas(bob.user_id(), bob.flow_id().as_str()).unwrap();
let request = alice.accept().unwrap();
@ -472,7 +479,7 @@ mod test {
#[tokio::test]
async fn timing_out() {
let (alice_machine, bob) = setup_verification_machine().await;
let alice = alice_machine.get_sas(bob.flow_id().as_str()).unwrap();
let alice = alice_machine.get_sas(bob.user_id(), bob.flow_id().as_str()).unwrap();
assert!(!alice.timed_out());
assert!(alice_machine.verifications.outgoing_requests().is_empty());

View file

@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#![allow(missing_docs)]
mod cache;
mod event_enums;
mod machine;
@ -57,11 +59,31 @@ impl Verification {
}
}
pub fn sas_v1(self) -> Option<Sas> {
if let Verification::SasV1(sas) = self {
Some(sas)
} else {
None
}
}
pub fn flow_id(&self) -> &str {
match self {
Verification::SasV1(s) => s.flow_id().as_str(),
}
}
pub fn is_cancelled(&self) -> bool {
match self {
Verification::SasV1(s) => s.is_cancelled(),
}
}
pub fn other_user(&self) -> &UserId {
match self {
Verification::SasV1(s) => s.other_user_id(),
}
}
}
impl From<Sas> for Verification {

View file

@ -810,7 +810,8 @@ mod test {
let content = StartContent::try_from(&start_content).unwrap();
let flow_id = content.flow_id().to_owned();
alice_request.receive_start(bob_device.user_id(), &content).await.unwrap();
let alice_sas = alice_request.verification_cache.get_sas(&flow_id).unwrap();
let alice_sas =
alice_request.verification_cache.get_sas(bob_device.user_id(), &flow_id).unwrap();
assert!(!bob_sas.is_cancelled());
assert!(!alice_sas.is_cancelled());
@ -867,7 +868,8 @@ mod test {
let content = StartContent::try_from(&start_content).unwrap();
let flow_id = content.flow_id().to_owned();
alice_request.receive_start(bob_device.user_id(), &content).await.unwrap();
let alice_sas = alice_request.verification_cache.get_sas(&flow_id).unwrap();
let alice_sas =
alice_request.verification_cache.get_sas(bob_device.user_id(), &flow_id).unwrap();
assert!(!bob_sas.is_cancelled());
assert!(!alice_sas.is_cancelled());