crypto: Scope the verification requests behind the other user id
parent
58d3b42a60
commit
5d38bc3802
|
@ -129,7 +129,7 @@ async fn login(
|
|||
if let MessageType::VerificationRequest(_) = &m.content.msgtype
|
||||
{
|
||||
let request = client
|
||||
.get_verification_request(&m.event_id)
|
||||
.get_verification_request(&m.sender, &m.event_id)
|
||||
.await
|
||||
.expect("Request object wasn't created");
|
||||
|
||||
|
|
|
@ -2192,16 +2192,18 @@ impl Client {
|
|||
.map(|sas| Sas { inner: sas, client: self.clone() })
|
||||
}
|
||||
|
||||
/// Get a `VerificationRequest` object with the given flow id.
|
||||
/// Get a `VerificationRequest` object for the given user with the given
|
||||
/// flow id.
|
||||
#[cfg(feature = "encryption")]
|
||||
#[cfg_attr(feature = "docs", doc(cfg(encryption)))]
|
||||
pub async fn get_verification_request(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
flow_id: impl AsRef<str>,
|
||||
) -> Option<VerificationRequest> {
|
||||
let olm = self.base_client.olm_machine().await?;
|
||||
|
||||
olm.get_verification_request(flow_id)
|
||||
olm.get_verification_request(user_id, flow_id)
|
||||
.map(|r| VerificationRequest { inner: r, client: self.clone() })
|
||||
}
|
||||
|
||||
|
|
|
@ -725,9 +725,10 @@ impl OlmMachine {
|
|||
/// Get a verification request object with the given flow id.
|
||||
pub fn get_verification_request(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
flow_id: impl AsRef<str>,
|
||||
) -> Option<VerificationRequest> {
|
||||
self.verification_machine.get_request(flow_id)
|
||||
self.verification_machine.get_request(user_id, flow_id)
|
||||
}
|
||||
|
||||
async fn update_one_time_key_count(&self, key_count: &BTreeMap<DeviceKeyAlgorithm, UInt>) {
|
||||
|
|
|
@ -40,7 +40,7 @@ pub struct VerificationMachine {
|
|||
private_identity: Arc<Mutex<PrivateCrossSigningIdentity>>,
|
||||
pub(crate) store: Arc<dyn CryptoStore>,
|
||||
verifications: VerificationCache,
|
||||
requests: Arc<DashMap<String, VerificationRequest>>,
|
||||
requests: Arc<DashMap<UserId, DashMap<String, VerificationRequest>>>,
|
||||
}
|
||||
|
||||
impl VerificationMachine {
|
||||
|
@ -91,8 +91,19 @@ impl VerificationMachine {
|
|||
Ok((sas, request))
|
||||
}
|
||||
|
||||
pub fn get_request(&self, flow_id: impl AsRef<str>) -> Option<VerificationRequest> {
|
||||
self.requests.get(flow_id.as_ref()).map(|s| s.clone())
|
||||
pub fn get_request(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
flow_id: impl AsRef<str>,
|
||||
) -> Option<VerificationRequest> {
|
||||
self.requests.get(user_id).and_then(|v| v.get(flow_id.as_ref()).map(|s| s.clone()))
|
||||
}
|
||||
|
||||
fn insert_request(&self, request: VerificationRequest) {
|
||||
self.requests
|
||||
.entry(request.other_user().to_owned())
|
||||
.or_insert_with(DashMap::new)
|
||||
.insert(request.flow_id().as_str().to_owned(), request);
|
||||
}
|
||||
|
||||
pub fn get_verification(&self, user_id: &UserId, flow_id: &str) -> Option<Verification> {
|
||||
|
@ -145,7 +156,10 @@ impl VerificationMachine {
|
|||
}
|
||||
|
||||
pub fn garbage_collect(&self) {
|
||||
self.requests.retain(|_, r| !(r.is_done() || r.is_cancelled()));
|
||||
for user_verification in self.requests.iter() {
|
||||
user_verification.retain(|_, v| !(v.is_done() || v.is_cancelled()));
|
||||
}
|
||||
self.requests.retain(|_, v| !v.is_empty());
|
||||
|
||||
for request in self.verifications.garbage_collect() {
|
||||
self.verifications.add_request(request)
|
||||
|
@ -234,8 +248,7 @@ impl VerificationMachine {
|
|||
r,
|
||||
);
|
||||
|
||||
self.requests
|
||||
.insert(request.flow_id().as_str().to_owned(), request);
|
||||
self.insert_request(request);
|
||||
} else {
|
||||
trace!(
|
||||
sender = event.sender().as_str(),
|
||||
|
@ -260,7 +273,7 @@ impl VerificationMachine {
|
|||
}
|
||||
}
|
||||
AnyVerificationContent::Cancel(c) => {
|
||||
if let Some(verification) = self.get_request(flow_id.as_str()) {
|
||||
if let Some(verification) = self.get_request(event.sender(), flow_id.as_str()) {
|
||||
verification.receive_cancel(event.sender(), c);
|
||||
}
|
||||
|
||||
|
@ -270,7 +283,7 @@ impl VerificationMachine {
|
|||
}
|
||||
}
|
||||
AnyVerificationContent::Ready(c) => {
|
||||
if let Some(request) = self.requests.get(flow_id.as_str()) {
|
||||
if let Some(request) = self.get_request(event.sender(), flow_id.as_str()) {
|
||||
if request.flow_id() == &flow_id {
|
||||
request.receive_ready(event.sender(), c);
|
||||
} else {
|
||||
|
@ -279,7 +292,7 @@ impl VerificationMachine {
|
|||
}
|
||||
}
|
||||
AnyVerificationContent::Start(c) => {
|
||||
if let Some(request) = self.requests.get(flow_id.as_str()) {
|
||||
if let Some(request) = self.get_request(event.sender(), flow_id.as_str()) {
|
||||
if request.flow_id() == &flow_id {
|
||||
request.receive_start(event.sender(), c).await?
|
||||
} else {
|
||||
|
@ -345,7 +358,7 @@ impl VerificationMachine {
|
|||
}
|
||||
}
|
||||
AnyVerificationContent::Done(c) => {
|
||||
if let Some(verification) = self.get_request(flow_id.as_str()) {
|
||||
if let Some(verification) = self.get_request(event.sender(), flow_id.as_str()) {
|
||||
verification.receive_done(event.sender(), c);
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue