crypto: Scope the verification requests behind the other user id

This commit is contained in:
Damir Jelić 2021-06-15 21:14:12 +02:00
parent 58d3b42a60
commit 5d38bc3802
4 changed files with 30 additions and 14 deletions

View file

@ -129,7 +129,7 @@ async fn login(
if let MessageType::VerificationRequest(_) = &m.content.msgtype
{
let request = client
.get_verification_request(&m.event_id)
.get_verification_request(&m.sender, &m.event_id)
.await
.expect("Request object wasn't created");

View file

@ -2192,16 +2192,18 @@ impl Client {
.map(|sas| Sas { inner: sas, client: self.clone() })
}
/// Get a `VerificationRequest` object with the given flow id.
/// Get a `VerificationRequest` object for the given user with the given
/// flow id.
#[cfg(feature = "encryption")]
#[cfg_attr(feature = "docs", doc(cfg(encryption)))]
pub async fn get_verification_request(
&self,
user_id: &UserId,
flow_id: impl AsRef<str>,
) -> Option<VerificationRequest> {
let olm = self.base_client.olm_machine().await?;
olm.get_verification_request(flow_id)
olm.get_verification_request(user_id, flow_id)
.map(|r| VerificationRequest { inner: r, client: self.clone() })
}

View file

@ -725,9 +725,10 @@ impl OlmMachine {
/// Get a verification request object with the given flow id.
pub fn get_verification_request(
&self,
user_id: &UserId,
flow_id: impl AsRef<str>,
) -> Option<VerificationRequest> {
self.verification_machine.get_request(flow_id)
self.verification_machine.get_request(user_id, flow_id)
}
async fn update_one_time_key_count(&self, key_count: &BTreeMap<DeviceKeyAlgorithm, UInt>) {

View file

@ -40,7 +40,7 @@ pub struct VerificationMachine {
private_identity: Arc<Mutex<PrivateCrossSigningIdentity>>,
pub(crate) store: Arc<dyn CryptoStore>,
verifications: VerificationCache,
requests: Arc<DashMap<String, VerificationRequest>>,
requests: Arc<DashMap<UserId, DashMap<String, VerificationRequest>>>,
}
impl VerificationMachine {
@ -91,8 +91,19 @@ impl VerificationMachine {
Ok((sas, request))
}
pub fn get_request(&self, flow_id: impl AsRef<str>) -> Option<VerificationRequest> {
self.requests.get(flow_id.as_ref()).map(|s| s.clone())
pub fn get_request(
&self,
user_id: &UserId,
flow_id: impl AsRef<str>,
) -> Option<VerificationRequest> {
self.requests.get(user_id).and_then(|v| v.get(flow_id.as_ref()).map(|s| s.clone()))
}
fn insert_request(&self, request: VerificationRequest) {
self.requests
.entry(request.other_user().to_owned())
.or_insert_with(DashMap::new)
.insert(request.flow_id().as_str().to_owned(), request);
}
pub fn get_verification(&self, user_id: &UserId, flow_id: &str) -> Option<Verification> {
@ -145,7 +156,10 @@ impl VerificationMachine {
}
pub fn garbage_collect(&self) {
self.requests.retain(|_, r| !(r.is_done() || r.is_cancelled()));
for user_verification in self.requests.iter() {
user_verification.retain(|_, v| !(v.is_done() || v.is_cancelled()));
}
self.requests.retain(|_, v| !v.is_empty());
for request in self.verifications.garbage_collect() {
self.verifications.add_request(request)
@ -234,8 +248,7 @@ impl VerificationMachine {
r,
);
self.requests
.insert(request.flow_id().as_str().to_owned(), request);
self.insert_request(request);
} else {
trace!(
sender = event.sender().as_str(),
@ -260,7 +273,7 @@ impl VerificationMachine {
}
}
AnyVerificationContent::Cancel(c) => {
if let Some(verification) = self.get_request(flow_id.as_str()) {
if let Some(verification) = self.get_request(event.sender(), flow_id.as_str()) {
verification.receive_cancel(event.sender(), c);
}
@ -270,7 +283,7 @@ impl VerificationMachine {
}
}
AnyVerificationContent::Ready(c) => {
if let Some(request) = self.requests.get(flow_id.as_str()) {
if let Some(request) = self.get_request(event.sender(), flow_id.as_str()) {
if request.flow_id() == &flow_id {
request.receive_ready(event.sender(), c);
} else {
@ -279,7 +292,7 @@ impl VerificationMachine {
}
}
AnyVerificationContent::Start(c) => {
if let Some(request) = self.requests.get(flow_id.as_str()) {
if let Some(request) = self.get_request(event.sender(), flow_id.as_str()) {
if request.flow_id() == &flow_id {
request.receive_start(event.sender(), c).await?
} else {
@ -345,7 +358,7 @@ impl VerificationMachine {
}
}
AnyVerificationContent::Done(c) => {
if let Some(verification) = self.get_request(flow_id.as_str()) {
if let Some(verification) = self.get_request(event.sender(), flow_id.as_str()) {
verification.receive_done(event.sender(), c);
}