crypto: Add a VerificationCache struct

This commit is contained in:
Damir Jelić 2021-05-21 14:03:33 +02:00
parent 98c259dc1e
commit d928f39f68

View file

@ -37,13 +37,80 @@ use crate::{
OutgoingVerificationRequest, ReadOnlyAccount, ReadOnlyDevice, RoomMessageRequest,
};
#[derive(Clone, Debug)]
pub struct VerificationCache {
sas_verification: Arc<DashMap<String, Sas>>,
room_sas_verifications: Arc<DashMap<EventId, Sas>>,
}
impl VerificationCache {
fn new() -> Self {
Self {
sas_verification: DashMap::new().into(),
room_sas_verifications: DashMap::new().into(),
}
}
#[cfg(test)]
fn is_empty(&self) -> bool {
self.room_sas_verifications.is_empty() && self.sas_verification.is_empty()
}
pub fn get_room_sas(&self, event_id: &EventId) -> Option<Sas> {
self.room_sas_verifications.get(event_id).map(|s| s.clone())
}
pub fn garbage_collect(&self) -> Vec<OutgoingRequest> {
self.sas_verification.retain(|_, s| !(s.is_done() || s.is_canceled()));
self.room_sas_verifications.retain(|_, s| !(s.is_done() || s.is_canceled()));
let mut requests: Vec<OutgoingRequest> = self
.sas_verification
.iter()
.filter_map(|s| {
s.cancel_if_timed_out().map(|r| OutgoingRequest {
request_id: r.request_id(),
request: Arc::new(r.into()),
})
})
.collect();
let room_requests: Vec<OutgoingRequest> = self
.room_sas_verifications
.iter()
.filter_map(|s| {
s.cancel_if_timed_out().map(|r| OutgoingRequest {
request_id: r.request_id(),
request: Arc::new(r.into()),
})
})
.collect();
requests.extend(room_requests);
requests
}
pub fn get_sas(&self, transaction_id: &str) -> Option<Sas> {
let sas = if let Ok(e) = EventId::try_from(transaction_id) {
self.room_sas_verifications.get(&e).map(|s| s.clone())
} else {
None
};
if sas.is_some() {
sas
} else {
self.sas_verification.get(transaction_id).map(|s| s.clone())
}
}
}
#[derive(Clone, Debug)]
pub struct VerificationMachine {
account: ReadOnlyAccount,
private_identity: Arc<Mutex<PrivateCrossSigningIdentity>>,
pub(crate) store: Arc<Box<dyn CryptoStore>>,
verifications: Arc<DashMap<String, Sas>>,
room_verifications: Arc<DashMap<EventId, Sas>>,
verifications: VerificationCache,
requests: Arc<DashMap<String, VerificationRequest>>,
outgoing_messages: Arc<DashMap<Uuid, OutgoingRequest>>,
}
@ -58,9 +125,8 @@ impl VerificationMachine {
account,
private_identity: identity,
store,
verifications: DashMap::new().into(),
verifications: VerificationCache::new(),
requests: DashMap::new().into(),
room_verifications: DashMap::new().into(),
outgoing_messages: DashMap::new().into(),
}
}
@ -89,7 +155,9 @@ impl VerificationMachine {
let request =
content_to_request(device.user_id(), device.device_id().to_owned(), c);
self.verifications.insert(sas.flow_id().as_str().to_owned(), sas.clone());
self.verifications
.sas_verification
.insert(sas.flow_id().as_str().to_owned(), sas.clone());
request.into()
}
@ -103,19 +171,7 @@ impl VerificationMachine {
}
pub fn get_sas(&self, transaction_id: &str) -> Option<Sas> {
let sas = if let Ok(e) = EventId::try_from(transaction_id) {
#[allow(clippy::map_clone)]
self.room_verifications.get(&e).map(|s| s.clone())
} else {
None
};
if sas.is_some() {
sas
} else {
#[allow(clippy::map_clone)]
self.verifications.get(transaction_id).map(|s| s.clone())
}
self.verifications.get_sas(transaction_id)
}
fn queue_up_content(
@ -170,15 +226,8 @@ impl VerificationMachine {
}
pub fn garbage_collect(&self) {
self.verifications.retain(|_, s| !(s.is_done() || s.is_canceled()));
for sas in self.verifications.iter() {
if let Some(r) = sas.cancel_if_timed_out() {
self.outgoing_messages.insert(
r.request_id(),
OutgoingRequest { request_id: r.request_id(), request: Arc::new(r.into()) },
);
}
for request in self.verifications.garbage_collect() {
self.outgoing_messages.insert(*request.request_id(), request);
}
}
@ -255,7 +304,8 @@ impl VerificationMachine {
// TODO remove this unwrap
let accept_request = s.accept().unwrap();
self.room_verifications
self.verifications
.room_sas_verifications
.insert(e.content.relation.event_id.clone(), s);
self.outgoing_messages
@ -273,7 +323,7 @@ impl VerificationMachine {
}
}
AnySyncMessageEvent::KeyVerificationKey(e) => {
if let Some(s) = self.room_verifications.get(&e.content.relation.event_id) {
if let Some(s) = self.verifications.get_room_sas(&e.content.relation.event_id) {
self.receive_room_event_helper(
&s,
&m.clone().into_full_event(room_id.clone()),
@ -281,7 +331,7 @@ impl VerificationMachine {
};
}
AnySyncMessageEvent::KeyVerificationMac(e) => {
if let Some(s) = self.room_verifications.get(&e.content.relation.event_id) {
if let Some(s) = self.verifications.get_room_sas(&e.content.relation.event_id) {
self.receive_room_event_helper(
&s,
&m.clone().into_full_event(room_id.clone()),
@ -290,7 +340,7 @@ impl VerificationMachine {
}
AnySyncMessageEvent::KeyVerificationDone(e) => {
if let Some(s) = self.room_verifications.get(&e.content.relation.event_id) {
if let Some(s) = self.verifications.get_room_sas(&e.content.relation.event_id) {
let content =
s.receive_room_event(&m.clone().into_full_event(room_id.clone()));
@ -373,7 +423,9 @@ impl VerificationMachine {
self.store.get_user_identity(&e.sender).await?,
) {
Ok(s) => {
self.verifications.insert(e.content.transaction_id.clone(), s);
self.verifications
.sas_verification
.insert(e.content.transaction_id.clone(), s);
}
Err(c) => {
warn!(
@ -391,7 +443,7 @@ impl VerificationMachine {
}
}
AnyToDeviceEvent::KeyVerificationCancel(e) => {
self.verifications.remove(&e.content.transaction_id);
self.verifications.sas_verification.remove(&e.content.transaction_id);
}
AnyToDeviceEvent::KeyVerificationAccept(e) => {
if let Some(s) = self.get_sas(&e.content.transaction_id) {