crypto: Add a VerificationCache struct
parent
98c259dc1e
commit
d928f39f68
|
@ -37,13 +37,80 @@ use crate::{
|
||||||
OutgoingVerificationRequest, ReadOnlyAccount, ReadOnlyDevice, RoomMessageRequest,
|
OutgoingVerificationRequest, ReadOnlyAccount, ReadOnlyDevice, RoomMessageRequest,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct VerificationCache {
|
||||||
|
sas_verification: Arc<DashMap<String, Sas>>,
|
||||||
|
room_sas_verifications: Arc<DashMap<EventId, Sas>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl VerificationCache {
|
||||||
|
fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
sas_verification: DashMap::new().into(),
|
||||||
|
room_sas_verifications: DashMap::new().into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
fn is_empty(&self) -> bool {
|
||||||
|
self.room_sas_verifications.is_empty() && self.sas_verification.is_empty()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_room_sas(&self, event_id: &EventId) -> Option<Sas> {
|
||||||
|
self.room_sas_verifications.get(event_id).map(|s| s.clone())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn garbage_collect(&self) -> Vec<OutgoingRequest> {
|
||||||
|
self.sas_verification.retain(|_, s| !(s.is_done() || s.is_canceled()));
|
||||||
|
self.room_sas_verifications.retain(|_, s| !(s.is_done() || s.is_canceled()));
|
||||||
|
|
||||||
|
let mut requests: Vec<OutgoingRequest> = self
|
||||||
|
.sas_verification
|
||||||
|
.iter()
|
||||||
|
.filter_map(|s| {
|
||||||
|
s.cancel_if_timed_out().map(|r| OutgoingRequest {
|
||||||
|
request_id: r.request_id(),
|
||||||
|
request: Arc::new(r.into()),
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
let room_requests: Vec<OutgoingRequest> = self
|
||||||
|
.room_sas_verifications
|
||||||
|
.iter()
|
||||||
|
.filter_map(|s| {
|
||||||
|
s.cancel_if_timed_out().map(|r| OutgoingRequest {
|
||||||
|
request_id: r.request_id(),
|
||||||
|
request: Arc::new(r.into()),
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
requests.extend(room_requests);
|
||||||
|
|
||||||
|
requests
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_sas(&self, transaction_id: &str) -> Option<Sas> {
|
||||||
|
let sas = if let Ok(e) = EventId::try_from(transaction_id) {
|
||||||
|
self.room_sas_verifications.get(&e).map(|s| s.clone())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
if sas.is_some() {
|
||||||
|
sas
|
||||||
|
} else {
|
||||||
|
self.sas_verification.get(transaction_id).map(|s| s.clone())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct VerificationMachine {
|
pub struct VerificationMachine {
|
||||||
account: ReadOnlyAccount,
|
account: ReadOnlyAccount,
|
||||||
private_identity: Arc<Mutex<PrivateCrossSigningIdentity>>,
|
private_identity: Arc<Mutex<PrivateCrossSigningIdentity>>,
|
||||||
pub(crate) store: Arc<Box<dyn CryptoStore>>,
|
pub(crate) store: Arc<Box<dyn CryptoStore>>,
|
||||||
verifications: Arc<DashMap<String, Sas>>,
|
verifications: VerificationCache,
|
||||||
room_verifications: Arc<DashMap<EventId, Sas>>,
|
|
||||||
requests: Arc<DashMap<String, VerificationRequest>>,
|
requests: Arc<DashMap<String, VerificationRequest>>,
|
||||||
outgoing_messages: Arc<DashMap<Uuid, OutgoingRequest>>,
|
outgoing_messages: Arc<DashMap<Uuid, OutgoingRequest>>,
|
||||||
}
|
}
|
||||||
|
@ -58,9 +125,8 @@ impl VerificationMachine {
|
||||||
account,
|
account,
|
||||||
private_identity: identity,
|
private_identity: identity,
|
||||||
store,
|
store,
|
||||||
verifications: DashMap::new().into(),
|
verifications: VerificationCache::new(),
|
||||||
requests: DashMap::new().into(),
|
requests: DashMap::new().into(),
|
||||||
room_verifications: DashMap::new().into(),
|
|
||||||
outgoing_messages: DashMap::new().into(),
|
outgoing_messages: DashMap::new().into(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -89,7 +155,9 @@ impl VerificationMachine {
|
||||||
let request =
|
let request =
|
||||||
content_to_request(device.user_id(), device.device_id().to_owned(), c);
|
content_to_request(device.user_id(), device.device_id().to_owned(), c);
|
||||||
|
|
||||||
self.verifications.insert(sas.flow_id().as_str().to_owned(), sas.clone());
|
self.verifications
|
||||||
|
.sas_verification
|
||||||
|
.insert(sas.flow_id().as_str().to_owned(), sas.clone());
|
||||||
|
|
||||||
request.into()
|
request.into()
|
||||||
}
|
}
|
||||||
|
@ -103,19 +171,7 @@ impl VerificationMachine {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_sas(&self, transaction_id: &str) -> Option<Sas> {
|
pub fn get_sas(&self, transaction_id: &str) -> Option<Sas> {
|
||||||
let sas = if let Ok(e) = EventId::try_from(transaction_id) {
|
self.verifications.get_sas(transaction_id)
|
||||||
#[allow(clippy::map_clone)]
|
|
||||||
self.room_verifications.get(&e).map(|s| s.clone())
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
if sas.is_some() {
|
|
||||||
sas
|
|
||||||
} else {
|
|
||||||
#[allow(clippy::map_clone)]
|
|
||||||
self.verifications.get(transaction_id).map(|s| s.clone())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn queue_up_content(
|
fn queue_up_content(
|
||||||
|
@ -170,15 +226,8 @@ impl VerificationMachine {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn garbage_collect(&self) {
|
pub fn garbage_collect(&self) {
|
||||||
self.verifications.retain(|_, s| !(s.is_done() || s.is_canceled()));
|
for request in self.verifications.garbage_collect() {
|
||||||
|
self.outgoing_messages.insert(*request.request_id(), request);
|
||||||
for sas in self.verifications.iter() {
|
|
||||||
if let Some(r) = sas.cancel_if_timed_out() {
|
|
||||||
self.outgoing_messages.insert(
|
|
||||||
r.request_id(),
|
|
||||||
OutgoingRequest { request_id: r.request_id(), request: Arc::new(r.into()) },
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -255,7 +304,8 @@ impl VerificationMachine {
|
||||||
// TODO remove this unwrap
|
// TODO remove this unwrap
|
||||||
let accept_request = s.accept().unwrap();
|
let accept_request = s.accept().unwrap();
|
||||||
|
|
||||||
self.room_verifications
|
self.verifications
|
||||||
|
.room_sas_verifications
|
||||||
.insert(e.content.relation.event_id.clone(), s);
|
.insert(e.content.relation.event_id.clone(), s);
|
||||||
|
|
||||||
self.outgoing_messages
|
self.outgoing_messages
|
||||||
|
@ -273,7 +323,7 @@ impl VerificationMachine {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
AnySyncMessageEvent::KeyVerificationKey(e) => {
|
AnySyncMessageEvent::KeyVerificationKey(e) => {
|
||||||
if let Some(s) = self.room_verifications.get(&e.content.relation.event_id) {
|
if let Some(s) = self.verifications.get_room_sas(&e.content.relation.event_id) {
|
||||||
self.receive_room_event_helper(
|
self.receive_room_event_helper(
|
||||||
&s,
|
&s,
|
||||||
&m.clone().into_full_event(room_id.clone()),
|
&m.clone().into_full_event(room_id.clone()),
|
||||||
|
@ -281,7 +331,7 @@ impl VerificationMachine {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
AnySyncMessageEvent::KeyVerificationMac(e) => {
|
AnySyncMessageEvent::KeyVerificationMac(e) => {
|
||||||
if let Some(s) = self.room_verifications.get(&e.content.relation.event_id) {
|
if let Some(s) = self.verifications.get_room_sas(&e.content.relation.event_id) {
|
||||||
self.receive_room_event_helper(
|
self.receive_room_event_helper(
|
||||||
&s,
|
&s,
|
||||||
&m.clone().into_full_event(room_id.clone()),
|
&m.clone().into_full_event(room_id.clone()),
|
||||||
|
@ -290,7 +340,7 @@ impl VerificationMachine {
|
||||||
}
|
}
|
||||||
|
|
||||||
AnySyncMessageEvent::KeyVerificationDone(e) => {
|
AnySyncMessageEvent::KeyVerificationDone(e) => {
|
||||||
if let Some(s) = self.room_verifications.get(&e.content.relation.event_id) {
|
if let Some(s) = self.verifications.get_room_sas(&e.content.relation.event_id) {
|
||||||
let content =
|
let content =
|
||||||
s.receive_room_event(&m.clone().into_full_event(room_id.clone()));
|
s.receive_room_event(&m.clone().into_full_event(room_id.clone()));
|
||||||
|
|
||||||
|
@ -373,7 +423,9 @@ impl VerificationMachine {
|
||||||
self.store.get_user_identity(&e.sender).await?,
|
self.store.get_user_identity(&e.sender).await?,
|
||||||
) {
|
) {
|
||||||
Ok(s) => {
|
Ok(s) => {
|
||||||
self.verifications.insert(e.content.transaction_id.clone(), s);
|
self.verifications
|
||||||
|
.sas_verification
|
||||||
|
.insert(e.content.transaction_id.clone(), s);
|
||||||
}
|
}
|
||||||
Err(c) => {
|
Err(c) => {
|
||||||
warn!(
|
warn!(
|
||||||
|
@ -391,7 +443,7 @@ impl VerificationMachine {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
AnyToDeviceEvent::KeyVerificationCancel(e) => {
|
AnyToDeviceEvent::KeyVerificationCancel(e) => {
|
||||||
self.verifications.remove(&e.content.transaction_id);
|
self.verifications.sas_verification.remove(&e.content.transaction_id);
|
||||||
}
|
}
|
||||||
AnyToDeviceEvent::KeyVerificationAccept(e) => {
|
AnyToDeviceEvent::KeyVerificationAccept(e) => {
|
||||||
if let Some(s) = self.get_sas(&e.content.transaction_id) {
|
if let Some(s) = self.get_sas(&e.content.transaction_id) {
|
||||||
|
|
Loading…
Reference in New Issue