crypto: Scope the verifications per sender
parent
533a5b92b0
commit
ada71586ac
|
@ -81,7 +81,7 @@ async fn login(
|
||||||
match event {
|
match event {
|
||||||
AnyToDeviceEvent::KeyVerificationStart(e) => {
|
AnyToDeviceEvent::KeyVerificationStart(e) => {
|
||||||
let sas = client
|
let sas = client
|
||||||
.get_verification(&e.content.transaction_id)
|
.get_verification(&e.sender, &e.content.transaction_id)
|
||||||
.await
|
.await
|
||||||
.expect("Sas object wasn't created");
|
.expect("Sas object wasn't created");
|
||||||
println!(
|
println!(
|
||||||
|
@ -95,7 +95,7 @@ async fn login(
|
||||||
|
|
||||||
AnyToDeviceEvent::KeyVerificationKey(e) => {
|
AnyToDeviceEvent::KeyVerificationKey(e) => {
|
||||||
let sas = client
|
let sas = client
|
||||||
.get_verification(&e.content.transaction_id)
|
.get_verification(&e.sender, &e.content.transaction_id)
|
||||||
.await
|
.await
|
||||||
.expect("Sas object wasn't created");
|
.expect("Sas object wasn't created");
|
||||||
|
|
||||||
|
@ -104,7 +104,7 @@ async fn login(
|
||||||
|
|
||||||
AnyToDeviceEvent::KeyVerificationMac(e) => {
|
AnyToDeviceEvent::KeyVerificationMac(e) => {
|
||||||
let sas = client
|
let sas = client
|
||||||
.get_verification(&e.content.transaction_id)
|
.get_verification(&e.sender, &e.content.transaction_id)
|
||||||
.await
|
.await
|
||||||
.expect("Sas object wasn't created");
|
.expect("Sas object wasn't created");
|
||||||
|
|
||||||
|
@ -141,7 +141,10 @@ async fn login(
|
||||||
}
|
}
|
||||||
AnySyncMessageEvent::KeyVerificationKey(e) => {
|
AnySyncMessageEvent::KeyVerificationKey(e) => {
|
||||||
let sas = client
|
let sas = client
|
||||||
.get_verification(e.content.relation.event_id.as_str())
|
.get_verification(
|
||||||
|
&e.sender,
|
||||||
|
e.content.relation.event_id.as_str(),
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.expect("Sas object wasn't created");
|
.expect("Sas object wasn't created");
|
||||||
|
|
||||||
|
@ -149,7 +152,10 @@ async fn login(
|
||||||
}
|
}
|
||||||
AnySyncMessageEvent::KeyVerificationMac(e) => {
|
AnySyncMessageEvent::KeyVerificationMac(e) => {
|
||||||
let sas = client
|
let sas = client
|
||||||
.get_verification(e.content.relation.event_id.as_str())
|
.get_verification(
|
||||||
|
&e.sender,
|
||||||
|
e.content.relation.event_id.as_str(),
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.expect("Sas object wasn't created");
|
.expect("Sas object wasn't created");
|
||||||
|
|
||||||
|
|
|
@ -2185,9 +2185,9 @@ impl Client {
|
||||||
/// Get a `Sas` verification object with the given flow id.
|
/// Get a `Sas` verification object with the given flow id.
|
||||||
#[cfg(feature = "encryption")]
|
#[cfg(feature = "encryption")]
|
||||||
#[cfg_attr(feature = "docs", doc(cfg(encryption)))]
|
#[cfg_attr(feature = "docs", doc(cfg(encryption)))]
|
||||||
pub async fn get_verification(&self, flow_id: &str) -> Option<Sas> {
|
pub async fn get_verification(&self, user_id: &UserId, flow_id: &str) -> Option<Sas> {
|
||||||
self.base_client
|
self.base_client
|
||||||
.get_verification(flow_id)
|
.get_verification(user_id, flow_id)
|
||||||
.await
|
.await
|
||||||
.map(|sas| Sas { inner: sas, client: self.clone() })
|
.map(|sas| Sas { inner: sas, client: self.clone() })
|
||||||
}
|
}
|
||||||
|
|
|
@ -1213,8 +1213,13 @@ impl BaseClient {
|
||||||
/// *m.key.verification.start* event.
|
/// *m.key.verification.start* event.
|
||||||
#[cfg(feature = "encryption")]
|
#[cfg(feature = "encryption")]
|
||||||
#[cfg_attr(feature = "docs", doc(cfg(encryption)))]
|
#[cfg_attr(feature = "docs", doc(cfg(encryption)))]
|
||||||
pub async fn get_verification(&self, flow_id: &str) -> Option<Sas> {
|
pub async fn get_verification(&self, user_id: &UserId, flow_id: &str) -> Option<Sas> {
|
||||||
self.olm.lock().await.as_ref().and_then(|o| o.get_verification(flow_id))
|
self.olm
|
||||||
|
.lock()
|
||||||
|
.await
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|o| o.get_verification(user_id, flow_id).map(|v| v.sas_v1()))
|
||||||
|
.flatten()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get a specific device of a user.
|
/// Get a specific device of a user.
|
||||||
|
|
|
@ -59,7 +59,7 @@ use crate::{
|
||||||
Changes, CryptoStore, DeviceChanges, IdentityChanges, MemoryStore, Result as StoreResult,
|
Changes, CryptoStore, DeviceChanges, IdentityChanges, MemoryStore, Result as StoreResult,
|
||||||
Store,
|
Store,
|
||||||
},
|
},
|
||||||
verification::{Sas, VerificationMachine, VerificationRequest},
|
verification::{Verification, VerificationMachine, VerificationRequest},
|
||||||
ToDeviceRequest,
|
ToDeviceRequest,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -717,9 +717,9 @@ impl OlmMachine {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get a `Sas` verification object with the given flow id.
|
/// Get a verification object for the given user id with the given flow id.
|
||||||
pub fn get_verification(&self, flow_id: &str) -> Option<Sas> {
|
pub fn get_verification(&self, user_id: &UserId, flow_id: &str) -> Option<Verification> {
|
||||||
self.verification_machine.get_sas(flow_id)
|
self.verification_machine.get_verification(user_id, flow_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get a verification request object with the given flow id.
|
/// Get a verification request object with the given flow id.
|
||||||
|
@ -1761,7 +1761,11 @@ pub(crate) mod test {
|
||||||
let event = request_to_event(alice.user_id(), &request.into());
|
let event = request_to_event(alice.user_id(), &request.into());
|
||||||
bob.handle_verification_event(&event).await;
|
bob.handle_verification_event(&event).await;
|
||||||
|
|
||||||
let bob_sas = bob.get_verification(alice_sas.flow_id().as_str()).unwrap();
|
let bob_sas = bob
|
||||||
|
.get_verification(alice.user_id(), alice_sas.flow_id().as_str())
|
||||||
|
.unwrap()
|
||||||
|
.sas_v1()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
assert!(alice_sas.emoji().is_none());
|
assert!(alice_sas.emoji().is_none());
|
||||||
assert!(bob_sas.emoji().is_none());
|
assert!(bob_sas.emoji().is_none());
|
||||||
|
|
|
@ -23,7 +23,7 @@ use crate::{OutgoingRequest, RoomMessageRequest};
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct VerificationCache {
|
pub struct VerificationCache {
|
||||||
verification: Arc<DashMap<String, Verification>>,
|
verification: Arc<DashMap<UserId, DashMap<String, Verification>>>,
|
||||||
outgoing_requests: Arc<DashMap<Uuid, OutgoingRequest>>,
|
outgoing_requests: Arc<DashMap<Uuid, OutgoingRequest>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -35,11 +35,24 @@ impl VerificationCache {
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
pub fn is_empty(&self) -> bool {
|
pub fn is_empty(&self) -> bool {
|
||||||
self.verification.is_empty()
|
self.verification.iter().all(|m| m.is_empty())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn insert(&self, verification: impl Into<Verification>) {
|
||||||
|
let verification = verification.into();
|
||||||
|
|
||||||
|
self.verification
|
||||||
|
.entry(verification.other_user().to_owned())
|
||||||
|
.or_insert_with(DashMap::new)
|
||||||
|
.insert(verification.flow_id().to_owned(), verification);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn insert_sas(&self, sas: Sas) {
|
pub fn insert_sas(&self, sas: Sas) {
|
||||||
self.verification.insert(sas.flow_id().as_str().to_string(), sas.into());
|
self.insert(sas);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get(&self, sender: &UserId, flow_id: &str) -> Option<Verification> {
|
||||||
|
self.verification.get(sender).and_then(|m| m.get(flow_id).map(|v| v.clone()))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn outgoing_requests(&self) -> Vec<OutgoingRequest> {
|
pub fn outgoing_requests(&self) -> Vec<OutgoingRequest> {
|
||||||
|
@ -47,12 +60,19 @@ impl VerificationCache {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn garbage_collect(&self) -> Vec<OutgoingRequest> {
|
pub fn garbage_collect(&self) -> Vec<OutgoingRequest> {
|
||||||
self.verification.retain(|_, s| !(s.is_done() || s.is_cancelled()));
|
for user_verification in self.verification.iter() {
|
||||||
|
user_verification.retain(|_, s| !(s.is_done() || s.is_cancelled()));
|
||||||
|
}
|
||||||
|
|
||||||
|
self.verification.retain(|_, m| !m.is_empty());
|
||||||
|
|
||||||
self.verification
|
self.verification
|
||||||
|
.iter()
|
||||||
|
.flat_map(|v| {
|
||||||
|
let requests: Vec<OutgoingRequest> = v
|
||||||
|
.value()
|
||||||
.iter()
|
.iter()
|
||||||
.filter_map(|s| {
|
.filter_map(|s| {
|
||||||
#[allow(irrefutable_let_patterns)]
|
|
||||||
if let Verification::SasV1(s) = s.value() {
|
if let Verification::SasV1(s) = s.value() {
|
||||||
s.cancel_if_timed_out().map(|r| OutgoingRequest {
|
s.cancel_if_timed_out().map(|r| OutgoingRequest {
|
||||||
request_id: r.request_id(),
|
request_id: r.request_id(),
|
||||||
|
@ -62,13 +82,16 @@ impl VerificationCache {
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
requests
|
||||||
|
})
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_sas(&self, transaction_id: &str) -> Option<Sas> {
|
pub fn get_sas(&self, user_id: &UserId, flow_id: &str) -> Option<Sas> {
|
||||||
self.verification.get(transaction_id).and_then(|v| {
|
self.get(user_id, flow_id).and_then(|v| {
|
||||||
#[allow(irrefutable_let_patterns)]
|
if let Verification::SasV1(sas) = v {
|
||||||
if let Verification::SasV1(sas) = v.value() {
|
|
||||||
Some(sas.clone())
|
Some(sas.clone())
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
|
|
|
@ -24,7 +24,7 @@ use super::{
|
||||||
event_enums::{AnyEvent, AnyVerificationContent, OutgoingContent},
|
event_enums::{AnyEvent, AnyVerificationContent, OutgoingContent},
|
||||||
requests::VerificationRequest,
|
requests::VerificationRequest,
|
||||||
sas::{content_to_request, Sas},
|
sas::{content_to_request, Sas},
|
||||||
FlowId, VerificationResult,
|
FlowId, Verification, VerificationResult,
|
||||||
};
|
};
|
||||||
use crate::{
|
use crate::{
|
||||||
olm::PrivateCrossSigningIdentity,
|
olm::PrivateCrossSigningIdentity,
|
||||||
|
@ -94,8 +94,12 @@ impl VerificationMachine {
|
||||||
self.requests.get(flow_id.as_ref()).map(|s| s.clone())
|
self.requests.get(flow_id.as_ref()).map(|s| s.clone())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_sas(&self, transaction_id: &str) -> Option<Sas> {
|
pub fn get_verification(&self, user_id: &UserId, flow_id: &str) -> Option<Verification> {
|
||||||
self.verifications.get_sas(transaction_id)
|
self.verifications.get(user_id, flow_id)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_sas(&self, user_id: &UserId, flow_id: &str) -> Option<Sas> {
|
||||||
|
self.verifications.get_sas(user_id, flow_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(not(target_arch = "wasm32"))]
|
#[cfg(not(target_arch = "wasm32"))]
|
||||||
|
@ -242,7 +246,7 @@ impl VerificationMachine {
|
||||||
verification.receive_cancel(event.sender(), c);
|
verification.receive_cancel(event.sender(), c);
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(sas) = self.verifications.get_sas(flow_id.as_str()) {
|
if let Some(sas) = self.get_sas(event.sender(), flow_id.as_str()) {
|
||||||
// This won't produce an outgoing content
|
// This won't produce an outgoing content
|
||||||
let _ = sas.receive_any_event(event.sender(), &content);
|
let _ = sas.receive_any_event(event.sender(), &content);
|
||||||
}
|
}
|
||||||
|
@ -296,7 +300,7 @@ impl VerificationMachine {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
AnyVerificationContent::Accept(_) | AnyVerificationContent::Key(_) => {
|
AnyVerificationContent::Accept(_) | AnyVerificationContent::Key(_) => {
|
||||||
if let Some(sas) = self.verifications.get_sas(flow_id.as_str()) {
|
if let Some(sas) = self.get_sas(event.sender(), flow_id.as_str()) {
|
||||||
if sas.flow_id() == &flow_id {
|
if sas.flow_id() == &flow_id {
|
||||||
if let Some(content) = sas.receive_any_event(event.sender(), &content) {
|
if let Some(content) = sas.receive_any_event(event.sender(), &content) {
|
||||||
self.queue_up_content(
|
self.queue_up_content(
|
||||||
|
@ -311,7 +315,7 @@ impl VerificationMachine {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
AnyVerificationContent::Mac(_) => {
|
AnyVerificationContent::Mac(_) => {
|
||||||
if let Some(s) = self.verifications.get_sas(flow_id.as_str()) {
|
if let Some(s) = self.get_sas(event.sender(), flow_id.as_str()) {
|
||||||
if s.flow_id() == &flow_id {
|
if s.flow_id() == &flow_id {
|
||||||
let content = s.receive_any_event(event.sender(), &content);
|
let content = s.receive_any_event(event.sender(), &content);
|
||||||
|
|
||||||
|
@ -328,13 +332,16 @@ impl VerificationMachine {
|
||||||
verification.receive_done(event.sender(), c);
|
verification.receive_done(event.sender(), c);
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(s) = self.verifications.get_sas(flow_id.as_str()) {
|
match self.get_verification(event.sender(), flow_id.as_str()) {
|
||||||
let content = s.receive_any_event(event.sender(), &content);
|
Some(Verification::SasV1(sas)) => {
|
||||||
|
let content = sas.receive_any_event(event.sender(), &content);
|
||||||
|
|
||||||
if s.is_done() {
|
if sas.is_done() {
|
||||||
self.mark_sas_as_done(s, content).await?;
|
self.mark_sas_as_done(sas, content).await?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
None => (),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -426,7 +433,7 @@ mod test {
|
||||||
async fn full_flow() {
|
async fn full_flow() {
|
||||||
let (alice_machine, bob) = setup_verification_machine().await;
|
let (alice_machine, bob) = setup_verification_machine().await;
|
||||||
|
|
||||||
let alice = alice_machine.get_sas(bob.flow_id().as_str()).unwrap();
|
let alice = alice_machine.get_sas(bob.user_id(), bob.flow_id().as_str()).unwrap();
|
||||||
|
|
||||||
let request = alice.accept().unwrap();
|
let request = alice.accept().unwrap();
|
||||||
|
|
||||||
|
@ -472,7 +479,7 @@ mod test {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn timing_out() {
|
async fn timing_out() {
|
||||||
let (alice_machine, bob) = setup_verification_machine().await;
|
let (alice_machine, bob) = setup_verification_machine().await;
|
||||||
let alice = alice_machine.get_sas(bob.flow_id().as_str()).unwrap();
|
let alice = alice_machine.get_sas(bob.user_id(), bob.flow_id().as_str()).unwrap();
|
||||||
|
|
||||||
assert!(!alice.timed_out());
|
assert!(!alice.timed_out());
|
||||||
assert!(alice_machine.verifications.outgoing_requests().is_empty());
|
assert!(alice_machine.verifications.outgoing_requests().is_empty());
|
||||||
|
|
|
@ -12,6 +12,8 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
#![allow(missing_docs)]
|
||||||
|
|
||||||
mod cache;
|
mod cache;
|
||||||
mod event_enums;
|
mod event_enums;
|
||||||
mod machine;
|
mod machine;
|
||||||
|
@ -57,11 +59,31 @@ impl Verification {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn sas_v1(self) -> Option<Sas> {
|
||||||
|
if let Verification::SasV1(sas) = self {
|
||||||
|
Some(sas)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn flow_id(&self) -> &str {
|
||||||
|
match self {
|
||||||
|
Verification::SasV1(s) => s.flow_id().as_str(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn is_cancelled(&self) -> bool {
|
pub fn is_cancelled(&self) -> bool {
|
||||||
match self {
|
match self {
|
||||||
Verification::SasV1(s) => s.is_cancelled(),
|
Verification::SasV1(s) => s.is_cancelled(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn other_user(&self) -> &UserId {
|
||||||
|
match self {
|
||||||
|
Verification::SasV1(s) => s.other_user_id(),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<Sas> for Verification {
|
impl From<Sas> for Verification {
|
||||||
|
|
|
@ -810,7 +810,8 @@ mod test {
|
||||||
let content = StartContent::try_from(&start_content).unwrap();
|
let content = StartContent::try_from(&start_content).unwrap();
|
||||||
let flow_id = content.flow_id().to_owned();
|
let flow_id = content.flow_id().to_owned();
|
||||||
alice_request.receive_start(bob_device.user_id(), &content).await.unwrap();
|
alice_request.receive_start(bob_device.user_id(), &content).await.unwrap();
|
||||||
let alice_sas = alice_request.verification_cache.get_sas(&flow_id).unwrap();
|
let alice_sas =
|
||||||
|
alice_request.verification_cache.get_sas(bob_device.user_id(), &flow_id).unwrap();
|
||||||
|
|
||||||
assert!(!bob_sas.is_cancelled());
|
assert!(!bob_sas.is_cancelled());
|
||||||
assert!(!alice_sas.is_cancelled());
|
assert!(!alice_sas.is_cancelled());
|
||||||
|
@ -867,7 +868,8 @@ mod test {
|
||||||
let content = StartContent::try_from(&start_content).unwrap();
|
let content = StartContent::try_from(&start_content).unwrap();
|
||||||
let flow_id = content.flow_id().to_owned();
|
let flow_id = content.flow_id().to_owned();
|
||||||
alice_request.receive_start(bob_device.user_id(), &content).await.unwrap();
|
alice_request.receive_start(bob_device.user_id(), &content).await.unwrap();
|
||||||
let alice_sas = alice_request.verification_cache.get_sas(&flow_id).unwrap();
|
let alice_sas =
|
||||||
|
alice_request.verification_cache.get_sas(bob_device.user_id(), &flow_id).unwrap();
|
||||||
|
|
||||||
assert!(!bob_sas.is_cancelled());
|
assert!(!bob_sas.is_cancelled());
|
||||||
assert!(!alice_sas.is_cancelled());
|
assert!(!alice_sas.is_cancelled());
|
||||||
|
|
Loading…
Reference in New Issue