crypto: Add a VerificationCache struct
This commit is contained in:
parent
98c259dc1e
commit
d928f39f68
1 changed files with 85 additions and 33 deletions
|
@ -37,13 +37,80 @@ use crate::{
|
|||
OutgoingVerificationRequest, ReadOnlyAccount, ReadOnlyDevice, RoomMessageRequest,
|
||||
};
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct VerificationCache {
|
||||
sas_verification: Arc<DashMap<String, Sas>>,
|
||||
room_sas_verifications: Arc<DashMap<EventId, Sas>>,
|
||||
}
|
||||
|
||||
impl VerificationCache {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
sas_verification: DashMap::new().into(),
|
||||
room_sas_verifications: DashMap::new().into(),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn is_empty(&self) -> bool {
|
||||
self.room_sas_verifications.is_empty() && self.sas_verification.is_empty()
|
||||
}
|
||||
|
||||
pub fn get_room_sas(&self, event_id: &EventId) -> Option<Sas> {
|
||||
self.room_sas_verifications.get(event_id).map(|s| s.clone())
|
||||
}
|
||||
|
||||
pub fn garbage_collect(&self) -> Vec<OutgoingRequest> {
|
||||
self.sas_verification.retain(|_, s| !(s.is_done() || s.is_canceled()));
|
||||
self.room_sas_verifications.retain(|_, s| !(s.is_done() || s.is_canceled()));
|
||||
|
||||
let mut requests: Vec<OutgoingRequest> = self
|
||||
.sas_verification
|
||||
.iter()
|
||||
.filter_map(|s| {
|
||||
s.cancel_if_timed_out().map(|r| OutgoingRequest {
|
||||
request_id: r.request_id(),
|
||||
request: Arc::new(r.into()),
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
let room_requests: Vec<OutgoingRequest> = self
|
||||
.room_sas_verifications
|
||||
.iter()
|
||||
.filter_map(|s| {
|
||||
s.cancel_if_timed_out().map(|r| OutgoingRequest {
|
||||
request_id: r.request_id(),
|
||||
request: Arc::new(r.into()),
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
requests.extend(room_requests);
|
||||
|
||||
requests
|
||||
}
|
||||
|
||||
pub fn get_sas(&self, transaction_id: &str) -> Option<Sas> {
|
||||
let sas = if let Ok(e) = EventId::try_from(transaction_id) {
|
||||
self.room_sas_verifications.get(&e).map(|s| s.clone())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
if sas.is_some() {
|
||||
sas
|
||||
} else {
|
||||
self.sas_verification.get(transaction_id).map(|s| s.clone())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct VerificationMachine {
|
||||
account: ReadOnlyAccount,
|
||||
private_identity: Arc<Mutex<PrivateCrossSigningIdentity>>,
|
||||
pub(crate) store: Arc<Box<dyn CryptoStore>>,
|
||||
verifications: Arc<DashMap<String, Sas>>,
|
||||
room_verifications: Arc<DashMap<EventId, Sas>>,
|
||||
verifications: VerificationCache,
|
||||
requests: Arc<DashMap<String, VerificationRequest>>,
|
||||
outgoing_messages: Arc<DashMap<Uuid, OutgoingRequest>>,
|
||||
}
|
||||
|
@ -58,9 +125,8 @@ impl VerificationMachine {
|
|||
account,
|
||||
private_identity: identity,
|
||||
store,
|
||||
verifications: DashMap::new().into(),
|
||||
verifications: VerificationCache::new(),
|
||||
requests: DashMap::new().into(),
|
||||
room_verifications: DashMap::new().into(),
|
||||
outgoing_messages: DashMap::new().into(),
|
||||
}
|
||||
}
|
||||
|
@ -89,7 +155,9 @@ impl VerificationMachine {
|
|||
let request =
|
||||
content_to_request(device.user_id(), device.device_id().to_owned(), c);
|
||||
|
||||
self.verifications.insert(sas.flow_id().as_str().to_owned(), sas.clone());
|
||||
self.verifications
|
||||
.sas_verification
|
||||
.insert(sas.flow_id().as_str().to_owned(), sas.clone());
|
||||
|
||||
request.into()
|
||||
}
|
||||
|
@ -103,19 +171,7 @@ impl VerificationMachine {
|
|||
}
|
||||
|
||||
pub fn get_sas(&self, transaction_id: &str) -> Option<Sas> {
|
||||
let sas = if let Ok(e) = EventId::try_from(transaction_id) {
|
||||
#[allow(clippy::map_clone)]
|
||||
self.room_verifications.get(&e).map(|s| s.clone())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
if sas.is_some() {
|
||||
sas
|
||||
} else {
|
||||
#[allow(clippy::map_clone)]
|
||||
self.verifications.get(transaction_id).map(|s| s.clone())
|
||||
}
|
||||
self.verifications.get_sas(transaction_id)
|
||||
}
|
||||
|
||||
fn queue_up_content(
|
||||
|
@ -170,15 +226,8 @@ impl VerificationMachine {
|
|||
}
|
||||
|
||||
pub fn garbage_collect(&self) {
|
||||
self.verifications.retain(|_, s| !(s.is_done() || s.is_canceled()));
|
||||
|
||||
for sas in self.verifications.iter() {
|
||||
if let Some(r) = sas.cancel_if_timed_out() {
|
||||
self.outgoing_messages.insert(
|
||||
r.request_id(),
|
||||
OutgoingRequest { request_id: r.request_id(), request: Arc::new(r.into()) },
|
||||
);
|
||||
}
|
||||
for request in self.verifications.garbage_collect() {
|
||||
self.outgoing_messages.insert(*request.request_id(), request);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -255,7 +304,8 @@ impl VerificationMachine {
|
|||
// TODO remove this unwrap
|
||||
let accept_request = s.accept().unwrap();
|
||||
|
||||
self.room_verifications
|
||||
self.verifications
|
||||
.room_sas_verifications
|
||||
.insert(e.content.relation.event_id.clone(), s);
|
||||
|
||||
self.outgoing_messages
|
||||
|
@ -273,7 +323,7 @@ impl VerificationMachine {
|
|||
}
|
||||
}
|
||||
AnySyncMessageEvent::KeyVerificationKey(e) => {
|
||||
if let Some(s) = self.room_verifications.get(&e.content.relation.event_id) {
|
||||
if let Some(s) = self.verifications.get_room_sas(&e.content.relation.event_id) {
|
||||
self.receive_room_event_helper(
|
||||
&s,
|
||||
&m.clone().into_full_event(room_id.clone()),
|
||||
|
@ -281,7 +331,7 @@ impl VerificationMachine {
|
|||
};
|
||||
}
|
||||
AnySyncMessageEvent::KeyVerificationMac(e) => {
|
||||
if let Some(s) = self.room_verifications.get(&e.content.relation.event_id) {
|
||||
if let Some(s) = self.verifications.get_room_sas(&e.content.relation.event_id) {
|
||||
self.receive_room_event_helper(
|
||||
&s,
|
||||
&m.clone().into_full_event(room_id.clone()),
|
||||
|
@ -290,7 +340,7 @@ impl VerificationMachine {
|
|||
}
|
||||
|
||||
AnySyncMessageEvent::KeyVerificationDone(e) => {
|
||||
if let Some(s) = self.room_verifications.get(&e.content.relation.event_id) {
|
||||
if let Some(s) = self.verifications.get_room_sas(&e.content.relation.event_id) {
|
||||
let content =
|
||||
s.receive_room_event(&m.clone().into_full_event(room_id.clone()));
|
||||
|
||||
|
@ -373,7 +423,9 @@ impl VerificationMachine {
|
|||
self.store.get_user_identity(&e.sender).await?,
|
||||
) {
|
||||
Ok(s) => {
|
||||
self.verifications.insert(e.content.transaction_id.clone(), s);
|
||||
self.verifications
|
||||
.sas_verification
|
||||
.insert(e.content.transaction_id.clone(), s);
|
||||
}
|
||||
Err(c) => {
|
||||
warn!(
|
||||
|
@ -391,7 +443,7 @@ impl VerificationMachine {
|
|||
}
|
||||
}
|
||||
AnyToDeviceEvent::KeyVerificationCancel(e) => {
|
||||
self.verifications.remove(&e.content.transaction_id);
|
||||
self.verifications.sas_verification.remove(&e.content.transaction_id);
|
||||
}
|
||||
AnyToDeviceEvent::KeyVerificationAccept(e) => {
|
||||
if let Some(s) = self.get_sas(&e.content.transaction_id) {
|
||||
|
|
Loading…
Reference in a new issue