crypto: Scope the verifications per sender
This commit is contained in:
parent
533a5b92b0
commit
ada71586ac
8 changed files with 114 additions and 45 deletions
|
@ -81,7 +81,7 @@ async fn login(
|
|||
match event {
|
||||
AnyToDeviceEvent::KeyVerificationStart(e) => {
|
||||
let sas = client
|
||||
.get_verification(&e.content.transaction_id)
|
||||
.get_verification(&e.sender, &e.content.transaction_id)
|
||||
.await
|
||||
.expect("Sas object wasn't created");
|
||||
println!(
|
||||
|
@ -95,7 +95,7 @@ async fn login(
|
|||
|
||||
AnyToDeviceEvent::KeyVerificationKey(e) => {
|
||||
let sas = client
|
||||
.get_verification(&e.content.transaction_id)
|
||||
.get_verification(&e.sender, &e.content.transaction_id)
|
||||
.await
|
||||
.expect("Sas object wasn't created");
|
||||
|
||||
|
@ -104,7 +104,7 @@ async fn login(
|
|||
|
||||
AnyToDeviceEvent::KeyVerificationMac(e) => {
|
||||
let sas = client
|
||||
.get_verification(&e.content.transaction_id)
|
||||
.get_verification(&e.sender, &e.content.transaction_id)
|
||||
.await
|
||||
.expect("Sas object wasn't created");
|
||||
|
||||
|
@ -141,7 +141,10 @@ async fn login(
|
|||
}
|
||||
AnySyncMessageEvent::KeyVerificationKey(e) => {
|
||||
let sas = client
|
||||
.get_verification(e.content.relation.event_id.as_str())
|
||||
.get_verification(
|
||||
&e.sender,
|
||||
e.content.relation.event_id.as_str(),
|
||||
)
|
||||
.await
|
||||
.expect("Sas object wasn't created");
|
||||
|
||||
|
@ -149,7 +152,10 @@ async fn login(
|
|||
}
|
||||
AnySyncMessageEvent::KeyVerificationMac(e) => {
|
||||
let sas = client
|
||||
.get_verification(e.content.relation.event_id.as_str())
|
||||
.get_verification(
|
||||
&e.sender,
|
||||
e.content.relation.event_id.as_str(),
|
||||
)
|
||||
.await
|
||||
.expect("Sas object wasn't created");
|
||||
|
||||
|
|
|
@ -2185,9 +2185,9 @@ impl Client {
|
|||
/// Get a `Sas` verification object with the given flow id.
|
||||
#[cfg(feature = "encryption")]
|
||||
#[cfg_attr(feature = "docs", doc(cfg(encryption)))]
|
||||
pub async fn get_verification(&self, flow_id: &str) -> Option<Sas> {
|
||||
pub async fn get_verification(&self, user_id: &UserId, flow_id: &str) -> Option<Sas> {
|
||||
self.base_client
|
||||
.get_verification(flow_id)
|
||||
.get_verification(user_id, flow_id)
|
||||
.await
|
||||
.map(|sas| Sas { inner: sas, client: self.clone() })
|
||||
}
|
||||
|
|
|
@ -1213,8 +1213,13 @@ impl BaseClient {
|
|||
/// *m.key.verification.start* event.
|
||||
#[cfg(feature = "encryption")]
|
||||
#[cfg_attr(feature = "docs", doc(cfg(encryption)))]
|
||||
pub async fn get_verification(&self, flow_id: &str) -> Option<Sas> {
|
||||
self.olm.lock().await.as_ref().and_then(|o| o.get_verification(flow_id))
|
||||
pub async fn get_verification(&self, user_id: &UserId, flow_id: &str) -> Option<Sas> {
|
||||
self.olm
|
||||
.lock()
|
||||
.await
|
||||
.as_ref()
|
||||
.and_then(|o| o.get_verification(user_id, flow_id).map(|v| v.sas_v1()))
|
||||
.flatten()
|
||||
}
|
||||
|
||||
/// Get a specific device of a user.
|
||||
|
|
|
@ -59,7 +59,7 @@ use crate::{
|
|||
Changes, CryptoStore, DeviceChanges, IdentityChanges, MemoryStore, Result as StoreResult,
|
||||
Store,
|
||||
},
|
||||
verification::{Sas, VerificationMachine, VerificationRequest},
|
||||
verification::{Verification, VerificationMachine, VerificationRequest},
|
||||
ToDeviceRequest,
|
||||
};
|
||||
|
||||
|
@ -717,9 +717,9 @@ impl OlmMachine {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
/// Get a `Sas` verification object with the given flow id.
|
||||
pub fn get_verification(&self, flow_id: &str) -> Option<Sas> {
|
||||
self.verification_machine.get_sas(flow_id)
|
||||
/// Get a verification object for the given user id with the given flow id.
|
||||
pub fn get_verification(&self, user_id: &UserId, flow_id: &str) -> Option<Verification> {
|
||||
self.verification_machine.get_verification(user_id, flow_id)
|
||||
}
|
||||
|
||||
/// Get a verification request object with the given flow id.
|
||||
|
@ -1761,7 +1761,11 @@ pub(crate) mod test {
|
|||
let event = request_to_event(alice.user_id(), &request.into());
|
||||
bob.handle_verification_event(&event).await;
|
||||
|
||||
let bob_sas = bob.get_verification(alice_sas.flow_id().as_str()).unwrap();
|
||||
let bob_sas = bob
|
||||
.get_verification(alice.user_id(), alice_sas.flow_id().as_str())
|
||||
.unwrap()
|
||||
.sas_v1()
|
||||
.unwrap();
|
||||
|
||||
assert!(alice_sas.emoji().is_none());
|
||||
assert!(bob_sas.emoji().is_none());
|
||||
|
|
|
@ -23,7 +23,7 @@ use crate::{OutgoingRequest, RoomMessageRequest};
|
|||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct VerificationCache {
|
||||
verification: Arc<DashMap<String, Verification>>,
|
||||
verification: Arc<DashMap<UserId, DashMap<String, Verification>>>,
|
||||
outgoing_requests: Arc<DashMap<Uuid, OutgoingRequest>>,
|
||||
}
|
||||
|
||||
|
@ -35,11 +35,24 @@ impl VerificationCache {
|
|||
#[cfg(test)]
|
||||
#[allow(dead_code)]
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.verification.is_empty()
|
||||
self.verification.iter().all(|m| m.is_empty())
|
||||
}
|
||||
|
||||
pub fn insert(&self, verification: impl Into<Verification>) {
|
||||
let verification = verification.into();
|
||||
|
||||
self.verification
|
||||
.entry(verification.other_user().to_owned())
|
||||
.or_insert_with(DashMap::new)
|
||||
.insert(verification.flow_id().to_owned(), verification);
|
||||
}
|
||||
|
||||
pub fn insert_sas(&self, sas: Sas) {
|
||||
self.verification.insert(sas.flow_id().as_str().to_string(), sas.into());
|
||||
self.insert(sas);
|
||||
}
|
||||
|
||||
pub fn get(&self, sender: &UserId, flow_id: &str) -> Option<Verification> {
|
||||
self.verification.get(sender).and_then(|m| m.get(flow_id).map(|v| v.clone()))
|
||||
}
|
||||
|
||||
pub fn outgoing_requests(&self) -> Vec<OutgoingRequest> {
|
||||
|
@ -47,28 +60,38 @@ impl VerificationCache {
|
|||
}
|
||||
|
||||
pub fn garbage_collect(&self) -> Vec<OutgoingRequest> {
|
||||
self.verification.retain(|_, s| !(s.is_done() || s.is_cancelled()));
|
||||
for user_verification in self.verification.iter() {
|
||||
user_verification.retain(|_, s| !(s.is_done() || s.is_cancelled()));
|
||||
}
|
||||
|
||||
self.verification.retain(|_, m| !m.is_empty());
|
||||
|
||||
self.verification
|
||||
.iter()
|
||||
.filter_map(|s| {
|
||||
#[allow(irrefutable_let_patterns)]
|
||||
if let Verification::SasV1(s) = s.value() {
|
||||
s.cancel_if_timed_out().map(|r| OutgoingRequest {
|
||||
request_id: r.request_id(),
|
||||
request: Arc::new(r.into()),
|
||||
.flat_map(|v| {
|
||||
let requests: Vec<OutgoingRequest> = v
|
||||
.value()
|
||||
.iter()
|
||||
.filter_map(|s| {
|
||||
if let Verification::SasV1(s) = s.value() {
|
||||
s.cancel_if_timed_out().map(|r| OutgoingRequest {
|
||||
request_id: r.request_id(),
|
||||
request: Arc::new(r.into()),
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
.collect();
|
||||
|
||||
requests
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn get_sas(&self, transaction_id: &str) -> Option<Sas> {
|
||||
self.verification.get(transaction_id).and_then(|v| {
|
||||
#[allow(irrefutable_let_patterns)]
|
||||
if let Verification::SasV1(sas) = v.value() {
|
||||
pub fn get_sas(&self, user_id: &UserId, flow_id: &str) -> Option<Sas> {
|
||||
self.get(user_id, flow_id).and_then(|v| {
|
||||
if let Verification::SasV1(sas) = v {
|
||||
Some(sas.clone())
|
||||
} else {
|
||||
None
|
||||
|
|
|
@ -24,7 +24,7 @@ use super::{
|
|||
event_enums::{AnyEvent, AnyVerificationContent, OutgoingContent},
|
||||
requests::VerificationRequest,
|
||||
sas::{content_to_request, Sas},
|
||||
FlowId, VerificationResult,
|
||||
FlowId, Verification, VerificationResult,
|
||||
};
|
||||
use crate::{
|
||||
olm::PrivateCrossSigningIdentity,
|
||||
|
@ -94,8 +94,12 @@ impl VerificationMachine {
|
|||
self.requests.get(flow_id.as_ref()).map(|s| s.clone())
|
||||
}
|
||||
|
||||
pub fn get_sas(&self, transaction_id: &str) -> Option<Sas> {
|
||||
self.verifications.get_sas(transaction_id)
|
||||
pub fn get_verification(&self, user_id: &UserId, flow_id: &str) -> Option<Verification> {
|
||||
self.verifications.get(user_id, flow_id)
|
||||
}
|
||||
|
||||
pub fn get_sas(&self, user_id: &UserId, flow_id: &str) -> Option<Sas> {
|
||||
self.verifications.get_sas(user_id, flow_id)
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
|
@ -242,7 +246,7 @@ impl VerificationMachine {
|
|||
verification.receive_cancel(event.sender(), c);
|
||||
}
|
||||
|
||||
if let Some(sas) = self.verifications.get_sas(flow_id.as_str()) {
|
||||
if let Some(sas) = self.get_sas(event.sender(), flow_id.as_str()) {
|
||||
// This won't produce an outgoing content
|
||||
let _ = sas.receive_any_event(event.sender(), &content);
|
||||
}
|
||||
|
@ -296,7 +300,7 @@ impl VerificationMachine {
|
|||
}
|
||||
}
|
||||
AnyVerificationContent::Accept(_) | AnyVerificationContent::Key(_) => {
|
||||
if let Some(sas) = self.verifications.get_sas(flow_id.as_str()) {
|
||||
if let Some(sas) = self.get_sas(event.sender(), flow_id.as_str()) {
|
||||
if sas.flow_id() == &flow_id {
|
||||
if let Some(content) = sas.receive_any_event(event.sender(), &content) {
|
||||
self.queue_up_content(
|
||||
|
@ -311,7 +315,7 @@ impl VerificationMachine {
|
|||
}
|
||||
}
|
||||
AnyVerificationContent::Mac(_) => {
|
||||
if let Some(s) = self.verifications.get_sas(flow_id.as_str()) {
|
||||
if let Some(s) = self.get_sas(event.sender(), flow_id.as_str()) {
|
||||
if s.flow_id() == &flow_id {
|
||||
let content = s.receive_any_event(event.sender(), &content);
|
||||
|
||||
|
@ -328,12 +332,15 @@ impl VerificationMachine {
|
|||
verification.receive_done(event.sender(), c);
|
||||
}
|
||||
|
||||
if let Some(s) = self.verifications.get_sas(flow_id.as_str()) {
|
||||
let content = s.receive_any_event(event.sender(), &content);
|
||||
match self.get_verification(event.sender(), flow_id.as_str()) {
|
||||
Some(Verification::SasV1(sas)) => {
|
||||
let content = sas.receive_any_event(event.sender(), &content);
|
||||
|
||||
if s.is_done() {
|
||||
self.mark_sas_as_done(s, content).await?;
|
||||
if sas.is_done() {
|
||||
self.mark_sas_as_done(sas, content).await?;
|
||||
}
|
||||
}
|
||||
None => (),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -426,7 +433,7 @@ mod test {
|
|||
async fn full_flow() {
|
||||
let (alice_machine, bob) = setup_verification_machine().await;
|
||||
|
||||
let alice = alice_machine.get_sas(bob.flow_id().as_str()).unwrap();
|
||||
let alice = alice_machine.get_sas(bob.user_id(), bob.flow_id().as_str()).unwrap();
|
||||
|
||||
let request = alice.accept().unwrap();
|
||||
|
||||
|
@ -472,7 +479,7 @@ mod test {
|
|||
#[tokio::test]
|
||||
async fn timing_out() {
|
||||
let (alice_machine, bob) = setup_verification_machine().await;
|
||||
let alice = alice_machine.get_sas(bob.flow_id().as_str()).unwrap();
|
||||
let alice = alice_machine.get_sas(bob.user_id(), bob.flow_id().as_str()).unwrap();
|
||||
|
||||
assert!(!alice.timed_out());
|
||||
assert!(alice_machine.verifications.outgoing_requests().is_empty());
|
||||
|
|
|
@ -12,6 +12,8 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#![allow(missing_docs)]
|
||||
|
||||
mod cache;
|
||||
mod event_enums;
|
||||
mod machine;
|
||||
|
@ -57,11 +59,31 @@ impl Verification {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn sas_v1(self) -> Option<Sas> {
|
||||
if let Verification::SasV1(sas) = self {
|
||||
Some(sas)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn flow_id(&self) -> &str {
|
||||
match self {
|
||||
Verification::SasV1(s) => s.flow_id().as_str(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_cancelled(&self) -> bool {
|
||||
match self {
|
||||
Verification::SasV1(s) => s.is_cancelled(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn other_user(&self) -> &UserId {
|
||||
match self {
|
||||
Verification::SasV1(s) => s.other_user_id(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Sas> for Verification {
|
||||
|
|
|
@ -810,7 +810,8 @@ mod test {
|
|||
let content = StartContent::try_from(&start_content).unwrap();
|
||||
let flow_id = content.flow_id().to_owned();
|
||||
alice_request.receive_start(bob_device.user_id(), &content).await.unwrap();
|
||||
let alice_sas = alice_request.verification_cache.get_sas(&flow_id).unwrap();
|
||||
let alice_sas =
|
||||
alice_request.verification_cache.get_sas(bob_device.user_id(), &flow_id).unwrap();
|
||||
|
||||
assert!(!bob_sas.is_cancelled());
|
||||
assert!(!alice_sas.is_cancelled());
|
||||
|
@ -867,7 +868,8 @@ mod test {
|
|||
let content = StartContent::try_from(&start_content).unwrap();
|
||||
let flow_id = content.flow_id().to_owned();
|
||||
alice_request.receive_start(bob_device.user_id(), &content).await.unwrap();
|
||||
let alice_sas = alice_request.verification_cache.get_sas(&flow_id).unwrap();
|
||||
let alice_sas =
|
||||
alice_request.verification_cache.get_sas(bob_device.user_id(), &flow_id).unwrap();
|
||||
|
||||
assert!(!bob_sas.is_cancelled());
|
||||
assert!(!alice_sas.is_cancelled());
|
||||
|
|
Loading…
Reference in a new issue