crypto: Move the verification cache into a separate module

master
Damir Jelić 2021-06-04 18:09:20 +02:00
parent 31e00eb434
commit 96d4566111
4 changed files with 134 additions and 101 deletions

View File

@ -0,0 +1,118 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use dashmap::DashMap;
use matrix_sdk_common::{
identifiers::{DeviceId, UserId},
uuid::Uuid,
};
use super::{event_enums::OutgoingContent, sas::content_to_request, Sas, Verification};
use crate::{OutgoingRequest, RoomMessageRequest};
#[derive(Clone, Debug)]
pub struct VerificationCache {
verification: Arc<DashMap<String, Verification>>,
outgoing_requests: Arc<DashMap<Uuid, OutgoingRequest>>,
}
impl VerificationCache {
pub fn new() -> Self {
Self { verification: DashMap::new().into(), outgoing_requests: DashMap::new().into() }
}
#[cfg(test)]
pub fn is_empty(&self) -> bool {
self.verification.is_empty()
}
pub fn insert_sas(&self, sas: Sas) {
self.verification.insert(sas.flow_id().as_str().to_string(), sas.into());
}
pub fn outgoing_requests(&self) -> Vec<OutgoingRequest> {
self.outgoing_requests.iter().map(|r| (*r).clone()).collect()
}
pub fn garbage_collect(&self) -> Vec<OutgoingRequest> {
self.verification.retain(|_, s| !(s.is_done() || s.is_cancelled()));
self.verification
.iter()
.filter_map(|s| {
#[allow(irrefutable_let_patterns)]
if let Verification::SasV1(s) = s.value() {
s.cancel_if_timed_out().map(|r| OutgoingRequest {
request_id: r.request_id(),
request: Arc::new(r.into()),
})
} else {
None
}
})
.collect()
}
pub fn get_sas(&self, transaction_id: &str) -> Option<Sas> {
self.verification.get(transaction_id).and_then(|v| {
#[allow(irrefutable_let_patterns)]
if let Verification::SasV1(sas) = v.value() {
Some(sas.clone())
} else {
None
}
})
}
pub fn add_request(&self, request: OutgoingRequest) {
self.outgoing_requests.insert(request.request_id, request);
}
pub fn queue_up_content(
&self,
recipient: &UserId,
recipient_device: &DeviceId,
content: OutgoingContent,
) {
match content {
OutgoingContent::ToDevice(c) => {
let request = content_to_request(recipient, recipient_device.to_owned(), c);
let request_id = request.txn_id;
let request = OutgoingRequest { request_id, request: Arc::new(request.into()) };
self.outgoing_requests.insert(request_id, request);
}
OutgoingContent::Room(r, c) => {
let request_id = Uuid::new_v4();
let request = OutgoingRequest {
request: Arc::new(
RoomMessageRequest { room_id: r, txn_id: request_id, content: c }.into(),
),
request_id,
};
self.outgoing_requests.insert(request_id, request);
}
}
}
pub fn mark_request_as_sent(&self, uuid: &Uuid) {
self.outgoing_requests.remove(uuid);
}
}

View File

@ -23,10 +23,11 @@ use matrix_sdk_common::{
use tracing::{info, warn}; use tracing::{info, warn};
use super::{ use super::{
cache::VerificationCache,
event_enums::{AnyEvent, AnyVerificationContent, OutgoingContent}, event_enums::{AnyEvent, AnyVerificationContent, OutgoingContent},
requests::VerificationRequest, requests::VerificationRequest,
sas::{content_to_request, Sas}, sas::{content_to_request, Sas},
FlowId, Verification, VerificationResult, FlowId, VerificationResult,
}; };
use crate::{ use crate::{
olm::PrivateCrossSigningIdentity, olm::PrivateCrossSigningIdentity,
@ -35,96 +36,6 @@ use crate::{
OutgoingVerificationRequest, ReadOnlyAccount, ReadOnlyDevice, RoomMessageRequest, OutgoingVerificationRequest, ReadOnlyAccount, ReadOnlyDevice, RoomMessageRequest,
}; };
#[derive(Clone, Debug)]
pub struct VerificationCache {
verification: Arc<DashMap<String, Verification>>,
outgoing_requests: Arc<DashMap<Uuid, OutgoingRequest>>,
}
impl VerificationCache {
pub fn new() -> Self {
Self { verification: DashMap::new().into(), outgoing_requests: DashMap::new().into() }
}
#[cfg(test)]
fn is_empty(&self) -> bool {
self.verification.is_empty()
}
pub fn insert_sas(&self, sas: Sas) {
self.verification.insert(sas.flow_id().as_str().to_string(), sas.into());
}
pub fn garbage_collect(&self) -> Vec<OutgoingRequest> {
self.verification.retain(|_, s| !(s.is_done() || s.is_cancelled()));
self.verification
.iter()
.filter_map(|s| {
#[allow(irrefutable_let_patterns)]
if let Verification::SasV1(s) = s.value() {
s.cancel_if_timed_out().map(|r| OutgoingRequest {
request_id: r.request_id(),
request: Arc::new(r.into()),
})
} else {
None
}
})
.collect()
}
pub fn get_sas(&self, transaction_id: &str) -> Option<Sas> {
self.verification.get(transaction_id).and_then(|v| {
#[allow(irrefutable_let_patterns)]
if let Verification::SasV1(sas) = v.value() {
Some(sas.clone())
} else {
None
}
})
}
pub fn add_request(&self, request: OutgoingRequest) {
self.outgoing_requests.insert(request.request_id, request);
}
pub fn queue_up_content(
&self,
recipient: &UserId,
recipient_device: &DeviceId,
content: OutgoingContent,
) {
match content {
OutgoingContent::ToDevice(c) => {
let request = content_to_request(recipient, recipient_device.to_owned(), c);
let request_id = request.txn_id;
let request = OutgoingRequest { request_id, request: Arc::new(request.into()) };
self.outgoing_requests.insert(request_id, request);
}
OutgoingContent::Room(r, c) => {
let request_id = Uuid::new_v4();
let request = OutgoingRequest {
request: Arc::new(
RoomMessageRequest { room_id: r, txn_id: request_id, content: c }.into(),
),
request_id,
};
self.outgoing_requests.insert(request_id, request);
}
}
}
pub fn mark_request_as_sent(&self, uuid: &Uuid) {
self.outgoing_requests.remove(uuid);
}
}
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct VerificationMachine { pub struct VerificationMachine {
account: ReadOnlyAccount, account: ReadOnlyAccount,
@ -204,7 +115,7 @@ impl VerificationMachine {
} }
pub fn outgoing_messages(&self) -> Vec<OutgoingRequest> { pub fn outgoing_messages(&self) -> Vec<OutgoingRequest> {
self.verifications.outgoing_requests.iter().map(|r| (*r).clone()).collect() self.verifications.outgoing_requests()
} }
pub fn garbage_collect(&self) { pub fn garbage_collect(&self) {
@ -490,11 +401,12 @@ mod test {
let event = wrap_any_to_device_content(bob.user_id(), content); let event = wrap_any_to_device_content(bob.user_id(), content);
assert!(alice_machine.verifications.outgoing_requests.is_empty()); assert!(alice_machine.verifications.outgoing_requests().is_empty());
alice_machine.receive_any_event(&event).await.unwrap(); alice_machine.receive_any_event(&event).await.unwrap();
assert!(!alice_machine.verifications.outgoing_requests.is_empty()); assert!(!alice_machine.verifications.outgoing_requests().is_empty());
let request = alice_machine.verifications.outgoing_requests.iter().next().unwrap().clone(); let request =
alice_machine.verifications.outgoing_requests().iter().next().unwrap().clone();
let txn_id = *request.request_id(); let txn_id = *request.request_id();
let content = OutgoingContent::try_from(request).unwrap(); let content = OutgoingContent::try_from(request).unwrap();
let content = KeyContent::try_from(&content).unwrap().into(); let content = KeyContent::try_from(&content).unwrap().into();
@ -528,14 +440,14 @@ mod test {
let alice = alice_machine.get_sas(bob.flow_id().as_str()).unwrap(); let alice = alice_machine.get_sas(bob.flow_id().as_str()).unwrap();
assert!(!alice.timed_out()); assert!(!alice.timed_out());
assert!(alice_machine.verifications.outgoing_requests.is_empty()); assert!(alice_machine.verifications.outgoing_requests().is_empty());
// This line panics on macOS, so we're disabled for now. // This line panics on macOS, so we're disabled for now.
alice.set_creation_time(Instant::now() - Duration::from_secs(60 * 15)); alice.set_creation_time(Instant::now() - Duration::from_secs(60 * 15));
assert!(alice.timed_out()); assert!(alice.timed_out());
assert!(alice_machine.verifications.outgoing_requests.is_empty()); assert!(alice_machine.verifications.outgoing_requests().is_empty());
alice_machine.garbage_collect(); alice_machine.garbage_collect();
assert!(!alice_machine.verifications.outgoing_requests.is_empty()); assert!(!alice_machine.verifications.outgoing_requests().is_empty());
alice_machine.garbage_collect(); alice_machine.garbage_collect();
assert!(alice_machine.verifications.is_empty()); assert!(alice_machine.verifications.is_empty());
} }

View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
mod cache;
mod event_enums; mod event_enums;
mod machine; mod machine;
mod requests; mod requests;
@ -20,7 +21,7 @@ mod sas;
use std::sync::Arc; use std::sync::Arc;
use event_enums::OutgoingContent; use event_enums::OutgoingContent;
pub use machine::{VerificationCache, VerificationMachine}; pub use machine::VerificationMachine;
use matrix_sdk_common::{ use matrix_sdk_common::{
api::r0::keys::upload_signatures::Request as SignatureUploadRequest, api::r0::keys::upload_signatures::Request as SignatureUploadRequest,
events::{ events::{

View File

@ -36,11 +36,12 @@ use matrix_sdk_common::{
use tracing::{info, warn}; use tracing::{info, warn};
use super::{ use super::{
cache::VerificationCache,
event_enums::{ event_enums::{
CancelContent, DoneContent, OutgoingContent, ReadyContent, RequestContent, StartContent, CancelContent, DoneContent, OutgoingContent, ReadyContent, RequestContent, StartContent,
}, },
sas::content_to_request, sas::content_to_request,
Cancelled, FlowId, VerificationCache, Cancelled, FlowId,
}; };
use crate::{ use crate::{
olm::{PrivateCrossSigningIdentity, ReadOnlyAccount}, olm::{PrivateCrossSigningIdentity, ReadOnlyAccount},
@ -684,8 +685,9 @@ mod test {
olm::{PrivateCrossSigningIdentity, ReadOnlyAccount}, olm::{PrivateCrossSigningIdentity, ReadOnlyAccount},
store::{Changes, CryptoStore, MemoryStore}, store::{Changes, CryptoStore, MemoryStore},
verification::{ verification::{
cache::VerificationCache,
event_enums::{OutgoingContent, ReadyContent, StartContent}, event_enums::{OutgoingContent, ReadyContent, StartContent},
FlowId, VerificationCache, FlowId,
}, },
ReadOnlyDevice, ReadOnlyDevice,
}; };